1
#![allow(dead_code)]
2
use crate::ast::{
3
    AbstractLiteral, Atom, DeclarationKind, Expression as Expr, Literal as Lit, Metadata,
4
    comprehension::{Comprehension, ComprehensionQualifier},
5
    matrix,
6
};
7
use crate::into_matrix;
8
use itertools::{Itertools as _, izip};
9
use std::cmp::Ordering as CmpOrdering;
10
use std::collections::HashSet;
11

            
12
/// Simplify an expression to a constant if possible
13
/// Returns:
14
/// `None` if the expression cannot be simplified to a constant (e.g. if it contains a variable)
15
/// `Some(Const)` if the expression can be simplified to a constant
16
56346138
pub fn eval_constant(expr: &Expr) -> Option<Lit> {
17
26598254
    match expr {
18
640
        Expr::Supset(_, a, b) => match (a.as_ref(), b.as_ref()) {
19
            (
20
320
                Expr::Atomic(_, Atom::Literal(Lit::AbstractLiteral(AbstractLiteral::Set(a)))),
21
320
                Expr::Atomic(_, Atom::Literal(Lit::AbstractLiteral(AbstractLiteral::Set(b)))),
22
            ) => {
23
320
                let a_set: HashSet<Lit> = a.iter().cloned().collect();
24
320
                let b_set: HashSet<Lit> = b.iter().cloned().collect();
25

            
26
320
                if a_set.difference(&b_set).count() > 0 {
27
240
                    Some(Lit::Bool(a_set.is_superset(&b_set)))
28
                } else {
29
80
                    Some(Lit::Bool(false))
30
                }
31
            }
32
320
            _ => None,
33
        },
34
640
        Expr::SupsetEq(_, a, b) => match (a.as_ref(), b.as_ref()) {
35
            (
36
320
                Expr::Atomic(_, Atom::Literal(Lit::AbstractLiteral(AbstractLiteral::Set(a)))),
37
320
                Expr::Atomic(_, Atom::Literal(Lit::AbstractLiteral(AbstractLiteral::Set(b)))),
38
320
            ) => Some(Lit::Bool(
39
320
                a.iter()
40
320
                    .cloned()
41
320
                    .collect::<HashSet<Lit>>()
42
320
                    .is_superset(&b.iter().cloned().collect::<HashSet<Lit>>()),
43
320
            )),
44
320
            _ => None,
45
        },
46
1200
        Expr::Subset(_, a, b) => match (a.as_ref(), b.as_ref()) {
47
            (
48
400
                Expr::Atomic(_, Atom::Literal(Lit::AbstractLiteral(AbstractLiteral::Set(a)))),
49
400
                Expr::Atomic(_, Atom::Literal(Lit::AbstractLiteral(AbstractLiteral::Set(b)))),
50
            ) => {
51
400
                let a_set: HashSet<Lit> = a.iter().cloned().collect();
52
400
                let b_set: HashSet<Lit> = b.iter().cloned().collect();
53

            
54
400
                if b_set.difference(&a_set).count() > 0 {
55
320
                    Some(Lit::Bool(a_set.is_subset(&b_set)))
56
                } else {
57
80
                    Some(Lit::Bool(false))
58
                }
59
            }
60
800
            _ => None,
61
        },
62
674
        Expr::SubsetEq(_, a, b) => match (a.as_ref(), b.as_ref()) {
63
            (
64
320
                Expr::Atomic(_, Atom::Literal(Lit::AbstractLiteral(AbstractLiteral::Set(a)))),
65
320
                Expr::Atomic(_, Atom::Literal(Lit::AbstractLiteral(AbstractLiteral::Set(b)))),
66
320
            ) => Some(Lit::Bool(
67
320
                a.iter()
68
320
                    .cloned()
69
320
                    .collect::<HashSet<Lit>>()
70
320
                    .is_subset(&b.iter().cloned().collect::<HashSet<Lit>>()),
71
320
            )),
72
354
            _ => None,
73
        },
74
720
        Expr::Intersect(_, a, b) => match (a.as_ref(), b.as_ref()) {
75
            (
76
160
                Expr::Atomic(_, Atom::Literal(Lit::AbstractLiteral(AbstractLiteral::Set(a)))),
77
160
                Expr::Atomic(_, Atom::Literal(Lit::AbstractLiteral(AbstractLiteral::Set(b)))),
78
            ) => {
79
160
                let mut res: Vec<Lit> = Vec::new();
80
400
                for lit in a.iter() {
81
400
                    if b.contains(lit) && !res.contains(lit) {
82
320
                        res.push(lit.clone());
83
320
                    }
84
                }
85
160
                Some(Lit::AbstractLiteral(AbstractLiteral::Set(res)))
86
            }
87
560
            _ => None,
88
        },
89
396
        Expr::Union(_, a, b) => match (a.as_ref(), b.as_ref()) {
90
            (
91
160
                Expr::Atomic(_, Atom::Literal(Lit::AbstractLiteral(AbstractLiteral::Set(a)))),
92
160
                Expr::Atomic(_, Atom::Literal(Lit::AbstractLiteral(AbstractLiteral::Set(b)))),
93
            ) => {
94
160
                let mut res: Vec<Lit> = Vec::new();
95
480
                for lit in a.iter() {
96
480
                    res.push(lit.clone());
97
480
                }
98
480
                for lit in b.iter() {
99
480
                    if !res.contains(lit) {
100
400
                        res.push(lit.clone());
101
400
                    }
102
                }
103
160
                Some(Lit::AbstractLiteral(AbstractLiteral::Set(res)))
104
            }
105
236
            _ => None,
106
        },
107
5184
        Expr::In(_, a, b) => {
108
            if let (
109
80
                Expr::Atomic(_, Atom::Literal(Lit::Int(c))),
110
80
                Expr::Atomic(_, Atom::Literal(Lit::AbstractLiteral(AbstractLiteral::Set(d)))),
111
5184
            ) = (a.as_ref(), b.as_ref())
112
            {
113
240
                for lit in d.iter() {
114
240
                    if let Lit::Int(x) = lit
115
240
                        && c == x
116
                    {
117
80
                        return Some(Lit::Bool(true));
118
160
                    }
119
                }
120
                Some(Lit::Bool(false))
121
            } else {
122
5104
                None
123
            }
124
        }
125
        Expr::FromSolution(_, _) => None,
126
        Expr::DominanceRelation(_, _) => None,
127
22428
        Expr::InDomain(_, e, domain) => {
128
22428
            let Expr::Atomic(_, Atom::Literal(lit)) = e.as_ref() else {
129
22284
                return None;
130
            };
131

            
132
144
            domain.contains(lit).ok().map(Into::into)
133
        }
134
10790704
        Expr::Atomic(_, Atom::Literal(c)) => Some(c.clone()),
135
15807550
        Expr::Atomic(_, Atom::Reference(reference)) => reference.resolve_constant(),
136
4754564
        Expr::AbstractLiteral(_, a) => Some(Lit::AbstractLiteral(a.clone().into_literals()?)),
137
413056
        Expr::Comprehension(_, comprehension) => {
138
413056
            eval_constant_comprehension(comprehension.as_ref())
139
        }
140
        Expr::AbstractComprehension(_, _) => None,
141
4354972
        Expr::UnsafeIndex(_, subject, indices) | Expr::SafeIndex(_, subject, indices) => {
142
6200250
            let subject: Lit = eval_constant(subject.as_ref())?;
143
89540
            let indices: Vec<Lit> = indices
144
89540
                .iter()
145
89540
                .map(eval_constant)
146
89540
                .collect::<Option<Vec<Lit>>>()?;
147

            
148
17100
            match subject {
149
16540
                Lit::AbstractLiteral(subject @ AbstractLiteral::Matrix(_, _)) => {
150
16540
                    matrix::flatten_enumerate(subject)
151
43854
                        .find(|(i, _)| i == &indices)
152
16540
                        .map(|(_, x)| x)
153
                }
154
400
                Lit::AbstractLiteral(subject @ AbstractLiteral::Tuple(_)) => {
155
400
                    let AbstractLiteral::Tuple(elems) = subject else {
156
                        return None;
157
                    };
158

            
159
400
                    assert!(indices.len() == 1, "nested tuples not supported yet");
160

            
161
400
                    let Lit::Int(index) = indices[0].clone() else {
162
                        return None;
163
                    };
164

            
165
400
                    if elems.len() < index as usize || index < 1 {
166
                        return None;
167
400
                    }
168

            
169
                    // -1 for 0-indexing vs 1-indexing
170
400
                    let item = elems[index as usize - 1].clone();
171

            
172
400
                    Some(item)
173
                }
174
160
                Lit::AbstractLiteral(subject @ AbstractLiteral::Record(_)) => {
175
160
                    let AbstractLiteral::Record(elems) = subject else {
176
                        return None;
177
                    };
178

            
179
160
                    assert!(indices.len() == 1, "nested record not supported yet");
180

            
181
160
                    let Lit::Int(index) = indices[0].clone() else {
182
                        return None;
183
                    };
184

            
185
160
                    if elems.len() < index as usize || index < 1 {
186
                        return None;
187
160
                    }
188

            
189
                    // -1 for 0-indexing vs 1-indexing
190
160
                    let item = elems[index as usize - 1].clone();
191
160
                    Some(item.value)
192
                }
193
                _ => None,
194
            }
195
        }
196
54080
        Expr::UnsafeSlice(_, subject, indices) | Expr::SafeSlice(_, subject, indices) => {
197
102960
            let subject: Lit = eval_constant(subject.as_ref())?;
198
80
            let Lit::AbstractLiteral(subject @ AbstractLiteral::Matrix(_, _)) = subject else {
199
                return None;
200
            };
201

            
202
80
            let hole_dim = indices
203
80
                .iter()
204
80
                .cloned()
205
160
                .position(|x| x.is_none())
206
80
                .expect("slice expression should have a hole dimension");
207

            
208
80
            let missing_domain = matrix::index_domains(subject.clone())[hole_dim].clone();
209

            
210
80
            let indices: Vec<Option<Lit>> = indices
211
80
                .iter()
212
80
                .cloned()
213
160
                .map(|x| {
214
                    // the outer option represents success of this iterator, the inner the index
215
                    // slice.
216
160
                    match x {
217
80
                        Some(x) => eval_constant(&x).map(Some),
218
80
                        None => Some(None),
219
                    }
220
160
                })
221
80
                .collect::<Option<Vec<Option<Lit>>>>()?;
222

            
223
80
            let indices_in_slice: Vec<Vec<Lit>> = missing_domain
224
80
                .values()
225
80
                .ok()?
226
240
                .map(|i| {
227
240
                    let mut indices = indices.clone();
228
240
                    indices[hole_dim] = Some(i);
229
                    // These unwraps will only fail if we have multiple holes.
230
                    // As this is invalid, panicking is fine.
231
480
                    indices.into_iter().map(|x| x.unwrap()).collect_vec()
232
240
                })
233
80
                .collect_vec();
234

            
235
            // Note: indices_in_slice is not necessarily sorted, so this is the best way.
236
80
            let elems = matrix::flatten_enumerate(subject)
237
720
                .filter(|(i, _)| indices_in_slice.contains(i))
238
80
                .map(|(_, elem)| elem)
239
80
                .collect();
240

            
241
80
            Some(Lit::AbstractLiteral(into_matrix![elems]))
242
        }
243
25920
        Expr::Abs(_, e) => un_op::<i32, i32>(|a| a.abs(), e).map(Lit::Int),
244
1535058
        Expr::Eq(_, a, b) => bin_op::<i32, bool>(|a, b| a == b, a, b)
245
1535058
            .or_else(|| bin_op::<bool, bool>(|a, b| a == b, a, b))
246
1535058
            .map(Lit::Bool),
247
872696
        Expr::Neq(_, a, b) => bin_op::<i32, bool>(|a, b| a != b, a, b).map(Lit::Bool),
248
275142
        Expr::Lt(_, a, b) => bin_op::<i32, bool>(|a, b| a < b, a, b).map(Lit::Bool),
249
15808
        Expr::Gt(_, a, b) => bin_op::<i32, bool>(|a, b| a > b, a, b).map(Lit::Bool),
250
1083166
        Expr::Leq(_, a, b) => bin_op::<i32, bool>(|a, b| a <= b, a, b).map(Lit::Bool),
251
360966
        Expr::Geq(_, a, b) => bin_op::<i32, bool>(|a, b| a >= b, a, b).map(Lit::Bool),
252
368178
        Expr::Not(_, expr) => un_op::<bool, bool>(|e| !e, expr).map(Lit::Bool),
253
1441514
        Expr::And(_, e) => {
254
1441514
            vec_lit_op::<bool, bool>(|e| e.iter().all(|&e| e), e.as_ref()).map(Lit::Bool)
255
        }
256
1440
        Expr::Table(_, _, _) => None,
257
240
        Expr::NegativeTable(_, _, _) => None,
258
504704
        Expr::Root(_, _) => None,
259
949692
        Expr::Or(_, es) => {
260
            // possibly cheating; definitely should be in partial eval instead
261
1524712
            for e in (**es).clone().unwrap_list()? {
262
4518
                if let Expr::Atomic(_, Atom::Literal(Lit::Bool(true))) = e {
263
2832
                    return Some(Lit::Bool(true));
264
1521880
                };
265
            }
266

            
267
734164
            vec_lit_op::<bool, bool>(|e| e.iter().any(|&e| e), es.as_ref()).map(Lit::Bool)
268
        }
269
470424
        Expr::Imply(_, box1, box2) => {
270
470424
            let a: &Atom = (&**box1).try_into().ok()?;
271
305960
            let b: &Atom = (&**box2).try_into().ok()?;
272

            
273
134628
            let a: bool = a.try_into().ok()?;
274
123400
            let b: bool = b.try_into().ok()?;
275

            
276
123400
            if a {
277
                // true -> b ~> b
278
60500
                Some(Lit::Bool(b))
279
            } else {
280
                // false -> b ~> true
281
62900
                Some(Lit::Bool(true))
282
            }
283
        }
284
15452
        Expr::Iff(_, box1, box2) => {
285
15452
            let a: &Atom = (&**box1).try_into().ok()?;
286
4248
            let b: &Atom = (&**box2).try_into().ok()?;
287

            
288
328
            let a: bool = a.try_into().ok()?;
289
88
            let b: bool = b.try_into().ok()?;
290

            
291
8
            Some(Lit::Bool(a == b))
292
        }
293
2676430
        Expr::Sum(_, exprs) => vec_lit_op::<i32, i32>(|e| e.iter().sum(), exprs).map(Lit::Int),
294
958564
        Expr::Product(_, exprs) => {
295
958564
            vec_lit_op::<i32, i32>(|e| e.iter().product(), exprs).map(Lit::Int)
296
        }
297
341128
        Expr::FlatIneq(_, a, b, c) => {
298
341128
            let a: i32 = a.try_into().ok()?;
299
216134
            let b: i32 = b.try_into().ok()?;
300
162800
            let c: i32 = c.try_into().ok()?;
301

            
302
162800
            Some(Lit::Bool(a <= b + c))
303
        }
304
409114
        Expr::FlatSumGeq(_, exprs, a) => {
305
768884
            let sum = exprs.iter().try_fold(0, |acc, atom: &Atom| {
306
768884
                let n: i32 = atom.try_into().ok()?;
307
359770
                let acc = acc + n;
308
359770
                Some(acc)
309
768884
            })?;
310

            
311
            Some(Lit::Bool(sum >= a.try_into().ok()?))
312
        }
313
457670
        Expr::FlatSumLeq(_, exprs, a) => {
314
829552
            let sum = exprs.iter().try_fold(0, |acc, atom: &Atom| {
315
829552
                let n: i32 = atom.try_into().ok()?;
316
372042
                let acc = acc + n;
317
372042
                Some(acc)
318
829552
            })?;
319

            
320
160
            Some(Lit::Bool(sum >= a.try_into().ok()?))
321
        }
322
66120
        Expr::Min(_, e) => {
323
66120
            opt_vec_lit_op::<i32, i32>(|e| e.iter().min().copied(), e.as_ref()).map(Lit::Int)
324
        }
325
68026
        Expr::Max(_, e) => {
326
68026
            opt_vec_lit_op::<i32, i32>(|e| e.iter().max().copied(), e.as_ref()).map(Lit::Int)
327
        }
328
150120
        Expr::UnsafeDiv(_, a, b) | Expr::SafeDiv(_, a, b) => {
329
251520
            if unwrap_expr::<i32>(b)? == 0 {
330
                return None;
331
21680
            }
332
21680
            bin_op::<i32, i32>(|a, b| ((a as f32) / (b as f32)).floor() as i32, a, b).map(Lit::Int)
333
        }
334
40800
        Expr::UnsafeMod(_, a, b) | Expr::SafeMod(_, a, b) => {
335
60080
            if unwrap_expr::<i32>(b)? == 0 {
336
                return None;
337
4480
            }
338
4480
            bin_op::<i32, i32>(|a, b| a - b * (a as f32 / b as f32).floor() as i32, a, b)
339
4480
                .map(Lit::Int)
340
        }
341
        Expr::Substring(_, s, t) => match (s.as_ref(), t.as_ref()) {
342
            (
343
                Expr::Atomic(_, Atom::Literal(Lit::AbstractLiteral(AbstractLiteral::Sequence(s)))),
344
                Expr::Atomic(_, Atom::Literal(Lit::AbstractLiteral(AbstractLiteral::Sequence(t)))),
345
            ) => {
346
                if s.len() > t.len() {
347
                    return Some(Lit::Bool(false));
348
                }
349

            
350
                let found = t.windows(s.len()).any(|window| window == s.as_slice());
351
                Some(Lit::Bool(found))
352
            }
353
            _ => None,
354
        },
355
        Expr::Subsequence(_, s, t) => match (s.as_ref(), t.as_ref()) {
356
            (
357
                Expr::Atomic(_, Atom::Literal(Lit::AbstractLiteral(AbstractLiteral::Sequence(s)))),
358
                Expr::Atomic(_, Atom::Literal(Lit::AbstractLiteral(AbstractLiteral::Sequence(t)))),
359
            ) => {
360
                let mut i = 0;
361
                let mut j = 0;
362

            
363
                while i < s.len() && j < t.len() {
364
                    if s[i] == t[j] {
365
                        i += 1;
366
                    }
367
                    j += 1;
368
                }
369

            
370
                Some(Lit::Bool(i == s.len()))
371
            }
372
            _ => None,
373
        },
374
12040
        Expr::MinionDivEqUndefZero(_, a, b, c) => {
375
            // div always rounds down
376
12040
            let a: i32 = a.try_into().ok()?;
377
240
            let b: i32 = b.try_into().ok()?;
378
            let c: i32 = c.try_into().ok()?;
379

            
380
            if b == 0 {
381
                return None;
382
            }
383

            
384
            let a = a as f32;
385
            let b = b as f32;
386
            let div: i32 = (a / b).floor() as i32;
387
            Some(Lit::Bool(div == c))
388
        }
389
53756
        Expr::Bubble(_, a, b) => bin_op::<bool, bool>(|a, b| a && b, a, b).map(Lit::Bool),
390
157680
        Expr::MinionReify(_, a, b) => {
391
157680
            let result = eval_constant(a)?;
392

            
393
15400
            let result: bool = result.try_into().ok()?;
394
15400
            let b: bool = b.try_into().ok()?;
395

            
396
            Some(Lit::Bool(b == result))
397
        }
398
73956
        Expr::MinionReifyImply(_, a, b) => {
399
73956
            let result = eval_constant(a)?;
400

            
401
            let result: bool = result.try_into().ok()?;
402
            let b: bool = b.try_into().ok()?;
403

            
404
            if b {
405
                Some(Lit::Bool(result))
406
            } else {
407
                Some(Lit::Bool(true))
408
            }
409
        }
410
2960
        Expr::MinionModuloEqUndefZero(_, a, b, c) => {
411
            // From Savile Row. Same semantics as division.
412
            //
413
            //   a - (b * floor(a/b))
414
            //
415
            // We don't use % as it has the same semantics as /. We don't use / as we want to round
416
            // down instead, not towards zero.
417

            
418
2960
            let a: i32 = a.try_into().ok()?;
419
240
            let b: i32 = b.try_into().ok()?;
420
            let c: i32 = c.try_into().ok()?;
421

            
422
            if b == 0 {
423
                return None;
424
            }
425

            
426
            let modulo = a - b * (a as f32 / b as f32).floor() as i32;
427
            Some(Lit::Bool(modulo == c))
428
        }
429
4950
        Expr::MinionPow(_, a, b, c) => {
430
            // only available for positive a b c
431

            
432
4950
            let a: i32 = a.try_into().ok()?;
433
            let b: i32 = b.try_into().ok()?;
434
            let c: i32 = c.try_into().ok()?;
435

            
436
            if a <= 0 {
437
                return None;
438
            }
439

            
440
            if b <= 0 {
441
                return None;
442
            }
443

            
444
            if c <= 0 {
445
                return None;
446
            }
447

            
448
            Some(Lit::Bool(a ^ b == c))
449
        }
450
640
        Expr::MinionWInSet(_, _, _) => None,
451
2040
        Expr::MinionWInIntervalSet(_, x, intervals) => {
452
2040
            let x_lit: &Lit = x.try_into().ok()?;
453

            
454
            let x_lit = match x_lit.clone() {
455
                Lit::Int(i) => Some(i),
456
                Lit::Bool(true) => Some(1),
457
                Lit::Bool(false) => Some(0),
458
                _ => None,
459
            }?;
460

            
461
            let mut intervals = intervals.iter();
462
            while let Some(lower) = intervals.next() {
463
                let Some(upper) = intervals.next() else {
464
                    break;
465
                };
466
                if &x_lit >= lower && &x_lit <= upper {
467
                    return Some(Lit::Bool(true));
468
                }
469
            }
470

            
471
            Some(Lit::Bool(false))
472
        }
473
        Expr::Flatten(_, _, _) => {
474
            // TODO
475
22196
            None
476
        }
477
106766
        Expr::AllDiff(_, e) => {
478
106766
            let es = (**e).clone().unwrap_list()?;
479
12294
            let mut lits: HashSet<Lit> = HashSet::new();
480
13414
            for expr in es {
481
6014
                let Expr::Atomic(_, Atom::Literal(x)) = expr else {
482
11974
                    return None;
483
                };
484
1440
                match x {
485
                    Lit::Int(_) | Lit::Bool(_) => {
486
1440
                        if lits.contains(&x) {
487
                            return Some(Lit::Bool(false));
488
1440
                        } else {
489
1440
                            lits.insert(x.clone());
490
1440
                        }
491
                    }
492
                    Lit::AbstractLiteral(_) => return None, // Reject AbstractLiteral cases
493
                }
494
            }
495
320
            Some(Lit::Bool(true))
496
        }
497
61554
        Expr::FlatAllDiff(_, es) => {
498
61554
            let mut lits: HashSet<Lit> = HashSet::new();
499
61554
            for atom in es {
500
61554
                let Atom::Literal(x) = atom else {
501
61554
                    return None;
502
                };
503

            
504
                match x {
505
                    Lit::Int(_) | Lit::Bool(_) => {
506
                        if lits.contains(x) {
507
                            return Some(Lit::Bool(false));
508
                        } else {
509
                            lits.insert(x.clone());
510
                        }
511
                    }
512
                    Lit::AbstractLiteral(_) => return None, // Reject AbstractLiteral cases
513
                }
514
            }
515
            Some(Lit::Bool(true))
516
        }
517
50216
        Expr::FlatWatchedLiteral(_, _, _) => None,
518
198446
        Expr::AuxDeclaration(_, _, _) => None,
519
122792
        Expr::Neg(_, a) => match eval_constant(a.as_ref())? {
520
17878
            Lit::Int(a) => Some(Lit::Int(-a)),
521
            _ => None,
522
        },
523
        Expr::Factorial(_, _) => None,
524
684846
        Expr::Minus(_, a, b) => bin_op::<i32, i32>(|a, b| a - b, a, b).map(Lit::Int),
525
1680
        Expr::FlatMinusEq(_, a, b) => {
526
1680
            let a: i32 = a.try_into().ok()?;
527
            let b: i32 = b.try_into().ok()?;
528
            Some(Lit::Bool(a == -b))
529
        }
530
6560
        Expr::FlatProductEq(_, a, b, c) => {
531
6560
            let a: i32 = a.try_into().ok()?;
532
            let b: i32 = b.try_into().ok()?;
533
            let c: i32 = c.try_into().ok()?;
534
            Some(Lit::Bool(a * b == c))
535
        }
536
155800
        Expr::FlatWeightedSumLeq(_, cs, vs, total) => {
537
155800
            let cs: Vec<i32> = cs
538
155800
                .iter()
539
321160
                .map(|x| TryInto::<i32>::try_into(x).ok())
540
155800
                .collect::<Option<Vec<i32>>>()?;
541
155800
            let vs: Vec<i32> = vs
542
155800
                .iter()
543
259560
                .map(|x| TryInto::<i32>::try_into(x).ok())
544
155800
                .collect::<Option<Vec<i32>>>()?;
545
2640
            let total: i32 = total.try_into().ok()?;
546

            
547
            let sum: i32 = izip!(cs, vs).fold(0, |acc, (c, v)| acc + (c * v));
548

            
549
            Some(Lit::Bool(sum <= total))
550
        }
551
152280
        Expr::FlatWeightedSumGeq(_, cs, vs, total) => {
552
152280
            let cs: Vec<i32> = cs
553
152280
                .iter()
554
310160
                .map(|x| TryInto::<i32>::try_into(x).ok())
555
152280
                .collect::<Option<Vec<i32>>>()?;
556
152280
            let vs: Vec<i32> = vs
557
152280
                .iter()
558
253560
                .map(|x| TryInto::<i32>::try_into(x).ok())
559
152280
                .collect::<Option<Vec<i32>>>()?;
560
2120
            let total: i32 = total.try_into().ok()?;
561

            
562
            let sum: i32 = izip!(cs, vs).fold(0, |acc, (c, v)| acc + (c * v));
563

            
564
            Some(Lit::Bool(sum >= total))
565
        }
566
3760
        Expr::FlatAbsEq(_, x, y) => {
567
3760
            let x: i32 = x.try_into().ok()?;
568
160
            let y: i32 = y.try_into().ok()?;
569

            
570
            Some(Lit::Bool(x == y.abs()))
571
        }
572
287068
        Expr::UnsafePow(_, a, b) | Expr::SafePow(_, a, b) => {
573
554700
            let a: &Atom = a.try_into().ok()?;
574
550300
            let a: i32 = a.try_into().ok()?;
575

            
576
481626
            let b: &Atom = b.try_into().ok()?;
577
481626
            let b: i32 = b.try_into().ok()?;
578

            
579
481626
            if (a != 0 || b != 0) && b >= 0 {
580
481626
                Some(Lit::Int(a.pow(b as u32)))
581
            } else {
582
                None
583
            }
584
        }
585
        Expr::Metavar(_, _) => None,
586
215804
        Expr::MinionElementOne(_, _, _, _) => None,
587
50436
        Expr::ToInt(_, expression) => {
588
50436
            let lit = eval_constant(expression.as_ref())?;
589
1612
            match lit {
590
                Lit::Int(_) => Some(lit),
591
1610
                Lit::Bool(true) => Some(Lit::Int(1)),
592
2
                Lit::Bool(false) => Some(Lit::Int(0)),
593
                _ => None,
594
            }
595
        }
596
        Expr::SATInt(_, _, _, _) => {
597
            // TODO: If this SATInt is composed of literals, we should evaluate it back to an
598
            // integer literal.
599
            //
600
            // This is important because `is_all_constant` currently returns true for SATInts
601
            // containing no references. If we don't evaluate them here, bubble rules will skip
602
            // them (thinking they'll be constant-folded later), but they'll actually reach
603
            // the solver adaptors as un-encoded unsafe operations, causing panics.
604
2254340
            None
605
        }
606
        Expr::PairwiseSum(_, a, b) => {
607
            match (eval_constant(a.as_ref())?, eval_constant(b.as_ref())?) {
608
                (Lit::Int(a_int), Lit::Int(b_int)) => Some(Lit::Int(a_int + b_int)),
609
                _ => None,
610
            }
611
        }
612
        Expr::PairwiseProduct(_, a, b) => {
613
            match (eval_constant(a.as_ref())?, eval_constant(b.as_ref())?) {
614
                (Lit::Int(a_int), Lit::Int(b_int)) => Some(Lit::Int(a_int * b_int)),
615
                _ => None,
616
            }
617
        }
618
        Expr::Defined(_, _) => todo!(),
619
        Expr::Range(_, _) => todo!(),
620
        Expr::Image(_, _, _) => todo!(),
621
        Expr::ImageSet(_, _, _) => todo!(),
622
        Expr::PreImage(_, _, _) => todo!(),
623
        Expr::Inverse(_, _, _) => todo!(),
624
        Expr::Restrict(_, _, _) => todo!(),
625
        Expr::Active(_, _, _) => todo!(),
626
        Expr::ToSet(_, _) => todo!(),
627
        Expr::ToMSet(_, _) => todo!(),
628
        Expr::ToRelation(_, _) => todo!(),
629
        Expr::RelationProj(_, _, _) => todo!(),
630
        Expr::Apart(_, _, _) => todo!(),
631
        Expr::Together(_, _, _) => todo!(),
632
        Expr::Participants(_, _) => todo!(),
633
        Expr::Party(_, _, _) => todo!(),
634
        Expr::Parts(_, _) => todo!(),
635
        Expr::Card(_, _) => todo!(),
636
2002
        Expr::LexLt(_, a, b) => {
637
2002
            let lt = vec_expr_pairs_op::<i32, _>(a, b, |pairs, (a_len, b_len)| {
638
320
                pairs
639
320
                    .iter()
640
640
                    .find_map(|(a, b)| match a.cmp(b) {
641
160
                        CmpOrdering::Less => Some(true),     // First difference is <
642
                        CmpOrdering::Greater => Some(false), // First difference is >
643
480
                        CmpOrdering::Equal => None,          // No difference
644
640
                    })
645
320
                    .unwrap_or(a_len < b_len) // [1,1] <lex [1,1,x]
646
1682
            })?;
647
320
            Some(lt.into())
648
        }
649
82760
        Expr::LexLeq(_, a, b) => {
650
82760
            let lt = vec_expr_pairs_op::<i32, _>(a, b, |pairs, (a_len, b_len)| {
651
160
                pairs
652
160
                    .iter()
653
320
                    .find_map(|(a, b)| match a.cmp(b) {
654
160
                        CmpOrdering::Less => Some(true),
655
                        CmpOrdering::Greater => Some(false),
656
160
                        CmpOrdering::Equal => None,
657
320
                    })
658
160
                    .unwrap_or(a_len <= b_len) // [1,1] <=lex [1,1,x]
659
82600
            })?;
660
160
            Some(lt.into())
661
        }
662
120
        Expr::LexGt(_, a, b) => eval_constant(&Expr::LexLt(Metadata::new(), b.clone(), a.clone())),
663
240
        Expr::LexGeq(_, a, b) => {
664
240
            eval_constant(&Expr::LexLeq(Metadata::new(), b.clone(), a.clone()))
665
        }
666
320
        Expr::FlatLexLt(_, a, b) => {
667
320
            let lt = atoms_pairs_op::<i32, _>(a, b, |pairs, (a_len, b_len)| {
668
                pairs
669
                    .iter()
670
                    .find_map(|(a, b)| match a.cmp(b) {
671
                        CmpOrdering::Less => Some(true),
672
                        CmpOrdering::Greater => Some(false),
673
                        CmpOrdering::Equal => None,
674
                    })
675
                    .unwrap_or(a_len < b_len)
676
320
            })?;
677
            Some(lt.into())
678
        }
679
480
        Expr::FlatLexLeq(_, a, b) => {
680
480
            let lt = atoms_pairs_op::<i32, _>(a, b, |pairs, (a_len, b_len)| {
681
                pairs
682
                    .iter()
683
                    .find_map(|(a, b)| match a.cmp(b) {
684
                        CmpOrdering::Less => Some(true),
685
                        CmpOrdering::Greater => Some(false),
686
                        CmpOrdering::Equal => None,
687
                    })
688
                    .unwrap_or(a_len <= b_len)
689
480
            })?;
690
            Some(lt.into())
691
        }
692
    }
693
56346138
}
694

            
695
394098
pub fn un_op<T, A>(f: fn(T) -> A, a: &Expr) -> Option<A>
696
394098
where
697
394098
    T: TryFrom<Lit>,
