conjure_core/rules/
partial_eval.rs

1use std::collections::HashSet;
2
3use conjure_macros::register_rule;
4use itertools::iproduct;
5
6use crate::ast::SymbolTable;
7use crate::into_matrix_expr;
8use crate::rule_engine::{ApplicationResult, Reduction};
9use crate::{
10    ast::{Atom, Expression as Expr, Literal as Lit, Literal::*},
11    metadata::Metadata,
12};
13
14#[register_rule(("Base",9000))]
15fn partial_evaluator(expr: &Expr, _: &SymbolTable) -> ApplicationResult {
16    use conjure_core::rule_engine::ApplicationError::RuleNotApplicable;
17    use Expr::*;
18
19    // NOTE: If nothing changes, we must return RuleNotApplicable, or the rewriter will try this
20    // rule infinitely!
21    // This is why we always check whether we found a constant or not.
22    match expr.clone() {
23        AbstractLiteral(_, _) => Err(RuleNotApplicable),
24        Comprehension(_, _) => Err(RuleNotApplicable),
25        DominanceRelation(_, _) => Err(RuleNotApplicable),
26        FromSolution(_, _) => Err(RuleNotApplicable),
27        UnsafeIndex(_, _, _) => Err(RuleNotApplicable),
28        UnsafeSlice(_, _, _) => Err(RuleNotApplicable),
29        SafeIndex(_, _, _) => Err(RuleNotApplicable),
30        SafeSlice(_, _, _) => Err(RuleNotApplicable),
31        InDomain(_, _, _) => Err(RuleNotApplicable),
32        Bubble(_, _, _) => Err(RuleNotApplicable),
33        Atomic(_, _) => Err(RuleNotApplicable),
34        Scope(_, _) => Err(RuleNotApplicable),
35        Abs(m, e) => match *e {
36            Neg(_, inner) => Ok(Reduction::pure(Abs(m, inner))),
37            _ => Err(RuleNotApplicable),
38        },
39        Sum(m, vec) => {
40            let vec = vec.unwrap_list().ok_or(RuleNotApplicable)?;
41            let mut acc = 0;
42            let mut n_consts = 0;
43            let mut new_vec: Vec<Expr> = Vec::new();
44            for expr in vec {
45                if let Expr::Atomic(_, Atom::Literal(Int(x))) = expr {
46                    acc += x;
47                    n_consts += 1;
48                } else {
49                    new_vec.push(expr);
50                }
51            }
52            if acc != 0 {
53                new_vec.push(Expr::Atomic(Default::default(), Atom::Literal(Int(acc))));
54            }
55
56            if n_consts <= 1 {
57                Err(RuleNotApplicable)
58            } else {
59                Ok(Reduction::pure(Sum(
60                    m,
61                    Box::new(into_matrix_expr![new_vec]),
62                )))
63            }
64        }
65
66        Product(m, vec) => {
67            let mut acc = 1;
68            let mut n_consts = 0;
69            let mut new_vec: Vec<Expr> = Vec::new();
70            for expr in vec {
71                if let Expr::Atomic(_, Atom::Literal(Int(x))) = expr {
72                    acc *= x;
73                    n_consts += 1;
74                } else {
75                    new_vec.push(expr);
76                }
77            }
78
79            if n_consts == 0 {
80                return Err(RuleNotApplicable);
81            }
82
83            new_vec.push(Expr::Atomic(Default::default(), Atom::Literal(Int(acc))));
84            let new_product = Product(m, new_vec);
85
86            if acc == 0 {
87                // if safe, 0 * exprs ~> 0
88                // otherwise, just return 0* exprs
89                if new_product.is_safe() {
90                    Ok(Reduction::pure(Expr::Atomic(
91                        Default::default(),
92                        Atom::Literal(Int(0)),
93                    )))
94                } else {
95                    Ok(Reduction::pure(new_product))
96                }
97            } else if n_consts == 1 {
98                // acc !=0, only one constant
99                Err(RuleNotApplicable)
100            } else {
101                // acc !=0, multiple constants found
102                Ok(Reduction::pure(new_product))
103            }
104        }
105
106        Min(m, e) => {
107            let Some(vec) = e.unwrap_list() else {
108                return Err(RuleNotApplicable);
109            };
110            let mut acc: Option<i32> = None;
111            let mut n_consts = 0;
112            let mut new_vec: Vec<Expr> = Vec::new();
113            for expr in vec {
114                if let Expr::Atomic(_, Atom::Literal(Int(x))) = expr {
115                    n_consts += 1;
116                    acc = match acc {
117                        Some(i) => {
118                            if i > x {
119                                Some(x)
120                            } else {
121                                Some(i)
122                            }
123                        }
124                        None => Some(x),
125                    };
126                } else {
127                    new_vec.push(expr);
128                }
129            }
130
131            if let Some(i) = acc {
132                new_vec.push(Expr::Atomic(Default::default(), Atom::Literal(Int(i))));
133            }
134
135            if n_consts <= 1 {
136                Err(RuleNotApplicable)
137            } else {
138                Ok(Reduction::pure(Min(
139                    m,
140                    Box::new(into_matrix_expr![new_vec]),
141                )))
142            }
143        }
144
145        Max(m, e) => {
146            let Some(vec) = e.unwrap_list() else {
147                return Err(RuleNotApplicable);
148            };
149
150            let mut acc: Option<i32> = None;
151            let mut n_consts = 0;
152            let mut new_vec: Vec<Expr> = Vec::new();
153            for expr in vec {
154                if let Expr::Atomic(_, Atom::Literal(Int(x))) = expr {
155                    n_consts += 1;
156                    acc = match acc {
157                        Some(i) => {
158                            if i < x {
159                                Some(x)
160                            } else {
161                                Some(i)
162                            }
163                        }
164                        None => Some(x),
165                    };
166                } else {
167                    new_vec.push(expr);
168                }
169            }
170
171            if let Some(i) = acc {
172                new_vec.push(Expr::Atomic(Default::default(), Atom::Literal(Int(i))));
173            }
174
175            if n_consts <= 1 {
176                Err(RuleNotApplicable)
177            } else {
178                Ok(Reduction::pure(Max(
179                    m,
180                    Box::new(into_matrix_expr![new_vec]),
181                )))
182            }
183        }
184        Not(_, _) => Err(RuleNotApplicable),
185        Or(m, e) => {
186            let Some(terms) = e.unwrap_list() else {
187                return Err(RuleNotApplicable);
188            };
189
190            let mut has_changed = false;
191
192            // 2. boolean literals
193            let mut new_terms = vec![];
194            for expr in terms {
195                if let Expr::Atomic(_, Atom::Literal(Bool(x))) = expr {
196                    has_changed = true;
197
198                    // true ~~> entire or is true
199                    // false ~~> remove false from the or
200                    if x {
201                        return Ok(Reduction::pure(true.into()));
202                    }
203                } else {
204                    new_terms.push(expr);
205                }
206            }
207
208            // 2. check pairwise tautologies.
209            if check_pairwise_or_tautologies(&new_terms) {
210                return Ok(Reduction::pure(true.into()));
211            }
212
213            // 3. empty or ~~> false
214            if new_terms.is_empty() {
215                return Ok(Reduction::pure(false.into()));
216            }
217
218            if !has_changed {
219                return Err(RuleNotApplicable);
220            }
221
222            Ok(Reduction::pure(Or(
223                m,
224                Box::new(into_matrix_expr![new_terms]),
225            )))
226        }
227        And(_, e) => {
228            let Some(vec) = e.unwrap_list() else {
229                return Err(RuleNotApplicable);
230            };
231            let mut new_vec: Vec<Expr> = Vec::new();
232            let mut has_const: bool = false;
233            for expr in vec {
234                if let Expr::Atomic(_, Atom::Literal(Bool(x))) = expr {
235                    has_const = true;
236                    if !x {
237                        return Ok(Reduction::pure(Atomic(
238                            Default::default(),
239                            Atom::Literal(Bool(false)),
240                        )));
241                    }
242                } else {
243                    new_vec.push(expr);
244                }
245            }
246
247            if !has_const {
248                Err(RuleNotApplicable)
249            } else {
250                Ok(Reduction::pure(Expr::And(
251                    Metadata::new(),
252                    Box::new(into_matrix_expr![new_vec]),
253                )))
254            }
255        }
256
257        // similar to And, but booleans are returned wrapped in Root.
258        Root(_, es) => {
259            match es.as_slice() {
260                [] => Err(RuleNotApplicable),
261                // want to unwrap nested ands
262                [Expr::And(_, _)] => Ok(()),
263                // root([true]) / root([false]) are already evaluated
264                [_] => Err(RuleNotApplicable),
265                [_, _, ..] => Ok(()),
266            }?;
267
268            let mut new_vec: Vec<Expr> = Vec::new();
269            let mut has_changed: bool = false;
270            for expr in es {
271                match expr {
272                    Expr::Atomic(_, Atom::Literal(Bool(x))) => {
273                        has_changed = true;
274                        if !x {
275                            // false
276                            return Ok(Reduction::pure(Root(
277                                Metadata::new(),
278                                vec![Atomic(Default::default(), Atom::Literal(Bool(false)))],
279                            )));
280                        }
281                        // remove trues
282                    }
283
284                    // flatten ands in root
285                    Expr::And(_, ref vecs) => match vecs.clone().unwrap_list() {
286                        Some(mut list) => {
287                            has_changed = true;
288                            new_vec.append(&mut list);
289                        }
290                        None => new_vec.push(expr),
291                    },
292                    _ => new_vec.push(expr),
293                }
294            }
295
296            if !has_changed {
297                Err(RuleNotApplicable)
298            } else {
299                if new_vec.is_empty() {
300                    new_vec.push(true.into());
301                }
302                Ok(Reduction::pure(Expr::Root(Metadata::new(), new_vec)))
303            }
304        }
305        Imply(_m, x, y) => {
306            if let Expr::Atomic(_, Atom::Literal(Lit::Bool(x))) = *x {
307                if x {
308                    // (true) -> y ~~> y
309                    return Ok(Reduction::pure(*y));
310                } else {
311                    // (false) -> y ~~> true
312                    return Ok(Reduction::pure(Expr::Atomic(Metadata::new(), true.into())));
313                }
314            };
315
316            // reflexivity: p -> p ~> true
317
318            // instead of checking syntactic equivalence of a possibly deep expression,
319            // let identical-CSE turn them into identical variables first. Then, check if they are
320            // identical variables.
321
322            if x.identical_atom_to(y.as_ref()) {
323                return Ok(Reduction::pure(true.into()));
324            }
325
326            Err(RuleNotApplicable)
327        }
328        Eq(_, _, _) => Err(RuleNotApplicable),
329        Neq(_, _, _) => Err(RuleNotApplicable),
330        Geq(_, _, _) => Err(RuleNotApplicable),
331        Leq(_, _, _) => Err(RuleNotApplicable),
332        Gt(_, _, _) => Err(RuleNotApplicable),
333        Lt(_, _, _) => Err(RuleNotApplicable),
334        SafeDiv(_, _, _) => Err(RuleNotApplicable),
335        UnsafeDiv(_, _, _) => Err(RuleNotApplicable),
336        AllDiff(m, e) => {
337            let Some(vec) = e.unwrap_list() else {
338                return Err(RuleNotApplicable);
339            };
340
341            let mut consts: HashSet<i32> = HashSet::new();
342
343            // check for duplicate constant values which would fail the constraint
344            for expr in vec {
345                if let Expr::Atomic(_, Atom::Literal(Int(x))) = expr {
346                    if !consts.insert(x) {
347                        return Ok(Reduction::pure(Expr::Atomic(m, Atom::Literal(Bool(false)))));
348                    }
349                }
350            }
351
352            // nothing has changed
353            Err(RuleNotApplicable)
354        }
355        Neg(_, _) => Err(RuleNotApplicable),
356        AuxDeclaration(_, _, _) => Err(RuleNotApplicable),
357        UnsafeMod(_, _, _) => Err(RuleNotApplicable),
358        SafeMod(_, _, _) => Err(RuleNotApplicable),
359        UnsafePow(_, _, _) => Err(RuleNotApplicable),
360        SafePow(_, _, _) => Err(RuleNotApplicable),
361        Minus(_, _, _) => Err(RuleNotApplicable),
362
363        // As these are in a low level solver form, I'm assuming that these have already been
364        // simplified and partially evaluated.
365        FlatAllDiff(_, _) => Err(RuleNotApplicable),
366        FlatAbsEq(_, _, _) => Err(RuleNotApplicable),
367        FlatIneq(_, _, _, _) => Err(RuleNotApplicable),
368        FlatMinusEq(_, _, _) => Err(RuleNotApplicable),
369        FlatProductEq(_, _, _, _) => Err(RuleNotApplicable),
370        FlatSumLeq(_, _, _) => Err(RuleNotApplicable),
371        FlatSumGeq(_, _, _) => Err(RuleNotApplicable),
372        FlatWatchedLiteral(_, _, _) => Err(RuleNotApplicable),
373        FlatWeightedSumLeq(_, _, _, _) => Err(RuleNotApplicable),
374        FlatWeightedSumGeq(_, _, _, _) => Err(RuleNotApplicable),
375        MinionDivEqUndefZero(_, _, _, _) => Err(RuleNotApplicable),
376        MinionModuloEqUndefZero(_, _, _, _) => Err(RuleNotApplicable),
377        MinionPow(_, _, _, _) => Err(RuleNotApplicable),
378        MinionReify(_, _, _) => Err(RuleNotApplicable),
379        MinionReifyImply(_, _, _) => Err(RuleNotApplicable),
380        MinionWInIntervalSet(_, _, _) => Err(RuleNotApplicable),
381        MinionElementOne(_, _, _, _) => Err(RuleNotApplicable),
382    }
383}
384
385/// Checks for tautologies involving pairs of terms inside an or, returning true if one is found.
386///
387/// This applies the following rules:
388///
389/// ```text
390/// (p->q) \/ (q->p) ~> true    [totality of implication]
391/// (p->q) \/ (p-> !q) ~> true  [conditional excluded middle]
392/// ```
393///
394fn check_pairwise_or_tautologies(or_terms: &[Expr]) -> bool {
395    // Collect terms that are structurally identical to the rule input.
396    // Then, try the rules on these terms, also checking the other conditions of the rules.
397
398    // stores (p,q) in p -> q
399    let mut p_implies_q: Vec<(&Expr, &Expr)> = vec![];
400
401    // stores (p,q) in p -> !q
402    let mut p_implies_not_q: Vec<(&Expr, &Expr)> = vec![];
403
404    for term in or_terms.iter() {
405        if let Expr::Imply(_, p, q) = term {
406            // we use identical_atom_to for equality later on, so these sets are mutually exclusive.
407            //
408            // in general however, p -> !q would be in p_implies_q as (p,!q)
409            if let Expr::Not(_, q_1) = q.as_ref() {
410                p_implies_not_q.push((p.as_ref(), q_1.as_ref()));
411            } else {
412                p_implies_q.push((p.as_ref(), q.as_ref()));
413            }
414        }
415    }
416
417    // `(p->q) \/ (q->p) ~> true    [totality of implication]`
418    for ((p1, q1), (q2, p2)) in iproduct!(p_implies_q.iter(), p_implies_q.iter()) {
419        if p1.identical_atom_to(p2) && q1.identical_atom_to(q2) {
420            return true;
421        }
422    }
423
424    // `(p->q) \/ (p-> !q) ~> true`    [conditional excluded middle]
425    for ((p1, q1), (p2, q2)) in iproduct!(p_implies_q.iter(), p_implies_not_q.iter()) {
426        if p1.identical_atom_to(p2) && q1.identical_atom_to(q2) {
427            return true;
428        }
429    }
430
431    false
432}