conjure_core/rules/
constant_eval.rs

1#![allow(dead_code)]
2use std::collections::HashSet;
3
4use conjure_core::ast::{Atom, Expression as Expr, Literal as Lit};
5use conjure_core::metadata::Metadata;
6use conjure_core::rule_engine::{
7    register_rule, register_rule_set, ApplicationError, ApplicationError::RuleNotApplicable,
8    ApplicationResult, Reduction,
9};
10use itertools::izip;
11
12use crate::ast::SymbolTable;
13
14register_rule_set!("Constant", ());
15
16#[register_rule(("Constant", 9001))]
17fn apply_eval_constant(expr: &Expr, _: &SymbolTable) -> ApplicationResult {
18    if let Expr::Atomic(_, Atom::Literal(_)) = expr {
19        return Err(ApplicationError::RuleNotApplicable);
20    }
21    eval_constant(expr)
22        .map(|c| Reduction::pure(Expr::Atomic(Metadata::new(), Atom::Literal(c))))
23        .ok_or(ApplicationError::RuleNotApplicable)
24}
25
26/// Simplify an expression to a constant if possible
27/// Returns:
28/// `None` if the expression cannot be simplified to a constant (e.g. if it contains a variable)
29/// `Some(Const)` if the expression can be simplified to a constant
30pub fn eval_constant(expr: &Expr) -> Option<Lit> {
31    match expr {
32        Expr::AbstractLiteral(_, _) => None,
33        // `fromSolution()` pulls a literal value from last found solution
34        Expr::FromSolution(_, _) => None,
35        // Same as Expr::Root, we should not replace the dominance relation with a constant
36        Expr::DominanceRelation(_, _) => None,
37        Expr::UnsafeIndex(_, _, _) => None,
38        // handled elsewhere
39        Expr::SafeIndex(_, _, _) => None,
40        Expr::UnsafeSlice(_, _, _) => None,
41        // handled elsewhere
42        Expr::SafeSlice(_, _, _) => None,
43        Expr::InDomain(_, e, domain) => {
44            let Expr::Atomic(_, Atom::Literal(lit)) = e.as_ref() else {
45                return None;
46            };
47
48            domain.contains(lit).map(Into::into)
49        }
50        Expr::Atomic(_, Atom::Literal(c)) => Some(c.clone()),
51        Expr::Atomic(_, Atom::Reference(_c)) => None,
52        Expr::Abs(_, e) => un_op::<i32, i32>(|a| a.abs(), e).map(Lit::Int),
53        Expr::Eq(_, a, b) => bin_op::<i32, bool>(|a, b| a == b, a, b)
54            .or_else(|| bin_op::<bool, bool>(|a, b| a == b, a, b))
55            .map(Lit::Bool),
56        Expr::Neq(_, a, b) => bin_op::<i32, bool>(|a, b| a != b, a, b).map(Lit::Bool),
57        Expr::Lt(_, a, b) => bin_op::<i32, bool>(|a, b| a < b, a, b).map(Lit::Bool),
58        Expr::Gt(_, a, b) => bin_op::<i32, bool>(|a, b| a > b, a, b).map(Lit::Bool),
59        Expr::Leq(_, a, b) => bin_op::<i32, bool>(|a, b| a <= b, a, b).map(Lit::Bool),
60        Expr::Geq(_, a, b) => bin_op::<i32, bool>(|a, b| a >= b, a, b).map(Lit::Bool),
61
62        Expr::Not(_, expr) => un_op::<bool, bool>(|e| !e, expr).map(Lit::Bool),
63
64        Expr::And(_, e) => {
65            vec_lit_op::<bool, bool>(|e| e.iter().all(|&e| e), e.as_ref()).map(Lit::Bool)
66        }
67        // this is done elsewhere instead - root should return a new root with a literal inside it,
68        // not a literal
69        Expr::Root(_, _) => None,
70        Expr::Or(_, e) => {
71            vec_lit_op::<bool, bool>(|e| e.iter().any(|&e| e), e.as_ref()).map(Lit::Bool)
72        }
73        Expr::Imply(_, box1, box2) => {
74            let a: &Atom = (&**box1).try_into().ok()?;
75            let b: &Atom = (&**box2).try_into().ok()?;
76
77            let a: bool = a.try_into().ok()?;
78            let b: bool = b.try_into().ok()?;
79
80            if a {
81                // true -> b ~> b
82                Some(Lit::Bool(b))
83            } else {
84                // false -> b ~> true
85                Some(Lit::Bool(true))
86            }
87        }
88
89        Expr::Sum(_, exprs) => vec_op::<i32, i32>(|e| e.iter().sum(), exprs).map(Lit::Int),
90        Expr::Product(_, exprs) => vec_op::<i32, i32>(|e| e.iter().product(), exprs).map(Lit::Int),
91
92        Expr::FlatIneq(_, a, b, c) => {
93            let a: i32 = a.try_into().ok()?;
94            let b: i32 = b.try_into().ok()?;
95            let c: i32 = c.try_into().ok()?;
96
97            Some(Lit::Bool(a <= b + c))
98        }
99
100        Expr::FlatSumGeq(_, exprs, a) => {
101            let sum = exprs.iter().try_fold(0, |acc, atom: &Atom| {
102                let n: i32 = atom.try_into().ok()?;
103                let acc = acc + n;
104                Some(acc)
105            })?;
106
107            Some(Lit::Bool(sum >= a.try_into().ok()?))
108        }
109        Expr::FlatSumLeq(_, exprs, a) => {
110            let sum = exprs.iter().try_fold(0, |acc, atom: &Atom| {
111                let n: i32 = atom.try_into().ok()?;
112                let acc = acc + n;
113                Some(acc)
114            })?;
115
116            Some(Lit::Bool(sum >= a.try_into().ok()?))
117        }
118        Expr::Min(_, e) => {
119            opt_vec_lit_op::<i32, i32>(|e| e.iter().min().copied(), e.as_ref()).map(Lit::Int)
120        }
121        Expr::Max(_, e) => {
122            opt_vec_lit_op::<i32, i32>(|e| e.iter().max().copied(), e.as_ref()).map(Lit::Int)
123        }
124        Expr::UnsafeDiv(_, a, b) | Expr::SafeDiv(_, a, b) => {
125            if unwrap_expr::<i32>(b)? == 0 {
126                return None;
127            }
128            bin_op::<i32, i32>(|a, b| ((a as f32) / (b as f32)).floor() as i32, a, b).map(Lit::Int)
129        }
130        Expr::UnsafeMod(_, a, b) | Expr::SafeMod(_, a, b) => {
131            if unwrap_expr::<i32>(b)? == 0 {
132                return None;
133            }
134            bin_op::<i32, i32>(|a, b| a - b * (a as f32 / b as f32).floor() as i32, a, b)
135                .map(Lit::Int)
136        }
137        Expr::MinionDivEqUndefZero(_, a, b, c) => {
138            // div always rounds down
139            let a: i32 = a.try_into().ok()?;
140            let b: i32 = b.try_into().ok()?;
141            let c: i32 = c.try_into().ok()?;
142
143            if b == 0 {
144                return None;
145            }
146
147            let a = a as f32;
148            let b = b as f32;
149            let div: i32 = (a / b).floor() as i32;
150            Some(Lit::Bool(div == c))
151        }
152        Expr::Bubble(_, a, b) => bin_op::<bool, bool>(|a, b| a && b, a, b).map(Lit::Bool),
153
154        Expr::MinionReify(_, a, b) => {
155            let result = eval_constant(a)?;
156
157            let result: bool = result.try_into().ok()?;
158            let b: bool = b.try_into().ok()?;
159
160            Some(Lit::Bool(b == result))
161        }
162
163        Expr::MinionReifyImply(_, a, b) => {
164            let result = eval_constant(a)?;
165
166            let result: bool = result.try_into().ok()?;
167            let b: bool = b.try_into().ok()?;
168
169            if b {
170                Some(Lit::Bool(result))
171            } else {
172                Some(Lit::Bool(true))
173            }
174        }
175        Expr::MinionModuloEqUndefZero(_, a, b, c) => {
176            // From Savile Row. Same semantics as division.
177            //
178            //   a - (b * floor(a/b))
179            //
180            // We don't use % as it has the same semantics as /. We don't use / as we want to round
181            // down instead, not towards zero.
182
183            let a: i32 = a.try_into().ok()?;
184            let b: i32 = b.try_into().ok()?;
185            let c: i32 = c.try_into().ok()?;
186
187            if b == 0 {
188                return None;
189            }
190
191            let modulo = a - b * (a as f32 / b as f32).floor() as i32;
192            Some(Lit::Bool(modulo == c))
193        }
194
195        Expr::MinionPow(_, a, b, c) => {
196            // only available for positive a b c
197
198            let a: i32 = a.try_into().ok()?;
199            let b: i32 = b.try_into().ok()?;
200            let c: i32 = c.try_into().ok()?;
201
202            if a <= 0 {
203                return None;
204            }
205
206            if b <= 0 {
207                return None;
208            }
209
210            if c <= 0 {
211                return None;
212            }
213
214            Some(Lit::Bool(a ^ b == c))
215        }
216
217        Expr::AllDiff(_, e) => {
218            let es = e.clone().unwrap_list()?;
219            let mut lits: HashSet<Lit> = HashSet::new();
220            for expr in es {
221                let Expr::Atomic(_, Atom::Literal(x)) = expr else {
222                    return None;
223                };
224                match x {
225                    Lit::Int(_) | Lit::Bool(_) => {
226                        if lits.contains(&x) {
227                            return Some(Lit::Bool(false));
228                        } else {
229                            lits.insert(x.clone());
230                        }
231                    }
232                    Lit::AbstractLiteral(_) => return None, // Reject AbstractLiteral cases
233                }
234            }
235            Some(Lit::Bool(true))
236        }
237        Expr::FlatAllDiff(_, es) => {
238            let mut lits: HashSet<Lit> = HashSet::new();
239            for atom in es {
240                let Atom::Literal(x) = atom else {
241                    return None;
242                };
243
244                match x {
245                    Lit::Int(_) | Lit::Bool(_) => {
246                        if lits.contains(x) {
247                            return Some(Lit::Bool(false));
248                        } else {
249                            lits.insert(x.clone());
250                        }
251                    }
252                    Lit::AbstractLiteral(_) => return None, // Reject AbstractLiteral cases
253                }
254            }
255            Some(Lit::Bool(true))
256        }
257        Expr::FlatWatchedLiteral(_, _, _) => None,
258        Expr::AuxDeclaration(_, _, _) => None,
259        Expr::Neg(_, a) => {
260            let a: &Atom = a.try_into().ok()?;
261            let a: i32 = a.try_into().ok()?;
262            Some(Lit::Int(-a))
263        }
264        Expr::Minus(_, a, b) => {
265            let a: &Atom = a.try_into().ok()?;
266            let a: i32 = a.try_into().ok()?;
267
268            let b: &Atom = b.try_into().ok()?;
269            let b: i32 = b.try_into().ok()?;
270
271            Some(Lit::Int(a - b))
272        }
273        Expr::FlatMinusEq(_, a, b) => {
274            let a: i32 = a.try_into().ok()?;
275            let b: i32 = b.try_into().ok()?;
276            Some(Lit::Bool(a == -b))
277        }
278        Expr::FlatProductEq(_, a, b, c) => {
279            let a: i32 = a.try_into().ok()?;
280            let b: i32 = b.try_into().ok()?;
281            let c: i32 = c.try_into().ok()?;
282            Some(Lit::Bool(a * b == c))
283        }
284        Expr::FlatWeightedSumLeq(_, cs, vs, total) => {
285            let cs: Vec<i32> = cs
286                .iter()
287                .map(|x| TryInto::<i32>::try_into(x).ok())
288                .collect::<Option<Vec<i32>>>()?;
289            let vs: Vec<i32> = vs
290                .iter()
291                .map(|x| TryInto::<i32>::try_into(x).ok())
292                .collect::<Option<Vec<i32>>>()?;
293            let total: i32 = total.try_into().ok()?;
294
295            let sum: i32 = izip!(cs, vs).fold(0, |acc, (c, v)| acc + (c * v));
296
297            Some(Lit::Bool(sum <= total))
298        }
299
300        Expr::FlatWeightedSumGeq(_, cs, vs, total) => {
301            let cs: Vec<i32> = cs
302                .iter()
303                .map(|x| TryInto::<i32>::try_into(x).ok())
304                .collect::<Option<Vec<i32>>>()?;
305            let vs: Vec<i32> = vs
306                .iter()
307                .map(|x| TryInto::<i32>::try_into(x).ok())
308                .collect::<Option<Vec<i32>>>()?;
309            let total: i32 = total.try_into().ok()?;
310
311            let sum: i32 = izip!(cs, vs).fold(0, |acc, (c, v)| acc + (c * v));
312
313            Some(Lit::Bool(sum >= total))
314        }
315        Expr::FlatAbsEq(_, x, y) => {
316            let x: i32 = x.try_into().ok()?;
317            let y: i32 = y.try_into().ok()?;
318
319            Some(Lit::Bool(x == y.abs()))
320        }
321
322        Expr::UnsafePow(_, a, b) | Expr::SafePow(_, a, b) => {
323            let a: &Atom = a.try_into().ok()?;
324            let a: i32 = a.try_into().ok()?;
325
326            let b: &Atom = b.try_into().ok()?;
327            let b: i32 = b.try_into().ok()?;
328
329            if (a != 0 || b != 0) && b >= 0 {
330                Some(Lit::Int(a ^ b))
331            } else {
332                None
333            }
334        }
335        Expr::Scope(_, _) => None,
336    }
337}
338
339/// Evaluate the root expression.
340///
341/// This returns either Expr::Root([true]) or Expr::Root([false]).
342#[register_rule(("Constant", 9001))]
343fn eval_root(expr: &Expr, _: &SymbolTable) -> ApplicationResult {
344    // this is its own rule not part of apply_eval_constant, because root should return a new root
345    // with a literal inside it, not just a literal
346
347    let Expr::Root(_, exprs) = expr else {
348        return Err(RuleNotApplicable);
349    };
350
351    match exprs.len() {
352        0 => Ok(Reduction::pure(Expr::Root(
353            Metadata::new(),
354            vec![true.into()],
355        ))),
356        1 => Err(RuleNotApplicable),
357        _ => {
358            let lit =
359                vec_op::<bool, bool>(|e| e.iter().all(|&e| e), exprs).ok_or(RuleNotApplicable)?;
360
361            Ok(Reduction::pure(Expr::Root(
362                Metadata::new(),
363                vec![lit.into()],
364            )))
365        }
366    }
367}
368
369fn un_op<T, A>(f: fn(T) -> A, a: &Expr) -> Option<A>
370where
371    T: TryFrom<Lit>,
372{
373    let a = unwrap_expr::<T>(a)?;
374    Some(f(a))
375}
376
377fn bin_op<T, A>(f: fn(T, T) -> A, a: &Expr, b: &Expr) -> Option<A>
378where
379    T: TryFrom<Lit>,
380{
381    let a = unwrap_expr::<T>(a)?;
382    let b = unwrap_expr::<T>(b)?;
383    Some(f(a, b))
384}
385
386#[allow(dead_code)]
387fn tern_op<T, A>(f: fn(T, T, T) -> A, a: &Expr, b: &Expr, c: &Expr) -> Option<A>
388where
389    T: TryFrom<Lit>,
390{
391    let a = unwrap_expr::<T>(a)?;
392    let b = unwrap_expr::<T>(b)?;
393    let c = unwrap_expr::<T>(c)?;
394    Some(f(a, b, c))
395}
396
397fn vec_op<T, A>(f: fn(Vec<T>) -> A, a: &[Expr]) -> Option<A>
398where
399    T: TryFrom<Lit>,
400{
401    let a = a.iter().map(unwrap_expr).collect::<Option<Vec<T>>>()?;
402    Some(f(a))
403}
404
405fn vec_lit_op<T, A>(f: fn(Vec<T>) -> A, a: &Expr) -> Option<A>
406where
407    T: TryFrom<Lit>,
408{
409    let a = a.clone().unwrap_list()?;
410    let a = a.iter().map(unwrap_expr).collect::<Option<Vec<T>>>()?;
411    Some(f(a))
412}
413
414fn opt_vec_op<T, A>(f: fn(Vec<T>) -> Option<A>, a: &[Expr]) -> Option<A>
415where
416    T: TryFrom<Lit>,
417{
418    let a = a.iter().map(unwrap_expr).collect::<Option<Vec<T>>>()?;
419    f(a)
420}
421
422fn opt_vec_lit_op<T, A>(f: fn(Vec<T>) -> Option<A>, a: &Expr) -> Option<A>
423where
424    T: TryFrom<Lit>,
425{
426    let a = a.clone().unwrap_list()?;
427    // FIXME: deal with explicit matrix domains
428    let a = a.iter().map(unwrap_expr).collect::<Option<Vec<T>>>()?;
429    f(a)
430}
431
432#[allow(dead_code)]
433fn flat_op<T, A>(f: fn(Vec<T>, T) -> A, a: &[Expr], b: &Expr) -> Option<A>
434where
435    T: TryFrom<Lit>,
436{
437    let a = a.iter().map(unwrap_expr).collect::<Option<Vec<T>>>()?;
438    let b = unwrap_expr::<T>(b)?;
439    Some(f(a, b))
440}
441
442fn unwrap_expr<T: TryFrom<Lit>>(expr: &Expr) -> Option<T> {
443    let c = eval_constant(expr)?;
444    TryInto::<T>::try_into(c).ok()
445}
446
447#[cfg(test)]
448mod tests {
449    use crate::rules::eval_constant;
450    use conjure_core::ast::{Atom, Expression, Literal};
451
452    #[test]
453    fn div_by_zero() {
454        let expr = Expression::UnsafeDiv(
455            Default::default(),
456            Box::new(Expression::Atomic(
457                Default::default(),
458                Atom::Literal(Literal::Int(1)),
459            )),
460            Box::new(Expression::Atomic(
461                Default::default(),
462                Atom::Literal(Literal::Int(0)),
463            )),
464        );
465        assert_eq!(eval_constant(&expr), None);
466    }
467
468    #[test]
469    fn safediv_by_zero() {
470        let expr = Expression::SafeDiv(
471            Default::default(),
472            Box::new(Expression::Atomic(
473                Default::default(),
474                Atom::Literal(Literal::Int(1)),
475            )),
476            Box::new(Expression::Atomic(
477                Default::default(),
478                Atom::Literal(Literal::Int(0)),
479            )),
480        );
481        assert_eq!(eval_constant(&expr), None);
482    }
483}