698
{
699
394098
    let a = unwrap_expr::<T>(a)?;
700
124028
    Some(f(a))
701
394098
}
702

            
703
6274372
pub fn bin_op<T, A>(f: fn(T, T) -> A, a: &Expr, b: &Expr) -> Option<A>
704
6274372
where
705
6274372
    T: TryFrom<Lit>,
706
{
707
6274372
    let a = unwrap_expr::<T>(a)?;
708
733824
    let b = unwrap_expr::<T>(b)?;
709
608196
    Some(f(a, b))
710
6274372
}
711

            
712
#[allow(dead_code)]
713
pub fn tern_op<T, A>(f: fn(T, T, T) -> A, a: &Expr, b: &Expr, c: &Expr) -> Option<A>
714
where
715
    T: TryFrom<Lit>,
716
{
717
    let a = unwrap_expr::<T>(a)?;
718
    let b = unwrap_expr::<T>(b)?;
719
    let c = unwrap_expr::<T>(c)?;
720
    Some(f(a, b, c))
721
}
722

            
723
53046
pub fn vec_op<T, A>(f: fn(Vec<T>) -> A, a: &[Expr]) -> Option<A>
724
53046
where
725
53046
    T: TryFrom<Lit>,
726
{
727
53046
    let a = a.iter().map(unwrap_expr).collect::<Option<Vec<T>>>()?;
728
24
    Some(f(a))
729
53046
}
730

            
731
5810672
pub fn vec_lit_op<T, A>(f: fn(Vec<T>) -> A, a: &Expr) -> Option<A>
732
5810672
where
733
5810672
    T: TryFrom<Lit>,
