Skip to main content

conjure_cp_core/ast/
partial_eval.rs

1use std::collections::HashSet;
2
3use crate::ast::Typeable;
4use crate::{
5    ast::{Atom, Expression as Expr, Literal as Lit, Metadata, Moo, ReturnType},
6    into_matrix_expr,
7    rule_engine::{ApplicationError::RuleNotApplicable, ApplicationResult, Reduction},
8};
9use itertools::iproduct;
10
11pub fn run_partial_evaluator(expr: &Expr) -> ApplicationResult {
12    // NOTE: If nothing changes, we must return RuleNotApplicable, or the rewriter will try this
13    // rule infinitely!
14    // This is why we always check whether we found a constant or not.
15    match expr {
16        Expr::Union(_, _, _) => Err(RuleNotApplicable),
17        Expr::In(_, _, _) => Err(RuleNotApplicable),
18        Expr::Intersect(_, _, _) => Err(RuleNotApplicable),
19        Expr::Supset(_, _, _) => Err(RuleNotApplicable),
20        Expr::SupsetEq(_, _, _) => Err(RuleNotApplicable),
21        Expr::Subset(_, _, _) => Err(RuleNotApplicable),
22        Expr::SubsetEq(_, _, _) => Err(RuleNotApplicable),
23        Expr::AbstractLiteral(_, _) => Err(RuleNotApplicable),
24        Expr::Comprehension(_, _) => Err(RuleNotApplicable),
25        Expr::AbstractComprehension(_, _) => Err(RuleNotApplicable),
26        Expr::DominanceRelation(_, _) => Err(RuleNotApplicable),
27        Expr::FromSolution(_, _) => Err(RuleNotApplicable),
28        Expr::Metavar(_, _) => Err(RuleNotApplicable),
29        Expr::UnsafeIndex(_, _, _) => Err(RuleNotApplicable),
30        Expr::UnsafeSlice(_, _, _) => Err(RuleNotApplicable),
31        Expr::Table(_, _, _) => Err(RuleNotApplicable),
32        Expr::NegativeTable(_, _, _) => Err(RuleNotApplicable),
33        Expr::SafeIndex(_, subject, indices) => {
34            // partially evaluate matrix literals indexed by a constant.
35
36            // subject must be a matrix literal
37            let (es, index_domain) = Moo::unwrap_or_clone(subject.clone())
38                .unwrap_matrix_unchecked()
39                .ok_or(RuleNotApplicable)?;
40
41            // must be indexing a 1d matrix.
42            //
43            // for n-d matrices, wait for the `remove_dimension_from_matrix_indexing` rule to run
44            // first. This reduces n-d indexing operations to 1d.
45            if indices.len() != 1 {
46                return Err(RuleNotApplicable);
47            }
48
49            // the index must be a number
50            let index: i32 = (&indices[0]).try_into().map_err(|_| RuleNotApplicable)?;
51
52            // index domain must be a single integer range with a lower bound
53            if let Some(ranges) = index_domain.as_int_ground()
54                && ranges.len() == 1
55                && let Some(from) = ranges[0].low()
56            {
57                let zero_indexed_index = index - from;
58                Ok(Reduction::pure(es[zero_indexed_index as usize].clone()))
59            } else {
60                Err(RuleNotApplicable)
61            }
62        }
63        Expr::SafeSlice(_, _, _) => Err(RuleNotApplicable),
64        Expr::InDomain(_, x, domain) => {
65            if let Expr::Atomic(_, Atom::Reference(decl)) = x.as_ref() {
66                let decl_domain = decl
67                    .domain()
68                    .ok_or(RuleNotApplicable)?
69                    .resolve()
70                    .ok_or(RuleNotApplicable)?;
71                let domain = domain.resolve().ok_or(RuleNotApplicable)?;
72
73                let intersection = decl_domain
74                    .intersect(&domain)
75                    .map_err(|_| RuleNotApplicable)?;
76
77                // if the declaration's domain is a subset of domain, expr is always true.
78                if &intersection == decl_domain.as_ref() {
79                    Ok(Reduction::pure(Expr::Atomic(Metadata::new(), true.into())))
80                }
81                // if no elements of declaration's domain are in the domain (i.e. they have no
82                // intersection), expr is always false.
83                //
84                // Only check this when the intersection is a finite integer domain, as we
85                // currently don't have a way to check whether other domain kinds are empty or not.
86                //
87                // we should expand this to cover more domain types in the future.
88                else if let Ok(values_in_domain) = intersection.values_i32()
89                    && values_in_domain.is_empty()
90                {
91                    Ok(Reduction::pure(Expr::Atomic(Metadata::new(), false.into())))
92                } else {
93                    Err(RuleNotApplicable)
94                }
95            } else if let Expr::Atomic(_, Atom::Literal(lit)) = x.as_ref() {
96                if domain
97                    .resolve()
98                    .ok_or(RuleNotApplicable)?
99                    .contains(lit)
100                    .ok()
101                    .ok_or(RuleNotApplicable)?
102                {
103                    Ok(Reduction::pure(Expr::Atomic(Metadata::new(), true.into())))
104                } else {
105                    Ok(Reduction::pure(Expr::Atomic(Metadata::new(), false.into())))
106                }
107            } else {
108                Err(RuleNotApplicable)
109            }
110        }
111        Expr::Bubble(_, expr, cond) => {
112            // definition of bubble is "expr is valid as long as cond is true"
113            //
114            // check if cond is true and pop the bubble!
115            if let Expr::Atomic(_, Atom::Literal(Lit::Bool(true))) = cond.as_ref() {
116                Ok(Reduction::pure(Moo::unwrap_or_clone(expr.clone())))
117            } else {
118                Err(RuleNotApplicable)
119            }
120        }
121        Expr::Atomic(_, _) => Err(RuleNotApplicable),
122        Expr::ToInt(_, expression) => {
123            if expression.return_type() == ReturnType::Int {
124                Ok(Reduction::pure(Moo::unwrap_or_clone(expression.clone())))
125            } else {
126                Err(RuleNotApplicable)
127            }
128        }
129        Expr::Abs(m, e) => match e.as_ref() {
130            Expr::Neg(_, inner) => Ok(Reduction::pure(Expr::Abs(m.clone(), inner.clone()))),
131            _ => Err(RuleNotApplicable),
132        },
133        Expr::Sum(m, vec) => {
134            let vec = Moo::unwrap_or_clone(vec.clone())
135                .unwrap_list()
136                .ok_or(RuleNotApplicable)?;
137            let mut acc = 0;
138            let mut n_consts = 0;
139            let mut new_vec: Vec<Expr> = Vec::new();
140            for expr in vec {
141                if let Expr::Atomic(_, Atom::Literal(Lit::Int(x))) = expr {
142                    acc += x;
143                    n_consts += 1;
144                } else {
145                    new_vec.push(expr);
146                }
147            }
148            if acc != 0 {
149                new_vec.push(Expr::Atomic(
150                    Default::default(),
151                    Atom::Literal(Lit::Int(acc)),
152                ));
153            }
154
155            if n_consts <= 1 {
156                Err(RuleNotApplicable)
157            } else {
158                Ok(Reduction::pure(Expr::Sum(
159                    m.clone(),
160                    Moo::new(into_matrix_expr![new_vec]),
161                )))
162            }
163        }
164
165        Expr::Product(m, vec) => {
166            let mut acc = 1;
167            let mut n_consts = 0;
168            let mut new_vec: Vec<Expr> = Vec::new();
169            let vec = Moo::unwrap_or_clone(vec.clone())
170                .unwrap_list()
171                .ok_or(RuleNotApplicable)?;
172            for expr in vec {
173                if let Expr::Atomic(_, Atom::Literal(Lit::Int(x))) = expr {
174                    acc *= x;
175                    n_consts += 1;
176                } else {
177                    new_vec.push(expr);
178                }
179            }
180
181            if n_consts == 0 {
182                return Err(RuleNotApplicable);
183            }
184
185            new_vec.push(Expr::Atomic(
186                Default::default(),
187                Atom::Literal(Lit::Int(acc)),
188            ));
189            let new_product = Expr::Product(m.clone(), Moo::new(into_matrix_expr![new_vec]));
190
191            if acc == 0 {
192                // if safe, 0 * exprs ~> 0
193                // otherwise, just return 0* exprs
194                if new_product.is_safe() {
195                    Ok(Reduction::pure(Expr::Atomic(
196                        Default::default(),
197                        Atom::Literal(Lit::Int(0)),
198                    )))
199                } else {
200                    Ok(Reduction::pure(new_product))
201                }
202            } else if n_consts == 1 {
203                // acc !=0, only one constant
204                Err(RuleNotApplicable)
205            } else {
206                // acc !=0, multiple constants found
207                Ok(Reduction::pure(new_product))
208            }
209        }
210
211        Expr::Min(m, e) => {
212            let Some(vec) = Moo::unwrap_or_clone(e.clone()).unwrap_list() else {
213                return Err(RuleNotApplicable);
214            };
215            let mut acc: Option<i32> = None;
216            let mut n_consts = 0;
217            let mut new_vec: Vec<Expr> = Vec::new();
218            for expr in vec {
219                if let Expr::Atomic(_, Atom::Literal(Lit::Int(x))) = expr {
220                    n_consts += 1;
221                    acc = match acc {
222                        Some(i) => {
223                            if i > x {
224                                Some(x)
225                            } else {
226                                Some(i)
227                            }
228                        }
229                        None => Some(x),
230                    };
231                } else {
232                    new_vec.push(expr);
233                }
234            }
235
236            if let Some(i) = acc {
237                new_vec.push(Expr::Atomic(Default::default(), Atom::Literal(Lit::Int(i))));
238            }
239
240            if n_consts <= 1 {
241                Err(RuleNotApplicable)
242            } else {
243                Ok(Reduction::pure(Expr::Min(
244                    m.clone(),
245                    Moo::new(into_matrix_expr![new_vec]),
246                )))
247            }
248        }
249
250        Expr::Max(m, e) => {
251            let Some(vec) = Moo::unwrap_or_clone(e.clone()).unwrap_list() else {
252                return Err(RuleNotApplicable);
253            };
254
255            let mut acc: Option<i32> = None;
256            let mut n_consts = 0;
257            let mut new_vec: Vec<Expr> = Vec::new();
258            for expr in vec {
259                if let Expr::Atomic(_, Atom::Literal(Lit::Int(x))) = expr {
260                    n_consts += 1;
261                    acc = match acc {
262                        Some(i) => {
263                            if i < x {
264                                Some(x)
265                            } else {
266                                Some(i)
267                            }
268                        }
269                        None => Some(x),
270                    };
271                } else {
272                    new_vec.push(expr);
273                }
274            }
275
276            if let Some(i) = acc {
277                new_vec.push(Expr::Atomic(Default::default(), Atom::Literal(Lit::Int(i))));
278            }
279
280            if n_consts <= 1 {
281                Err(RuleNotApplicable)
282            } else {
283                Ok(Reduction::pure(Expr::Max(
284                    m.clone(),
285                    Moo::new(into_matrix_expr![new_vec]),
286                )))
287            }
288        }
289        Expr::Not(_, _) => Err(RuleNotApplicable),
290        Expr::Or(m, e) => {
291            let Some(terms) = Moo::unwrap_or_clone(e.clone()).unwrap_list() else {
292                return Err(RuleNotApplicable);
293            };
294
295            let mut has_changed = false;
296
297            // 2. boolean literals
298            let mut new_terms = vec![];
299            for expr in terms {
300                if let Expr::Atomic(_, Atom::Literal(Lit::Bool(x))) = expr {
301                    has_changed = true;
302
303                    // true ~~> entire or is true
304                    // false ~~> remove false from the or
305                    if x {
306                        return Ok(Reduction::pure(true.into()));
307                    }
308                } else {
309                    new_terms.push(expr);
310                }
311            }
312
313            // 2. check pairwise tautologies.
314            if check_pairwise_or_tautologies(&new_terms) {
315                return Ok(Reduction::pure(true.into()));
316            }
317
318            // 3. empty or ~~> false
319            if new_terms.is_empty() {
320                return Ok(Reduction::pure(false.into()));
321            }
322
323            if !has_changed {
324                return Err(RuleNotApplicable);
325            }
326
327            Ok(Reduction::pure(Expr::Or(
328                m.clone(),
329                Moo::new(into_matrix_expr![new_terms]),
330            )))
331        }
332        Expr::And(_, e) => {
333            let Some(vec) = Moo::unwrap_or_clone(e.clone()).unwrap_list() else {
334                return Err(RuleNotApplicable);
335            };
336            let mut new_vec: Vec<Expr> = Vec::new();
337            let mut has_const: bool = false;
338            for expr in vec {
339                if let Expr::Atomic(_, Atom::Literal(Lit::Bool(x))) = expr {
340                    has_const = true;
341                    if !x {
342                        return Ok(Reduction::pure(Expr::Atomic(
343                            Default::default(),
344                            Atom::Literal(Lit::Bool(false)),
345                        )));
346                    }
347                } else {
348                    new_vec.push(expr);
349                }
350            }
351
352            if !has_const {
353                Err(RuleNotApplicable)
354            } else {
355                Ok(Reduction::pure(Expr::And(
356                    Metadata::new(),
357                    Moo::new(into_matrix_expr![new_vec]),
358                )))
359            }
360        }
361
362        // similar to And, but booleans are returned wrapped in Root.
363        Expr::Root(_, es) => {
364            match es.as_slice() {
365                [] => Err(RuleNotApplicable),
366                // want to unwrap nested ands
367                [Expr::And(_, _)] => Ok(()),
368                // root([true]) / root([false]) are already evaluated
369                [_] => Err(RuleNotApplicable),
370                [_, _, ..] => Ok(()),
371            }?;
372
373            let mut new_vec: Vec<Expr> = Vec::new();
374            let mut has_changed: bool = false;
375            for expr in es {
376                match expr {
377                    Expr::Atomic(_, Atom::Literal(Lit::Bool(x))) => {
378                        has_changed = true;
379                        if !x {
380                            // false
381                            return Ok(Reduction::pure(Expr::Root(
382                                Metadata::new(),
383                                vec![Expr::Atomic(
384                                    Default::default(),
385                                    Atom::Literal(Lit::Bool(false)),
386                                )],
387                            )));
388                        }
389                        // remove trues
390                    }
391
392                    // flatten ands in root
393                    Expr::And(_, vecs) => match Moo::unwrap_or_clone(vecs.clone()).unwrap_list() {
394                        Some(mut list) => {
395                            has_changed = true;
396                            new_vec.append(&mut list);
397                        }
398                        None => new_vec.push(expr.clone()),
399                    },
400                    _ => new_vec.push(expr.clone()),
401                }
402            }
403
404            if !has_changed {
405                Err(RuleNotApplicable)
406            } else {
407                if new_vec.is_empty() {
408                    new_vec.push(true.into());
409                }
410                Ok(Reduction::pure(Expr::Root(Metadata::new(), new_vec)))
411            }
412        }
413        Expr::Imply(_m, x, y) => {
414            if let Expr::Atomic(_, Atom::Literal(Lit::Bool(x))) = x.as_ref() {
415                if *x {
416                    // (true) -> y ~~> y
417                    return Ok(Reduction::pure(Moo::unwrap_or_clone(y.clone())));
418                } else {
419                    // (false) -> y ~~> true
420                    return Ok(Reduction::pure(Expr::Atomic(Metadata::new(), true.into())));
421                }
422            };
423
424            // reflexivity: p -> p ~> true
425
426            // instead of checking syntactic equivalence of a possibly deep expression,
427            // let identical-CSE turn them into identical variables first. Then, check if they are
428            // identical variables.
429
430            if x.identical_atom_to(y.as_ref()) {
431                return Ok(Reduction::pure(true.into()));
432            }
433
434            Err(RuleNotApplicable)
435        }
436        Expr::Iff(_m, x, y) => {
437            if let Expr::Atomic(_, Atom::Literal(Lit::Bool(x))) = x.as_ref() {
438                if *x {
439                    // (true) <-> y ~~> y
440                    return Ok(Reduction::pure(Moo::unwrap_or_clone(y.clone())));
441                } else {
442                    // (false) <-> y ~~> !y
443                    return Ok(Reduction::pure(Expr::Not(Metadata::new(), y.clone())));
444                }
445            };
446            if let Expr::Atomic(_, Atom::Literal(Lit::Bool(y))) = y.as_ref() {
447                if *y {
448                    // x <-> (true) ~~> x
449                    return Ok(Reduction::pure(Moo::unwrap_or_clone(x.clone())));
450                } else {
451                    // x <-> (false) ~~> !x
452                    return Ok(Reduction::pure(Expr::Not(Metadata::new(), x.clone())));
453                }
454            };
455
456            // reflexivity: p <-> p ~> true
457
458            // instead of checking syntactic equivalence of a possibly deep expression,
459            // let identical-CSE turn them into identical variables first. Then, check if they are
460            // identical variables.
461
462            if x.identical_atom_to(y.as_ref()) {
463                return Ok(Reduction::pure(true.into()));
464            }
465
466            Err(RuleNotApplicable)
467        }
468        Expr::Eq(_, _, _) => Err(RuleNotApplicable),
469        Expr::Neq(_, _, _) => Err(RuleNotApplicable),
470        Expr::Geq(_, _, _) => Err(RuleNotApplicable),
471        Expr::Leq(_, _, _) => Err(RuleNotApplicable),
472        Expr::Gt(_, _, _) => Err(RuleNotApplicable),
473        Expr::Lt(_, _, _) => Err(RuleNotApplicable),
474        Expr::SafeDiv(_, _, _) => Err(RuleNotApplicable),
475        Expr::UnsafeDiv(_, _, _) => Err(RuleNotApplicable),
476        Expr::Flatten(_, _, _) => Err(RuleNotApplicable), // TODO: check if anything can be done here
477        Expr::AllDiff(m, e) => {
478            let Some(vec) = Moo::unwrap_or_clone(e.clone()).unwrap_list() else {
479                return Err(RuleNotApplicable);
480            };
481
482            let mut consts: HashSet<i32> = HashSet::new();
483
484            // check for duplicate constant values which would fail the constraint
485            for expr in vec {
486                if let Expr::Atomic(_, Atom::Literal(Lit::Int(x))) = expr
487                    && !consts.insert(x)
488                {
489                    return Ok(Reduction::pure(Expr::Atomic(
490                        m.clone(),
491                        Atom::Literal(Lit::Bool(false)),
492                    )));
493                }
494            }
495
496            // nothing has changed
497            Err(RuleNotApplicable)
498        }
499        Expr::Neg(_, _) => Err(RuleNotApplicable),
500        Expr::AuxDeclaration(_, _, _) => Err(RuleNotApplicable),
501        Expr::UnsafeMod(_, _, _) => Err(RuleNotApplicable),
502        Expr::SafeMod(_, _, _) => Err(RuleNotApplicable),
503        Expr::UnsafePow(_, _, _) => Err(RuleNotApplicable),
504        Expr::SafePow(_, _, _) => Err(RuleNotApplicable),
505        Expr::Minus(_, _, _) => Err(RuleNotApplicable),
506
507        // As these are in a low level solver form, I'm assuming that these have already been
508        // simplified and partially evaluated.
509        Expr::FlatAllDiff(_, _) => Err(RuleNotApplicable),
510        Expr::FlatAbsEq(_, _, _) => Err(RuleNotApplicable),
511        Expr::FlatIneq(_, _, _, _) => Err(RuleNotApplicable),
512        Expr::FlatMinusEq(_, _, _) => Err(RuleNotApplicable),
513        Expr::FlatProductEq(_, _, _, _) => Err(RuleNotApplicable),
514        Expr::FlatSumLeq(_, _, _) => Err(RuleNotApplicable),
515        Expr::FlatSumGeq(_, _, _) => Err(RuleNotApplicable),
516        Expr::FlatWatchedLiteral(_, _, _) => Err(RuleNotApplicable),
517        Expr::FlatWeightedSumLeq(_, _, _, _) => Err(RuleNotApplicable),
518        Expr::FlatWeightedSumGeq(_, _, _, _) => Err(RuleNotApplicable),
519        Expr::MinionDivEqUndefZero(_, _, _, _) => Err(RuleNotApplicable),
520        Expr::MinionModuloEqUndefZero(_, _, _, _) => Err(RuleNotApplicable),
521        Expr::MinionPow(_, _, _, _) => Err(RuleNotApplicable),
522        Expr::MinionReify(_, _, _) => Err(RuleNotApplicable),
523        Expr::MinionReifyImply(_, _, _) => Err(RuleNotApplicable),
524        Expr::MinionWInIntervalSet(_, _, _) => Err(RuleNotApplicable),
525        Expr::MinionWInSet(_, _, _) => Err(RuleNotApplicable),
526        Expr::MinionElementOne(_, _, _, _) => Err(RuleNotApplicable),
527        Expr::SATInt(_, _, _, _) => Err(RuleNotApplicable),
528        Expr::PairwiseSum(_, _, _) => Err(RuleNotApplicable),
529        Expr::PairwiseProduct(_, _, _) => Err(RuleNotApplicable),
530        Expr::Defined(_, _) => todo!(),
531        Expr::Range(_, _) => todo!(),
532        Expr::Image(_, _, _) => todo!(),
533        Expr::ImageSet(_, _, _) => todo!(),
534        Expr::PreImage(_, _, _) => todo!(),
535        Expr::Inverse(_, _, _) => todo!(),
536        Expr::Restrict(_, _, _) => todo!(),
537        Expr::LexLt(_, _, _) => Err(RuleNotApplicable),
538        Expr::LexLeq(_, _, _) => Err(RuleNotApplicable),
539        Expr::LexGt(_, _, _) => Err(RuleNotApplicable),
540        Expr::LexGeq(_, _, _) => Err(RuleNotApplicable),
541        Expr::FlatLexLt(_, _, _) => Err(RuleNotApplicable),
542        Expr::FlatLexLeq(_, _, _) => Err(RuleNotApplicable),
543    }
544}
545
546/// Checks for tautologies involving pairs of terms inside an or, returning true if one is found.
547///
548/// This applies the following rules:
549///
550/// ```text
551/// (p->q) \/ (q->p) ~> true    [totality of implication]
552/// (p->q) \/ (p-> !q) ~> true  [conditional excluded middle]
553/// ```
554///
555fn check_pairwise_or_tautologies(or_terms: &[Expr]) -> bool {
556    // Collect terms that are structurally identical to the rule input.
557    // Then, try the rules on these terms, also checking the other conditions of the rules.
558
559    // stores (p,q) in p -> q
560    let mut p_implies_q: Vec<(&Expr, &Expr)> = vec![];
561
562    // stores (p,q) in p -> !q
563    let mut p_implies_not_q: Vec<(&Expr, &Expr)> = vec![];
564
565    for term in or_terms.iter() {
566        if let Expr::Imply(_, p, q) = term {
567            // we use identical_atom_to for equality later on, so these sets are mutually exclusive.
568            //
569            // in general however, p -> !q would be in p_implies_q as (p,!q)
570            if let Expr::Not(_, q_1) = q.as_ref() {
571                p_implies_not_q.push((p.as_ref(), q_1.as_ref()));
572            } else {
573                p_implies_q.push((p.as_ref(), q.as_ref()));
574            }
575        }
576    }
577
578    // `(p->q) \/ (q->p) ~> true    [totality of implication]`
579    for ((p1, q1), (q2, p2)) in iproduct!(p_implies_q.iter(), p_implies_q.iter()) {
580        if p1.identical_atom_to(p2) && q1.identical_atom_to(q2) {
581            return true;
582        }
583    }
584
585    // `(p->q) \/ (p-> !q) ~> true`    [conditional excluded middle]
586    for ((p1, q1), (p2, q2)) in iproduct!(p_implies_q.iter(), p_implies_not_q.iter()) {
587        if p1.identical_atom_to(p2) && q1.identical_atom_to(q2) {
588            return true;
589        }
590    }
591
592    false
593}