conjure_cp_core/ast/
partial_eval.rs

1use std::collections::HashSet;
2
3use crate::ast::Typeable;
4use crate::{
5    ast::{Atom, Expression as Expr, Literal as Lit, Metadata, Moo, ReturnType},
6    into_matrix_expr,
7    rule_engine::{ApplicationError::RuleNotApplicable, ApplicationResult, Reduction},
8};
9use itertools::iproduct;
10
11pub fn run_partial_evaluator(expr: &Expr) -> ApplicationResult {
12    // NOTE: If nothing changes, we must return RuleNotApplicable, or the rewriter will try this
13    // rule infinitely!
14    // This is why we always check whether we found a constant or not.
15    match expr {
16        Expr::Union(_, _, _) => Err(RuleNotApplicable),
17        Expr::In(_, _, _) => Err(RuleNotApplicable),
18        Expr::Intersect(_, _, _) => Err(RuleNotApplicable),
19        Expr::Supset(_, _, _) => Err(RuleNotApplicable),
20        Expr::SupsetEq(_, _, _) => Err(RuleNotApplicable),
21        Expr::Subset(_, _, _) => Err(RuleNotApplicable),
22        Expr::SubsetEq(_, _, _) => Err(RuleNotApplicable),
23        Expr::AbstractLiteral(_, _) => Err(RuleNotApplicable),
24        Expr::Comprehension(_, _) => Err(RuleNotApplicable),
25        Expr::DominanceRelation(_, _) => Err(RuleNotApplicable),
26        Expr::FromSolution(_, _) => Err(RuleNotApplicable),
27        Expr::Metavar(_, _) => Err(RuleNotApplicable),
28        Expr::UnsafeIndex(_, _, _) => Err(RuleNotApplicable),
29        Expr::UnsafeSlice(_, _, _) => Err(RuleNotApplicable),
30        Expr::SafeIndex(_, subject, indices) => {
31            // partially evaluate matrix literals indexed by a constant.
32
33            // subject must be a matrix literal
34            let (es, index_domain) = Moo::unwrap_or_clone(subject.clone())
35                .unwrap_matrix_unchecked()
36                .ok_or(RuleNotApplicable)?;
37
38            // must be indexing a 1d matrix.
39            //
40            // for n-d matrices, wait for the `remove_dimension_from_matrix_indexing` rule to run
41            // first. This reduces n-d indexing operations to 1d.
42            if indices.len() != 1 {
43                return Err(RuleNotApplicable);
44            }
45
46            // the index must be a number
47            let index: i32 = (&indices[0]).try_into().map_err(|_| RuleNotApplicable)?;
48
49            // index domain must be a single integer range with a lower bound
50            if let Some(ranges) = index_domain.as_int_ground()
51                && ranges.len() == 1
52                && let Some(from) = ranges[0].low()
53            {
54                let zero_indexed_index = index - from;
55                Ok(Reduction::pure(es[zero_indexed_index as usize].clone()))
56            } else {
57                Err(RuleNotApplicable)
58            }
59        }
60        Expr::SafeSlice(_, _, _) => Err(RuleNotApplicable),
61        Expr::InDomain(_, x, domain) => {
62            if let Expr::Atomic(_, Atom::Reference(decl)) = x.as_ref() {
63                let decl_domain = decl
64                    .domain()
65                    .ok_or(RuleNotApplicable)?
66                    .resolve()
67                    .ok_or(RuleNotApplicable)?;
68                let domain = domain.resolve().ok_or(RuleNotApplicable)?;
69
70                let intersection = decl_domain
71                    .intersect(&domain)
72                    .map_err(|_| RuleNotApplicable)?;
73
74                // if the declaration's domain is a subset of domain, expr is always true.
75                if &intersection == decl_domain.as_ref() {
76                    Ok(Reduction::pure(Expr::Atomic(Metadata::new(), true.into())))
77                }
78                // if no elements of declaration's domain are in the domain (i.e. they have no
79                // intersection), expr is always false.
80                //
81                // Only check this when the intersection is a finite integer domain, as we
82                // currently don't have a way to check whether other domain kinds are empty or not.
83                //
84                // we should expand this to cover more domain types in the future.
85                else if let Ok(values_in_domain) = intersection.values_i32()
86                    && values_in_domain.is_empty()
87                {
88                    Ok(Reduction::pure(Expr::Atomic(Metadata::new(), false.into())))
89                } else {
90                    Err(RuleNotApplicable)
91                }
92            } else if let Expr::Atomic(_, Atom::Literal(lit)) = x.as_ref() {
93                if domain
94                    .resolve()
95                    .ok_or(RuleNotApplicable)?
96                    .contains(lit)
97                    .ok()
98                    .ok_or(RuleNotApplicable)?
99                {
100                    Ok(Reduction::pure(Expr::Atomic(Metadata::new(), true.into())))
101                } else {
102                    Ok(Reduction::pure(Expr::Atomic(Metadata::new(), false.into())))
103                }
104            } else {
105                Err(RuleNotApplicable)
106            }
107        }
108        Expr::Bubble(_, expr, cond) => {
109            // definition of bubble is "expr is valid as long as cond is true"
110            //
111            // check if cond is true and pop the bubble!
112            if let Expr::Atomic(_, Atom::Literal(Lit::Bool(true))) = cond.as_ref() {
113                Ok(Reduction::pure(Moo::unwrap_or_clone(expr.clone())))
114            } else {
115                Err(RuleNotApplicable)
116            }
117        }
118        Expr::Atomic(_, _) => Err(RuleNotApplicable),
119        Expr::Scope(_, _) => Err(RuleNotApplicable),
120        Expr::ToInt(_, expression) => {
121            if expression.return_type() == ReturnType::Int {
122                Ok(Reduction::pure(Moo::unwrap_or_clone(expression.clone())))
123            } else {
124                Err(RuleNotApplicable)
125            }
126        }
127        Expr::Abs(m, e) => match e.as_ref() {
128            Expr::Neg(_, inner) => Ok(Reduction::pure(Expr::Abs(m.clone(), inner.clone()))),
129            _ => Err(RuleNotApplicable),
130        },
131        Expr::Sum(m, vec) => {
132            let vec = Moo::unwrap_or_clone(vec.clone())
133                .unwrap_list()
134                .ok_or(RuleNotApplicable)?;
135            let mut acc = 0;
136            let mut n_consts = 0;
137            let mut new_vec: Vec<Expr> = Vec::new();
138            for expr in vec {
139                if let Expr::Atomic(_, Atom::Literal(Lit::Int(x))) = expr {
140                    acc += x;
141                    n_consts += 1;
142                } else {
143                    new_vec.push(expr);
144                }
145            }
146            if acc != 0 {
147                new_vec.push(Expr::Atomic(
148                    Default::default(),
149                    Atom::Literal(Lit::Int(acc)),
150                ));
151            }
152
153            if n_consts <= 1 {
154                Err(RuleNotApplicable)
155            } else {
156                Ok(Reduction::pure(Expr::Sum(
157                    m.clone(),
158                    Moo::new(into_matrix_expr![new_vec]),
159                )))
160            }
161        }
162
163        Expr::Product(m, vec) => {
164            let mut acc = 1;
165            let mut n_consts = 0;
166            let mut new_vec: Vec<Expr> = Vec::new();
167            let vec = Moo::unwrap_or_clone(vec.clone())
168                .unwrap_list()
169                .ok_or(RuleNotApplicable)?;
170            for expr in vec {
171                if let Expr::Atomic(_, Atom::Literal(Lit::Int(x))) = expr {
172                    acc *= x;
173                    n_consts += 1;
174                } else {
175                    new_vec.push(expr);
176                }
177            }
178
179            if n_consts == 0 {
180                return Err(RuleNotApplicable);
181            }
182
183            new_vec.push(Expr::Atomic(
184                Default::default(),
185                Atom::Literal(Lit::Int(acc)),
186            ));
187            let new_product = Expr::Product(m.clone(), Moo::new(into_matrix_expr![new_vec]));
188
189            if acc == 0 {
190                // if safe, 0 * exprs ~> 0
191                // otherwise, just return 0* exprs
192                if new_product.is_safe() {
193                    Ok(Reduction::pure(Expr::Atomic(
194                        Default::default(),
195                        Atom::Literal(Lit::Int(0)),
196                    )))
197                } else {
198                    Ok(Reduction::pure(new_product))
199                }
200            } else if n_consts == 1 {
201                // acc !=0, only one constant
202                Err(RuleNotApplicable)
203            } else {
204                // acc !=0, multiple constants found
205                Ok(Reduction::pure(new_product))
206            }
207        }
208
209        Expr::Min(m, e) => {
210            let Some(vec) = Moo::unwrap_or_clone(e.clone()).unwrap_list() else {
211                return Err(RuleNotApplicable);
212            };
213            let mut acc: Option<i32> = None;
214            let mut n_consts = 0;
215            let mut new_vec: Vec<Expr> = Vec::new();
216            for expr in vec {
217                if let Expr::Atomic(_, Atom::Literal(Lit::Int(x))) = expr {
218                    n_consts += 1;
219                    acc = match acc {
220                        Some(i) => {
221                            if i > x {
222                                Some(x)
223                            } else {
224                                Some(i)
225                            }
226                        }
227                        None => Some(x),
228                    };
229                } else {
230                    new_vec.push(expr);
231                }
232            }
233
234            if let Some(i) = acc {
235                new_vec.push(Expr::Atomic(Default::default(), Atom::Literal(Lit::Int(i))));
236            }
237
238            if n_consts <= 1 {
239                Err(RuleNotApplicable)
240            } else {
241                Ok(Reduction::pure(Expr::Min(
242                    m.clone(),
243                    Moo::new(into_matrix_expr![new_vec]),
244                )))
245            }
246        }
247
248        Expr::Max(m, e) => {
249            let Some(vec) = Moo::unwrap_or_clone(e.clone()).unwrap_list() else {
250                return Err(RuleNotApplicable);
251            };
252
253            let mut acc: Option<i32> = None;
254            let mut n_consts = 0;
255            let mut new_vec: Vec<Expr> = Vec::new();
256            for expr in vec {
257                if let Expr::Atomic(_, Atom::Literal(Lit::Int(x))) = expr {
258                    n_consts += 1;
259                    acc = match acc {
260                        Some(i) => {
261                            if i < x {
262                                Some(x)
263                            } else {
264                                Some(i)
265                            }
266                        }
267                        None => Some(x),
268                    };
269                } else {
270                    new_vec.push(expr);
271                }
272            }
273
274            if let Some(i) = acc {
275                new_vec.push(Expr::Atomic(Default::default(), Atom::Literal(Lit::Int(i))));
276            }
277
278            if n_consts <= 1 {
279                Err(RuleNotApplicable)
280            } else {
281                Ok(Reduction::pure(Expr::Max(
282                    m.clone(),
283                    Moo::new(into_matrix_expr![new_vec]),
284                )))
285            }
286        }
287        Expr::Not(_, _) => Err(RuleNotApplicable),
288        Expr::Or(m, e) => {
289            let Some(terms) = Moo::unwrap_or_clone(e.clone()).unwrap_list() else {
290                return Err(RuleNotApplicable);
291            };
292
293            let mut has_changed = false;
294
295            // 2. boolean literals
296            let mut new_terms = vec![];
297            for expr in terms {
298                if let Expr::Atomic(_, Atom::Literal(Lit::Bool(x))) = expr {
299                    has_changed = true;
300
301                    // true ~~> entire or is true
302                    // false ~~> remove false from the or
303                    if x {
304                        return Ok(Reduction::pure(true.into()));
305                    }
306                } else {
307                    new_terms.push(expr);
308                }
309            }
310
311            // 2. check pairwise tautologies.
312            if check_pairwise_or_tautologies(&new_terms) {
313                return Ok(Reduction::pure(true.into()));
314            }
315
316            // 3. empty or ~~> false
317            if new_terms.is_empty() {
318                return Ok(Reduction::pure(false.into()));
319            }
320
321            if !has_changed {
322                return Err(RuleNotApplicable);
323            }
324
325            Ok(Reduction::pure(Expr::Or(
326                m.clone(),
327                Moo::new(into_matrix_expr![new_terms]),
328            )))
329        }
330        Expr::And(_, e) => {
331            let Some(vec) = Moo::unwrap_or_clone(e.clone()).unwrap_list() else {
332                return Err(RuleNotApplicable);
333            };
334            let mut new_vec: Vec<Expr> = Vec::new();
335            let mut has_const: bool = false;
336            for expr in vec {
337                if let Expr::Atomic(_, Atom::Literal(Lit::Bool(x))) = expr {
338                    has_const = true;
339                    if !x {
340                        return Ok(Reduction::pure(Expr::Atomic(
341                            Default::default(),
342                            Atom::Literal(Lit::Bool(false)),
343                        )));
344                    }
345                } else {
346                    new_vec.push(expr);
347                }
348            }
349
350            if !has_const {
351                Err(RuleNotApplicable)
352            } else {
353                Ok(Reduction::pure(Expr::And(
354                    Metadata::new(),
355                    Moo::new(into_matrix_expr![new_vec]),
356                )))
357            }
358        }
359
360        // similar to And, but booleans are returned wrapped in Root.
361        Expr::Root(_, es) => {
362            match es.as_slice() {
363                [] => Err(RuleNotApplicable),
364                // want to unwrap nested ands
365                [Expr::And(_, _)] => Ok(()),
366                // root([true]) / root([false]) are already evaluated
367                [_] => Err(RuleNotApplicable),
368                [_, _, ..] => Ok(()),
369            }?;
370
371            let mut new_vec: Vec<Expr> = Vec::new();
372            let mut has_changed: bool = false;
373            for expr in es {
374                match expr {
375                    Expr::Atomic(_, Atom::Literal(Lit::Bool(x))) => {
376                        has_changed = true;
377                        if !x {
378                            // false
379                            return Ok(Reduction::pure(Expr::Root(
380                                Metadata::new(),
381                                vec![Expr::Atomic(
382                                    Default::default(),
383                                    Atom::Literal(Lit::Bool(false)),
384                                )],
385                            )));
386                        }
387                        // remove trues
388                    }
389
390                    // flatten ands in root
391                    Expr::And(_, vecs) => match Moo::unwrap_or_clone(vecs.clone()).unwrap_list() {
392                        Some(mut list) => {
393                            has_changed = true;
394                            new_vec.append(&mut list);
395                        }
396                        None => new_vec.push(expr.clone()),
397                    },
398                    _ => new_vec.push(expr.clone()),
399                }
400            }
401
402            if !has_changed {
403                Err(RuleNotApplicable)
404            } else {
405                if new_vec.is_empty() {
406                    new_vec.push(true.into());
407                }
408                Ok(Reduction::pure(Expr::Root(Metadata::new(), new_vec)))
409            }
410        }
411        Expr::Imply(_m, x, y) => {
412            if let Expr::Atomic(_, Atom::Literal(Lit::Bool(x))) = x.as_ref() {
413                if *x {
414                    // (true) -> y ~~> y
415                    return Ok(Reduction::pure(Moo::unwrap_or_clone(y.clone())));
416                } else {
417                    // (false) -> y ~~> true
418                    return Ok(Reduction::pure(Expr::Atomic(Metadata::new(), true.into())));
419                }
420            };
421
422            // reflexivity: p -> p ~> true
423
424            // instead of checking syntactic equivalence of a possibly deep expression,
425            // let identical-CSE turn them into identical variables first. Then, check if they are
426            // identical variables.
427
428            if x.identical_atom_to(y.as_ref()) {
429                return Ok(Reduction::pure(true.into()));
430            }
431
432            Err(RuleNotApplicable)
433        }
434        Expr::Iff(_m, x, y) => {
435            if let Expr::Atomic(_, Atom::Literal(Lit::Bool(x))) = x.as_ref() {
436                if *x {
437                    // (true) <-> y ~~> y
438                    return Ok(Reduction::pure(Moo::unwrap_or_clone(y.clone())));
439                } else {
440                    // (false) <-> y ~~> !y
441                    return Ok(Reduction::pure(Expr::Not(Metadata::new(), y.clone())));
442                }
443            };
444            if let Expr::Atomic(_, Atom::Literal(Lit::Bool(y))) = y.as_ref() {
445                if *y {
446                    // x <-> (true) ~~> x
447                    return Ok(Reduction::pure(Moo::unwrap_or_clone(x.clone())));
448                } else {
449                    // x <-> (false) ~~> !x
450                    return Ok(Reduction::pure(Expr::Not(Metadata::new(), x.clone())));
451                }
452            };
453
454            // reflexivity: p <-> p ~> true
455
456            // instead of checking syntactic equivalence of a possibly deep expression,
457            // let identical-CSE turn them into identical variables first. Then, check if they are
458            // identical variables.
459
460            if x.identical_atom_to(y.as_ref()) {
461                return Ok(Reduction::pure(true.into()));
462            }
463
464            Err(RuleNotApplicable)
465        }
466        Expr::Eq(_, _, _) => Err(RuleNotApplicable),
467        Expr::Neq(_, _, _) => Err(RuleNotApplicable),
468        Expr::Geq(_, _, _) => Err(RuleNotApplicable),
469        Expr::Leq(_, _, _) => Err(RuleNotApplicable),
470        Expr::Gt(_, _, _) => Err(RuleNotApplicable),
471        Expr::Lt(_, _, _) => Err(RuleNotApplicable),
472        Expr::SafeDiv(_, _, _) => Err(RuleNotApplicable),
473        Expr::UnsafeDiv(_, _, _) => Err(RuleNotApplicable),
474        Expr::Flatten(_, _, _) => Err(RuleNotApplicable), // TODO: check if anything can be done here
475        Expr::AllDiff(m, e) => {
476            let Some(vec) = Moo::unwrap_or_clone(e.clone()).unwrap_list() else {
477                return Err(RuleNotApplicable);
478            };
479
480            let mut consts: HashSet<i32> = HashSet::new();
481
482            // check for duplicate constant values which would fail the constraint
483            for expr in vec {
484                if let Expr::Atomic(_, Atom::Literal(Lit::Int(x))) = expr
485                    && !consts.insert(x)
486                {
487                    return Ok(Reduction::pure(Expr::Atomic(
488                        m.clone(),
489                        Atom::Literal(Lit::Bool(false)),
490                    )));
491                }
492            }
493
494            // nothing has changed
495            Err(RuleNotApplicable)
496        }
497        Expr::Neg(_, _) => Err(RuleNotApplicable),
498        Expr::AuxDeclaration(_, _, _) => Err(RuleNotApplicable),
499        Expr::UnsafeMod(_, _, _) => Err(RuleNotApplicable),
500        Expr::SafeMod(_, _, _) => Err(RuleNotApplicable),
501        Expr::UnsafePow(_, _, _) => Err(RuleNotApplicable),
502        Expr::SafePow(_, _, _) => Err(RuleNotApplicable),
503        Expr::Minus(_, _, _) => Err(RuleNotApplicable),
504
505        // As these are in a low level solver form, I'm assuming that these have already been
506        // simplified and partially evaluated.
507        Expr::FlatAllDiff(_, _) => Err(RuleNotApplicable),
508        Expr::FlatAbsEq(_, _, _) => Err(RuleNotApplicable),
509        Expr::FlatIneq(_, _, _, _) => Err(RuleNotApplicable),
510        Expr::FlatMinusEq(_, _, _) => Err(RuleNotApplicable),
511        Expr::FlatProductEq(_, _, _, _) => Err(RuleNotApplicable),
512        Expr::FlatSumLeq(_, _, _) => Err(RuleNotApplicable),
513        Expr::FlatSumGeq(_, _, _) => Err(RuleNotApplicable),
514        Expr::FlatWatchedLiteral(_, _, _) => Err(RuleNotApplicable),
515        Expr::FlatWeightedSumLeq(_, _, _, _) => Err(RuleNotApplicable),
516        Expr::FlatWeightedSumGeq(_, _, _, _) => Err(RuleNotApplicable),
517        Expr::MinionDivEqUndefZero(_, _, _, _) => Err(RuleNotApplicable),
518        Expr::MinionModuloEqUndefZero(_, _, _, _) => Err(RuleNotApplicable),
519        Expr::MinionPow(_, _, _, _) => Err(RuleNotApplicable),
520        Expr::MinionReify(_, _, _) => Err(RuleNotApplicable),
521        Expr::MinionReifyImply(_, _, _) => Err(RuleNotApplicable),
522        Expr::MinionWInIntervalSet(_, _, _) => Err(RuleNotApplicable),
523        Expr::MinionWInSet(_, _, _) => Err(RuleNotApplicable),
524        Expr::MinionElementOne(_, _, _, _) => Err(RuleNotApplicable),
525        Expr::SATInt(_, _) => Err(RuleNotApplicable),
526        Expr::PairwiseSum(_, _, _) => Err(RuleNotApplicable),
527        Expr::PairwiseProduct(_, _, _) => Err(RuleNotApplicable),
528        Expr::Defined(_, _) => todo!(),
529        Expr::Range(_, _) => todo!(),
530        Expr::Image(_, _, _) => todo!(),
531        Expr::ImageSet(_, _, _) => todo!(),
532        Expr::PreImage(_, _, _) => todo!(),
533        Expr::Inverse(_, _, _) => todo!(),
534        Expr::Restrict(_, _, _) => todo!(),
535        Expr::LexLt(_, _, _) => Err(RuleNotApplicable),
536        Expr::LexLeq(_, _, _) => Err(RuleNotApplicable),
537        Expr::LexGt(_, _, _) => Err(RuleNotApplicable),
538        Expr::LexGeq(_, _, _) => Err(RuleNotApplicable),
539        Expr::FlatLexLt(_, _, _) => Err(RuleNotApplicable),
540        Expr::FlatLexLeq(_, _, _) => Err(RuleNotApplicable),
541    }
542}
543
544/// Checks for tautologies involving pairs of terms inside an or, returning true if one is found.
545///
546/// This applies the following rules:
547///
548/// ```text
549/// (p->q) \/ (q->p) ~> true    [totality of implication]
550/// (p->q) \/ (p-> !q) ~> true  [conditional excluded middle]
551/// ```
552///
553fn check_pairwise_or_tautologies(or_terms: &[Expr]) -> bool {
554    // Collect terms that are structurally identical to the rule input.
555    // Then, try the rules on these terms, also checking the other conditions of the rules.
556
557    // stores (p,q) in p -> q
558    let mut p_implies_q: Vec<(&Expr, &Expr)> = vec![];
559
560    // stores (p,q) in p -> !q
561    let mut p_implies_not_q: Vec<(&Expr, &Expr)> = vec![];
562
563    for term in or_terms.iter() {
564        if let Expr::Imply(_, p, q) = term {
565            // we use identical_atom_to for equality later on, so these sets are mutually exclusive.
566            //
567            // in general however, p -> !q would be in p_implies_q as (p,!q)
568            if let Expr::Not(_, q_1) = q.as_ref() {
569                p_implies_not_q.push((p.as_ref(), q_1.as_ref()));
570            } else {
571                p_implies_q.push((p.as_ref(), q.as_ref()));
572            }
573        }
574    }
575
576    // `(p->q) \/ (q->p) ~> true    [totality of implication]`
577    for ((p1, q1), (q2, p2)) in iproduct!(p_implies_q.iter(), p_implies_q.iter()) {
578        if p1.identical_atom_to(p2) && q1.identical_atom_to(q2) {
579            return true;
580        }
581    }
582
583    // `(p->q) \/ (p-> !q) ~> true`    [conditional excluded middle]
584    for ((p1, q1), (p2, q2)) in iproduct!(p_implies_q.iter(), p_implies_not_q.iter()) {
585        if p1.identical_atom_to(p2) && q1.identical_atom_to(q2) {
586            return true;
587        }
588    }
589
590    false
591}