734
{
735
5810672
    Some(f(eval_list_items(a)?))
736
5810672
}
737

            
738
type PairsCallback<T, A> = fn(Vec<(T, T)>, (usize, usize)) -> A;
739

            
740
/// Calls the given function on each consecutive pair of elements in the list expressions.
741
/// Also passes the length of the two lists.
742
84762
fn vec_expr_pairs_op<T, A>(a: &Expr, b: &Expr, f: PairsCallback<T, A>) -> Option<A>
743
84762
where
744
84762
    T: TryFrom<Lit>,
745
{
746
84762
    let a_exprs = a.clone().unwrap_matrix_unchecked()?.0;
747
2080
    let b_exprs = b.clone().unwrap_matrix_unchecked()?.0;
748
1280
    let lens = (a_exprs.len(), b_exprs.len());
749

            
750
1280
    let lit_pairs = std::iter::zip(a_exprs, b_exprs)
751
1760
        .map(|(a, b)| Some((unwrap_expr(&a)?, unwrap_expr(&b)?)))
752
1280
        .collect::<Option<Vec<(T, T)>>>()?;
753
480
    Some(f(lit_pairs, lens))
754
84762
}
755

            
756
/// Same as [`vec_expr_pairs_op`], but over slices of atoms.
757
800
fn atoms_pairs_op<T, A>(a: &[Atom], b: &[Atom], f: PairsCallback<T, A>) -> Option<A>
758
800
where
759
800
    T: TryFrom<Atom>,
