Skip to main content

conjure_cp_core/ast/
eval.rs

1#![allow(dead_code)]
2use crate::ast::{AbstractLiteral, Atom, Expression as Expr, Literal as Lit, Metadata, matrix};
3use crate::into_matrix;
4use itertools::{Itertools as _, izip};
5use std::cmp::Ordering as CmpOrdering;
6use std::collections::HashSet;
7
8/// Simplify an expression to a constant if possible
9/// Returns:
10/// `None` if the expression cannot be simplified to a constant (e.g. if it contains a variable)
11/// `Some(Const)` if the expression can be simplified to a constant
12pub fn eval_constant(expr: &Expr) -> Option<Lit> {
13    match expr {
14        Expr::Supset(_, a, b) => match (a.as_ref(), b.as_ref()) {
15            (
16                Expr::Atomic(_, Atom::Literal(Lit::AbstractLiteral(AbstractLiteral::Set(a)))),
17                Expr::Atomic(_, Atom::Literal(Lit::AbstractLiteral(AbstractLiteral::Set(b)))),
18            ) => {
19                let a_set: HashSet<Lit> = a.iter().cloned().collect();
20                let b_set: HashSet<Lit> = b.iter().cloned().collect();
21
22                if a_set.difference(&b_set).count() > 0 {
23                    Some(Lit::Bool(a_set.is_superset(&b_set)))
24                } else {
25                    Some(Lit::Bool(false))
26                }
27            }
28            _ => None,
29        },
30        Expr::SupsetEq(_, a, b) => match (a.as_ref(), b.as_ref()) {
31            (
32                Expr::Atomic(_, Atom::Literal(Lit::AbstractLiteral(AbstractLiteral::Set(a)))),
33                Expr::Atomic(_, Atom::Literal(Lit::AbstractLiteral(AbstractLiteral::Set(b)))),
34            ) => Some(Lit::Bool(
35                a.iter()
36                    .cloned()
37                    .collect::<HashSet<Lit>>()
38                    .is_superset(&b.iter().cloned().collect::<HashSet<Lit>>()),
39            )),
40            _ => None,
41        },
42        Expr::Subset(_, a, b) => match (a.as_ref(), b.as_ref()) {
43            (
44                Expr::Atomic(_, Atom::Literal(Lit::AbstractLiteral(AbstractLiteral::Set(a)))),
45                Expr::Atomic(_, Atom::Literal(Lit::AbstractLiteral(AbstractLiteral::Set(b)))),
46            ) => {
47                let a_set: HashSet<Lit> = a.iter().cloned().collect();
48                let b_set: HashSet<Lit> = b.iter().cloned().collect();
49
50                if b_set.difference(&a_set).count() > 0 {
51                    Some(Lit::Bool(a_set.is_subset(&b_set)))
52                } else {
53                    Some(Lit::Bool(false))
54                }
55            }
56            _ => None,
57        },
58        Expr::SubsetEq(_, a, b) => match (a.as_ref(), b.as_ref()) {
59            (
60                Expr::Atomic(_, Atom::Literal(Lit::AbstractLiteral(AbstractLiteral::Set(a)))),
61                Expr::Atomic(_, Atom::Literal(Lit::AbstractLiteral(AbstractLiteral::Set(b)))),
62            ) => Some(Lit::Bool(
63                a.iter()
64                    .cloned()
65                    .collect::<HashSet<Lit>>()
66                    .is_subset(&b.iter().cloned().collect::<HashSet<Lit>>()),
67            )),
68            _ => None,
69        },
70        Expr::Intersect(_, a, b) => match (a.as_ref(), b.as_ref()) {
71            (
72                Expr::Atomic(_, Atom::Literal(Lit::AbstractLiteral(AbstractLiteral::Set(a)))),
73                Expr::Atomic(_, Atom::Literal(Lit::AbstractLiteral(AbstractLiteral::Set(b)))),
74            ) => {
75                let mut res: Vec<Lit> = Vec::new();
76                for lit in a.iter() {
77                    if b.contains(lit) && !res.contains(lit) {
78                        res.push(lit.clone());
79                    }
80                }
81                Some(Lit::AbstractLiteral(AbstractLiteral::Set(res)))
82            }
83            _ => None,
84        },
85        Expr::Union(_, a, b) => match (a.as_ref(), b.as_ref()) {
86            (
87                Expr::Atomic(_, Atom::Literal(Lit::AbstractLiteral(AbstractLiteral::Set(a)))),
88                Expr::Atomic(_, Atom::Literal(Lit::AbstractLiteral(AbstractLiteral::Set(b)))),
89            ) => {
90                let mut res: Vec<Lit> = Vec::new();
91                for lit in a.iter() {
92                    res.push(lit.clone());
93                }
94                for lit in b.iter() {
95                    if !res.contains(lit) {
96                        res.push(lit.clone());
97                    }
98                }
99                Some(Lit::AbstractLiteral(AbstractLiteral::Set(res)))
100            }
101            _ => None,
102        },
103        Expr::In(_, a, b) => {
104            if let (
105                Expr::Atomic(_, Atom::Literal(Lit::Int(c))),
106                Expr::Atomic(_, Atom::Literal(Lit::AbstractLiteral(AbstractLiteral::Set(d)))),
107            ) = (a.as_ref(), b.as_ref())
108            {
109                for lit in d.iter() {
110                    if let Lit::Int(x) = lit
111                        && c == x
112                    {
113                        return Some(Lit::Bool(true));
114                    }
115                }
116                Some(Lit::Bool(false))
117            } else {
118                None
119            }
120        }
121        Expr::FromSolution(_, _) => None,
122        Expr::DominanceRelation(_, _) => None,
123        Expr::InDomain(_, e, domain) => {
124            let Expr::Atomic(_, Atom::Literal(lit)) = e.as_ref() else {
125                return None;
126            };
127
128            domain.contains(lit).ok().map(Into::into)
129        }
130        Expr::Atomic(_, Atom::Literal(c)) => Some(c.clone()),
131        Expr::Atomic(_, Atom::Reference(_)) => None,
132        Expr::AbstractLiteral(_, a) => {
133            if let AbstractLiteral::Set(s) = a {
134                let mut copy = Vec::new();
135                for expr in s.iter() {
136                    if let Expr::Atomic(_, Atom::Literal(lit)) = expr {
137                        copy.push(lit.clone());
138                    } else {
139                        return None;
140                    }
141                }
142                Some(Lit::AbstractLiteral(AbstractLiteral::Set(copy)))
143            } else {
144                None
145            }
146        }
147        Expr::Comprehension(_, _) => None,
148        Expr::AbstractComprehension(_, _) => None,
149        Expr::UnsafeIndex(_, subject, indices) | Expr::SafeIndex(_, subject, indices) => {
150            let subject: Lit = subject.as_ref().clone().into_literal()?;
151            let indices: Vec<Lit> = indices
152                .iter()
153                .cloned()
154                .map(|x| x.into_literal())
155                .collect::<Option<Vec<Lit>>>()?;
156
157            match subject {
158                Lit::AbstractLiteral(subject @ AbstractLiteral::Matrix(_, _)) => {
159                    matrix::flatten_enumerate(subject)
160                        .find(|(i, _)| i == &indices)
161                        .map(|(_, x)| x)
162                }
163                Lit::AbstractLiteral(subject @ AbstractLiteral::Tuple(_)) => {
164                    let AbstractLiteral::Tuple(elems) = subject else {
165                        return None;
166                    };
167
168                    assert!(indices.len() == 1, "nested tuples not supported yet");
169
170                    let Lit::Int(index) = indices[0].clone() else {
171                        return None;
172                    };
173
174                    if elems.len() < index as usize || index < 1 {
175                        return None;
176                    }
177
178                    // -1 for 0-indexing vs 1-indexing
179                    let item = elems[index as usize - 1].clone();
180
181                    Some(item)
182                }
183                Lit::AbstractLiteral(subject @ AbstractLiteral::Record(_)) => {
184                    let AbstractLiteral::Record(elems) = subject else {
185                        return None;
186                    };
187
188                    assert!(indices.len() == 1, "nested record not supported yet");
189
190                    let Lit::Int(index) = indices[0].clone() else {
191                        return None;
192                    };
193
194                    if elems.len() < index as usize || index < 1 {
195                        return None;
196                    }
197
198                    // -1 for 0-indexing vs 1-indexing
199                    let item = elems[index as usize - 1].clone();
200                    Some(item.value)
201                }
202                _ => None,
203            }
204        }
205        Expr::UnsafeSlice(_, subject, indices) | Expr::SafeSlice(_, subject, indices) => {
206            let subject: Lit = subject.as_ref().clone().into_literal()?;
207            let Lit::AbstractLiteral(subject @ AbstractLiteral::Matrix(_, _)) = subject else {
208                return None;
209            };
210
211            let hole_dim = indices
212                .iter()
213                .cloned()
214                .position(|x| x.is_none())
215                .expect("slice expression should have a hole dimension");
216
217            let missing_domain = matrix::index_domains(subject.clone())[hole_dim].clone();
218
219            let indices: Vec<Option<Lit>> = indices
220                .iter()
221                .cloned()
222                .map(|x| {
223                    // the outer option represents success of this iterator, the inner the index
224                    // slice.
225                    match x {
226                        Some(x) => x.into_literal().map(Some),
227                        None => Some(None),
228                    }
229                })
230                .collect::<Option<Vec<Option<Lit>>>>()?;
231
232            let indices_in_slice: Vec<Vec<Lit>> = missing_domain
233                .values()
234                .ok()?
235                .map(|i| {
236                    let mut indices = indices.clone();
237                    indices[hole_dim] = Some(i);
238                    // These unwraps will only fail if we have multiple holes.
239                    // As this is invalid, panicking is fine.
240                    indices.into_iter().map(|x| x.unwrap()).collect_vec()
241                })
242                .collect_vec();
243
244            // Note: indices_in_slice is not necessarily sorted, so this is the best way.
245            let elems = matrix::flatten_enumerate(subject)
246                .filter(|(i, _)| indices_in_slice.contains(i))
247                .map(|(_, elem)| elem)
248                .collect();
249
250            Some(Lit::AbstractLiteral(into_matrix![elems]))
251        }
252        Expr::Abs(_, e) => un_op::<i32, i32>(|a| a.abs(), e).map(Lit::Int),
253        Expr::Eq(_, a, b) => bin_op::<i32, bool>(|a, b| a == b, a, b)
254            .or_else(|| bin_op::<bool, bool>(|a, b| a == b, a, b))
255            .map(Lit::Bool),
256        Expr::Neq(_, a, b) => bin_op::<i32, bool>(|a, b| a != b, a, b).map(Lit::Bool),
257        Expr::Lt(_, a, b) => bin_op::<i32, bool>(|a, b| a < b, a, b).map(Lit::Bool),
258        Expr::Gt(_, a, b) => bin_op::<i32, bool>(|a, b| a > b, a, b).map(Lit::Bool),
259        Expr::Leq(_, a, b) => bin_op::<i32, bool>(|a, b| a <= b, a, b).map(Lit::Bool),
260        Expr::Geq(_, a, b) => bin_op::<i32, bool>(|a, b| a >= b, a, b).map(Lit::Bool),
261        Expr::Not(_, expr) => un_op::<bool, bool>(|e| !e, expr).map(Lit::Bool),
262        Expr::And(_, e) => {
263            vec_lit_op::<bool, bool>(|e| e.iter().all(|&e| e), e.as_ref()).map(Lit::Bool)
264        }
265        Expr::Root(_, _) => None,
266        Expr::Or(_, es) => {
267            // possibly cheating; definitely should be in partial eval instead
268            for e in (**es).clone().unwrap_list()? {
269                if let Expr::Atomic(_, Atom::Literal(Lit::Bool(true))) = e {
270                    return Some(Lit::Bool(true));
271                };
272            }
273
274            vec_lit_op::<bool, bool>(|e| e.iter().any(|&e| e), es.as_ref()).map(Lit::Bool)
275        }
276        Expr::Imply(_, box1, box2) => {
277            let a: &Atom = (&**box1).try_into().ok()?;
278            let b: &Atom = (&**box2).try_into().ok()?;
279
280            let a: bool = a.try_into().ok()?;
281            let b: bool = b.try_into().ok()?;
282
283            if a {
284                // true -> b ~> b
285                Some(Lit::Bool(b))
286            } else {
287                // false -> b ~> true
288                Some(Lit::Bool(true))
289            }
290        }
291        Expr::Iff(_, box1, box2) => {
292            let a: &Atom = (&**box1).try_into().ok()?;
293            let b: &Atom = (&**box2).try_into().ok()?;
294
295            let a: bool = a.try_into().ok()?;
296            let b: bool = b.try_into().ok()?;
297
298            Some(Lit::Bool(a == b))
299        }
300        Expr::Sum(_, exprs) => vec_lit_op::<i32, i32>(|e| e.iter().sum(), exprs).map(Lit::Int),
301        Expr::Product(_, exprs) => {
302            vec_lit_op::<i32, i32>(|e| e.iter().product(), exprs).map(Lit::Int)
303        }
304        Expr::FlatIneq(_, a, b, c) => {
305            let a: i32 = a.try_into().ok()?;
306            let b: i32 = b.try_into().ok()?;
307            let c: i32 = c.try_into().ok()?;
308
309            Some(Lit::Bool(a <= b + c))
310        }
311        Expr::FlatSumGeq(_, exprs, a) => {
312            let sum = exprs.iter().try_fold(0, |acc, atom: &Atom| {
313                let n: i32 = atom.try_into().ok()?;
314                let acc = acc + n;
315                Some(acc)
316            })?;
317
318            Some(Lit::Bool(sum >= a.try_into().ok()?))
319        }
320        Expr::FlatSumLeq(_, exprs, a) => {
321            let sum = exprs.iter().try_fold(0, |acc, atom: &Atom| {
322                let n: i32 = atom.try_into().ok()?;
323                let acc = acc + n;
324                Some(acc)
325            })?;
326
327            Some(Lit::Bool(sum >= a.try_into().ok()?))
328        }
329        Expr::Min(_, e) => {
330            opt_vec_lit_op::<i32, i32>(|e| e.iter().min().copied(), e.as_ref()).map(Lit::Int)
331        }
332        Expr::Max(_, e) => {
333            opt_vec_lit_op::<i32, i32>(|e| e.iter().max().copied(), e.as_ref()).map(Lit::Int)
334        }
335        Expr::UnsafeDiv(_, a, b) | Expr::SafeDiv(_, a, b) => {
336            if unwrap_expr::<i32>(b)? == 0 {
337                return None;
338            }
339            bin_op::<i32, i32>(|a, b| ((a as f32) / (b as f32)).floor() as i32, a, b).map(Lit::Int)
340        }
341        Expr::UnsafeMod(_, a, b) | Expr::SafeMod(_, a, b) => {
342            if unwrap_expr::<i32>(b)? == 0 {
343                return None;
344            }
345            bin_op::<i32, i32>(|a, b| a - b * (a as f32 / b as f32).floor() as i32, a, b)
346                .map(Lit::Int)
347        }
348        Expr::MinionDivEqUndefZero(_, a, b, c) => {
349            // div always rounds down
350            let a: i32 = a.try_into().ok()?;
351            let b: i32 = b.try_into().ok()?;
352            let c: i32 = c.try_into().ok()?;
353
354            if b == 0 {
355                return None;
356            }
357
358            let a = a as f32;
359            let b = b as f32;
360            let div: i32 = (a / b).floor() as i32;
361            Some(Lit::Bool(div == c))
362        }
363        Expr::Bubble(_, a, b) => bin_op::<bool, bool>(|a, b| a && b, a, b).map(Lit::Bool),
364        Expr::MinionReify(_, a, b) => {
365            let result = eval_constant(a)?;
366
367            let result: bool = result.try_into().ok()?;
368            let b: bool = b.try_into().ok()?;
369
370            Some(Lit::Bool(b == result))
371        }
372        Expr::MinionReifyImply(_, a, b) => {
373            let result = eval_constant(a)?;
374
375            let result: bool = result.try_into().ok()?;
376            let b: bool = b.try_into().ok()?;
377
378            if b {
379                Some(Lit::Bool(result))
380            } else {
381                Some(Lit::Bool(true))
382            }
383        }
384        Expr::MinionModuloEqUndefZero(_, a, b, c) => {
385            // From Savile Row. Same semantics as division.
386            //
387            //   a - (b * floor(a/b))
388            //
389            // We don't use % as it has the same semantics as /. We don't use / as we want to round
390            // down instead, not towards zero.
391
392            let a: i32 = a.try_into().ok()?;
393            let b: i32 = b.try_into().ok()?;
394            let c: i32 = c.try_into().ok()?;
395
396            if b == 0 {
397                return None;
398            }
399
400            let modulo = a - b * (a as f32 / b as f32).floor() as i32;
401            Some(Lit::Bool(modulo == c))
402        }
403        Expr::MinionPow(_, a, b, c) => {
404            // only available for positive a b c
405
406            let a: i32 = a.try_into().ok()?;
407            let b: i32 = b.try_into().ok()?;
408            let c: i32 = c.try_into().ok()?;
409
410            if a <= 0 {
411                return None;
412            }
413
414            if b <= 0 {
415                return None;
416            }
417
418            if c <= 0 {
419                return None;
420            }
421
422            Some(Lit::Bool(a ^ b == c))
423        }
424        Expr::MinionWInSet(_, _, _) => None,
425        Expr::MinionWInIntervalSet(_, x, intervals) => {
426            let x_lit: &Lit = x.try_into().ok()?;
427
428            let x_lit = match x_lit.clone() {
429                Lit::Int(i) => Some(i),
430                Lit::Bool(true) => Some(1),
431                Lit::Bool(false) => Some(0),
432                _ => None,
433            }?;
434
435            let mut intervals = intervals.iter();
436            loop {
437                let Some(lower) = intervals.next() else {
438                    break;
439                };
440
441                let Some(upper) = intervals.next() else {
442                    break;
443                };
444                if &x_lit >= lower && &x_lit <= upper {
445                    return Some(Lit::Bool(true));
446                }
447            }
448
449            Some(Lit::Bool(false))
450        }
451        Expr::Flatten(_, _, _) => {
452            // TODO
453            None
454        }
455        Expr::AllDiff(_, e) => {
456            let es = (**e).clone().unwrap_list()?;
457            let mut lits: HashSet<Lit> = HashSet::new();
458            for expr in es {
459                let Expr::Atomic(_, Atom::Literal(x)) = expr else {
460                    return None;
461                };
462                match x {
463                    Lit::Int(_) | Lit::Bool(_) => {
464                        if lits.contains(&x) {
465                            return Some(Lit::Bool(false));
466                        } else {
467                            lits.insert(x.clone());
468                        }
469                    }
470                    Lit::AbstractLiteral(_) => return None, // Reject AbstractLiteral cases
471                }
472            }
473            Some(Lit::Bool(true))
474        }
475        Expr::FlatAllDiff(_, es) => {
476            let mut lits: HashSet<Lit> = HashSet::new();
477            for atom in es {
478                let Atom::Literal(x) = atom else {
479                    return None;
480                };
481
482                match x {
483                    Lit::Int(_) | Lit::Bool(_) => {
484                        if lits.contains(x) {
485                            return Some(Lit::Bool(false));
486                        } else {
487                            lits.insert(x.clone());
488                        }
489                    }
490                    Lit::AbstractLiteral(_) => return None, // Reject AbstractLiteral cases
491                }
492            }
493            Some(Lit::Bool(true))
494        }
495        Expr::FlatWatchedLiteral(_, _, _) => None,
496        Expr::AuxDeclaration(_, _, _) => None,
497        Expr::Neg(_, a) => {
498            let a: &Atom = a.try_into().ok()?;
499            let a: i32 = a.try_into().ok()?;
500            Some(Lit::Int(-a))
501        }
502        Expr::Minus(_, a, b) => {
503            let a: &Atom = a.try_into().ok()?;
504            let a: i32 = a.try_into().ok()?;
505
506            let b: &Atom = b.try_into().ok()?;
507            let b: i32 = b.try_into().ok()?;
508
509            Some(Lit::Int(a - b))
510        }
511        Expr::FlatMinusEq(_, a, b) => {
512            let a: i32 = a.try_into().ok()?;
513            let b: i32 = b.try_into().ok()?;
514            Some(Lit::Bool(a == -b))
515        }
516        Expr::FlatProductEq(_, a, b, c) => {
517            let a: i32 = a.try_into().ok()?;
518            let b: i32 = b.try_into().ok()?;
519            let c: i32 = c.try_into().ok()?;
520            Some(Lit::Bool(a * b == c))
521        }
522        Expr::FlatWeightedSumLeq(_, cs, vs, total) => {
523            let cs: Vec<i32> = cs
524                .iter()
525                .map(|x| TryInto::<i32>::try_into(x).ok())
526                .collect::<Option<Vec<i32>>>()?;
527            let vs: Vec<i32> = vs
528                .iter()
529                .map(|x| TryInto::<i32>::try_into(x).ok())
530                .collect::<Option<Vec<i32>>>()?;
531            let total: i32 = total.try_into().ok()?;
532
533            let sum: i32 = izip!(cs, vs).fold(0, |acc, (c, v)| acc + (c * v));
534
535            Some(Lit::Bool(sum <= total))
536        }
537        Expr::FlatWeightedSumGeq(_, cs, vs, total) => {
538            let cs: Vec<i32> = cs
539                .iter()
540                .map(|x| TryInto::<i32>::try_into(x).ok())
541                .collect::<Option<Vec<i32>>>()?;
542            let vs: Vec<i32> = vs
543                .iter()
544                .map(|x| TryInto::<i32>::try_into(x).ok())
545                .collect::<Option<Vec<i32>>>()?;
546            let total: i32 = total.try_into().ok()?;
547
548            let sum: i32 = izip!(cs, vs).fold(0, |acc, (c, v)| acc + (c * v));
549
550            Some(Lit::Bool(sum >= total))
551        }
552        Expr::FlatAbsEq(_, x, y) => {
553            let x: i32 = x.try_into().ok()?;
554            let y: i32 = y.try_into().ok()?;
555
556            Some(Lit::Bool(x == y.abs()))
557        }
558        Expr::UnsafePow(_, a, b) | Expr::SafePow(_, a, b) => {
559            let a: &Atom = a.try_into().ok()?;
560            let a: i32 = a.try_into().ok()?;
561
562            let b: &Atom = b.try_into().ok()?;
563            let b: i32 = b.try_into().ok()?;
564
565            if (a != 0 || b != 0) && b >= 0 {
566                Some(Lit::Int(a.pow(b as u32)))
567            } else {
568                None
569            }
570        }
571        Expr::Scope(_, _) => None,
572        Expr::Metavar(_, _) => None,
573        Expr::MinionElementOne(_, _, _, _) => None,
574        Expr::ToInt(_, expression) => {
575            let lit = (**expression).clone().into_literal()?;
576            match lit {
577                Lit::Int(_) => Some(lit),
578                Lit::Bool(true) => Some(Lit::Int(1)),
579                Lit::Bool(false) => Some(Lit::Int(0)),
580                _ => None,
581            }
582        }
583        Expr::SATInt(..) => None,
584        Expr::PairwiseSum(_, a, b) => {
585            match ((**a).clone().into_literal()?, (**b).clone().into_literal()?) {
586                (Lit::Int(a_int), Lit::Int(b_int)) => Some(Lit::Int(a_int + b_int)),
587                _ => None,
588            }
589        }
590        Expr::PairwiseProduct(_, a, b) => {
591            match ((**a).clone().into_literal()?, (**b).clone().into_literal()?) {
592                (Lit::Int(a_int), Lit::Int(b_int)) => Some(Lit::Int(a_int * b_int)),
593                _ => None,
594            }
595        }
596        Expr::Defined(_, _) => todo!(),
597        Expr::Range(_, _) => todo!(),
598        Expr::Image(_, _, _) => todo!(),
599        Expr::ImageSet(_, _, _) => todo!(),
600        Expr::PreImage(_, _, _) => todo!(),
601        Expr::Inverse(_, _, _) => todo!(),
602        Expr::Restrict(_, _, _) => todo!(),
603        Expr::LexLt(_, a, b) => {
604            let lt = vec_expr_pairs_op::<i32, _>(a, b, |pairs, (a_len, b_len)| {
605                pairs
606                    .iter()
607                    .find_map(|(a, b)| match a.cmp(b) {
608                        CmpOrdering::Less => Some(true),     // First difference is <
609                        CmpOrdering::Greater => Some(false), // First difference is >
610                        CmpOrdering::Equal => None,          // No difference
611                    })
612                    .unwrap_or(a_len < b_len) // [1,1] <lex [1,1,x]
613            })?;
614            Some(lt.into())
615        }
616        Expr::LexLeq(_, a, b) => {
617            let lt = vec_expr_pairs_op::<i32, _>(a, b, |pairs, (a_len, b_len)| {
618                pairs
619                    .iter()
620                    .find_map(|(a, b)| match a.cmp(b) {
621                        CmpOrdering::Less => Some(true),
622                        CmpOrdering::Greater => Some(false),
623                        CmpOrdering::Equal => None,
624                    })
625                    .unwrap_or(a_len <= b_len) // [1,1] <=lex [1,1,x]
626            })?;
627            Some(lt.into())
628        }
629        Expr::LexGt(_, a, b) => eval_constant(&Expr::LexLt(Metadata::new(), b.clone(), a.clone())),
630        Expr::LexGeq(_, a, b) => {
631            eval_constant(&Expr::LexLeq(Metadata::new(), b.clone(), a.clone()))
632        }
633        Expr::FlatLexLt(_, a, b) => {
634            let lt = atoms_pairs_op::<i32, _>(a, b, |pairs, (a_len, b_len)| {
635                pairs
636                    .iter()
637                    .find_map(|(a, b)| match a.cmp(b) {
638                        CmpOrdering::Less => Some(true),
639                        CmpOrdering::Greater => Some(false),
640                        CmpOrdering::Equal => None,
641                    })
642                    .unwrap_or(a_len < b_len)
643            })?;
644            Some(lt.into())
645        }
646        Expr::FlatLexLeq(_, a, b) => {
647            let lt = atoms_pairs_op::<i32, _>(a, b, |pairs, (a_len, b_len)| {
648                pairs
649                    .iter()
650                    .find_map(|(a, b)| match a.cmp(b) {
651                        CmpOrdering::Less => Some(true),
652                        CmpOrdering::Greater => Some(false),
653                        CmpOrdering::Equal => None,
654                    })
655                    .unwrap_or(a_len <= b_len)
656            })?;
657            Some(lt.into())
658        }
659    }
660}
661
662pub fn un_op<T, A>(f: fn(T) -> A, a: &Expr) -> Option<A>
663where
664    T: TryFrom<Lit>,
665{
666    let a = unwrap_expr::<T>(a)?;
667    Some(f(a))
668}
669
670pub fn bin_op<T, A>(f: fn(T, T) -> A, a: &Expr, b: &Expr) -> Option<A>
671where
672    T: TryFrom<Lit>,
673{
674    let a = unwrap_expr::<T>(a)?;
675    let b = unwrap_expr::<T>(b)?;
676    Some(f(a, b))
677}
678
679#[allow(dead_code)]
680pub fn tern_op<T, A>(f: fn(T, T, T) -> A, a: &Expr, b: &Expr, c: &Expr) -> Option<A>
681where
682    T: TryFrom<Lit>,
683{
684    let a = unwrap_expr::<T>(a)?;
685    let b = unwrap_expr::<T>(b)?;
686    let c = unwrap_expr::<T>(c)?;
687    Some(f(a, b, c))
688}
689
690pub fn vec_op<T, A>(f: fn(Vec<T>) -> A, a: &[Expr]) -> Option<A>
691where
692    T: TryFrom<Lit>,
693{
694    let a = a.iter().map(unwrap_expr).collect::<Option<Vec<T>>>()?;
695    Some(f(a))
696}
697
698pub fn vec_lit_op<T, A>(f: fn(Vec<T>) -> A, a: &Expr) -> Option<A>
699where
700    T: TryFrom<Lit>,
701{
702    // we don't care about preserving indices here, as we will be getting rid of the vector
703    // anyways!
704    let a = a.clone().unwrap_matrix_unchecked()?.0;
705    let a = a.iter().map(unwrap_expr).collect::<Option<Vec<T>>>()?;
706    Some(f(a))
707}
708
709type PairsCallback<T, A> = fn(Vec<(T, T)>, (usize, usize)) -> A;
710
711/// Calls the given function on each consecutive pair of elements in the list expressions.
712/// Also passes the length of the two lists.
713fn vec_expr_pairs_op<T, A>(a: &Expr, b: &Expr, f: PairsCallback<T, A>) -> Option<A>
714where
715    T: TryFrom<Lit>,
716{
717    let a_exprs = a.clone().unwrap_matrix_unchecked()?.0;
718    let b_exprs = b.clone().unwrap_matrix_unchecked()?.0;
719    let lens = (a_exprs.len(), b_exprs.len());
720
721    let lit_pairs = std::iter::zip(a_exprs, b_exprs)
722        .map(|(a, b)| Some((unwrap_expr(&a)?, unwrap_expr(&b)?)))
723        .collect::<Option<Vec<(T, T)>>>()?;
724    Some(f(lit_pairs, lens))
725}
726
727/// Same as [`vec_expr_pairs_op`], but over slices of atoms.
728fn atoms_pairs_op<T, A>(a: &[Atom], b: &[Atom], f: PairsCallback<T, A>) -> Option<A>
729where
730    T: TryFrom<Atom>,
731{
732    let lit_pairs = Iterator::zip(a.iter(), b.iter())
733        .map(|(a, b)| Some((a.clone().try_into().ok()?, b.clone().try_into().ok()?)))
734        .collect::<Option<Vec<(T, T)>>>()?;
735    Some(f(lit_pairs, (a.len(), b.len())))
736}
737
738pub fn opt_vec_op<T, A>(f: fn(Vec<T>) -> Option<A>, a: &[Expr]) -> Option<A>
739where
740    T: TryFrom<Lit>,
741{
742    let a = a.iter().map(unwrap_expr).collect::<Option<Vec<T>>>()?;
743    f(a)
744}
745
746pub fn opt_vec_lit_op<T, A>(f: fn(Vec<T>) -> Option<A>, a: &Expr) -> Option<A>
747where
748    T: TryFrom<Lit>,
749{
750    let a = a.clone().unwrap_list()?;
751    // FIXME: deal with explicit matrix domains
752    let a = a.iter().map(unwrap_expr).collect::<Option<Vec<T>>>()?;
753    f(a)
754}
755
756#[allow(dead_code)]
757pub fn flat_op<T, A>(f: fn(Vec<T>, T) -> A, a: &[Expr], b: &Expr) -> Option<A>
758where
759    T: TryFrom<Lit>,
760{
761    let a = a.iter().map(unwrap_expr).collect::<Option<Vec<T>>>()?;
762    let b = unwrap_expr::<T>(b)?;
763    Some(f(a, b))
764}
765
766pub fn unwrap_expr<T: TryFrom<Lit>>(expr: &Expr) -> Option<T> {
767    let c = eval_constant(expr)?;
768    TryInto::<T>::try_into(c).ok()
769}