Skip to main content

conjure_cp_core/ast/
eval.rs

1#![allow(dead_code)]
2use crate::ast::{AbstractLiteral, Atom, Expression as Expr, Literal as Lit, Metadata, matrix};
3use crate::into_matrix;
4use itertools::{Itertools as _, izip};
5use std::cmp::Ordering as CmpOrdering;
6use std::collections::HashSet;
7
8/// Simplify an expression to a constant if possible
9/// Returns:
10/// `None` if the expression cannot be simplified to a constant (e.g. if it contains a variable)
11/// `Some(Const)` if the expression can be simplified to a constant
12pub fn eval_constant(expr: &Expr) -> Option<Lit> {
13    match expr {
14        Expr::Supset(_, a, b) => match (a.as_ref(), b.as_ref()) {
15            (
16                Expr::Atomic(_, Atom::Literal(Lit::AbstractLiteral(AbstractLiteral::Set(a)))),
17                Expr::Atomic(_, Atom::Literal(Lit::AbstractLiteral(AbstractLiteral::Set(b)))),
18            ) => {
19                let a_set: HashSet<Lit> = a.iter().cloned().collect();
20                let b_set: HashSet<Lit> = b.iter().cloned().collect();
21
22                if a_set.difference(&b_set).count() > 0 {
23                    Some(Lit::Bool(a_set.is_superset(&b_set)))
24                } else {
25                    Some(Lit::Bool(false))
26                }
27            }
28            _ => None,
29        },
30        Expr::SupsetEq(_, a, b) => match (a.as_ref(), b.as_ref()) {
31            (
32                Expr::Atomic(_, Atom::Literal(Lit::AbstractLiteral(AbstractLiteral::Set(a)))),
33                Expr::Atomic(_, Atom::Literal(Lit::AbstractLiteral(AbstractLiteral::Set(b)))),
34            ) => Some(Lit::Bool(
35                a.iter()
36                    .cloned()
37                    .collect::<HashSet<Lit>>()
38                    .is_superset(&b.iter().cloned().collect::<HashSet<Lit>>()),
39            )),
40            _ => None,
41        },
42        Expr::Subset(_, a, b) => match (a.as_ref(), b.as_ref()) {
43            (
44                Expr::Atomic(_, Atom::Literal(Lit::AbstractLiteral(AbstractLiteral::Set(a)))),
45                Expr::Atomic(_, Atom::Literal(Lit::AbstractLiteral(AbstractLiteral::Set(b)))),
46            ) => {
47                let a_set: HashSet<Lit> = a.iter().cloned().collect();
48                let b_set: HashSet<Lit> = b.iter().cloned().collect();
49
50                if b_set.difference(&a_set).count() > 0 {
51                    Some(Lit::Bool(a_set.is_subset(&b_set)))
52                } else {
53                    Some(Lit::Bool(false))
54                }
55            }
56            _ => None,
57        },
58        Expr::SubsetEq(_, a, b) => match (a.as_ref(), b.as_ref()) {
59            (
60                Expr::Atomic(_, Atom::Literal(Lit::AbstractLiteral(AbstractLiteral::Set(a)))),
61                Expr::Atomic(_, Atom::Literal(Lit::AbstractLiteral(AbstractLiteral::Set(b)))),
62            ) => Some(Lit::Bool(
63                a.iter()
64                    .cloned()
65                    .collect::<HashSet<Lit>>()
66                    .is_subset(&b.iter().cloned().collect::<HashSet<Lit>>()),
67            )),
68            _ => None,
69        },
70        Expr::Intersect(_, a, b) => match (a.as_ref(), b.as_ref()) {
71            (
72                Expr::Atomic(_, Atom::Literal(Lit::AbstractLiteral(AbstractLiteral::Set(a)))),
73                Expr::Atomic(_, Atom::Literal(Lit::AbstractLiteral(AbstractLiteral::Set(b)))),
74            ) => {
75                let mut res: Vec<Lit> = Vec::new();
76                for lit in a.iter() {
77                    if b.contains(lit) && !res.contains(lit) {
78                        res.push(lit.clone());
79                    }
80                }
81                Some(Lit::AbstractLiteral(AbstractLiteral::Set(res)))
82            }
83            _ => None,
84        },
85        Expr::Union(_, a, b) => match (a.as_ref(), b.as_ref()) {
86            (
87                Expr::Atomic(_, Atom::Literal(Lit::AbstractLiteral(AbstractLiteral::Set(a)))),
88                Expr::Atomic(_, Atom::Literal(Lit::AbstractLiteral(AbstractLiteral::Set(b)))),
89            ) => {
90                let mut res: Vec<Lit> = Vec::new();
91                for lit in a.iter() {
92                    res.push(lit.clone());
93                }
94                for lit in b.iter() {
95                    if !res.contains(lit) {
96                        res.push(lit.clone());
97                    }
98                }
99                Some(Lit::AbstractLiteral(AbstractLiteral::Set(res)))
100            }
101            _ => None,
102        },
103        Expr::In(_, a, b) => {
104            if let (
105                Expr::Atomic(_, Atom::Literal(Lit::Int(c))),
106                Expr::Atomic(_, Atom::Literal(Lit::AbstractLiteral(AbstractLiteral::Set(d)))),
107            ) = (a.as_ref(), b.as_ref())
108            {
109                for lit in d.iter() {
110                    if let Lit::Int(x) = lit
111                        && c == x
112                    {
113                        return Some(Lit::Bool(true));
114                    }
115                }
116                Some(Lit::Bool(false))
117            } else {
118                None
119            }
120        }
121        Expr::FromSolution(_, _) => None,
122        Expr::DominanceRelation(_, _) => None,
123        Expr::InDomain(_, e, domain) => {
124            let Expr::Atomic(_, Atom::Literal(lit)) = e.as_ref() else {
125                return None;
126            };
127
128            domain.contains(lit).ok().map(Into::into)
129        }
130        Expr::Atomic(_, Atom::Literal(c)) => Some(c.clone()),
131        Expr::Atomic(_, Atom::Reference(reference)) => reference.resolve_constant(),
132        Expr::AbstractLiteral(_, a) => Some(Lit::AbstractLiteral(a.clone().into_literals()?)),
133        Expr::Comprehension(_, _) => None,
134        Expr::AbstractComprehension(_, _) => None,
135        Expr::UnsafeIndex(_, subject, indices) | Expr::SafeIndex(_, subject, indices) => {
136            let subject: Lit = eval_constant(subject.as_ref())?;
137            let indices: Vec<Lit> = indices
138                .iter()
139                .map(eval_constant)
140                .collect::<Option<Vec<Lit>>>()?;
141
142            match subject {
143                Lit::AbstractLiteral(subject @ AbstractLiteral::Matrix(_, _)) => {
144                    matrix::flatten_enumerate(subject)
145                        .find(|(i, _)| i == &indices)
146                        .map(|(_, x)| x)
147                }
148                Lit::AbstractLiteral(subject @ AbstractLiteral::Tuple(_)) => {
149                    let AbstractLiteral::Tuple(elems) = subject else {
150                        return None;
151                    };
152
153                    assert!(indices.len() == 1, "nested tuples not supported yet");
154
155                    let Lit::Int(index) = indices[0].clone() else {
156                        return None;
157                    };
158
159                    if elems.len() < index as usize || index < 1 {
160                        return None;
161                    }
162
163                    // -1 for 0-indexing vs 1-indexing
164                    let item = elems[index as usize - 1].clone();
165
166                    Some(item)
167                }
168                Lit::AbstractLiteral(subject @ AbstractLiteral::Record(_)) => {
169                    let AbstractLiteral::Record(elems) = subject else {
170                        return None;
171                    };
172
173                    assert!(indices.len() == 1, "nested record not supported yet");
174
175                    let Lit::Int(index) = indices[0].clone() else {
176                        return None;
177                    };
178
179                    if elems.len() < index as usize || index < 1 {
180                        return None;
181                    }
182
183                    // -1 for 0-indexing vs 1-indexing
184                    let item = elems[index as usize - 1].clone();
185                    Some(item.value)
186                }
187                _ => None,
188            }
189        }
190        Expr::UnsafeSlice(_, subject, indices) | Expr::SafeSlice(_, subject, indices) => {
191            let subject: Lit = eval_constant(subject.as_ref())?;
192            let Lit::AbstractLiteral(subject @ AbstractLiteral::Matrix(_, _)) = subject else {
193                return None;
194            };
195
196            let hole_dim = indices
197                .iter()
198                .cloned()
199                .position(|x| x.is_none())
200                .expect("slice expression should have a hole dimension");
201
202            let missing_domain = matrix::index_domains(subject.clone())[hole_dim].clone();
203
204            let indices: Vec<Option<Lit>> = indices
205                .iter()
206                .cloned()
207                .map(|x| {
208                    // the outer option represents success of this iterator, the inner the index
209                    // slice.
210                    match x {
211                        Some(x) => eval_constant(&x).map(Some),
212                        None => Some(None),
213                    }
214                })
215                .collect::<Option<Vec<Option<Lit>>>>()?;
216
217            let indices_in_slice: Vec<Vec<Lit>> = missing_domain
218                .values()
219                .ok()?
220                .map(|i| {
221                    let mut indices = indices.clone();
222                    indices[hole_dim] = Some(i);
223                    // These unwraps will only fail if we have multiple holes.
224                    // As this is invalid, panicking is fine.
225                    indices.into_iter().map(|x| x.unwrap()).collect_vec()
226                })
227                .collect_vec();
228
229            // Note: indices_in_slice is not necessarily sorted, so this is the best way.
230            let elems = matrix::flatten_enumerate(subject)
231                .filter(|(i, _)| indices_in_slice.contains(i))
232                .map(|(_, elem)| elem)
233                .collect();
234
235            Some(Lit::AbstractLiteral(into_matrix![elems]))
236        }
237        Expr::Abs(_, e) => un_op::<i32, i32>(|a| a.abs(), e).map(Lit::Int),
238        Expr::Eq(_, a, b) => bin_op::<i32, bool>(|a, b| a == b, a, b)
239            .or_else(|| bin_op::<bool, bool>(|a, b| a == b, a, b))
240            .map(Lit::Bool),
241        Expr::Neq(_, a, b) => bin_op::<i32, bool>(|a, b| a != b, a, b).map(Lit::Bool),
242        Expr::Lt(_, a, b) => bin_op::<i32, bool>(|a, b| a < b, a, b).map(Lit::Bool),
243        Expr::Gt(_, a, b) => bin_op::<i32, bool>(|a, b| a > b, a, b).map(Lit::Bool),
244        Expr::Leq(_, a, b) => bin_op::<i32, bool>(|a, b| a <= b, a, b).map(Lit::Bool),
245        Expr::Geq(_, a, b) => bin_op::<i32, bool>(|a, b| a >= b, a, b).map(Lit::Bool),
246        Expr::Not(_, expr) => un_op::<bool, bool>(|e| !e, expr).map(Lit::Bool),
247        Expr::And(_, e) => {
248            vec_lit_op::<bool, bool>(|e| e.iter().all(|&e| e), e.as_ref()).map(Lit::Bool)
249        }
250        Expr::Table(_, _, _) => None,
251        Expr::NegativeTable(_, _, _) => None,
252        Expr::Root(_, _) => None,
253        Expr::Or(_, es) => {
254            // possibly cheating; definitely should be in partial eval instead
255            for e in (**es).clone().unwrap_list()? {
256                if let Expr::Atomic(_, Atom::Literal(Lit::Bool(true))) = e {
257                    return Some(Lit::Bool(true));
258                };
259            }
260
261            vec_lit_op::<bool, bool>(|e| e.iter().any(|&e| e), es.as_ref()).map(Lit::Bool)
262        }
263        Expr::Imply(_, box1, box2) => {
264            let a: &Atom = (&**box1).try_into().ok()?;
265            let b: &Atom = (&**box2).try_into().ok()?;
266
267            let a: bool = a.try_into().ok()?;
268            let b: bool = b.try_into().ok()?;
269
270            if a {
271                // true -> b ~> b
272                Some(Lit::Bool(b))
273            } else {
274                // false -> b ~> true
275                Some(Lit::Bool(true))
276            }
277        }
278        Expr::Iff(_, box1, box2) => {
279            let a: &Atom = (&**box1).try_into().ok()?;
280            let b: &Atom = (&**box2).try_into().ok()?;
281
282            let a: bool = a.try_into().ok()?;
283            let b: bool = b.try_into().ok()?;
284
285            Some(Lit::Bool(a == b))
286        }
287        Expr::Sum(_, exprs) => vec_lit_op::<i32, i32>(|e| e.iter().sum(), exprs).map(Lit::Int),
288        Expr::Product(_, exprs) => {
289            vec_lit_op::<i32, i32>(|e| e.iter().product(), exprs).map(Lit::Int)
290        }
291        Expr::FlatIneq(_, a, b, c) => {
292            let a: i32 = a.try_into().ok()?;
293            let b: i32 = b.try_into().ok()?;
294            let c: i32 = c.try_into().ok()?;
295
296            Some(Lit::Bool(a <= b + c))
297        }
298        Expr::FlatSumGeq(_, exprs, a) => {
299            let sum = exprs.iter().try_fold(0, |acc, atom: &Atom| {
300                let n: i32 = atom.try_into().ok()?;
301                let acc = acc + n;
302                Some(acc)
303            })?;
304
305            Some(Lit::Bool(sum >= a.try_into().ok()?))
306        }
307        Expr::FlatSumLeq(_, exprs, a) => {
308            let sum = exprs.iter().try_fold(0, |acc, atom: &Atom| {
309                let n: i32 = atom.try_into().ok()?;
310                let acc = acc + n;
311                Some(acc)
312            })?;
313
314            Some(Lit::Bool(sum >= a.try_into().ok()?))
315        }
316        Expr::Min(_, e) => {
317            opt_vec_lit_op::<i32, i32>(|e| e.iter().min().copied(), e.as_ref()).map(Lit::Int)
318        }
319        Expr::Max(_, e) => {
320            opt_vec_lit_op::<i32, i32>(|e| e.iter().max().copied(), e.as_ref()).map(Lit::Int)
321        }
322        Expr::UnsafeDiv(_, a, b) | Expr::SafeDiv(_, a, b) => {
323            if unwrap_expr::<i32>(b)? == 0 {
324                return None;
325            }
326            bin_op::<i32, i32>(|a, b| ((a as f32) / (b as f32)).floor() as i32, a, b).map(Lit::Int)
327        }
328        Expr::UnsafeMod(_, a, b) | Expr::SafeMod(_, a, b) => {
329            if unwrap_expr::<i32>(b)? == 0 {
330                return None;
331            }
332            bin_op::<i32, i32>(|a, b| a - b * (a as f32 / b as f32).floor() as i32, a, b)
333                .map(Lit::Int)
334        }
335        Expr::MinionDivEqUndefZero(_, a, b, c) => {
336            // div always rounds down
337            let a: i32 = a.try_into().ok()?;
338            let b: i32 = b.try_into().ok()?;
339            let c: i32 = c.try_into().ok()?;
340
341            if b == 0 {
342                return None;
343            }
344
345            let a = a as f32;
346            let b = b as f32;
347            let div: i32 = (a / b).floor() as i32;
348            Some(Lit::Bool(div == c))
349        }
350        Expr::Bubble(_, a, b) => bin_op::<bool, bool>(|a, b| a && b, a, b).map(Lit::Bool),
351        Expr::MinionReify(_, a, b) => {
352            let result = eval_constant(a)?;
353
354            let result: bool = result.try_into().ok()?;
355            let b: bool = b.try_into().ok()?;
356
357            Some(Lit::Bool(b == result))
358        }
359        Expr::MinionReifyImply(_, a, b) => {
360            let result = eval_constant(a)?;
361
362            let result: bool = result.try_into().ok()?;
363            let b: bool = b.try_into().ok()?;
364
365            if b {
366                Some(Lit::Bool(result))
367            } else {
368                Some(Lit::Bool(true))
369            }
370        }
371        Expr::MinionModuloEqUndefZero(_, a, b, c) => {
372            // From Savile Row. Same semantics as division.
373            //
374            //   a - (b * floor(a/b))
375            //
376            // We don't use % as it has the same semantics as /. We don't use / as we want to round
377            // down instead, not towards zero.
378
379            let a: i32 = a.try_into().ok()?;
380            let b: i32 = b.try_into().ok()?;
381            let c: i32 = c.try_into().ok()?;
382
383            if b == 0 {
384                return None;
385            }
386
387            let modulo = a - b * (a as f32 / b as f32).floor() as i32;
388            Some(Lit::Bool(modulo == c))
389        }
390        Expr::MinionPow(_, a, b, c) => {
391            // only available for positive a b c
392
393            let a: i32 = a.try_into().ok()?;
394            let b: i32 = b.try_into().ok()?;
395            let c: i32 = c.try_into().ok()?;
396
397            if a <= 0 {
398                return None;
399            }
400
401            if b <= 0 {
402                return None;
403            }
404
405            if c <= 0 {
406                return None;
407            }
408
409            Some(Lit::Bool(a ^ b == c))
410        }
411        Expr::MinionWInSet(_, _, _) => None,
412        Expr::MinionWInIntervalSet(_, x, intervals) => {
413            let x_lit: &Lit = x.try_into().ok()?;
414
415            let x_lit = match x_lit.clone() {
416                Lit::Int(i) => Some(i),
417                Lit::Bool(true) => Some(1),
418                Lit::Bool(false) => Some(0),
419                _ => None,
420            }?;
421
422            let mut intervals = intervals.iter();
423            loop {
424                let Some(lower) = intervals.next() else {
425                    break;
426                };
427
428                let Some(upper) = intervals.next() else {
429                    break;
430                };
431                if &x_lit >= lower && &x_lit <= upper {
432                    return Some(Lit::Bool(true));
433                }
434            }
435
436            Some(Lit::Bool(false))
437        }
438        Expr::Flatten(_, _, _) => {
439            // TODO
440            None
441        }
442        Expr::AllDiff(_, e) => {
443            let es = (**e).clone().unwrap_list()?;
444            let mut lits: HashSet<Lit> = HashSet::new();
445            for expr in es {
446                let Expr::Atomic(_, Atom::Literal(x)) = expr else {
447                    return None;
448                };
449                match x {
450                    Lit::Int(_) | Lit::Bool(_) => {
451                        if lits.contains(&x) {
452                            return Some(Lit::Bool(false));
453                        } else {
454                            lits.insert(x.clone());
455                        }
456                    }
457                    Lit::AbstractLiteral(_) => return None, // Reject AbstractLiteral cases
458                }
459            }
460            Some(Lit::Bool(true))
461        }
462        Expr::FlatAllDiff(_, es) => {
463            let mut lits: HashSet<Lit> = HashSet::new();
464            for atom in es {
465                let Atom::Literal(x) = atom else {
466                    return None;
467                };
468
469                match x {
470                    Lit::Int(_) | Lit::Bool(_) => {
471                        if lits.contains(x) {
472                            return Some(Lit::Bool(false));
473                        } else {
474                            lits.insert(x.clone());
475                        }
476                    }
477                    Lit::AbstractLiteral(_) => return None, // Reject AbstractLiteral cases
478                }
479            }
480            Some(Lit::Bool(true))
481        }
482        Expr::FlatWatchedLiteral(_, _, _) => None,
483        Expr::AuxDeclaration(_, _, _) => None,
484        Expr::Neg(_, a) => {
485            let a: &Atom = a.try_into().ok()?;
486            let a: i32 = a.try_into().ok()?;
487            Some(Lit::Int(-a))
488        }
489        Expr::Minus(_, a, b) => {
490            let a: &Atom = a.try_into().ok()?;
491            let a: i32 = a.try_into().ok()?;
492
493            let b: &Atom = b.try_into().ok()?;
494            let b: i32 = b.try_into().ok()?;
495
496            Some(Lit::Int(a - b))
497        }
498        Expr::FlatMinusEq(_, a, b) => {
499            let a: i32 = a.try_into().ok()?;
500            let b: i32 = b.try_into().ok()?;
501            Some(Lit::Bool(a == -b))
502        }
503        Expr::FlatProductEq(_, a, b, c) => {
504            let a: i32 = a.try_into().ok()?;
505            let b: i32 = b.try_into().ok()?;
506            let c: i32 = c.try_into().ok()?;
507            Some(Lit::Bool(a * b == c))
508        }
509        Expr::FlatWeightedSumLeq(_, cs, vs, total) => {
510            let cs: Vec<i32> = cs
511                .iter()
512                .map(|x| TryInto::<i32>::try_into(x).ok())
513                .collect::<Option<Vec<i32>>>()?;
514            let vs: Vec<i32> = vs
515                .iter()
516                .map(|x| TryInto::<i32>::try_into(x).ok())
517                .collect::<Option<Vec<i32>>>()?;
518            let total: i32 = total.try_into().ok()?;
519
520            let sum: i32 = izip!(cs, vs).fold(0, |acc, (c, v)| acc + (c * v));
521
522            Some(Lit::Bool(sum <= total))
523        }
524        Expr::FlatWeightedSumGeq(_, cs, vs, total) => {
525            let cs: Vec<i32> = cs
526                .iter()
527                .map(|x| TryInto::<i32>::try_into(x).ok())
528                .collect::<Option<Vec<i32>>>()?;
529            let vs: Vec<i32> = vs
530                .iter()
531                .map(|x| TryInto::<i32>::try_into(x).ok())
532                .collect::<Option<Vec<i32>>>()?;
533            let total: i32 = total.try_into().ok()?;
534
535            let sum: i32 = izip!(cs, vs).fold(0, |acc, (c, v)| acc + (c * v));
536
537            Some(Lit::Bool(sum >= total))
538        }
539        Expr::FlatAbsEq(_, x, y) => {
540            let x: i32 = x.try_into().ok()?;
541            let y: i32 = y.try_into().ok()?;
542
543            Some(Lit::Bool(x == y.abs()))
544        }
545        Expr::UnsafePow(_, a, b) | Expr::SafePow(_, a, b) => {
546            let a: &Atom = a.try_into().ok()?;
547            let a: i32 = a.try_into().ok()?;
548
549            let b: &Atom = b.try_into().ok()?;
550            let b: i32 = b.try_into().ok()?;
551
552            if (a != 0 || b != 0) && b >= 0 {
553                Some(Lit::Int(a.pow(b as u32)))
554            } else {
555                None
556            }
557        }
558        Expr::Metavar(_, _) => None,
559        Expr::MinionElementOne(_, _, _, _) => None,
560        Expr::ToInt(_, expression) => {
561            let lit = eval_constant(expression.as_ref())?;
562            match lit {
563                Lit::Int(_) => Some(lit),
564                Lit::Bool(true) => Some(Lit::Int(1)),
565                Lit::Bool(false) => Some(Lit::Int(0)),
566                _ => None,
567            }
568        }
569        Expr::SATInt(_, _, _, _) => {
570            // TODO: If this SATInt is composed of literals, we should evaluate it back to an
571            // integer literal.
572            //
573            // This is important because `is_all_constant` currently returns true for SATInts
574            // containing no references. If we don't evaluate them here, bubble rules will skip
575            // them (thinking they'll be constant-folded later), but they'll actually reach
576            // the solver adaptors as un-encoded unsafe operations, causing panics.
577            None
578        }
579        Expr::PairwiseSum(_, a, b) => {
580            match (eval_constant(a.as_ref())?, eval_constant(b.as_ref())?) {
581                (Lit::Int(a_int), Lit::Int(b_int)) => Some(Lit::Int(a_int + b_int)),
582                _ => None,
583            }
584        }
585        Expr::PairwiseProduct(_, a, b) => {
586            match (eval_constant(a.as_ref())?, eval_constant(b.as_ref())?) {
587                (Lit::Int(a_int), Lit::Int(b_int)) => Some(Lit::Int(a_int * b_int)),
588                _ => None,
589            }
590        }
591        Expr::Defined(_, _) => todo!(),
592        Expr::Range(_, _) => todo!(),
593        Expr::Image(_, _, _) => todo!(),
594        Expr::ImageSet(_, _, _) => todo!(),
595        Expr::PreImage(_, _, _) => todo!(),
596        Expr::Inverse(_, _, _) => todo!(),
597        Expr::Restrict(_, _, _) => todo!(),
598        Expr::LexLt(_, a, b) => {
599            let lt = vec_expr_pairs_op::<i32, _>(a, b, |pairs, (a_len, b_len)| {
600                pairs
601                    .iter()
602                    .find_map(|(a, b)| match a.cmp(b) {
603                        CmpOrdering::Less => Some(true),     // First difference is <
604                        CmpOrdering::Greater => Some(false), // First difference is >
605                        CmpOrdering::Equal => None,          // No difference
606                    })
607                    .unwrap_or(a_len < b_len) // [1,1] <lex [1,1,x]
608            })?;
609            Some(lt.into())
610        }
611        Expr::LexLeq(_, a, b) => {
612            let lt = vec_expr_pairs_op::<i32, _>(a, b, |pairs, (a_len, b_len)| {
613                pairs
614                    .iter()
615                    .find_map(|(a, b)| match a.cmp(b) {
616                        CmpOrdering::Less => Some(true),
617                        CmpOrdering::Greater => Some(false),
618                        CmpOrdering::Equal => None,
619                    })
620                    .unwrap_or(a_len <= b_len) // [1,1] <=lex [1,1,x]
621            })?;
622            Some(lt.into())
623        }
624        Expr::LexGt(_, a, b) => eval_constant(&Expr::LexLt(Metadata::new(), b.clone(), a.clone())),
625        Expr::LexGeq(_, a, b) => {
626            eval_constant(&Expr::LexLeq(Metadata::new(), b.clone(), a.clone()))
627        }
628        Expr::FlatLexLt(_, a, b) => {
629            let lt = atoms_pairs_op::<i32, _>(a, b, |pairs, (a_len, b_len)| {
630                pairs
631                    .iter()
632                    .find_map(|(a, b)| match a.cmp(b) {
633                        CmpOrdering::Less => Some(true),
634                        CmpOrdering::Greater => Some(false),
635                        CmpOrdering::Equal => None,
636                    })
637                    .unwrap_or(a_len < b_len)
638            })?;
639            Some(lt.into())
640        }
641        Expr::FlatLexLeq(_, a, b) => {
642            let lt = atoms_pairs_op::<i32, _>(a, b, |pairs, (a_len, b_len)| {
643                pairs
644                    .iter()
645                    .find_map(|(a, b)| match a.cmp(b) {
646                        CmpOrdering::Less => Some(true),
647                        CmpOrdering::Greater => Some(false),
648                        CmpOrdering::Equal => None,
649                    })
650                    .unwrap_or(a_len <= b_len)
651            })?;
652            Some(lt.into())
653        }
654    }
655}
656
657pub fn un_op<T, A>(f: fn(T) -> A, a: &Expr) -> Option<A>
658where
659    T: TryFrom<Lit>,
660{
661    let a = unwrap_expr::<T>(a)?;
662    Some(f(a))
663}
664
665pub fn bin_op<T, A>(f: fn(T, T) -> A, a: &Expr, b: &Expr) -> Option<A>
666where
667    T: TryFrom<Lit>,
668{
669    let a = unwrap_expr::<T>(a)?;
670    let b = unwrap_expr::<T>(b)?;
671    Some(f(a, b))
672}
673
674#[allow(dead_code)]
675pub fn tern_op<T, A>(f: fn(T, T, T) -> A, a: &Expr, b: &Expr, c: &Expr) -> Option<A>
676where
677    T: TryFrom<Lit>,
678{
679    let a = unwrap_expr::<T>(a)?;
680    let b = unwrap_expr::<T>(b)?;
681    let c = unwrap_expr::<T>(c)?;
682    Some(f(a, b, c))
683}
684
685pub fn vec_op<T, A>(f: fn(Vec<T>) -> A, a: &[Expr]) -> Option<A>
686where
687    T: TryFrom<Lit>,
688{
689    let a = a.iter().map(unwrap_expr).collect::<Option<Vec<T>>>()?;
690    Some(f(a))
691}
692
693pub fn vec_lit_op<T, A>(f: fn(Vec<T>) -> A, a: &Expr) -> Option<A>
694where
695    T: TryFrom<Lit>,
696{
697    // we don't care about preserving indices here, as we will be getting rid of the vector
698    // anyways!
699    let a = a.clone().unwrap_matrix_unchecked()?.0;
700    let a = a.iter().map(unwrap_expr).collect::<Option<Vec<T>>>()?;
701    Some(f(a))
702}
703
704type PairsCallback<T, A> = fn(Vec<(T, T)>, (usize, usize)) -> A;
705
706/// Calls the given function on each consecutive pair of elements in the list expressions.
707/// Also passes the length of the two lists.
708fn vec_expr_pairs_op<T, A>(a: &Expr, b: &Expr, f: PairsCallback<T, A>) -> Option<A>
709where
710    T: TryFrom<Lit>,
711{
712    let a_exprs = a.clone().unwrap_matrix_unchecked()?.0;
713    let b_exprs = b.clone().unwrap_matrix_unchecked()?.0;
714    let lens = (a_exprs.len(), b_exprs.len());
715
716    let lit_pairs = std::iter::zip(a_exprs, b_exprs)
717        .map(|(a, b)| Some((unwrap_expr(&a)?, unwrap_expr(&b)?)))
718        .collect::<Option<Vec<(T, T)>>>()?;
719    Some(f(lit_pairs, lens))
720}
721
722/// Same as [`vec_expr_pairs_op`], but over slices of atoms.
723fn atoms_pairs_op<T, A>(a: &[Atom], b: &[Atom], f: PairsCallback<T, A>) -> Option<A>
724where
725    T: TryFrom<Atom>,
726{
727    let lit_pairs = Iterator::zip(a.iter(), b.iter())
728        .map(|(a, b)| Some((a.clone().try_into().ok()?, b.clone().try_into().ok()?)))
729        .collect::<Option<Vec<(T, T)>>>()?;
730    Some(f(lit_pairs, (a.len(), b.len())))
731}
732
733pub fn opt_vec_op<T, A>(f: fn(Vec<T>) -> Option<A>, a: &[Expr]) -> Option<A>
734where
735    T: TryFrom<Lit>,
736{
737    let a = a.iter().map(unwrap_expr).collect::<Option<Vec<T>>>()?;
738    f(a)
739}
740
741pub fn opt_vec_lit_op<T, A>(f: fn(Vec<T>) -> Option<A>, a: &Expr) -> Option<A>
742where
743    T: TryFrom<Lit>,
744{
745    let a = a.clone().unwrap_list()?;
746    // FIXME: deal with explicit matrix domains
747    let a = a.iter().map(unwrap_expr).collect::<Option<Vec<T>>>()?;
748    f(a)
749}
750
751#[allow(dead_code)]
752pub fn flat_op<T, A>(f: fn(Vec<T>, T) -> A, a: &[Expr], b: &Expr) -> Option<A>
753where
754    T: TryFrom<Lit>,
755{
756    let a = a.iter().map(unwrap_expr).collect::<Option<Vec<T>>>()?;
757    let b = unwrap_expr::<T>(b)?;
758    Some(f(a, b))
759}
760
761pub fn unwrap_expr<T: TryFrom<Lit>>(expr: &Expr) -> Option<T> {
762    let c = eval_constant(expr)?;
763    TryInto::<T>::try_into(c).ok()
764}