760
{
761
800
    let lit_pairs = Iterator::zip(a.iter(), b.iter())
762
800
        .map(|(a, b)| Some((a.clone().try_into().ok()?, b.clone().try_into().ok()?)))
763
800
        .collect::<Option<Vec<(T, T)>>>()?;
764
    Some(f(lit_pairs, (a.len(), b.len())))
765
800
}
766

            
767
pub fn opt_vec_op<T, A>(f: fn(Vec<T>) -> Option<A>, a: &[Expr]) -> Option<A>
768
where
769
    T: TryFrom<Lit>,
770
{
771
    let a = a.iter().map(unwrap_expr).collect::<Option<Vec<T>>>()?;
772
    f(a)
773
}
774

            
775
134146
pub fn opt_vec_lit_op<T, A>(f: fn(Vec<T>) -> Option<A>, a: &Expr) -> Option<A>
776
134146
where
777
134146
    T: TryFrom<Lit>,
778
{
779
134146
    f(eval_list_items(a)?)
780
134146
}
781

            
782
#[allow(dead_code)]
783
pub fn flat_op<T, A>(f: fn(Vec<T>, T) -> A, a: &[Expr], b: &Expr) -> Option<A>
784
where
785
    T: TryFrom<Lit>,
786
{
787
    let a = a.iter().map(unwrap_expr).collect::<Option<Vec<T>>>()?;
788
    let b = unwrap_expr::<T>(b)?;
789
    Some(f(a, b))
790
}
791

            
792
15520892
pub fn unwrap_expr<T: TryFrom<Lit>>(expr: &Expr) -> Option<T> {
793
15520892
    let c = eval_constant(expr)?;
794
4398812
    TryInto::<T>::try_into(c).ok()
795
15520892
}
796

            
797
5944818
fn eval_list_items<T>(expr: &Expr) -> Option<Vec<T>>
798
5944818
where
799
5944818
    T: TryFrom<Lit>,
