Skip to main content

conjure_cp_core/ast/
partial_eval.rs

1use std::collections::HashSet;
2
3use crate::ast::Typeable;
4use crate::{
5    ast::{Atom, Expression as Expr, Literal as Lit, Metadata, Moo, ReturnType},
6    into_matrix_expr,
7    rule_engine::{ApplicationError::RuleNotApplicable, ApplicationResult, Reduction},
8};
9use itertools::iproduct;
10
11pub fn run_partial_evaluator(expr: &Expr) -> ApplicationResult {
12    // NOTE: If nothing changes, we must return RuleNotApplicable, or the rewriter will try this
13    // rule infinitely!
14    // This is why we always check whether we found a constant or not.
15    match expr {
16        Expr::Union(_, _, _) => Err(RuleNotApplicable),
17        Expr::In(_, _, _) => Err(RuleNotApplicable),
18        Expr::Intersect(_, _, _) => Err(RuleNotApplicable),
19        Expr::Supset(_, _, _) => Err(RuleNotApplicable),
20        Expr::SupsetEq(_, _, _) => Err(RuleNotApplicable),
21        Expr::Subset(_, _, _) => Err(RuleNotApplicable),
22        Expr::SubsetEq(_, _, _) => Err(RuleNotApplicable),
23        Expr::AbstractLiteral(_, _) => Err(RuleNotApplicable),
24        Expr::Comprehension(_, _) => Err(RuleNotApplicable),
25        Expr::AbstractComprehension(_, _) => Err(RuleNotApplicable),
26        Expr::DominanceRelation(_, _) => Err(RuleNotApplicable),
27        Expr::FromSolution(_, _) => Err(RuleNotApplicable),
28        Expr::Metavar(_, _) => Err(RuleNotApplicable),
29        Expr::UnsafeIndex(_, _, _) => Err(RuleNotApplicable),
30        Expr::UnsafeSlice(_, _, _) => Err(RuleNotApplicable),
31        Expr::SafeIndex(_, subject, indices) => {
32            // partially evaluate matrix literals indexed by a constant.
33
34            // subject must be a matrix literal
35            let (es, index_domain) = Moo::unwrap_or_clone(subject.clone())
36                .unwrap_matrix_unchecked()
37                .ok_or(RuleNotApplicable)?;
38
39            // must be indexing a 1d matrix.
40            //
41            // for n-d matrices, wait for the `remove_dimension_from_matrix_indexing` rule to run
42            // first. This reduces n-d indexing operations to 1d.
43            if indices.len() != 1 {
44                return Err(RuleNotApplicable);
45            }
46
47            // the index must be a number
48            let index: i32 = (&indices[0]).try_into().map_err(|_| RuleNotApplicable)?;
49
50            // index domain must be a single integer range with a lower bound
51            if let Some(ranges) = index_domain.as_int_ground()
52                && ranges.len() == 1
53                && let Some(from) = ranges[0].low()
54            {
55                let zero_indexed_index = index - from;
56                Ok(Reduction::pure(es[zero_indexed_index as usize].clone()))
57            } else {
58                Err(RuleNotApplicable)
59            }
60        }
61        Expr::SafeSlice(_, _, _) => Err(RuleNotApplicable),
62        Expr::InDomain(_, x, domain) => {
63            if let Expr::Atomic(_, Atom::Reference(decl)) = x.as_ref() {
64                let decl_domain = decl
65                    .domain()
66                    .ok_or(RuleNotApplicable)?
67                    .resolve()
68                    .ok_or(RuleNotApplicable)?;
69                let domain = domain.resolve().ok_or(RuleNotApplicable)?;
70
71                let intersection = decl_domain
72                    .intersect(&domain)
73                    .map_err(|_| RuleNotApplicable)?;
74
75                // if the declaration's domain is a subset of domain, expr is always true.
76                if &intersection == decl_domain.as_ref() {
77                    Ok(Reduction::pure(Expr::Atomic(Metadata::new(), true.into())))
78                }
79                // if no elements of declaration's domain are in the domain (i.e. they have no
80                // intersection), expr is always false.
81                //
82                // Only check this when the intersection is a finite integer domain, as we
83                // currently don't have a way to check whether other domain kinds are empty or not.
84                //
85                // we should expand this to cover more domain types in the future.
86                else if let Ok(values_in_domain) = intersection.values_i32()
87                    && values_in_domain.is_empty()
88                {
89                    Ok(Reduction::pure(Expr::Atomic(Metadata::new(), false.into())))
90                } else {
91                    Err(RuleNotApplicable)
92                }
93            } else if let Expr::Atomic(_, Atom::Literal(lit)) = x.as_ref() {
94                if domain
95                    .resolve()
96                    .ok_or(RuleNotApplicable)?
97                    .contains(lit)
98                    .ok()
99                    .ok_or(RuleNotApplicable)?
100                {
101                    Ok(Reduction::pure(Expr::Atomic(Metadata::new(), true.into())))
102                } else {
103                    Ok(Reduction::pure(Expr::Atomic(Metadata::new(), false.into())))
104                }
105            } else {
106                Err(RuleNotApplicable)
107            }
108        }
109        Expr::Bubble(_, expr, cond) => {
110            // definition of bubble is "expr is valid as long as cond is true"
111            //
112            // check if cond is true and pop the bubble!
113            if let Expr::Atomic(_, Atom::Literal(Lit::Bool(true))) = cond.as_ref() {
114                Ok(Reduction::pure(Moo::unwrap_or_clone(expr.clone())))
115            } else {
116                Err(RuleNotApplicable)
117            }
118        }
119        Expr::Atomic(_, _) => Err(RuleNotApplicable),
120        Expr::Scope(_, _) => Err(RuleNotApplicable),
121        Expr::ToInt(_, expression) => {
122            if expression.return_type() == ReturnType::Int {
123                Ok(Reduction::pure(Moo::unwrap_or_clone(expression.clone())))
124            } else {
125                Err(RuleNotApplicable)
126            }
127        }
128        Expr::Abs(m, e) => match e.as_ref() {
129            Expr::Neg(_, inner) => Ok(Reduction::pure(Expr::Abs(m.clone(), inner.clone()))),
130            _ => Err(RuleNotApplicable),
131        },
132        Expr::Sum(m, vec) => {
133            let vec = Moo::unwrap_or_clone(vec.clone())
134                .unwrap_list()
135                .ok_or(RuleNotApplicable)?;
136            let mut acc = 0;
137            let mut n_consts = 0;
138            let mut new_vec: Vec<Expr> = Vec::new();
139            for expr in vec {
140                if let Expr::Atomic(_, Atom::Literal(Lit::Int(x))) = expr {
141                    acc += x;
142                    n_consts += 1;
143                } else {
144                    new_vec.push(expr);
145                }
146            }
147            if acc != 0 {
148                new_vec.push(Expr::Atomic(
149                    Default::default(),
150                    Atom::Literal(Lit::Int(acc)),
151                ));
152            }
153
154            if n_consts <= 1 {
155                Err(RuleNotApplicable)
156            } else {
157                Ok(Reduction::pure(Expr::Sum(
158                    m.clone(),
159                    Moo::new(into_matrix_expr![new_vec]),
160                )))
161            }
162        }
163
164        Expr::Product(m, vec) => {
165            let mut acc = 1;
166            let mut n_consts = 0;
167            let mut new_vec: Vec<Expr> = Vec::new();
168            let vec = Moo::unwrap_or_clone(vec.clone())
169                .unwrap_list()
170                .ok_or(RuleNotApplicable)?;
171            for expr in vec {
172                if let Expr::Atomic(_, Atom::Literal(Lit::Int(x))) = expr {
173                    acc *= x;
174                    n_consts += 1;
175                } else {
176                    new_vec.push(expr);
177                }
178            }
179
180            if n_consts == 0 {
181                return Err(RuleNotApplicable);
182            }
183
184            new_vec.push(Expr::Atomic(
185                Default::default(),
186                Atom::Literal(Lit::Int(acc)),
187            ));
188            let new_product = Expr::Product(m.clone(), Moo::new(into_matrix_expr![new_vec]));
189
190            if acc == 0 {
191                // if safe, 0 * exprs ~> 0
192                // otherwise, just return 0* exprs
193                if new_product.is_safe() {
194                    Ok(Reduction::pure(Expr::Atomic(
195                        Default::default(),
196                        Atom::Literal(Lit::Int(0)),
197                    )))
198                } else {
199                    Ok(Reduction::pure(new_product))
200                }
201            } else if n_consts == 1 {
202                // acc !=0, only one constant
203                Err(RuleNotApplicable)
204            } else {
205                // acc !=0, multiple constants found
206                Ok(Reduction::pure(new_product))
207            }
208        }
209
210        Expr::Min(m, e) => {
211            let Some(vec) = Moo::unwrap_or_clone(e.clone()).unwrap_list() else {
212                return Err(RuleNotApplicable);
213            };
214            let mut acc: Option<i32> = None;
215            let mut n_consts = 0;
216            let mut new_vec: Vec<Expr> = Vec::new();
217            for expr in vec {
218                if let Expr::Atomic(_, Atom::Literal(Lit::Int(x))) = expr {
219                    n_consts += 1;
220                    acc = match acc {
221                        Some(i) => {
222                            if i > x {
223                                Some(x)
224                            } else {
225                                Some(i)
226                            }
227                        }
228                        None => Some(x),
229                    };
230                } else {
231                    new_vec.push(expr);
232                }
233            }
234
235            if let Some(i) = acc {
236                new_vec.push(Expr::Atomic(Default::default(), Atom::Literal(Lit::Int(i))));
237            }
238
239            if n_consts <= 1 {
240                Err(RuleNotApplicable)
241            } else {
242                Ok(Reduction::pure(Expr::Min(
243                    m.clone(),
244                    Moo::new(into_matrix_expr![new_vec]),
245                )))
246            }
247        }
248
249        Expr::Max(m, e) => {
250            let Some(vec) = Moo::unwrap_or_clone(e.clone()).unwrap_list() else {
251                return Err(RuleNotApplicable);
252            };
253
254            let mut acc: Option<i32> = None;
255            let mut n_consts = 0;
256            let mut new_vec: Vec<Expr> = Vec::new();
257            for expr in vec {
258                if let Expr::Atomic(_, Atom::Literal(Lit::Int(x))) = expr {
259                    n_consts += 1;
260                    acc = match acc {
261                        Some(i) => {
262                            if i < x {
263                                Some(x)
264                            } else {
265                                Some(i)
266                            }
267                        }
268                        None => Some(x),
269                    };
270                } else {
271                    new_vec.push(expr);
272                }
273            }
274
275            if let Some(i) = acc {
276                new_vec.push(Expr::Atomic(Default::default(), Atom::Literal(Lit::Int(i))));
277            }
278
279            if n_consts <= 1 {
280                Err(RuleNotApplicable)
281            } else {
282                Ok(Reduction::pure(Expr::Max(
283                    m.clone(),
284                    Moo::new(into_matrix_expr![new_vec]),
285                )))
286            }
287        }
288        Expr::Not(_, _) => Err(RuleNotApplicable),
289        Expr::Or(m, e) => {
290            let Some(terms) = Moo::unwrap_or_clone(e.clone()).unwrap_list() else {
291                return Err(RuleNotApplicable);
292            };
293
294            let mut has_changed = false;
295
296            // 2. boolean literals
297            let mut new_terms = vec![];
298            for expr in terms {
299                if let Expr::Atomic(_, Atom::Literal(Lit::Bool(x))) = expr {
300                    has_changed = true;
301
302                    // true ~~> entire or is true
303                    // false ~~> remove false from the or
304                    if x {
305                        return Ok(Reduction::pure(true.into()));
306                    }
307                } else {
308                    new_terms.push(expr);
309                }
310            }
311
312            // 2. check pairwise tautologies.
313            if check_pairwise_or_tautologies(&new_terms) {
314                return Ok(Reduction::pure(true.into()));
315            }
316
317            // 3. empty or ~~> false
318            if new_terms.is_empty() {
319                return Ok(Reduction::pure(false.into()));
320            }
321
322            if !has_changed {
323                return Err(RuleNotApplicable);
324            }
325
326            Ok(Reduction::pure(Expr::Or(
327                m.clone(),
328                Moo::new(into_matrix_expr![new_terms]),
329            )))
330        }
331        Expr::And(_, e) => {
332            let Some(vec) = Moo::unwrap_or_clone(e.clone()).unwrap_list() else {
333                return Err(RuleNotApplicable);
334            };
335            let mut new_vec: Vec<Expr> = Vec::new();
336            let mut has_const: bool = false;
337            for expr in vec {
338                if let Expr::Atomic(_, Atom::Literal(Lit::Bool(x))) = expr {
339                    has_const = true;
340                    if !x {
341                        return Ok(Reduction::pure(Expr::Atomic(
342                            Default::default(),
343                            Atom::Literal(Lit::Bool(false)),
344                        )));
345                    }
346                } else {
347                    new_vec.push(expr);
348                }
349            }
350
351            if !has_const {
352                Err(RuleNotApplicable)
353            } else {
354                Ok(Reduction::pure(Expr::And(
355                    Metadata::new(),
356                    Moo::new(into_matrix_expr![new_vec]),
357                )))
358            }
359        }
360
361        // similar to And, but booleans are returned wrapped in Root.
362        Expr::Root(_, es) => {
363            match es.as_slice() {
364                [] => Err(RuleNotApplicable),
365                // want to unwrap nested ands
366                [Expr::And(_, _)] => Ok(()),
367                // root([true]) / root([false]) are already evaluated
368                [_] => Err(RuleNotApplicable),
369                [_, _, ..] => Ok(()),
370            }?;
371
372            let mut new_vec: Vec<Expr> = Vec::new();
373            let mut has_changed: bool = false;
374            for expr in es {
375                match expr {
376                    Expr::Atomic(_, Atom::Literal(Lit::Bool(x))) => {
377                        has_changed = true;
378                        if !x {
379                            // false
380                            return Ok(Reduction::pure(Expr::Root(
381                                Metadata::new(),
382                                vec![Expr::Atomic(
383                                    Default::default(),
384                                    Atom::Literal(Lit::Bool(false)),
385                                )],
386                            )));
387                        }
388                        // remove trues
389                    }
390
391                    // flatten ands in root
392                    Expr::And(_, vecs) => match Moo::unwrap_or_clone(vecs.clone()).unwrap_list() {
393                        Some(mut list) => {
394                            has_changed = true;
395                            new_vec.append(&mut list);
396                        }
397                        None => new_vec.push(expr.clone()),
398                    },
399                    _ => new_vec.push(expr.clone()),
400                }
401            }
402
403            if !has_changed {
404                Err(RuleNotApplicable)
405            } else {
406                if new_vec.is_empty() {
407                    new_vec.push(true.into());
408                }
409                Ok(Reduction::pure(Expr::Root(Metadata::new(), new_vec)))
410            }
411        }
412        Expr::Imply(_m, x, y) => {
413            if let Expr::Atomic(_, Atom::Literal(Lit::Bool(x))) = x.as_ref() {
414                if *x {
415                    // (true) -> y ~~> y
416                    return Ok(Reduction::pure(Moo::unwrap_or_clone(y.clone())));
417                } else {
418                    // (false) -> y ~~> true
419                    return Ok(Reduction::pure(Expr::Atomic(Metadata::new(), true.into())));
420                }
421            };
422
423            // reflexivity: p -> p ~> true
424
425            // instead of checking syntactic equivalence of a possibly deep expression,
426            // let identical-CSE turn them into identical variables first. Then, check if they are
427            // identical variables.
428
429            if x.identical_atom_to(y.as_ref()) {
430                return Ok(Reduction::pure(true.into()));
431            }
432
433            Err(RuleNotApplicable)
434        }
435        Expr::Iff(_m, x, y) => {
436            if let Expr::Atomic(_, Atom::Literal(Lit::Bool(x))) = x.as_ref() {
437                if *x {
438                    // (true) <-> y ~~> y
439                    return Ok(Reduction::pure(Moo::unwrap_or_clone(y.clone())));
440                } else {
441                    // (false) <-> y ~~> !y
442                    return Ok(Reduction::pure(Expr::Not(Metadata::new(), y.clone())));
443                }
444            };
445            if let Expr::Atomic(_, Atom::Literal(Lit::Bool(y))) = y.as_ref() {
446                if *y {
447                    // x <-> (true) ~~> x
448                    return Ok(Reduction::pure(Moo::unwrap_or_clone(x.clone())));
449                } else {
450                    // x <-> (false) ~~> !x
451                    return Ok(Reduction::pure(Expr::Not(Metadata::new(), x.clone())));
452                }
453            };
454
455            // reflexivity: p <-> p ~> true
456
457            // instead of checking syntactic equivalence of a possibly deep expression,
458            // let identical-CSE turn them into identical variables first. Then, check if they are
459            // identical variables.
460
461            if x.identical_atom_to(y.as_ref()) {
462                return Ok(Reduction::pure(true.into()));
463            }
464
465            Err(RuleNotApplicable)
466        }
467        Expr::Eq(_, _, _) => Err(RuleNotApplicable),
468        Expr::Neq(_, _, _) => Err(RuleNotApplicable),
469        Expr::Geq(_, _, _) => Err(RuleNotApplicable),
470        Expr::Leq(_, _, _) => Err(RuleNotApplicable),
471        Expr::Gt(_, _, _) => Err(RuleNotApplicable),
472        Expr::Lt(_, _, _) => Err(RuleNotApplicable),
473        Expr::SafeDiv(_, _, _) => Err(RuleNotApplicable),
474        Expr::UnsafeDiv(_, _, _) => Err(RuleNotApplicable),
475        Expr::Flatten(_, _, _) => Err(RuleNotApplicable), // TODO: check if anything can be done here
476        Expr::AllDiff(m, e) => {
477            let Some(vec) = Moo::unwrap_or_clone(e.clone()).unwrap_list() else {
478                return Err(RuleNotApplicable);
479            };
480
481            let mut consts: HashSet<i32> = HashSet::new();
482
483            // check for duplicate constant values which would fail the constraint
484            for expr in vec {
485                if let Expr::Atomic(_, Atom::Literal(Lit::Int(x))) = expr
486                    && !consts.insert(x)
487                {
488                    return Ok(Reduction::pure(Expr::Atomic(
489                        m.clone(),
490                        Atom::Literal(Lit::Bool(false)),
491                    )));
492                }
493            }
494
495            // nothing has changed
496            Err(RuleNotApplicable)
497        }
498        Expr::Neg(_, _) => Err(RuleNotApplicable),
499        Expr::AuxDeclaration(_, _, _) => Err(RuleNotApplicable),
500        Expr::UnsafeMod(_, _, _) => Err(RuleNotApplicable),
501        Expr::SafeMod(_, _, _) => Err(RuleNotApplicable),
502        Expr::UnsafePow(_, _, _) => Err(RuleNotApplicable),
503        Expr::SafePow(_, _, _) => Err(RuleNotApplicable),
504        Expr::Minus(_, _, _) => Err(RuleNotApplicable),
505
506        // As these are in a low level solver form, I'm assuming that these have already been
507        // simplified and partially evaluated.
508        Expr::FlatAllDiff(_, _) => Err(RuleNotApplicable),
509        Expr::FlatAbsEq(_, _, _) => Err(RuleNotApplicable),
510        Expr::FlatIneq(_, _, _, _) => Err(RuleNotApplicable),
511        Expr::FlatMinusEq(_, _, _) => Err(RuleNotApplicable),
512        Expr::FlatProductEq(_, _, _, _) => Err(RuleNotApplicable),
513        Expr::FlatSumLeq(_, _, _) => Err(RuleNotApplicable),
514        Expr::FlatSumGeq(_, _, _) => Err(RuleNotApplicable),
515        Expr::FlatWatchedLiteral(_, _, _) => Err(RuleNotApplicable),
516        Expr::FlatWeightedSumLeq(_, _, _, _) => Err(RuleNotApplicable),
517        Expr::FlatWeightedSumGeq(_, _, _, _) => Err(RuleNotApplicable),
518        Expr::MinionDivEqUndefZero(_, _, _, _) => Err(RuleNotApplicable),
519        Expr::MinionModuloEqUndefZero(_, _, _, _) => Err(RuleNotApplicable),
520        Expr::MinionPow(_, _, _, _) => Err(RuleNotApplicable),
521        Expr::MinionReify(_, _, _) => Err(RuleNotApplicable),
522        Expr::MinionReifyImply(_, _, _) => Err(RuleNotApplicable),
523        Expr::MinionWInIntervalSet(_, _, _) => Err(RuleNotApplicable),
524        Expr::MinionWInSet(_, _, _) => Err(RuleNotApplicable),
525        Expr::MinionElementOne(_, _, _, _) => Err(RuleNotApplicable),
526        Expr::SATInt(_, _, _, _) => Err(RuleNotApplicable),
527        Expr::PairwiseSum(_, _, _) => Err(RuleNotApplicable),
528        Expr::PairwiseProduct(_, _, _) => Err(RuleNotApplicable),
529        Expr::Defined(_, _) => todo!(),
530        Expr::Range(_, _) => todo!(),
531        Expr::Image(_, _, _) => todo!(),
532        Expr::ImageSet(_, _, _) => todo!(),
533        Expr::PreImage(_, _, _) => todo!(),
534        Expr::Inverse(_, _, _) => todo!(),
535        Expr::Restrict(_, _, _) => todo!(),
536        Expr::LexLt(_, _, _) => Err(RuleNotApplicable),
537        Expr::LexLeq(_, _, _) => Err(RuleNotApplicable),
538        Expr::LexGt(_, _, _) => Err(RuleNotApplicable),
539        Expr::LexGeq(_, _, _) => Err(RuleNotApplicable),
540        Expr::FlatLexLt(_, _, _) => Err(RuleNotApplicable),
541        Expr::FlatLexLeq(_, _, _) => Err(RuleNotApplicable),
542    }
543}
544
545/// Checks for tautologies involving pairs of terms inside an or, returning true if one is found.
546///
547/// This applies the following rules:
548///
549/// ```text
550/// (p->q) \/ (q->p) ~> true    [totality of implication]
551/// (p->q) \/ (p-> !q) ~> true  [conditional excluded middle]
552/// ```
553///
554fn check_pairwise_or_tautologies(or_terms: &[Expr]) -> bool {
555    // Collect terms that are structurally identical to the rule input.
556    // Then, try the rules on these terms, also checking the other conditions of the rules.
557
558    // stores (p,q) in p -> q
559    let mut p_implies_q: Vec<(&Expr, &Expr)> = vec![];
560
561    // stores (p,q) in p -> !q
562    let mut p_implies_not_q: Vec<(&Expr, &Expr)> = vec![];
563
564    for term in or_terms.iter() {
565        if let Expr::Imply(_, p, q) = term {
566            // we use identical_atom_to for equality later on, so these sets are mutually exclusive.
567            //
568            // in general however, p -> !q would be in p_implies_q as (p,!q)
569            if let Expr::Not(_, q_1) = q.as_ref() {
570                p_implies_not_q.push((p.as_ref(), q_1.as_ref()));
571            } else {
572                p_implies_q.push((p.as_ref(), q.as_ref()));
573            }
574        }
575    }
576
577    // `(p->q) \/ (q->p) ~> true    [totality of implication]`
578    for ((p1, q1), (q2, p2)) in iproduct!(p_implies_q.iter(), p_implies_q.iter()) {
579        if p1.identical_atom_to(p2) && q1.identical_atom_to(q2) {
580            return true;
581        }
582    }
583
584    // `(p->q) \/ (p-> !q) ~> true`    [conditional excluded middle]
585    for ((p1, q1), (p2, q2)) in iproduct!(p_implies_q.iter(), p_implies_not_q.iter()) {
586        if p1.identical_atom_to(p2) && q1.identical_atom_to(q2) {
587            return true;
588        }
589    }
590
591    false
592}