conjure_cp_core/ast/
eval.rs

1#![allow(dead_code)]
2use crate::ast::{AbstractLiteral, Atom, Expression as Expr, Literal as Lit, Metadata, matrix};
3use crate::into_matrix;
4use itertools::{Itertools as _, izip};
5use std::cmp::Ordering as CmpOrdering;
6use std::collections::HashSet;
7
8/// Simplify an expression to a constant if possible
9/// Returns:
10/// `None` if the expression cannot be simplified to a constant (e.g. if it contains a variable)
11/// `Some(Const)` if the expression can be simplified to a constant
12pub fn eval_constant(expr: &Expr) -> Option<Lit> {
13    match expr {
14        Expr::Supset(_, a, b) => match (a.as_ref(), b.as_ref()) {
15            (
16                Expr::Atomic(_, Atom::Literal(Lit::AbstractLiteral(AbstractLiteral::Set(a)))),
17                Expr::Atomic(_, Atom::Literal(Lit::AbstractLiteral(AbstractLiteral::Set(b)))),
18            ) => {
19                let a_set: HashSet<Lit> = a.iter().cloned().collect();
20                let b_set: HashSet<Lit> = b.iter().cloned().collect();
21
22                if a_set.difference(&b_set).count() > 0 {
23                    Some(Lit::Bool(a_set.is_superset(&b_set)))
24                } else {
25                    Some(Lit::Bool(false))
26                }
27            }
28            _ => None,
29        },
30        Expr::SupsetEq(_, a, b) => match (a.as_ref(), b.as_ref()) {
31            (
32                Expr::Atomic(_, Atom::Literal(Lit::AbstractLiteral(AbstractLiteral::Set(a)))),
33                Expr::Atomic(_, Atom::Literal(Lit::AbstractLiteral(AbstractLiteral::Set(b)))),
34            ) => Some(Lit::Bool(
35                a.iter()
36                    .cloned()
37                    .collect::<HashSet<Lit>>()
38                    .is_superset(&b.iter().cloned().collect::<HashSet<Lit>>()),
39            )),
40            _ => None,
41        },
42        Expr::Subset(_, a, b) => match (a.as_ref(), b.as_ref()) {
43            (
44                Expr::Atomic(_, Atom::Literal(Lit::AbstractLiteral(AbstractLiteral::Set(a)))),
45                Expr::Atomic(_, Atom::Literal(Lit::AbstractLiteral(AbstractLiteral::Set(b)))),
46            ) => {
47                let a_set: HashSet<Lit> = a.iter().cloned().collect();
48                let b_set: HashSet<Lit> = b.iter().cloned().collect();
49
50                if b_set.difference(&a_set).count() > 0 {
51                    Some(Lit::Bool(a_set.is_subset(&b_set)))
52                } else {
53                    Some(Lit::Bool(false))
54                }
55            }
56            _ => None,
57        },
58        Expr::SubsetEq(_, a, b) => match (a.as_ref(), b.as_ref()) {
59            (
60                Expr::Atomic(_, Atom::Literal(Lit::AbstractLiteral(AbstractLiteral::Set(a)))),
61                Expr::Atomic(_, Atom::Literal(Lit::AbstractLiteral(AbstractLiteral::Set(b)))),
62            ) => Some(Lit::Bool(
63                a.iter()
64                    .cloned()
65                    .collect::<HashSet<Lit>>()
66                    .is_subset(&b.iter().cloned().collect::<HashSet<Lit>>()),
67            )),
68            _ => None,
69        },
70        Expr::Intersect(_, a, b) => match (a.as_ref(), b.as_ref()) {
71            (
72                Expr::Atomic(_, Atom::Literal(Lit::AbstractLiteral(AbstractLiteral::Set(a)))),
73                Expr::Atomic(_, Atom::Literal(Lit::AbstractLiteral(AbstractLiteral::Set(b)))),
74            ) => {
75                let mut res: Vec<Lit> = Vec::new();
76                for lit in a.iter() {
77                    if b.contains(lit) && !res.contains(lit) {
78                        res.push(lit.clone());
79                    }
80                }
81                Some(Lit::AbstractLiteral(AbstractLiteral::Set(res)))
82            }
83            _ => None,
84        },
85        Expr::Union(_, a, b) => match (a.as_ref(), b.as_ref()) {
86            (
87                Expr::Atomic(_, Atom::Literal(Lit::AbstractLiteral(AbstractLiteral::Set(a)))),
88                Expr::Atomic(_, Atom::Literal(Lit::AbstractLiteral(AbstractLiteral::Set(b)))),
89            ) => {
90                let mut res: Vec<Lit> = Vec::new();
91                for lit in a.iter() {
92                    res.push(lit.clone());
93                }
94                for lit in b.iter() {
95                    if !res.contains(lit) {
96                        res.push(lit.clone());
97                    }
98                }
99                Some(Lit::AbstractLiteral(AbstractLiteral::Set(res)))
100            }
101            _ => None,
102        },
103        Expr::In(_, a, b) => {
104            if let (
105                Expr::Atomic(_, Atom::Literal(Lit::Int(c))),
106                Expr::Atomic(_, Atom::Literal(Lit::AbstractLiteral(AbstractLiteral::Set(d)))),
107            ) = (a.as_ref(), b.as_ref())
108            {
109                for lit in d.iter() {
110                    if let Lit::Int(x) = lit
111                        && c == x
112                    {
113                        return Some(Lit::Bool(true));
114                    }
115                }
116                Some(Lit::Bool(false))
117            } else {
118                None
119            }
120        }
121        Expr::FromSolution(_, _) => None,
122        Expr::DominanceRelation(_, _) => None,
123        Expr::InDomain(_, e, domain) => {
124            let Expr::Atomic(_, Atom::Literal(lit)) = e.as_ref() else {
125                return None;
126            };
127
128            domain.contains(lit).ok().map(Into::into)
129        }
130        Expr::Atomic(_, Atom::Literal(c)) => Some(c.clone()),
131        Expr::Atomic(_, Atom::Reference(_)) => None,
132        Expr::AbstractLiteral(_, a) => {
133            if let AbstractLiteral::Set(s) = a {
134                let mut copy = Vec::new();
135                for expr in s.iter() {
136                    if let Expr::Atomic(_, Atom::Literal(lit)) = expr {
137                        copy.push(lit.clone());
138                    } else {
139                        return None;
140                    }
141                }
142                Some(Lit::AbstractLiteral(AbstractLiteral::Set(copy)))
143            } else {
144                None
145            }
146        }
147        Expr::Comprehension(_, _) => None,
148        Expr::UnsafeIndex(_, subject, indices) | Expr::SafeIndex(_, subject, indices) => {
149            let subject: Lit = subject.as_ref().clone().into_literal()?;
150            let indices: Vec<Lit> = indices
151                .iter()
152                .cloned()
153                .map(|x| x.into_literal())
154                .collect::<Option<Vec<Lit>>>()?;
155
156            match subject {
157                Lit::AbstractLiteral(subject @ AbstractLiteral::Matrix(_, _)) => {
158                    matrix::flatten_enumerate(subject)
159                        .find(|(i, _)| i == &indices)
160                        .map(|(_, x)| x)
161                }
162                Lit::AbstractLiteral(subject @ AbstractLiteral::Tuple(_)) => {
163                    let AbstractLiteral::Tuple(elems) = subject else {
164                        return None;
165                    };
166
167                    assert!(indices.len() == 1, "nested tuples not supported yet");
168
169                    let Lit::Int(index) = indices[0].clone() else {
170                        return None;
171                    };
172
173                    if elems.len() < index as usize || index < 1 {
174                        return None;
175                    }
176
177                    // -1 for 0-indexing vs 1-indexing
178                    let item = elems[index as usize - 1].clone();
179
180                    Some(item)
181                }
182                Lit::AbstractLiteral(subject @ AbstractLiteral::Record(_)) => {
183                    let AbstractLiteral::Record(elems) = subject else {
184                        return None;
185                    };
186
187                    assert!(indices.len() == 1, "nested record not supported yet");
188
189                    let Lit::Int(index) = indices[0].clone() else {
190                        return None;
191                    };
192
193                    if elems.len() < index as usize || index < 1 {
194                        return None;
195                    }
196
197                    // -1 for 0-indexing vs 1-indexing
198                    let item = elems[index as usize - 1].clone();
199                    Some(item.value)
200                }
201                _ => None,
202            }
203        }
204        Expr::UnsafeSlice(_, subject, indices) | Expr::SafeSlice(_, subject, indices) => {
205            let subject: Lit = subject.as_ref().clone().into_literal()?;
206            let Lit::AbstractLiteral(subject @ AbstractLiteral::Matrix(_, _)) = subject else {
207                return None;
208            };
209
210            let hole_dim = indices
211                .iter()
212                .cloned()
213                .position(|x| x.is_none())
214                .expect("slice expression should have a hole dimension");
215
216            let missing_domain = matrix::index_domains(subject.clone())[hole_dim].clone();
217
218            let indices: Vec<Option<Lit>> = indices
219                .iter()
220                .cloned()
221                .map(|x| {
222                    // the outer option represents success of this iterator, the inner the index
223                    // slice.
224                    match x {
225                        Some(x) => x.into_literal().map(Some),
226                        None => Some(None),
227                    }
228                })
229                .collect::<Option<Vec<Option<Lit>>>>()?;
230
231            let indices_in_slice: Vec<Vec<Lit>> = missing_domain
232                .values()
233                .ok()?
234                .map(|i| {
235                    let mut indices = indices.clone();
236                    indices[hole_dim] = Some(i);
237                    // These unwraps will only fail if we have multiple holes.
238                    // As this is invalid, panicking is fine.
239                    indices.into_iter().map(|x| x.unwrap()).collect_vec()
240                })
241                .collect_vec();
242
243            // Note: indices_in_slice is not necessarily sorted, so this is the best way.
244            let elems = matrix::flatten_enumerate(subject)
245                .filter(|(i, _)| indices_in_slice.contains(i))
246                .map(|(_, elem)| elem)
247                .collect();
248
249            Some(Lit::AbstractLiteral(into_matrix![elems]))
250        }
251        Expr::Abs(_, e) => un_op::<i32, i32>(|a| a.abs(), e).map(Lit::Int),
252        Expr::Eq(_, a, b) => bin_op::<i32, bool>(|a, b| a == b, a, b)
253            .or_else(|| bin_op::<bool, bool>(|a, b| a == b, a, b))
254            .map(Lit::Bool),
255        Expr::Neq(_, a, b) => bin_op::<i32, bool>(|a, b| a != b, a, b).map(Lit::Bool),
256        Expr::Lt(_, a, b) => bin_op::<i32, bool>(|a, b| a < b, a, b).map(Lit::Bool),
257        Expr::Gt(_, a, b) => bin_op::<i32, bool>(|a, b| a > b, a, b).map(Lit::Bool),
258        Expr::Leq(_, a, b) => bin_op::<i32, bool>(|a, b| a <= b, a, b).map(Lit::Bool),
259        Expr::Geq(_, a, b) => bin_op::<i32, bool>(|a, b| a >= b, a, b).map(Lit::Bool),
260        Expr::Not(_, expr) => un_op::<bool, bool>(|e| !e, expr).map(Lit::Bool),
261        Expr::And(_, e) => {
262            vec_lit_op::<bool, bool>(|e| e.iter().all(|&e| e), e.as_ref()).map(Lit::Bool)
263        }
264        Expr::Root(_, _) => None,
265        Expr::Or(_, es) => {
266            // possibly cheating; definitely should be in partial eval instead
267            for e in (**es).clone().unwrap_list()? {
268                if let Expr::Atomic(_, Atom::Literal(Lit::Bool(true))) = e {
269                    return Some(Lit::Bool(true));
270                };
271            }
272
273            vec_lit_op::<bool, bool>(|e| e.iter().any(|&e| e), es.as_ref()).map(Lit::Bool)
274        }
275        Expr::Imply(_, box1, box2) => {
276            let a: &Atom = (&**box1).try_into().ok()?;
277            let b: &Atom = (&**box2).try_into().ok()?;
278
279            let a: bool = a.try_into().ok()?;
280            let b: bool = b.try_into().ok()?;
281
282            if a {
283                // true -> b ~> b
284                Some(Lit::Bool(b))
285            } else {
286                // false -> b ~> true
287                Some(Lit::Bool(true))
288            }
289        }
290        Expr::Iff(_, box1, box2) => {
291            let a: &Atom = (&**box1).try_into().ok()?;
292            let b: &Atom = (&**box2).try_into().ok()?;
293
294            let a: bool = a.try_into().ok()?;
295            let b: bool = b.try_into().ok()?;
296
297            Some(Lit::Bool(a == b))
298        }
299        Expr::Sum(_, exprs) => vec_lit_op::<i32, i32>(|e| e.iter().sum(), exprs).map(Lit::Int),
300        Expr::Product(_, exprs) => {
301            vec_lit_op::<i32, i32>(|e| e.iter().product(), exprs).map(Lit::Int)
302        }
303        Expr::FlatIneq(_, a, b, c) => {
304            let a: i32 = a.try_into().ok()?;
305            let b: i32 = b.try_into().ok()?;
306            let c: i32 = c.try_into().ok()?;
307
308            Some(Lit::Bool(a <= b + c))
309        }
310        Expr::FlatSumGeq(_, exprs, a) => {
311            let sum = exprs.iter().try_fold(0, |acc, atom: &Atom| {
312                let n: i32 = atom.try_into().ok()?;
313                let acc = acc + n;
314                Some(acc)
315            })?;
316
317            Some(Lit::Bool(sum >= a.try_into().ok()?))
318        }
319        Expr::FlatSumLeq(_, exprs, a) => {
320            let sum = exprs.iter().try_fold(0, |acc, atom: &Atom| {
321                let n: i32 = atom.try_into().ok()?;
322                let acc = acc + n;
323                Some(acc)
324            })?;
325
326            Some(Lit::Bool(sum >= a.try_into().ok()?))
327        }
328        Expr::Min(_, e) => {
329            opt_vec_lit_op::<i32, i32>(|e| e.iter().min().copied(), e.as_ref()).map(Lit::Int)
330        }
331        Expr::Max(_, e) => {
332            opt_vec_lit_op::<i32, i32>(|e| e.iter().max().copied(), e.as_ref()).map(Lit::Int)
333        }
334        Expr::UnsafeDiv(_, a, b) | Expr::SafeDiv(_, a, b) => {
335            if unwrap_expr::<i32>(b)? == 0 {
336                return None;
337            }
338            bin_op::<i32, i32>(|a, b| ((a as f32) / (b as f32)).floor() as i32, a, b).map(Lit::Int)
339        }
340        Expr::UnsafeMod(_, a, b) | Expr::SafeMod(_, a, b) => {
341            if unwrap_expr::<i32>(b)? == 0 {
342                return None;
343            }
344            bin_op::<i32, i32>(|a, b| a - b * (a as f32 / b as f32).floor() as i32, a, b)
345                .map(Lit::Int)
346        }
347        Expr::MinionDivEqUndefZero(_, a, b, c) => {
348            // div always rounds down
349            let a: i32 = a.try_into().ok()?;
350            let b: i32 = b.try_into().ok()?;
351            let c: i32 = c.try_into().ok()?;
352
353            if b == 0 {
354                return None;
355            }
356
357            let a = a as f32;
358            let b = b as f32;
359            let div: i32 = (a / b).floor() as i32;
360            Some(Lit::Bool(div == c))
361        }
362        Expr::Bubble(_, a, b) => bin_op::<bool, bool>(|a, b| a && b, a, b).map(Lit::Bool),
363        Expr::MinionReify(_, a, b) => {
364            let result = eval_constant(a)?;
365
366            let result: bool = result.try_into().ok()?;
367            let b: bool = b.try_into().ok()?;
368
369            Some(Lit::Bool(b == result))
370        }
371        Expr::MinionReifyImply(_, a, b) => {
372            let result = eval_constant(a)?;
373
374            let result: bool = result.try_into().ok()?;
375            let b: bool = b.try_into().ok()?;
376
377            if b {
378                Some(Lit::Bool(result))
379            } else {
380                Some(Lit::Bool(true))
381            }
382        }
383        Expr::MinionModuloEqUndefZero(_, a, b, c) => {
384            // From Savile Row. Same semantics as division.
385            //
386            //   a - (b * floor(a/b))
387            //
388            // We don't use % as it has the same semantics as /. We don't use / as we want to round
389            // down instead, not towards zero.
390
391            let a: i32 = a.try_into().ok()?;
392            let b: i32 = b.try_into().ok()?;
393            let c: i32 = c.try_into().ok()?;
394
395            if b == 0 {
396                return None;
397            }
398
399            let modulo = a - b * (a as f32 / b as f32).floor() as i32;
400            Some(Lit::Bool(modulo == c))
401        }
402        Expr::MinionPow(_, a, b, c) => {
403            // only available for positive a b c
404
405            let a: i32 = a.try_into().ok()?;
406            let b: i32 = b.try_into().ok()?;
407            let c: i32 = c.try_into().ok()?;
408
409            if a <= 0 {
410                return None;
411            }
412
413            if b <= 0 {
414                return None;
415            }
416
417            if c <= 0 {
418                return None;
419            }
420
421            Some(Lit::Bool(a ^ b == c))
422        }
423        Expr::MinionWInSet(_, _, _) => None,
424        Expr::MinionWInIntervalSet(_, x, intervals) => {
425            let x_lit: &Lit = x.try_into().ok()?;
426
427            let x_lit = match x_lit.clone() {
428                Lit::Int(i) => Some(i),
429                Lit::Bool(true) => Some(1),
430                Lit::Bool(false) => Some(0),
431                _ => None,
432            }?;
433
434            let mut intervals = intervals.iter();
435            loop {
436                let Some(lower) = intervals.next() else {
437                    break;
438                };
439
440                let Some(upper) = intervals.next() else {
441                    break;
442                };
443                if &x_lit >= lower && &x_lit <= upper {
444                    return Some(Lit::Bool(true));
445                }
446            }
447
448            Some(Lit::Bool(false))
449        }
450        Expr::Flatten(_, _, _) => {
451            // TODO
452            None
453        }
454        Expr::AllDiff(_, e) => {
455            let es = (**e).clone().unwrap_list()?;
456            let mut lits: HashSet<Lit> = HashSet::new();
457            for expr in es {
458                let Expr::Atomic(_, Atom::Literal(x)) = expr else {
459                    return None;
460                };
461                match x {
462                    Lit::Int(_) | Lit::Bool(_) => {
463                        if lits.contains(&x) {
464                            return Some(Lit::Bool(false));
465                        } else {
466                            lits.insert(x.clone());
467                        }
468                    }
469                    Lit::AbstractLiteral(_) => return None, // Reject AbstractLiteral cases
470                }
471            }
472            Some(Lit::Bool(true))
473        }
474        Expr::FlatAllDiff(_, es) => {
475            let mut lits: HashSet<Lit> = HashSet::new();
476            for atom in es {
477                let Atom::Literal(x) = atom else {
478                    return None;
479                };
480
481                match x {
482                    Lit::Int(_) | Lit::Bool(_) => {
483                        if lits.contains(x) {
484                            return Some(Lit::Bool(false));
485                        } else {
486                            lits.insert(x.clone());
487                        }
488                    }
489                    Lit::AbstractLiteral(_) => return None, // Reject AbstractLiteral cases
490                }
491            }
492            Some(Lit::Bool(true))
493        }
494        Expr::FlatWatchedLiteral(_, _, _) => None,
495        Expr::AuxDeclaration(_, _, _) => None,
496        Expr::Neg(_, a) => {
497            let a: &Atom = a.try_into().ok()?;
498            let a: i32 = a.try_into().ok()?;
499            Some(Lit::Int(-a))
500        }
501        Expr::Minus(_, a, b) => {
502            let a: &Atom = a.try_into().ok()?;
503            let a: i32 = a.try_into().ok()?;
504
505            let b: &Atom = b.try_into().ok()?;
506            let b: i32 = b.try_into().ok()?;
507
508            Some(Lit::Int(a - b))
509        }
510        Expr::FlatMinusEq(_, a, b) => {
511            let a: i32 = a.try_into().ok()?;
512            let b: i32 = b.try_into().ok()?;
513            Some(Lit::Bool(a == -b))
514        }
515        Expr::FlatProductEq(_, a, b, c) => {
516            let a: i32 = a.try_into().ok()?;
517            let b: i32 = b.try_into().ok()?;
518            let c: i32 = c.try_into().ok()?;
519            Some(Lit::Bool(a * b == c))
520        }
521        Expr::FlatWeightedSumLeq(_, cs, vs, total) => {
522            let cs: Vec<i32> = cs
523                .iter()
524                .map(|x| TryInto::<i32>::try_into(x).ok())
525                .collect::<Option<Vec<i32>>>()?;
526            let vs: Vec<i32> = vs
527                .iter()
528                .map(|x| TryInto::<i32>::try_into(x).ok())
529                .collect::<Option<Vec<i32>>>()?;
530            let total: i32 = total.try_into().ok()?;
531
532            let sum: i32 = izip!(cs, vs).fold(0, |acc, (c, v)| acc + (c * v));
533
534            Some(Lit::Bool(sum <= total))
535        }
536        Expr::FlatWeightedSumGeq(_, cs, vs, total) => {
537            let cs: Vec<i32> = cs
538                .iter()
539                .map(|x| TryInto::<i32>::try_into(x).ok())
540                .collect::<Option<Vec<i32>>>()?;
541            let vs: Vec<i32> = vs
542                .iter()
543                .map(|x| TryInto::<i32>::try_into(x).ok())
544                .collect::<Option<Vec<i32>>>()?;
545            let total: i32 = total.try_into().ok()?;
546
547            let sum: i32 = izip!(cs, vs).fold(0, |acc, (c, v)| acc + (c * v));
548
549            Some(Lit::Bool(sum >= total))
550        }
551        Expr::FlatAbsEq(_, x, y) => {
552            let x: i32 = x.try_into().ok()?;
553            let y: i32 = y.try_into().ok()?;
554
555            Some(Lit::Bool(x == y.abs()))
556        }
557        Expr::UnsafePow(_, a, b) | Expr::SafePow(_, a, b) => {
558            let a: &Atom = a.try_into().ok()?;
559            let a: i32 = a.try_into().ok()?;
560
561            let b: &Atom = b.try_into().ok()?;
562            let b: i32 = b.try_into().ok()?;
563
564            if (a != 0 || b != 0) && b >= 0 {
565                Some(Lit::Int(a.pow(b as u32)))
566            } else {
567                None
568            }
569        }
570        Expr::Scope(_, _) => None,
571        Expr::Metavar(_, _) => None,
572        Expr::MinionElementOne(_, _, _, _) => None,
573        Expr::ToInt(_, expression) => {
574            let lit = (**expression).clone().into_literal()?;
575            match lit {
576                Lit::Int(_) => Some(lit),
577                Lit::Bool(true) => Some(Lit::Int(1)),
578                Lit::Bool(false) => Some(Lit::Int(0)),
579                _ => None,
580            }
581        }
582        Expr::SATInt(_, _) => None,
583        Expr::PairwiseSum(_, a, b) => {
584            match ((**a).clone().into_literal()?, (**b).clone().into_literal()?) {
585                (Lit::Int(a_int), Lit::Int(b_int)) => Some(Lit::Int(a_int + b_int)),
586                _ => None,
587            }
588        }
589        Expr::PairwiseProduct(_, a, b) => {
590            match ((**a).clone().into_literal()?, (**b).clone().into_literal()?) {
591                (Lit::Int(a_int), Lit::Int(b_int)) => Some(Lit::Int(a_int * b_int)),
592                _ => None,
593            }
594        }
595        Expr::Defined(_, _) => todo!(),
596        Expr::Range(_, _) => todo!(),
597        Expr::Image(_, _, _) => todo!(),
598        Expr::ImageSet(_, _, _) => todo!(),
599        Expr::PreImage(_, _, _) => todo!(),
600        Expr::Inverse(_, _, _) => todo!(),
601        Expr::Restrict(_, _, _) => todo!(),
602        Expr::LexLt(_, a, b) => {
603            let lt = vec_expr_pairs_op::<i32, _>(a, b, |pairs, (a_len, b_len)| {
604                pairs
605                    .iter()
606                    .find_map(|(a, b)| match a.cmp(b) {
607                        CmpOrdering::Less => Some(true),     // First difference is <
608                        CmpOrdering::Greater => Some(false), // First difference is >
609                        CmpOrdering::Equal => None,          // No difference
610                    })
611                    .unwrap_or(a_len < b_len) // [1,1] <lex [1,1,x]
612            })?;
613            Some(lt.into())
614        }
615        Expr::LexLeq(_, a, b) => {
616            let lt = vec_expr_pairs_op::<i32, _>(a, b, |pairs, (a_len, b_len)| {
617                pairs
618                    .iter()
619                    .find_map(|(a, b)| match a.cmp(b) {
620                        CmpOrdering::Less => Some(true),
621                        CmpOrdering::Greater => Some(false),
622                        CmpOrdering::Equal => None,
623                    })
624                    .unwrap_or(a_len <= b_len) // [1,1] <=lex [1,1,x]
625            })?;
626            Some(lt.into())
627        }
628        Expr::LexGt(_, a, b) => eval_constant(&Expr::LexLt(Metadata::new(), b.clone(), a.clone())),
629        Expr::LexGeq(_, a, b) => {
630            eval_constant(&Expr::LexLeq(Metadata::new(), b.clone(), a.clone()))
631        }
632        Expr::FlatLexLt(_, a, b) => {
633            let lt = atoms_pairs_op::<i32, _>(a, b, |pairs, (a_len, b_len)| {
634                pairs
635                    .iter()
636                    .find_map(|(a, b)| match a.cmp(b) {
637                        CmpOrdering::Less => Some(true),
638                        CmpOrdering::Greater => Some(false),
639                        CmpOrdering::Equal => None,
640                    })
641                    .unwrap_or(a_len < b_len)
642            })?;
643            Some(lt.into())
644        }
645        Expr::FlatLexLeq(_, a, b) => {
646            let lt = atoms_pairs_op::<i32, _>(a, b, |pairs, (a_len, b_len)| {
647                pairs
648                    .iter()
649                    .find_map(|(a, b)| match a.cmp(b) {
650                        CmpOrdering::Less => Some(true),
651                        CmpOrdering::Greater => Some(false),
652                        CmpOrdering::Equal => None,
653                    })
654                    .unwrap_or(a_len <= b_len)
655            })?;
656            Some(lt.into())
657        }
658    }
659}
660
661pub fn un_op<T, A>(f: fn(T) -> A, a: &Expr) -> Option<A>
662where
663    T: TryFrom<Lit>,
664{
665    let a = unwrap_expr::<T>(a)?;
666    Some(f(a))
667}
668
669pub fn bin_op<T, A>(f: fn(T, T) -> A, a: &Expr, b: &Expr) -> Option<A>
670where
671    T: TryFrom<Lit>,
672{
673    let a = unwrap_expr::<T>(a)?;
674    let b = unwrap_expr::<T>(b)?;
675    Some(f(a, b))
676}
677
678#[allow(dead_code)]
679pub fn tern_op<T, A>(f: fn(T, T, T) -> A, a: &Expr, b: &Expr, c: &Expr) -> Option<A>
680where
681    T: TryFrom<Lit>,
682{
683    let a = unwrap_expr::<T>(a)?;
684    let b = unwrap_expr::<T>(b)?;
685    let c = unwrap_expr::<T>(c)?;
686    Some(f(a, b, c))
687}
688
689pub fn vec_op<T, A>(f: fn(Vec<T>) -> A, a: &[Expr]) -> Option<A>
690where
691    T: TryFrom<Lit>,
692{
693    let a = a.iter().map(unwrap_expr).collect::<Option<Vec<T>>>()?;
694    Some(f(a))
695}
696
697pub fn vec_lit_op<T, A>(f: fn(Vec<T>) -> A, a: &Expr) -> Option<A>
698where
699    T: TryFrom<Lit>,
700{
701    // we don't care about preserving indices here, as we will be getting rid of the vector
702    // anyways!
703    let a = a.clone().unwrap_matrix_unchecked()?.0;
704    let a = a.iter().map(unwrap_expr).collect::<Option<Vec<T>>>()?;
705    Some(f(a))
706}
707
708type PairsCallback<T, A> = fn(Vec<(T, T)>, (usize, usize)) -> A;
709
710/// Calls the given function on each consecutive pair of elements in the list expressions.
711/// Also passes the length of the two lists.
712fn vec_expr_pairs_op<T, A>(a: &Expr, b: &Expr, f: PairsCallback<T, A>) -> Option<A>
713where
714    T: TryFrom<Lit>,
715{
716    let a_exprs = a.clone().unwrap_matrix_unchecked()?.0;
717    let b_exprs = b.clone().unwrap_matrix_unchecked()?.0;
718    let lens = (a_exprs.len(), b_exprs.len());
719
720    let lit_pairs = std::iter::zip(a_exprs, b_exprs)
721        .map(|(a, b)| Some((unwrap_expr(&a)?, unwrap_expr(&b)?)))
722        .collect::<Option<Vec<(T, T)>>>()?;
723    Some(f(lit_pairs, lens))
724}
725
726/// Same as [`vec_expr_pairs_op`], but over slices of atoms.
727fn atoms_pairs_op<T, A>(a: &[Atom], b: &[Atom], f: PairsCallback<T, A>) -> Option<A>
728where
729    T: TryFrom<Atom>,
730{
731    let lit_pairs = Iterator::zip(a.iter(), b.iter())
732        .map(|(a, b)| Some((a.clone().try_into().ok()?, b.clone().try_into().ok()?)))
733        .collect::<Option<Vec<(T, T)>>>()?;
734    Some(f(lit_pairs, (a.len(), b.len())))
735}
736
737pub fn opt_vec_op<T, A>(f: fn(Vec<T>) -> Option<A>, a: &[Expr]) -> Option<A>
738where
739    T: TryFrom<Lit>,
740{
741    let a = a.iter().map(unwrap_expr).collect::<Option<Vec<T>>>()?;
742    f(a)
743}
744
745pub fn opt_vec_lit_op<T, A>(f: fn(Vec<T>) -> Option<A>, a: &Expr) -> Option<A>
746where
747    T: TryFrom<Lit>,
748{
749    let a = a.clone().unwrap_list()?;
750    // FIXME: deal with explicit matrix domains
751    let a = a.iter().map(unwrap_expr).collect::<Option<Vec<T>>>()?;
752    f(a)
753}
754
755#[allow(dead_code)]
756pub fn flat_op<T, A>(f: fn(Vec<T>, T) -> A, a: &[Expr], b: &Expr) -> Option<A>
757where
758    T: TryFrom<Lit>,
759{
760    let a = a.iter().map(unwrap_expr).collect::<Option<Vec<T>>>()?;
761    let b = unwrap_expr::<T>(b)?;
762    Some(f(a, b))
763}
764
765pub fn unwrap_expr<T: TryFrom<Lit>>(expr: &Expr) -> Option<T> {
766    let c = eval_constant(expr)?;
767    TryInto::<T>::try_into(c).ok()
768}