Skip to main content

conjure_cp_core/ast/
eval.rs

1#![allow(dead_code)]
2use crate::ast::{
3    AbstractLiteral, Atom, DeclarationKind, Expression as Expr, Literal as Lit, Metadata,
4    comprehension::{Comprehension, ComprehensionQualifier},
5    matrix,
6};
7use crate::into_matrix;
8use itertools::{Itertools as _, izip};
9use std::cmp::Ordering as CmpOrdering;
10use std::collections::HashSet;
11
12/// Simplify an expression to a constant if possible
13/// Returns:
14/// `None` if the expression cannot be simplified to a constant (e.g. if it contains a variable)
15/// `Some(Const)` if the expression can be simplified to a constant
16pub fn eval_constant(expr: &Expr) -> Option<Lit> {
17    match expr {
18        Expr::Supset(_, a, b) => match (a.as_ref(), b.as_ref()) {
19            (
20                Expr::Atomic(_, Atom::Literal(Lit::AbstractLiteral(AbstractLiteral::Set(a)))),
21                Expr::Atomic(_, Atom::Literal(Lit::AbstractLiteral(AbstractLiteral::Set(b)))),
22            ) => {
23                let a_set: HashSet<Lit> = a.iter().cloned().collect();
24                let b_set: HashSet<Lit> = b.iter().cloned().collect();
25
26                if a_set.difference(&b_set).count() > 0 {
27                    Some(Lit::Bool(a_set.is_superset(&b_set)))
28                } else {
29                    Some(Lit::Bool(false))
30                }
31            }
32            _ => None,
33        },
34        Expr::SupsetEq(_, a, b) => match (a.as_ref(), b.as_ref()) {
35            (
36                Expr::Atomic(_, Atom::Literal(Lit::AbstractLiteral(AbstractLiteral::Set(a)))),
37                Expr::Atomic(_, Atom::Literal(Lit::AbstractLiteral(AbstractLiteral::Set(b)))),
38            ) => Some(Lit::Bool(
39                a.iter()
40                    .cloned()
41                    .collect::<HashSet<Lit>>()
42                    .is_superset(&b.iter().cloned().collect::<HashSet<Lit>>()),
43            )),
44            _ => None,
45        },
46        Expr::Subset(_, a, b) => match (a.as_ref(), b.as_ref()) {
47            (
48                Expr::Atomic(_, Atom::Literal(Lit::AbstractLiteral(AbstractLiteral::Set(a)))),
49                Expr::Atomic(_, Atom::Literal(Lit::AbstractLiteral(AbstractLiteral::Set(b)))),
50            ) => {
51                let a_set: HashSet<Lit> = a.iter().cloned().collect();
52                let b_set: HashSet<Lit> = b.iter().cloned().collect();
53
54                if b_set.difference(&a_set).count() > 0 {
55                    Some(Lit::Bool(a_set.is_subset(&b_set)))
56                } else {
57                    Some(Lit::Bool(false))
58                }
59            }
60            _ => None,
61        },
62        Expr::SubsetEq(_, a, b) => match (a.as_ref(), b.as_ref()) {
63            (
64                Expr::Atomic(_, Atom::Literal(Lit::AbstractLiteral(AbstractLiteral::Set(a)))),
65                Expr::Atomic(_, Atom::Literal(Lit::AbstractLiteral(AbstractLiteral::Set(b)))),
66            ) => Some(Lit::Bool(
67                a.iter()
68                    .cloned()
69                    .collect::<HashSet<Lit>>()
70                    .is_subset(&b.iter().cloned().collect::<HashSet<Lit>>()),
71            )),
72            _ => None,
73        },
74        Expr::Intersect(_, a, b) => match (a.as_ref(), b.as_ref()) {
75            (
76                Expr::Atomic(_, Atom::Literal(Lit::AbstractLiteral(AbstractLiteral::Set(a)))),
77                Expr::Atomic(_, Atom::Literal(Lit::AbstractLiteral(AbstractLiteral::Set(b)))),
78            ) => {
79                let mut res: Vec<Lit> = Vec::new();
80                for lit in a.iter() {
81                    if b.contains(lit) && !res.contains(lit) {
82                        res.push(lit.clone());
83                    }
84                }
85                Some(Lit::AbstractLiteral(AbstractLiteral::Set(res)))
86            }
87            _ => None,
88        },
89        Expr::Union(_, a, b) => match (a.as_ref(), b.as_ref()) {
90            (
91                Expr::Atomic(_, Atom::Literal(Lit::AbstractLiteral(AbstractLiteral::Set(a)))),
92                Expr::Atomic(_, Atom::Literal(Lit::AbstractLiteral(AbstractLiteral::Set(b)))),
93            ) => {
94                let mut res: Vec<Lit> = Vec::new();
95                for lit in a.iter() {
96                    res.push(lit.clone());
97                }
98                for lit in b.iter() {
99                    if !res.contains(lit) {
100                        res.push(lit.clone());
101                    }
102                }
103                Some(Lit::AbstractLiteral(AbstractLiteral::Set(res)))
104            }
105            _ => None,
106        },
107        Expr::In(_, a, b) => {
108            if let (
109                Expr::Atomic(_, Atom::Literal(Lit::Int(c))),
110                Expr::Atomic(_, Atom::Literal(Lit::AbstractLiteral(AbstractLiteral::Set(d)))),
111            ) = (a.as_ref(), b.as_ref())
112            {
113                for lit in d.iter() {
114                    if let Lit::Int(x) = lit
115                        && c == x
116                    {
117                        return Some(Lit::Bool(true));
118                    }
119                }
120                Some(Lit::Bool(false))
121            } else {
122                None
123            }
124        }
125        Expr::FromSolution(_, _) => None,
126        Expr::DominanceRelation(_, _) => None,
127        Expr::InDomain(_, e, domain) => {
128            let Expr::Atomic(_, Atom::Literal(lit)) = e.as_ref() else {
129                return None;
130            };
131
132            domain.contains(lit).ok().map(Into::into)
133        }
134        Expr::Atomic(_, Atom::Literal(c)) => Some(c.clone()),
135        Expr::Atomic(_, Atom::Reference(reference)) => reference.resolve_constant(),
136        Expr::AbstractLiteral(_, a) => Some(Lit::AbstractLiteral(a.clone().into_literals()?)),
137        Expr::Comprehension(_, comprehension) => {
138            eval_constant_comprehension(comprehension.as_ref())
139        }
140        Expr::AbstractComprehension(_, _) => None,
141        Expr::UnsafeIndex(_, subject, indices) | Expr::SafeIndex(_, subject, indices) => {
142            let subject: Lit = eval_constant(subject.as_ref())?;
143            let indices: Vec<Lit> = indices
144                .iter()
145                .map(eval_constant)
146                .collect::<Option<Vec<Lit>>>()?;
147
148            match subject {
149                Lit::AbstractLiteral(subject @ AbstractLiteral::Matrix(_, _)) => {
150                    matrix::flatten_enumerate(subject)
151                        .find(|(i, _)| i == &indices)
152                        .map(|(_, x)| x)
153                }
154                Lit::AbstractLiteral(subject @ AbstractLiteral::Tuple(_)) => {
155                    let AbstractLiteral::Tuple(elems) = subject else {
156                        return None;
157                    };
158
159                    assert!(indices.len() == 1, "nested tuples not supported yet");
160
161                    let Lit::Int(index) = indices[0].clone() else {
162                        return None;
163                    };
164
165                    if elems.len() < index as usize || index < 1 {
166                        return None;
167                    }
168
169                    // -1 for 0-indexing vs 1-indexing
170                    let item = elems[index as usize - 1].clone();
171
172                    Some(item)
173                }
174                Lit::AbstractLiteral(subject @ AbstractLiteral::Record(_)) => {
175                    let AbstractLiteral::Record(elems) = subject else {
176                        return None;
177                    };
178
179                    assert!(indices.len() == 1, "nested record not supported yet");
180
181                    let Lit::Int(index) = indices[0].clone() else {
182                        return None;
183                    };
184
185                    if elems.len() < index as usize || index < 1 {
186                        return None;
187                    }
188
189                    // -1 for 0-indexing vs 1-indexing
190                    let item = elems[index as usize - 1].clone();
191                    Some(item.value)
192                }
193                _ => None,
194            }
195        }
196        Expr::UnsafeSlice(_, subject, indices) | Expr::SafeSlice(_, subject, indices) => {
197            let subject: Lit = eval_constant(subject.as_ref())?;
198            let Lit::AbstractLiteral(subject @ AbstractLiteral::Matrix(_, _)) = subject else {
199                return None;
200            };
201
202            let hole_dim = indices
203                .iter()
204                .cloned()
205                .position(|x| x.is_none())
206                .expect("slice expression should have a hole dimension");
207
208            let missing_domain = matrix::index_domains(subject.clone())[hole_dim].clone();
209
210            let indices: Vec<Option<Lit>> = indices
211                .iter()
212                .cloned()
213                .map(|x| {
214                    // the outer option represents success of this iterator, the inner the index
215                    // slice.
216                    match x {
217                        Some(x) => eval_constant(&x).map(Some),
218                        None => Some(None),
219                    }
220                })
221                .collect::<Option<Vec<Option<Lit>>>>()?;
222
223            let indices_in_slice: Vec<Vec<Lit>> = missing_domain
224                .values()
225                .ok()?
226                .map(|i| {
227                    let mut indices = indices.clone();
228                    indices[hole_dim] = Some(i);
229                    // These unwraps will only fail if we have multiple holes.
230                    // As this is invalid, panicking is fine.
231                    indices.into_iter().map(|x| x.unwrap()).collect_vec()
232                })
233                .collect_vec();
234
235            // Note: indices_in_slice is not necessarily sorted, so this is the best way.
236            let elems = matrix::flatten_enumerate(subject)
237                .filter(|(i, _)| indices_in_slice.contains(i))
238                .map(|(_, elem)| elem)
239                .collect();
240
241            Some(Lit::AbstractLiteral(into_matrix![elems]))
242        }
243        Expr::Abs(_, e) => un_op::<i32, i32>(|a| a.abs(), e).map(Lit::Int),
244        Expr::Eq(_, a, b) => bin_op::<i32, bool>(|a, b| a == b, a, b)
245            .or_else(|| bin_op::<bool, bool>(|a, b| a == b, a, b))
246            .map(Lit::Bool),
247        Expr::Neq(_, a, b) => bin_op::<i32, bool>(|a, b| a != b, a, b).map(Lit::Bool),
248        Expr::Lt(_, a, b) => bin_op::<i32, bool>(|a, b| a < b, a, b).map(Lit::Bool),
249        Expr::Gt(_, a, b) => bin_op::<i32, bool>(|a, b| a > b, a, b).map(Lit::Bool),
250        Expr::Leq(_, a, b) => bin_op::<i32, bool>(|a, b| a <= b, a, b).map(Lit::Bool),
251        Expr::Geq(_, a, b) => bin_op::<i32, bool>(|a, b| a >= b, a, b).map(Lit::Bool),
252        Expr::Not(_, expr) => un_op::<bool, bool>(|e| !e, expr).map(Lit::Bool),
253        Expr::And(_, e) => {
254            vec_lit_op::<bool, bool>(|e| e.iter().all(|&e| e), e.as_ref()).map(Lit::Bool)
255        }
256        Expr::Table(_, _, _) => None,
257        Expr::NegativeTable(_, _, _) => None,
258        Expr::Root(_, _) => None,
259        Expr::Or(_, es) => {
260            // possibly cheating; definitely should be in partial eval instead
261            for e in (**es).clone().unwrap_list()? {
262                if let Expr::Atomic(_, Atom::Literal(Lit::Bool(true))) = e {
263                    return Some(Lit::Bool(true));
264                };
265            }
266
267            vec_lit_op::<bool, bool>(|e| e.iter().any(|&e| e), es.as_ref()).map(Lit::Bool)
268        }
269        Expr::Imply(_, box1, box2) => {
270            let a: &Atom = (&**box1).try_into().ok()?;
271            let b: &Atom = (&**box2).try_into().ok()?;
272
273            let a: bool = a.try_into().ok()?;
274            let b: bool = b.try_into().ok()?;
275
276            if a {
277                // true -> b ~> b
278                Some(Lit::Bool(b))
279            } else {
280                // false -> b ~> true
281                Some(Lit::Bool(true))
282            }
283        }
284        Expr::Iff(_, box1, box2) => {
285            let a: &Atom = (&**box1).try_into().ok()?;
286            let b: &Atom = (&**box2).try_into().ok()?;
287
288            let a: bool = a.try_into().ok()?;
289            let b: bool = b.try_into().ok()?;
290
291            Some(Lit::Bool(a == b))
292        }
293        Expr::Sum(_, exprs) => vec_lit_op::<i32, i32>(|e| e.iter().sum(), exprs).map(Lit::Int),
294        Expr::Product(_, exprs) => {
295            vec_lit_op::<i32, i32>(|e| e.iter().product(), exprs).map(Lit::Int)
296        }
297        Expr::FlatIneq(_, a, b, c) => {
298            let a: i32 = a.try_into().ok()?;
299            let b: i32 = b.try_into().ok()?;
300            let c: i32 = c.try_into().ok()?;
301
302            Some(Lit::Bool(a <= b + c))
303        }
304        Expr::FlatSumGeq(_, exprs, a) => {
305            let sum = exprs.iter().try_fold(0, |acc, atom: &Atom| {
306                let n: i32 = atom.try_into().ok()?;
307                let acc = acc + n;
308                Some(acc)
309            })?;
310
311            Some(Lit::Bool(sum >= a.try_into().ok()?))
312        }
313        Expr::FlatSumLeq(_, exprs, a) => {
314            let sum = exprs.iter().try_fold(0, |acc, atom: &Atom| {
315                let n: i32 = atom.try_into().ok()?;
316                let acc = acc + n;
317                Some(acc)
318            })?;
319
320            Some(Lit::Bool(sum >= a.try_into().ok()?))
321        }
322        Expr::Min(_, e) => {
323            opt_vec_lit_op::<i32, i32>(|e| e.iter().min().copied(), e.as_ref()).map(Lit::Int)
324        }
325        Expr::Max(_, e) => {
326            opt_vec_lit_op::<i32, i32>(|e| e.iter().max().copied(), e.as_ref()).map(Lit::Int)
327        }
328        Expr::UnsafeDiv(_, a, b) | Expr::SafeDiv(_, a, b) => {
329            if unwrap_expr::<i32>(b)? == 0 {
330                return None;
331            }
332            bin_op::<i32, i32>(|a, b| ((a as f32) / (b as f32)).floor() as i32, a, b).map(Lit::Int)
333        }
334        Expr::UnsafeMod(_, a, b) | Expr::SafeMod(_, a, b) => {
335            if unwrap_expr::<i32>(b)? == 0 {
336                return None;
337            }
338            bin_op::<i32, i32>(|a, b| a - b * (a as f32 / b as f32).floor() as i32, a, b)
339                .map(Lit::Int)
340        }
341        Expr::Substring(_, s, t) => match (s.as_ref(), t.as_ref()) {
342            (
343                Expr::Atomic(_, Atom::Literal(Lit::AbstractLiteral(AbstractLiteral::Sequence(s)))),
344                Expr::Atomic(_, Atom::Literal(Lit::AbstractLiteral(AbstractLiteral::Sequence(t)))),
345            ) => {
346                if s.len() > t.len() {
347                    return Some(Lit::Bool(false));
348                }
349
350                let found = t.windows(s.len()).any(|window| window == s.as_slice());
351                Some(Lit::Bool(found))
352            }
353            _ => None,
354        },
355        Expr::Subsequence(_, s, t) => match (s.as_ref(), t.as_ref()) {
356            (
357                Expr::Atomic(_, Atom::Literal(Lit::AbstractLiteral(AbstractLiteral::Sequence(s)))),
358                Expr::Atomic(_, Atom::Literal(Lit::AbstractLiteral(AbstractLiteral::Sequence(t)))),
359            ) => {
360                let mut i = 0;
361                let mut j = 0;
362
363                while i < s.len() && j < t.len() {
364                    if s[i] == t[j] {
365                        i += 1;
366                    }
367                    j += 1;
368                }
369
370                Some(Lit::Bool(i == s.len()))
371            }
372            _ => None,
373        },
374        Expr::MinionDivEqUndefZero(_, a, b, c) => {
375            // div always rounds down
376            let a: i32 = a.try_into().ok()?;
377            let b: i32 = b.try_into().ok()?;
378            let c: i32 = c.try_into().ok()?;
379
380            if b == 0 {
381                return None;
382            }
383
384            let a = a as f32;
385            let b = b as f32;
386            let div: i32 = (a / b).floor() as i32;
387            Some(Lit::Bool(div == c))
388        }
389        Expr::Bubble(_, a, b) => bin_op::<bool, bool>(|a, b| a && b, a, b).map(Lit::Bool),
390        Expr::MinionReify(_, a, b) => {
391            let result = eval_constant(a)?;
392
393            let result: bool = result.try_into().ok()?;
394            let b: bool = b.try_into().ok()?;
395
396            Some(Lit::Bool(b == result))
397        }
398        Expr::MinionReifyImply(_, a, b) => {
399            let result = eval_constant(a)?;
400
401            let result: bool = result.try_into().ok()?;
402            let b: bool = b.try_into().ok()?;
403
404            if b {
405                Some(Lit::Bool(result))
406            } else {
407                Some(Lit::Bool(true))
408            }
409        }
410        Expr::MinionModuloEqUndefZero(_, a, b, c) => {
411            // From Savile Row. Same semantics as division.
412            //
413            //   a - (b * floor(a/b))
414            //
415            // We don't use % as it has the same semantics as /. We don't use / as we want to round
416            // down instead, not towards zero.
417
418            let a: i32 = a.try_into().ok()?;
419            let b: i32 = b.try_into().ok()?;
420            let c: i32 = c.try_into().ok()?;
421
422            if b == 0 {
423                return None;
424            }
425
426            let modulo = a - b * (a as f32 / b as f32).floor() as i32;
427            Some(Lit::Bool(modulo == c))
428        }
429        Expr::MinionPow(_, a, b, c) => {
430            // only available for positive a b c
431
432            let a: i32 = a.try_into().ok()?;
433            let b: i32 = b.try_into().ok()?;
434            let c: i32 = c.try_into().ok()?;
435
436            if a <= 0 {
437                return None;
438            }
439
440            if b <= 0 {
441                return None;
442            }
443
444            if c <= 0 {
445                return None;
446            }
447
448            Some(Lit::Bool(a ^ b == c))
449        }
450        Expr::MinionWInSet(_, _, _) => None,
451        Expr::MinionWInIntervalSet(_, x, intervals) => {
452            let x_lit: &Lit = x.try_into().ok()?;
453
454            let x_lit = match x_lit.clone() {
455                Lit::Int(i) => Some(i),
456                Lit::Bool(true) => Some(1),
457                Lit::Bool(false) => Some(0),
458                _ => None,
459            }?;
460
461            let mut intervals = intervals.iter();
462            while let Some(lower) = intervals.next() {
463                let Some(upper) = intervals.next() else {
464                    break;
465                };
466                if &x_lit >= lower && &x_lit <= upper {
467                    return Some(Lit::Bool(true));
468                }
469            }
470
471            Some(Lit::Bool(false))
472        }
473        Expr::Flatten(_, _, _) => {
474            // TODO
475            None
476        }
477        Expr::AllDiff(_, e) => {
478            let es = (**e).clone().unwrap_list()?;
479            let mut lits: HashSet<Lit> = HashSet::new();
480            for expr in es {
481                let Expr::Atomic(_, Atom::Literal(x)) = expr else {
482                    return None;
483                };
484                match x {
485                    Lit::Int(_) | Lit::Bool(_) => {
486                        if lits.contains(&x) {
487                            return Some(Lit::Bool(false));
488                        } else {
489                            lits.insert(x.clone());
490                        }
491                    }
492                    Lit::AbstractLiteral(_) => return None, // Reject AbstractLiteral cases
493                }
494            }
495            Some(Lit::Bool(true))
496        }
497        Expr::FlatAllDiff(_, es) => {
498            let mut lits: HashSet<Lit> = HashSet::new();
499            for atom in es {
500                let Atom::Literal(x) = atom else {
501                    return None;
502                };
503
504                match x {
505                    Lit::Int(_) | Lit::Bool(_) => {
506                        if lits.contains(x) {
507                            return Some(Lit::Bool(false));
508                        } else {
509                            lits.insert(x.clone());
510                        }
511                    }
512                    Lit::AbstractLiteral(_) => return None, // Reject AbstractLiteral cases
513                }
514            }
515            Some(Lit::Bool(true))
516        }
517        Expr::FlatWatchedLiteral(_, _, _) => None,
518        Expr::AuxDeclaration(_, _, _) => None,
519        Expr::Neg(_, a) => match eval_constant(a.as_ref())? {
520            Lit::Int(a) => Some(Lit::Int(-a)),
521            _ => None,
522        },
523        Expr::Factorial(_, _) => None,
524        Expr::Minus(_, a, b) => bin_op::<i32, i32>(|a, b| a - b, a, b).map(Lit::Int),
525        Expr::FlatMinusEq(_, a, b) => {
526            let a: i32 = a.try_into().ok()?;
527            let b: i32 = b.try_into().ok()?;
528            Some(Lit::Bool(a == -b))
529        }
530        Expr::FlatProductEq(_, a, b, c) => {
531            let a: i32 = a.try_into().ok()?;
532            let b: i32 = b.try_into().ok()?;
533            let c: i32 = c.try_into().ok()?;
534            Some(Lit::Bool(a * b == c))
535        }
536        Expr::FlatWeightedSumLeq(_, cs, vs, total) => {
537            let cs: Vec<i32> = cs
538                .iter()
539                .map(|x| TryInto::<i32>::try_into(x).ok())
540                .collect::<Option<Vec<i32>>>()?;
541            let vs: Vec<i32> = vs
542                .iter()
543                .map(|x| TryInto::<i32>::try_into(x).ok())
544                .collect::<Option<Vec<i32>>>()?;
545            let total: i32 = total.try_into().ok()?;
546
547            let sum: i32 = izip!(cs, vs).fold(0, |acc, (c, v)| acc + (c * v));
548
549            Some(Lit::Bool(sum <= total))
550        }
551        Expr::FlatWeightedSumGeq(_, cs, vs, total) => {
552            let cs: Vec<i32> = cs
553                .iter()
554                .map(|x| TryInto::<i32>::try_into(x).ok())
555                .collect::<Option<Vec<i32>>>()?;
556            let vs: Vec<i32> = vs
557                .iter()
558                .map(|x| TryInto::<i32>::try_into(x).ok())
559                .collect::<Option<Vec<i32>>>()?;
560            let total: i32 = total.try_into().ok()?;
561
562            let sum: i32 = izip!(cs, vs).fold(0, |acc, (c, v)| acc + (c * v));
563
564            Some(Lit::Bool(sum >= total))
565        }
566        Expr::FlatAbsEq(_, x, y) => {
567            let x: i32 = x.try_into().ok()?;
568            let y: i32 = y.try_into().ok()?;
569
570            Some(Lit::Bool(x == y.abs()))
571        }
572        Expr::UnsafePow(_, a, b) | Expr::SafePow(_, a, b) => {
573            let a: &Atom = a.try_into().ok()?;
574            let a: i32 = a.try_into().ok()?;
575
576            let b: &Atom = b.try_into().ok()?;
577            let b: i32 = b.try_into().ok()?;
578
579            if (a != 0 || b != 0) && b >= 0 {
580                Some(Lit::Int(a.pow(b as u32)))
581            } else {
582                None
583            }
584        }
585        Expr::Metavar(_, _) => None,
586        Expr::MinionElementOne(_, _, _, _) => None,
587        Expr::ToInt(_, expression) => {
588            let lit = eval_constant(expression.as_ref())?;
589            match lit {
590                Lit::Int(_) => Some(lit),
591                Lit::Bool(true) => Some(Lit::Int(1)),
592                Lit::Bool(false) => Some(Lit::Int(0)),
593                _ => None,
594            }
595        }
596        Expr::SATInt(_, _, _, _) => {
597            // TODO: If this SATInt is composed of literals, we should evaluate it back to an
598            // integer literal.
599            //
600            // This is important because `is_all_constant` currently returns true for SATInts
601            // containing no references. If we don't evaluate them here, bubble rules will skip
602            // them (thinking they'll be constant-folded later), but they'll actually reach
603            // the solver adaptors as un-encoded unsafe operations, causing panics.
604            None
605        }
606        Expr::PairwiseSum(_, a, b) => {
607            match (eval_constant(a.as_ref())?, eval_constant(b.as_ref())?) {
608                (Lit::Int(a_int), Lit::Int(b_int)) => Some(Lit::Int(a_int + b_int)),
609                _ => None,
610            }
611        }
612        Expr::PairwiseProduct(_, a, b) => {
613            match (eval_constant(a.as_ref())?, eval_constant(b.as_ref())?) {
614                (Lit::Int(a_int), Lit::Int(b_int)) => Some(Lit::Int(a_int * b_int)),
615                _ => None,
616            }
617        }
618        Expr::Defined(_, _) => todo!(),
619        Expr::Range(_, _) => todo!(),
620        Expr::Image(_, _, _) => todo!(),
621        Expr::ImageSet(_, _, _) => todo!(),
622        Expr::PreImage(_, _, _) => todo!(),
623        Expr::Inverse(_, _, _) => todo!(),
624        Expr::Restrict(_, _, _) => todo!(),
625        Expr::Active(_, _, _) => todo!(),
626        Expr::ToSet(_, _) => todo!(),
627        Expr::ToMSet(_, _) => todo!(),
628        Expr::ToRelation(_, _) => todo!(),
629        Expr::RelationProj(_, _, _) => todo!(),
630        Expr::Apart(_, _, _) => todo!(),
631        Expr::Together(_, _, _) => todo!(),
632        Expr::Participants(_, _) => todo!(),
633        Expr::Party(_, _, _) => todo!(),
634        Expr::Parts(_, _) => todo!(),
635        Expr::Card(_, _) => todo!(),
636        Expr::LexLt(_, a, b) => {
637            let lt = vec_expr_pairs_op::<i32, _>(a, b, |pairs, (a_len, b_len)| {
638                pairs
639                    .iter()
640                    .find_map(|(a, b)| match a.cmp(b) {
641                        CmpOrdering::Less => Some(true),     // First difference is <
642                        CmpOrdering::Greater => Some(false), // First difference is >
643                        CmpOrdering::Equal => None,          // No difference
644                    })
645                    .unwrap_or(a_len < b_len) // [1,1] <lex [1,1,x]
646            })?;
647            Some(lt.into())
648        }
649        Expr::LexLeq(_, a, b) => {
650            let lt = vec_expr_pairs_op::<i32, _>(a, b, |pairs, (a_len, b_len)| {
651                pairs
652                    .iter()
653                    .find_map(|(a, b)| match a.cmp(b) {
654                        CmpOrdering::Less => Some(true),
655                        CmpOrdering::Greater => Some(false),
656                        CmpOrdering::Equal => None,
657                    })
658                    .unwrap_or(a_len <= b_len) // [1,1] <=lex [1,1,x]
659            })?;
660            Some(lt.into())
661        }
662        Expr::LexGt(_, a, b) => eval_constant(&Expr::LexLt(Metadata::new(), b.clone(), a.clone())),
663        Expr::LexGeq(_, a, b) => {
664            eval_constant(&Expr::LexLeq(Metadata::new(), b.clone(), a.clone()))
665        }
666        Expr::FlatLexLt(_, a, b) => {
667            let lt = atoms_pairs_op::<i32, _>(a, b, |pairs, (a_len, b_len)| {
668                pairs
669                    .iter()
670                    .find_map(|(a, b)| match a.cmp(b) {
671                        CmpOrdering::Less => Some(true),
672                        CmpOrdering::Greater => Some(false),
673                        CmpOrdering::Equal => None,
674                    })
675                    .unwrap_or(a_len < b_len)
676            })?;
677            Some(lt.into())
678        }
679        Expr::FlatLexLeq(_, a, b) => {
680            let lt = atoms_pairs_op::<i32, _>(a, b, |pairs, (a_len, b_len)| {
681                pairs
682                    .iter()
683                    .find_map(|(a, b)| match a.cmp(b) {
684                        CmpOrdering::Less => Some(true),
685                        CmpOrdering::Greater => Some(false),
686                        CmpOrdering::Equal => None,
687                    })
688                    .unwrap_or(a_len <= b_len)
689            })?;
690            Some(lt.into())
691        }
692    }
693}
694
695pub fn un_op<T, A>(f: fn(T) -> A, a: &Expr) -> Option<A>
696where
697    T: TryFrom<Lit>,
698{
699    let a = unwrap_expr::<T>(a)?;
700    Some(f(a))
701}
702
703pub fn bin_op<T, A>(f: fn(T, T) -> A, a: &Expr, b: &Expr) -> Option<A>
704where
705    T: TryFrom<Lit>,
706{
707    let a = unwrap_expr::<T>(a)?;
708    let b = unwrap_expr::<T>(b)?;
709    Some(f(a, b))
710}
711
712#[allow(dead_code)]
713pub fn tern_op<T, A>(f: fn(T, T, T) -> A, a: &Expr, b: &Expr, c: &Expr) -> Option<A>
714where
715    T: TryFrom<Lit>,
716{
717    let a = unwrap_expr::<T>(a)?;
718    let b = unwrap_expr::<T>(b)?;
719    let c = unwrap_expr::<T>(c)?;
720    Some(f(a, b, c))
721}
722
723pub fn vec_op<T, A>(f: fn(Vec<T>) -> A, a: &[Expr]) -> Option<A>
724where
725    T: TryFrom<Lit>,
726{
727    let a = a.iter().map(unwrap_expr).collect::<Option<Vec<T>>>()?;
728    Some(f(a))
729}
730
731pub fn vec_lit_op<T, A>(f: fn(Vec<T>) -> A, a: &Expr) -> Option<A>
732where
733    T: TryFrom<Lit>,
734{
735    Some(f(eval_list_items(a)?))
736}
737
738type PairsCallback<T, A> = fn(Vec<(T, T)>, (usize, usize)) -> A;
739
740/// Calls the given function on each consecutive pair of elements in the list expressions.
741/// Also passes the length of the two lists.
742fn vec_expr_pairs_op<T, A>(a: &Expr, b: &Expr, f: PairsCallback<T, A>) -> Option<A>
743where
744    T: TryFrom<Lit>,
745{
746    let a_exprs = a.clone().unwrap_matrix_unchecked()?.0;
747    let b_exprs = b.clone().unwrap_matrix_unchecked()?.0;
748    let lens = (a_exprs.len(), b_exprs.len());
749
750    let lit_pairs = std::iter::zip(a_exprs, b_exprs)
751        .map(|(a, b)| Some((unwrap_expr(&a)?, unwrap_expr(&b)?)))
752        .collect::<Option<Vec<(T, T)>>>()?;
753    Some(f(lit_pairs, lens))
754}
755
756/// Same as [`vec_expr_pairs_op`], but over slices of atoms.
757fn atoms_pairs_op<T, A>(a: &[Atom], b: &[Atom], f: PairsCallback<T, A>) -> Option<A>
758where
759    T: TryFrom<Atom>,
760{
761    let lit_pairs = Iterator::zip(a.iter(), b.iter())
762        .map(|(a, b)| Some((a.clone().try_into().ok()?, b.clone().try_into().ok()?)))
763        .collect::<Option<Vec<(T, T)>>>()?;
764    Some(f(lit_pairs, (a.len(), b.len())))
765}
766
767pub fn opt_vec_op<T, A>(f: fn(Vec<T>) -> Option<A>, a: &[Expr]) -> Option<A>
768where
769    T: TryFrom<Lit>,
770{
771    let a = a.iter().map(unwrap_expr).collect::<Option<Vec<T>>>()?;
772    f(a)
773}
774
775pub fn opt_vec_lit_op<T, A>(f: fn(Vec<T>) -> Option<A>, a: &Expr) -> Option<A>
776where
777    T: TryFrom<Lit>,
778{
779    f(eval_list_items(a)?)
780}
781
782#[allow(dead_code)]
783pub fn flat_op<T, A>(f: fn(Vec<T>, T) -> A, a: &[Expr], b: &Expr) -> Option<A>
784where
785    T: TryFrom<Lit>,
786{
787    let a = a.iter().map(unwrap_expr).collect::<Option<Vec<T>>>()?;
788    let b = unwrap_expr::<T>(b)?;
789    Some(f(a, b))
790}
791
792pub fn unwrap_expr<T: TryFrom<Lit>>(expr: &Expr) -> Option<T> {
793    let c = eval_constant(expr)?;
794    TryInto::<T>::try_into(c).ok()
795}
796
797fn eval_list_items<T>(expr: &Expr) -> Option<Vec<T>>
798where
799    T: TryFrom<Lit>,
800{
801    if let Some(items) = expr
802        .clone()
803        .unwrap_matrix_unchecked()
804        .map(|(items, _)| items)
805    {
806        return items.iter().map(unwrap_expr).collect();
807    }
808
809    let Lit::AbstractLiteral(list) = eval_constant(expr)? else {
810        return None;
811    };
812
813    let items = list.unwrap_list()?;
814    items
815        .iter()
816        .cloned()
817        .map(TryInto::try_into)
818        .collect::<Result<Vec<_>, _>>()
819        .ok()
820}
821
822fn eval_constant_comprehension(comprehension: &Comprehension) -> Option<Lit> {
823    let mut values = Vec::new();
824    eval_comprehension_qualifiers(comprehension, 0, &mut values)?;
825    Some(Lit::AbstractLiteral(
826        AbstractLiteral::matrix_implied_indices(values),
827    ))
828}
829
830fn eval_comprehension_qualifiers(
831    comprehension: &Comprehension,
832    qualifier_index: usize,
833    values: &mut Vec<Lit>,
834) -> Option<()> {
835    if qualifier_index == comprehension.qualifiers.len() {
836        values.push(eval_constant(&comprehension.return_expression)?);
837        return Some(());
838    }
839
840    match &comprehension.qualifiers[qualifier_index] {
841        ComprehensionQualifier::Generator { ptr } => {
842            let domain = ptr.domain()?;
843            let generator_values = domain.resolve()?.values().ok()?.collect_vec();
844
845            for value in generator_values {
846                with_temporary_quantified_binding(ptr, &value, || {
847                    eval_comprehension_qualifiers(comprehension, qualifier_index + 1, values)
848                })?;
849            }
850        }
851        ComprehensionQualifier::ExpressionGenerator { ptr } => {
852            // clone immediately so the read lock guard is dropped
853            let expr = ptr.as_quantified_expr()?.clone();
854            let generator_values = generator_values_from_expr(&expr)?;
855
856            for value in generator_values {
857                with_temporary_quantified_binding(ptr, &value, || {
858                    eval_comprehension_qualifiers(comprehension, qualifier_index + 1, values)
859                })?;
860            }
861        }
862        ComprehensionQualifier::Condition(condition) => match eval_constant(condition)? {
863            Lit::Bool(true) => {
864                eval_comprehension_qualifiers(comprehension, qualifier_index + 1, values)?
865            }
866            Lit::Bool(false) => {}
867            _ => return None,
868        },
869    }
870
871    Some(())
872}
873
874fn generator_values_from_expr(expr: &Expr) -> Option<Vec<Lit>> {
875    match eval_constant(expr)? {
876        Lit::AbstractLiteral(AbstractLiteral::Set(values))
877        | Lit::AbstractLiteral(AbstractLiteral::MSet(values))
878        | Lit::AbstractLiteral(AbstractLiteral::Tuple(values)) => Some(values),
879        Lit::AbstractLiteral(list) => list.unwrap_list().cloned(),
880        _ => None,
881    }
882}
883
884fn with_temporary_quantified_binding<T>(
885    quantified: &crate::ast::DeclarationPtr,
886    value: &Lit,
887    f: impl FnOnce() -> Option<T>,
888) -> Option<T> {
889    let mut targets = vec![quantified.clone()];
890    if let DeclarationKind::Quantified(inner) = &*quantified.kind()
891        && let Some(generator) = inner.generator()
892    {
893        targets.push(generator.clone());
894    }
895
896    let mut originals = Vec::with_capacity(targets.len());
897    for mut target in targets {
898        let old_kind = target.replace_kind(DeclarationKind::TemporaryValueLetting(Expr::Atomic(
899            Metadata::new(),
900            Atom::Literal(value.clone()),
901        )));
902        originals.push((target, old_kind));
903    }
904
905    let result = f();
906
907    for (mut target, old_kind) in originals.into_iter().rev() {
908        let _ = target.replace_kind(old_kind);
909    }
910
911    result
912}