Skip to main content

conjure_cp_core/ast/
eval.rs

1#![allow(dead_code)]
2use crate::ast::{AbstractLiteral, Atom, Expression as Expr, Literal as Lit, Metadata, matrix};
3use crate::into_matrix;
4use itertools::{Itertools as _, izip};
5use std::cmp::Ordering as CmpOrdering;
6use std::collections::HashSet;
7
8/// Simplify an expression to a constant if possible
9/// Returns:
10/// `None` if the expression cannot be simplified to a constant (e.g. if it contains a variable)
11/// `Some(Const)` if the expression can be simplified to a constant
12pub fn eval_constant(expr: &Expr) -> Option<Lit> {
13    match expr {
14        Expr::Supset(_, a, b) => match (a.as_ref(), b.as_ref()) {
15            (
16                Expr::Atomic(_, Atom::Literal(Lit::AbstractLiteral(AbstractLiteral::Set(a)))),
17                Expr::Atomic(_, Atom::Literal(Lit::AbstractLiteral(AbstractLiteral::Set(b)))),
18            ) => {
19                let a_set: HashSet<Lit> = a.iter().cloned().collect();
20                let b_set: HashSet<Lit> = b.iter().cloned().collect();
21
22                if a_set.difference(&b_set).count() > 0 {
23                    Some(Lit::Bool(a_set.is_superset(&b_set)))
24                } else {
25                    Some(Lit::Bool(false))
26                }
27            }
28            _ => None,
29        },
30        Expr::SupsetEq(_, a, b) => match (a.as_ref(), b.as_ref()) {
31            (
32                Expr::Atomic(_, Atom::Literal(Lit::AbstractLiteral(AbstractLiteral::Set(a)))),
33                Expr::Atomic(_, Atom::Literal(Lit::AbstractLiteral(AbstractLiteral::Set(b)))),
34            ) => Some(Lit::Bool(
35                a.iter()
36                    .cloned()
37                    .collect::<HashSet<Lit>>()
38                    .is_superset(&b.iter().cloned().collect::<HashSet<Lit>>()),
39            )),
40            _ => None,
41        },
42        Expr::Subset(_, a, b) => match (a.as_ref(), b.as_ref()) {
43            (
44                Expr::Atomic(_, Atom::Literal(Lit::AbstractLiteral(AbstractLiteral::Set(a)))),
45                Expr::Atomic(_, Atom::Literal(Lit::AbstractLiteral(AbstractLiteral::Set(b)))),
46            ) => {
47                let a_set: HashSet<Lit> = a.iter().cloned().collect();
48                let b_set: HashSet<Lit> = b.iter().cloned().collect();
49
50                if b_set.difference(&a_set).count() > 0 {
51                    Some(Lit::Bool(a_set.is_subset(&b_set)))
52                } else {
53                    Some(Lit::Bool(false))
54                }
55            }
56            _ => None,
57        },
58        Expr::SubsetEq(_, a, b) => match (a.as_ref(), b.as_ref()) {
59            (
60                Expr::Atomic(_, Atom::Literal(Lit::AbstractLiteral(AbstractLiteral::Set(a)))),
61                Expr::Atomic(_, Atom::Literal(Lit::AbstractLiteral(AbstractLiteral::Set(b)))),
62            ) => Some(Lit::Bool(
63                a.iter()
64                    .cloned()
65                    .collect::<HashSet<Lit>>()
66                    .is_subset(&b.iter().cloned().collect::<HashSet<Lit>>()),
67            )),
68            _ => None,
69        },
70        Expr::Intersect(_, a, b) => match (a.as_ref(), b.as_ref()) {
71            (
72                Expr::Atomic(_, Atom::Literal(Lit::AbstractLiteral(AbstractLiteral::Set(a)))),
73                Expr::Atomic(_, Atom::Literal(Lit::AbstractLiteral(AbstractLiteral::Set(b)))),
74            ) => {
75                let mut res: Vec<Lit> = Vec::new();
76                for lit in a.iter() {
77                    if b.contains(lit) && !res.contains(lit) {
78                        res.push(lit.clone());
79                    }
80                }
81                Some(Lit::AbstractLiteral(AbstractLiteral::Set(res)))
82            }
83            _ => None,
84        },
85        Expr::Union(_, a, b) => match (a.as_ref(), b.as_ref()) {
86            (
87                Expr::Atomic(_, Atom::Literal(Lit::AbstractLiteral(AbstractLiteral::Set(a)))),
88                Expr::Atomic(_, Atom::Literal(Lit::AbstractLiteral(AbstractLiteral::Set(b)))),
89            ) => {
90                let mut res: Vec<Lit> = Vec::new();
91                for lit in a.iter() {
92                    res.push(lit.clone());
93                }
94                for lit in b.iter() {
95                    if !res.contains(lit) {
96                        res.push(lit.clone());
97                    }
98                }
99                Some(Lit::AbstractLiteral(AbstractLiteral::Set(res)))
100            }
101            _ => None,
102        },
103        Expr::In(_, a, b) => {
104            if let (
105                Expr::Atomic(_, Atom::Literal(Lit::Int(c))),
106                Expr::Atomic(_, Atom::Literal(Lit::AbstractLiteral(AbstractLiteral::Set(d)))),
107            ) = (a.as_ref(), b.as_ref())
108            {
109                for lit in d.iter() {
110                    if let Lit::Int(x) = lit
111                        && c == x
112                    {
113                        return Some(Lit::Bool(true));
114                    }
115                }
116                Some(Lit::Bool(false))
117            } else {
118                None
119            }
120        }
121        Expr::FromSolution(_, _) => None,
122        Expr::DominanceRelation(_, _) => None,
123        Expr::InDomain(_, e, domain) => {
124            let Expr::Atomic(_, Atom::Literal(lit)) = e.as_ref() else {
125                return None;
126            };
127
128            domain.contains(lit).ok().map(Into::into)
129        }
130        Expr::Atomic(_, Atom::Literal(c)) => Some(c.clone()),
131        Expr::Atomic(_, Atom::Reference(reference)) => reference.resolve_constant(),
132        Expr::AbstractLiteral(_, a) => Some(Lit::AbstractLiteral(a.clone().into_literals()?)),
133        Expr::Comprehension(_, _) => None,
134        Expr::AbstractComprehension(_, _) => None,
135        Expr::UnsafeIndex(_, subject, indices) | Expr::SafeIndex(_, subject, indices) => {
136            let subject: Lit = eval_constant(subject.as_ref())?;
137            let indices: Vec<Lit> = indices
138                .iter()
139                .map(eval_constant)
140                .collect::<Option<Vec<Lit>>>()?;
141
142            match subject {
143                Lit::AbstractLiteral(subject @ AbstractLiteral::Matrix(_, _)) => {
144                    matrix::flatten_enumerate(subject)
145                        .find(|(i, _)| i == &indices)
146                        .map(|(_, x)| x)
147                }
148                Lit::AbstractLiteral(subject @ AbstractLiteral::Tuple(_)) => {
149                    let AbstractLiteral::Tuple(elems) = subject else {
150                        return None;
151                    };
152
153                    assert!(indices.len() == 1, "nested tuples not supported yet");
154
155                    let Lit::Int(index) = indices[0].clone() else {
156                        return None;
157                    };
158
159                    if elems.len() < index as usize || index < 1 {
160                        return None;
161                    }
162
163                    // -1 for 0-indexing vs 1-indexing
164                    let item = elems[index as usize - 1].clone();
165
166                    Some(item)
167                }
168                Lit::AbstractLiteral(subject @ AbstractLiteral::Record(_)) => {
169                    let AbstractLiteral::Record(elems) = subject else {
170                        return None;
171                    };
172
173                    assert!(indices.len() == 1, "nested record not supported yet");
174
175                    let Lit::Int(index) = indices[0].clone() else {
176                        return None;
177                    };
178
179                    if elems.len() < index as usize || index < 1 {
180                        return None;
181                    }
182
183                    // -1 for 0-indexing vs 1-indexing
184                    let item = elems[index as usize - 1].clone();
185                    Some(item.value)
186                }
187                _ => None,
188            }
189        }
190        Expr::UnsafeSlice(_, subject, indices) | Expr::SafeSlice(_, subject, indices) => {
191            let subject: Lit = eval_constant(subject.as_ref())?;
192            let Lit::AbstractLiteral(subject @ AbstractLiteral::Matrix(_, _)) = subject else {
193                return None;
194            };
195
196            let hole_dim = indices
197                .iter()
198                .cloned()
199                .position(|x| x.is_none())
200                .expect("slice expression should have a hole dimension");
201
202            let missing_domain = matrix::index_domains(subject.clone())[hole_dim].clone();
203
204            let indices: Vec<Option<Lit>> = indices
205                .iter()
206                .cloned()
207                .map(|x| {
208                    // the outer option represents success of this iterator, the inner the index
209                    // slice.
210                    match x {
211                        Some(x) => eval_constant(&x).map(Some),
212                        None => Some(None),
213                    }
214                })
215                .collect::<Option<Vec<Option<Lit>>>>()?;
216
217            let indices_in_slice: Vec<Vec<Lit>> = missing_domain
218                .values()
219                .ok()?
220                .map(|i| {
221                    let mut indices = indices.clone();
222                    indices[hole_dim] = Some(i);
223                    // These unwraps will only fail if we have multiple holes.
224                    // As this is invalid, panicking is fine.
225                    indices.into_iter().map(|x| x.unwrap()).collect_vec()
226                })
227                .collect_vec();
228
229            // Note: indices_in_slice is not necessarily sorted, so this is the best way.
230            let elems = matrix::flatten_enumerate(subject)
231                .filter(|(i, _)| indices_in_slice.contains(i))
232                .map(|(_, elem)| elem)
233                .collect();
234
235            Some(Lit::AbstractLiteral(into_matrix![elems]))
236        }
237        Expr::Abs(_, e) => un_op::<i32, i32>(|a| a.abs(), e).map(Lit::Int),
238        Expr::Eq(_, a, b) => bin_op::<i32, bool>(|a, b| a == b, a, b)
239            .or_else(|| bin_op::<bool, bool>(|a, b| a == b, a, b))
240            .map(Lit::Bool),
241        Expr::Neq(_, a, b) => bin_op::<i32, bool>(|a, b| a != b, a, b).map(Lit::Bool),
242        Expr::Lt(_, a, b) => bin_op::<i32, bool>(|a, b| a < b, a, b).map(Lit::Bool),
243        Expr::Gt(_, a, b) => bin_op::<i32, bool>(|a, b| a > b, a, b).map(Lit::Bool),
244        Expr::Leq(_, a, b) => bin_op::<i32, bool>(|a, b| a <= b, a, b).map(Lit::Bool),
245        Expr::Geq(_, a, b) => bin_op::<i32, bool>(|a, b| a >= b, a, b).map(Lit::Bool),
246        Expr::Not(_, expr) => un_op::<bool, bool>(|e| !e, expr).map(Lit::Bool),
247        Expr::And(_, e) => {
248            vec_lit_op::<bool, bool>(|e| e.iter().all(|&e| e), e.as_ref()).map(Lit::Bool)
249        }
250        Expr::Root(_, _) => None,
251        Expr::Or(_, es) => {
252            // possibly cheating; definitely should be in partial eval instead
253            for e in (**es).clone().unwrap_list()? {
254                if let Expr::Atomic(_, Atom::Literal(Lit::Bool(true))) = e {
255                    return Some(Lit::Bool(true));
256                };
257            }
258
259            vec_lit_op::<bool, bool>(|e| e.iter().any(|&e| e), es.as_ref()).map(Lit::Bool)
260        }
261        Expr::Imply(_, box1, box2) => {
262            let a: &Atom = (&**box1).try_into().ok()?;
263            let b: &Atom = (&**box2).try_into().ok()?;
264
265            let a: bool = a.try_into().ok()?;
266            let b: bool = b.try_into().ok()?;
267
268            if a {
269                // true -> b ~> b
270                Some(Lit::Bool(b))
271            } else {
272                // false -> b ~> true
273                Some(Lit::Bool(true))
274            }
275        }
276        Expr::Iff(_, box1, box2) => {
277            let a: &Atom = (&**box1).try_into().ok()?;
278            let b: &Atom = (&**box2).try_into().ok()?;
279
280            let a: bool = a.try_into().ok()?;
281            let b: bool = b.try_into().ok()?;
282
283            Some(Lit::Bool(a == b))
284        }
285        Expr::Sum(_, exprs) => vec_lit_op::<i32, i32>(|e| e.iter().sum(), exprs).map(Lit::Int),
286        Expr::Product(_, exprs) => {
287            vec_lit_op::<i32, i32>(|e| e.iter().product(), exprs).map(Lit::Int)
288        }
289        Expr::FlatIneq(_, a, b, c) => {
290            let a: i32 = a.try_into().ok()?;
291            let b: i32 = b.try_into().ok()?;
292            let c: i32 = c.try_into().ok()?;
293
294            Some(Lit::Bool(a <= b + c))
295        }
296        Expr::FlatSumGeq(_, exprs, a) => {
297            let sum = exprs.iter().try_fold(0, |acc, atom: &Atom| {
298                let n: i32 = atom.try_into().ok()?;
299                let acc = acc + n;
300                Some(acc)
301            })?;
302
303            Some(Lit::Bool(sum >= a.try_into().ok()?))
304        }
305        Expr::FlatSumLeq(_, exprs, a) => {
306            let sum = exprs.iter().try_fold(0, |acc, atom: &Atom| {
307                let n: i32 = atom.try_into().ok()?;
308                let acc = acc + n;
309                Some(acc)
310            })?;
311
312            Some(Lit::Bool(sum >= a.try_into().ok()?))
313        }
314        Expr::Min(_, e) => {
315            opt_vec_lit_op::<i32, i32>(|e| e.iter().min().copied(), e.as_ref()).map(Lit::Int)
316        }
317        Expr::Max(_, e) => {
318            opt_vec_lit_op::<i32, i32>(|e| e.iter().max().copied(), e.as_ref()).map(Lit::Int)
319        }
320        Expr::UnsafeDiv(_, a, b) | Expr::SafeDiv(_, a, b) => {
321            if unwrap_expr::<i32>(b)? == 0 {
322                return None;
323            }
324            bin_op::<i32, i32>(|a, b| ((a as f32) / (b as f32)).floor() as i32, a, b).map(Lit::Int)
325        }
326        Expr::UnsafeMod(_, a, b) | Expr::SafeMod(_, a, b) => {
327            if unwrap_expr::<i32>(b)? == 0 {
328                return None;
329            }
330            bin_op::<i32, i32>(|a, b| a - b * (a as f32 / b as f32).floor() as i32, a, b)
331                .map(Lit::Int)
332        }
333        Expr::MinionDivEqUndefZero(_, a, b, c) => {
334            // div always rounds down
335            let a: i32 = a.try_into().ok()?;
336            let b: i32 = b.try_into().ok()?;
337            let c: i32 = c.try_into().ok()?;
338
339            if b == 0 {
340                return None;
341            }
342
343            let a = a as f32;
344            let b = b as f32;
345            let div: i32 = (a / b).floor() as i32;
346            Some(Lit::Bool(div == c))
347        }
348        Expr::Bubble(_, a, b) => bin_op::<bool, bool>(|a, b| a && b, a, b).map(Lit::Bool),
349        Expr::MinionReify(_, a, b) => {
350            let result = eval_constant(a)?;
351
352            let result: bool = result.try_into().ok()?;
353            let b: bool = b.try_into().ok()?;
354
355            Some(Lit::Bool(b == result))
356        }
357        Expr::MinionReifyImply(_, a, b) => {
358            let result = eval_constant(a)?;
359
360            let result: bool = result.try_into().ok()?;
361            let b: bool = b.try_into().ok()?;
362
363            if b {
364                Some(Lit::Bool(result))
365            } else {
366                Some(Lit::Bool(true))
367            }
368        }
369        Expr::MinionModuloEqUndefZero(_, a, b, c) => {
370            // From Savile Row. Same semantics as division.
371            //
372            //   a - (b * floor(a/b))
373            //
374            // We don't use % as it has the same semantics as /. We don't use / as we want to round
375            // down instead, not towards zero.
376
377            let a: i32 = a.try_into().ok()?;
378            let b: i32 = b.try_into().ok()?;
379            let c: i32 = c.try_into().ok()?;
380
381            if b == 0 {
382                return None;
383            }
384
385            let modulo = a - b * (a as f32 / b as f32).floor() as i32;
386            Some(Lit::Bool(modulo == c))
387        }
388        Expr::MinionPow(_, a, b, c) => {
389            // only available for positive a b c
390
391            let a: i32 = a.try_into().ok()?;
392            let b: i32 = b.try_into().ok()?;
393            let c: i32 = c.try_into().ok()?;
394
395            if a <= 0 {
396                return None;
397            }
398
399            if b <= 0 {
400                return None;
401            }
402
403            if c <= 0 {
404                return None;
405            }
406
407            Some(Lit::Bool(a ^ b == c))
408        }
409        Expr::MinionWInSet(_, _, _) => None,
410        Expr::MinionWInIntervalSet(_, x, intervals) => {
411            let x_lit: &Lit = x.try_into().ok()?;
412
413            let x_lit = match x_lit.clone() {
414                Lit::Int(i) => Some(i),
415                Lit::Bool(true) => Some(1),
416                Lit::Bool(false) => Some(0),
417                _ => None,
418            }?;
419
420            let mut intervals = intervals.iter();
421            loop {
422                let Some(lower) = intervals.next() else {
423                    break;
424                };
425
426                let Some(upper) = intervals.next() else {
427                    break;
428                };
429                if &x_lit >= lower && &x_lit <= upper {
430                    return Some(Lit::Bool(true));
431                }
432            }
433
434            Some(Lit::Bool(false))
435        }
436        Expr::Flatten(_, _, _) => {
437            // TODO
438            None
439        }
440        Expr::AllDiff(_, e) => {
441            let es = (**e).clone().unwrap_list()?;
442            let mut lits: HashSet<Lit> = HashSet::new();
443            for expr in es {
444                let Expr::Atomic(_, Atom::Literal(x)) = expr else {
445                    return None;
446                };
447                match x {
448                    Lit::Int(_) | Lit::Bool(_) => {
449                        if lits.contains(&x) {
450                            return Some(Lit::Bool(false));
451                        } else {
452                            lits.insert(x.clone());
453                        }
454                    }
455                    Lit::AbstractLiteral(_) => return None, // Reject AbstractLiteral cases
456                }
457            }
458            Some(Lit::Bool(true))
459        }
460        Expr::FlatAllDiff(_, es) => {
461            let mut lits: HashSet<Lit> = HashSet::new();
462            for atom in es {
463                let Atom::Literal(x) = atom else {
464                    return None;
465                };
466
467                match x {
468                    Lit::Int(_) | Lit::Bool(_) => {
469                        if lits.contains(x) {
470                            return Some(Lit::Bool(false));
471                        } else {
472                            lits.insert(x.clone());
473                        }
474                    }
475                    Lit::AbstractLiteral(_) => return None, // Reject AbstractLiteral cases
476                }
477            }
478            Some(Lit::Bool(true))
479        }
480        Expr::FlatWatchedLiteral(_, _, _) => None,
481        Expr::AuxDeclaration(_, _, _) => None,
482        Expr::Neg(_, a) => {
483            let a: &Atom = a.try_into().ok()?;
484            let a: i32 = a.try_into().ok()?;
485            Some(Lit::Int(-a))
486        }
487        Expr::Minus(_, a, b) => {
488            let a: &Atom = a.try_into().ok()?;
489            let a: i32 = a.try_into().ok()?;
490
491            let b: &Atom = b.try_into().ok()?;
492            let b: i32 = b.try_into().ok()?;
493
494            Some(Lit::Int(a - b))
495        }
496        Expr::FlatMinusEq(_, a, b) => {
497            let a: i32 = a.try_into().ok()?;
498            let b: i32 = b.try_into().ok()?;
499            Some(Lit::Bool(a == -b))
500        }
501        Expr::FlatProductEq(_, a, b, c) => {
502            let a: i32 = a.try_into().ok()?;
503            let b: i32 = b.try_into().ok()?;
504            let c: i32 = c.try_into().ok()?;
505            Some(Lit::Bool(a * b == c))
506        }
507        Expr::FlatWeightedSumLeq(_, cs, vs, total) => {
508            let cs: Vec<i32> = cs
509                .iter()
510                .map(|x| TryInto::<i32>::try_into(x).ok())
511                .collect::<Option<Vec<i32>>>()?;
512            let vs: Vec<i32> = vs
513                .iter()
514                .map(|x| TryInto::<i32>::try_into(x).ok())
515                .collect::<Option<Vec<i32>>>()?;
516            let total: i32 = total.try_into().ok()?;
517
518            let sum: i32 = izip!(cs, vs).fold(0, |acc, (c, v)| acc + (c * v));
519
520            Some(Lit::Bool(sum <= total))
521        }
522        Expr::FlatWeightedSumGeq(_, cs, vs, total) => {
523            let cs: Vec<i32> = cs
524                .iter()
525                .map(|x| TryInto::<i32>::try_into(x).ok())
526                .collect::<Option<Vec<i32>>>()?;
527            let vs: Vec<i32> = vs
528                .iter()
529                .map(|x| TryInto::<i32>::try_into(x).ok())
530                .collect::<Option<Vec<i32>>>()?;
531            let total: i32 = total.try_into().ok()?;
532
533            let sum: i32 = izip!(cs, vs).fold(0, |acc, (c, v)| acc + (c * v));
534
535            Some(Lit::Bool(sum >= total))
536        }
537        Expr::FlatAbsEq(_, x, y) => {
538            let x: i32 = x.try_into().ok()?;
539            let y: i32 = y.try_into().ok()?;
540
541            Some(Lit::Bool(x == y.abs()))
542        }
543        Expr::UnsafePow(_, a, b) | Expr::SafePow(_, a, b) => {
544            let a: &Atom = a.try_into().ok()?;
545            let a: i32 = a.try_into().ok()?;
546
547            let b: &Atom = b.try_into().ok()?;
548            let b: i32 = b.try_into().ok()?;
549
550            if (a != 0 || b != 0) && b >= 0 {
551                Some(Lit::Int(a.pow(b as u32)))
552            } else {
553                None
554            }
555        }
556        Expr::Scope(_, _) => None,
557        Expr::Metavar(_, _) => None,
558        Expr::MinionElementOne(_, _, _, _) => None,
559        Expr::ToInt(_, expression) => {
560            let lit = eval_constant(expression.as_ref())?;
561            match lit {
562                Lit::Int(_) => Some(lit),
563                Lit::Bool(true) => Some(Lit::Int(1)),
564                Lit::Bool(false) => Some(Lit::Int(0)),
565                _ => None,
566            }
567        }
568        Expr::SATInt(..) => None,
569        Expr::PairwiseSum(_, a, b) => {
570            match (eval_constant(a.as_ref())?, eval_constant(b.as_ref())?) {
571                (Lit::Int(a_int), Lit::Int(b_int)) => Some(Lit::Int(a_int + b_int)),
572                _ => None,
573            }
574        }
575        Expr::PairwiseProduct(_, a, b) => {
576            match (eval_constant(a.as_ref())?, eval_constant(b.as_ref())?) {
577                (Lit::Int(a_int), Lit::Int(b_int)) => Some(Lit::Int(a_int * b_int)),
578                _ => None,
579            }
580        }
581        Expr::Defined(_, _) => todo!(),
582        Expr::Range(_, _) => todo!(),
583        Expr::Image(_, _, _) => todo!(),
584        Expr::ImageSet(_, _, _) => todo!(),
585        Expr::PreImage(_, _, _) => todo!(),
586        Expr::Inverse(_, _, _) => todo!(),
587        Expr::Restrict(_, _, _) => todo!(),
588        Expr::LexLt(_, a, b) => {
589            let lt = vec_expr_pairs_op::<i32, _>(a, b, |pairs, (a_len, b_len)| {
590                pairs
591                    .iter()
592                    .find_map(|(a, b)| match a.cmp(b) {
593                        CmpOrdering::Less => Some(true),     // First difference is <
594                        CmpOrdering::Greater => Some(false), // First difference is >
595                        CmpOrdering::Equal => None,          // No difference
596                    })
597                    .unwrap_or(a_len < b_len) // [1,1] <lex [1,1,x]
598            })?;
599            Some(lt.into())
600        }
601        Expr::LexLeq(_, a, b) => {
602            let lt = vec_expr_pairs_op::<i32, _>(a, b, |pairs, (a_len, b_len)| {
603                pairs
604                    .iter()
605                    .find_map(|(a, b)| match a.cmp(b) {
606                        CmpOrdering::Less => Some(true),
607                        CmpOrdering::Greater => Some(false),
608                        CmpOrdering::Equal => None,
609                    })
610                    .unwrap_or(a_len <= b_len) // [1,1] <=lex [1,1,x]
611            })?;
612            Some(lt.into())
613        }
614        Expr::LexGt(_, a, b) => eval_constant(&Expr::LexLt(Metadata::new(), b.clone(), a.clone())),
615        Expr::LexGeq(_, a, b) => {
616            eval_constant(&Expr::LexLeq(Metadata::new(), b.clone(), a.clone()))
617        }
618        Expr::FlatLexLt(_, a, b) => {
619            let lt = atoms_pairs_op::<i32, _>(a, b, |pairs, (a_len, b_len)| {
620                pairs
621                    .iter()
622                    .find_map(|(a, b)| match a.cmp(b) {
623                        CmpOrdering::Less => Some(true),
624                        CmpOrdering::Greater => Some(false),
625                        CmpOrdering::Equal => None,
626                    })
627                    .unwrap_or(a_len < b_len)
628            })?;
629            Some(lt.into())
630        }
631        Expr::FlatLexLeq(_, a, b) => {
632            let lt = atoms_pairs_op::<i32, _>(a, b, |pairs, (a_len, b_len)| {
633                pairs
634                    .iter()
635                    .find_map(|(a, b)| match a.cmp(b) {
636                        CmpOrdering::Less => Some(true),
637                        CmpOrdering::Greater => Some(false),
638                        CmpOrdering::Equal => None,
639                    })
640                    .unwrap_or(a_len <= b_len)
641            })?;
642            Some(lt.into())
643        }
644    }
645}
646
647pub fn un_op<T, A>(f: fn(T) -> A, a: &Expr) -> Option<A>
648where
649    T: TryFrom<Lit>,
650{
651    let a = unwrap_expr::<T>(a)?;
652    Some(f(a))
653}
654
655pub fn bin_op<T, A>(f: fn(T, T) -> A, a: &Expr, b: &Expr) -> Option<A>
656where
657    T: TryFrom<Lit>,
658{
659    let a = unwrap_expr::<T>(a)?;
660    let b = unwrap_expr::<T>(b)?;
661    Some(f(a, b))
662}
663
664#[allow(dead_code)]
665pub fn tern_op<T, A>(f: fn(T, T, T) -> A, a: &Expr, b: &Expr, c: &Expr) -> Option<A>
666where
667    T: TryFrom<Lit>,
668{
669    let a = unwrap_expr::<T>(a)?;
670    let b = unwrap_expr::<T>(b)?;
671    let c = unwrap_expr::<T>(c)?;
672    Some(f(a, b, c))
673}
674
675pub fn vec_op<T, A>(f: fn(Vec<T>) -> A, a: &[Expr]) -> Option<A>
676where
677    T: TryFrom<Lit>,
678{
679    let a = a.iter().map(unwrap_expr).collect::<Option<Vec<T>>>()?;
680    Some(f(a))
681}
682
683pub fn vec_lit_op<T, A>(f: fn(Vec<T>) -> A, a: &Expr) -> Option<A>
684where
685    T: TryFrom<Lit>,
686{
687    // we don't care about preserving indices here, as we will be getting rid of the vector
688    // anyways!
689    let a = a.clone().unwrap_matrix_unchecked()?.0;
690    let a = a.iter().map(unwrap_expr).collect::<Option<Vec<T>>>()?;
691    Some(f(a))
692}
693
694type PairsCallback<T, A> = fn(Vec<(T, T)>, (usize, usize)) -> A;
695
696/// Calls the given function on each consecutive pair of elements in the list expressions.
697/// Also passes the length of the two lists.
698fn vec_expr_pairs_op<T, A>(a: &Expr, b: &Expr, f: PairsCallback<T, A>) -> Option<A>
699where
700    T: TryFrom<Lit>,
701{
702    let a_exprs = a.clone().unwrap_matrix_unchecked()?.0;
703    let b_exprs = b.clone().unwrap_matrix_unchecked()?.0;
704    let lens = (a_exprs.len(), b_exprs.len());
705
706    let lit_pairs = std::iter::zip(a_exprs, b_exprs)
707        .map(|(a, b)| Some((unwrap_expr(&a)?, unwrap_expr(&b)?)))
708        .collect::<Option<Vec<(T, T)>>>()?;
709    Some(f(lit_pairs, lens))
710}
711
712/// Same as [`vec_expr_pairs_op`], but over slices of atoms.
713fn atoms_pairs_op<T, A>(a: &[Atom], b: &[Atom], f: PairsCallback<T, A>) -> Option<A>
714where
715    T: TryFrom<Atom>,
716{
717    let lit_pairs = Iterator::zip(a.iter(), b.iter())
718        .map(|(a, b)| Some((a.clone().try_into().ok()?, b.clone().try_into().ok()?)))
719        .collect::<Option<Vec<(T, T)>>>()?;
720    Some(f(lit_pairs, (a.len(), b.len())))
721}
722
723pub fn opt_vec_op<T, A>(f: fn(Vec<T>) -> Option<A>, a: &[Expr]) -> Option<A>
724where
725    T: TryFrom<Lit>,
726{
727    let a = a.iter().map(unwrap_expr).collect::<Option<Vec<T>>>()?;
728    f(a)
729}
730
731pub fn opt_vec_lit_op<T, A>(f: fn(Vec<T>) -> Option<A>, a: &Expr) -> Option<A>
732where
733    T: TryFrom<Lit>,
734{
735    let a = a.clone().unwrap_list()?;
736    // FIXME: deal with explicit matrix domains
737    let a = a.iter().map(unwrap_expr).collect::<Option<Vec<T>>>()?;
738    f(a)
739}
740
741#[allow(dead_code)]
742pub fn flat_op<T, A>(f: fn(Vec<T>, T) -> A, a: &[Expr], b: &Expr) -> Option<A>
743where
744    T: TryFrom<Lit>,
745{
746    let a = a.iter().map(unwrap_expr).collect::<Option<Vec<T>>>()?;
747    let b = unwrap_expr::<T>(b)?;
748    Some(f(a, b))
749}
750
751pub fn unwrap_expr<T: TryFrom<Lit>>(expr: &Expr) -> Option<T> {
752    let c = eval_constant(expr)?;
753    TryInto::<T>::try_into(c).ok()
754}