800
{
801
5944818
    if let Some(items) = expr
802
5944818
        .clone()
803
5944818
        .unwrap_matrix_unchecked()
804
5944818
        .map(|(items, _)| items)
805
    {
806
5688760
        return items.iter().map(unwrap_expr).collect();
807
256058
    }
808

            
809
256058
    let Lit::AbstractLiteral(list) = eval_constant(expr)? else {
810
        return None;
811
    };
812

            
813
4828
    let items = list.unwrap_list()?;
814
4828
    items
815
4828
        .iter()
816
4828
        .cloned()
817
4828
        .map(TryInto::try_into)
818
4828
        .collect::<Result<Vec<_>, _>>()
819
4828
        .ok()
820
5944818
}
821

            
822
413056
fn eval_constant_comprehension(comprehension: &Comprehension) -> Option<Lit> {
823
413056
    let mut values = Vec::new();
824
413056
    eval_comprehension_qualifiers(comprehension, 0, &mut values)?;
825
5908
    Some(Lit::AbstractLiteral(
826
5908
        AbstractLiteral::matrix_implied_indices(values),
827
5908
    ))
828
413056
}
829

            
830
958786
fn eval_comprehension_qualifiers(
831
958786
    comprehension: &Comprehension,
832
958786
    qualifier_index: usize,
833
958786
    values: &mut Vec<Lit>,
834
958786
) -> Option<()> {
835
958786
    if qualifier_index == comprehension.qualifiers.len() {
836
404546
        values.push(eval_constant(&comprehension.return_expression)?);
837
12400
        return Some(());
838
554240
    }
839

            
840
554240
    match &comprehension.qualifiers[qualifier_index] {
841
482952
        ComprehensionQualifier::Generator { ptr } => {
842
482952
            let domain = ptr.domain()?;
843
482952
            let generator_values = domain.resolve()?.values().ok()?.collect_vec();
844

            
845
493418
            for value in generator_values {
846
493418
                with_temporary_quantified_binding(ptr, &value, || {
847
493418
                    eval_comprehension_qualifiers(comprehension, qualifier_index + 1, values)
848
493418
                })?;
849
            }
850
        }
851
610
        ComprehensionQualifier::ExpressionGenerator { ptr } => {
852
            // clone immediately so the read lock guard is dropped
853
610
            let expr = ptr.as_quantified_expr()?.clone();
854
610
            let generator_values = generator_values_from_expr(&expr)?;
855

            
856
8
            for value in generator_values {
857
8
                with_temporary_quantified_binding(ptr, &value, || {
858
8
                    eval_comprehension_qualifiers(comprehension, qualifier_index + 1, values)
859
8
                })?;
860
            }
861
        }
862
70678
        ComprehensionQualifier::Condition(condition) => match eval_constant(condition)? {
863
            Lit::Bool(true) => {
864
52304
                eval_comprehension_qualifiers(comprehension, qualifier_index + 1, values)?
865
            }
866
5420
            Lit::Bool(false) => {}
867
            _ => return None,
868
        },
869
    }
870

            
871
14040
    Some(())
872
958786
}
873

            
874
610
fn generator_values_from_expr(expr: &Expr) -> Option<Vec<Lit>> {
875
610
    match eval_constant(expr)? {
876
8
        Lit::AbstractLiteral(AbstractLiteral::Set(values))
877
        | Lit::AbstractLiteral(AbstractLiteral::MSet(values))
878
8
        | Lit::AbstractLiteral(AbstractLiteral::Tuple(values)) => Some(values),
879
        Lit::AbstractLiteral(list) => list.unwrap_list().cloned(),
880
        _ => None,
881
    }
882
610
}
883

            
884
493426
fn with_temporary_quantified_binding<T>(
885
493426
    quantified: &crate::ast::DeclarationPtr,
886
493426
    value: &Lit,
887
493426
    f: impl FnOnce() -> Option<T>,
888
493426
) -> Option<T> {
889
493426
    let mut targets = vec![quantified.clone()];
890
493426
    if let DeclarationKind::Quantified(inner) = &*quantified.kind()
891
493418
        && let Some(generator) = inner.generator()
892
    {
893
        targets.push(generator.clone());
894
493426
    }
895

            
896
493426
    let mut originals = Vec::with_capacity(targets.len());
897
493426
    for mut target in targets {
898
493426
        let old_kind = target.replace_kind(DeclarationKind::TemporaryValueLetting(Expr::Atomic(
899
493426
            Metadata::new(),
900
493426
            Atom::Literal(value.clone()),
901
493426
        )));
902
493426
        originals.push((target, old_kind));
903
493426
    }
904

            
905
493426
    let result = f();
906

            
907
493426
    for (mut target, old_kind) in originals.into_iter().rev() {
908
493426
        let _ = target.replace_kind(old_kind);
909
493426
    }
910

            
911
493426
    result
912
493426
}