1
#![allow(dead_code)]
2
use crate::ast::{
3
    AbstractLiteral, Atom, DeclarationKind, Expression as Expr, Literal as Lit, Metadata,
4
    comprehension::{Comprehension, ComprehensionQualifier},
5
    matrix,
6
};
7
use crate::into_matrix;
8
use itertools::{Itertools as _, izip};
9
use std::cmp::Ordering as CmpOrdering;
10
use std::collections::HashSet;
11

            
12
/// Simplify an expression to a constant if possible
13
/// Returns:
14
/// `None` if the expression cannot be simplified to a constant (e.g. if it contains a variable)
15
/// `Some(Const)` if the expression can be simplified to a constant
16
56579166
pub fn eval_constant(expr: &Expr) -> Option<Lit> {
17
26743656
    match expr {
18
640
        Expr::Supset(_, a, b) => match (a.as_ref(), b.as_ref()) {
19
            (
20
320
                Expr::Atomic(_, Atom::Literal(Lit::AbstractLiteral(AbstractLiteral::Set(a)))),
21
320
                Expr::Atomic(_, Atom::Literal(Lit::AbstractLiteral(AbstractLiteral::Set(b)))),
22
            ) => {
23
320
                let a_set: HashSet<Lit> = a.iter().cloned().collect();
24
320
                let b_set: HashSet<Lit> = b.iter().cloned().collect();
25

            
26
320
                if a_set.difference(&b_set).count() > 0 {
27
240
                    Some(Lit::Bool(a_set.is_superset(&b_set)))
28
                } else {
29
80
                    Some(Lit::Bool(false))
30
                }
31
            }
32
320
            _ => None,
33
        },
34
640
        Expr::SupsetEq(_, a, b) => match (a.as_ref(), b.as_ref()) {
35
            (
36
320
                Expr::Atomic(_, Atom::Literal(Lit::AbstractLiteral(AbstractLiteral::Set(a)))),
37
320
                Expr::Atomic(_, Atom::Literal(Lit::AbstractLiteral(AbstractLiteral::Set(b)))),
38
320
            ) => Some(Lit::Bool(
39
320
                a.iter()
40
320
                    .cloned()
41
320
                    .collect::<HashSet<Lit>>()
42
320
                    .is_superset(&b.iter().cloned().collect::<HashSet<Lit>>()),
43
320
            )),
44
320
            _ => None,
45
        },
46
1200
        Expr::Subset(_, a, b) => match (a.as_ref(), b.as_ref()) {
47
            (
48
400
                Expr::Atomic(_, Atom::Literal(Lit::AbstractLiteral(AbstractLiteral::Set(a)))),
49
400
                Expr::Atomic(_, Atom::Literal(Lit::AbstractLiteral(AbstractLiteral::Set(b)))),
50
            ) => {
51
400
                let a_set: HashSet<Lit> = a.iter().cloned().collect();
52
400
                let b_set: HashSet<Lit> = b.iter().cloned().collect();
53

            
54
400
                if b_set.difference(&a_set).count() > 0 {
55
320
                    Some(Lit::Bool(a_set.is_subset(&b_set)))
56
                } else {
57
80
                    Some(Lit::Bool(false))
58
                }
59
            }
60
800
            _ => None,
61
        },
62
674
        Expr::SubsetEq(_, a, b) => match (a.as_ref(), b.as_ref()) {
63
            (
64
320
                Expr::Atomic(_, Atom::Literal(Lit::AbstractLiteral(AbstractLiteral::Set(a)))),
65
320
                Expr::Atomic(_, Atom::Literal(Lit::AbstractLiteral(AbstractLiteral::Set(b)))),
66
320
            ) => Some(Lit::Bool(
67
320
                a.iter()
68
320
                    .cloned()
69
320
                    .collect::<HashSet<Lit>>()
70
320
                    .is_subset(&b.iter().cloned().collect::<HashSet<Lit>>()),
71
320
            )),
72
354
            _ => None,
73
        },
74
720
        Expr::Intersect(_, a, b) => match (a.as_ref(), b.as_ref()) {
75
            (
76
160
                Expr::Atomic(_, Atom::Literal(Lit::AbstractLiteral(AbstractLiteral::Set(a)))),
77
160
                Expr::Atomic(_, Atom::Literal(Lit::AbstractLiteral(AbstractLiteral::Set(b)))),
78
            ) => {
79
160
                let mut res: Vec<Lit> = Vec::new();
80
400
                for lit in a.iter() {
81
400
                    if b.contains(lit) && !res.contains(lit) {
82
320
                        res.push(lit.clone());
83
320
                    }
84
                }
85
160
                Some(Lit::AbstractLiteral(AbstractLiteral::Set(res)))
86
            }
87
560
            _ => None,
88
        },
89
396
        Expr::Union(_, a, b) => match (a.as_ref(), b.as_ref()) {
90
            (
91
160
                Expr::Atomic(_, Atom::Literal(Lit::AbstractLiteral(AbstractLiteral::Set(a)))),
92
160
                Expr::Atomic(_, Atom::Literal(Lit::AbstractLiteral(AbstractLiteral::Set(b)))),
93
            ) => {
94
160
                let mut res: Vec<Lit> = Vec::new();
95
480
                for lit in a.iter() {
96
480
                    res.push(lit.clone());
97
480
                }
98
480
                for lit in b.iter() {
99
480
                    if !res.contains(lit) {
100
400
                        res.push(lit.clone());
101
400
                    }
102
                }
103
160
                Some(Lit::AbstractLiteral(AbstractLiteral::Set(res)))
104
            }
105
236
            _ => None,
106
        },
107
5180
        Expr::In(_, a, b) => {
108
            if let (
109
80
                Expr::Atomic(_, Atom::Literal(Lit::Int(c))),
110
80
                Expr::Atomic(_, Atom::Literal(Lit::AbstractLiteral(AbstractLiteral::Set(d)))),
111
5180
            ) = (a.as_ref(), b.as_ref())
112
            {
113
240
                for lit in d.iter() {
114
240
                    if let Lit::Int(x) = lit
115
240
                        && c == x
116
                    {
117
80
                        return Some(Lit::Bool(true));
118
160
                    }
119
                }
120
                Some(Lit::Bool(false))
121
            } else {
122
5100
                None
123
            }
124
        }
125
        Expr::FromSolution(_, _) => None,
126
        Expr::DominanceRelation(_, _) => None,
127
22680
        Expr::InDomain(_, e, domain) => {
128
22680
            let Expr::Atomic(_, Atom::Literal(lit)) = e.as_ref() else {
129
22536
                return None;
130
            };
131

            
132
144
            domain.contains(lit).ok().map(Into::into)
133
        }
134
10983512
        Expr::Atomic(_, Atom::Literal(c)) => Some(c.clone()),
135
15760144
        Expr::Atomic(_, Atom::Reference(reference)) => reference.resolve_constant(),
136
4738824
        Expr::AbstractLiteral(_, a) => Some(Lit::AbstractLiteral(a.clone().into_literals()?)),
137
414128
        Expr::Comprehension(_, comprehension) => {
138
414128
            eval_constant_comprehension(comprehension.as_ref())
139
        }
140
        Expr::AbstractComprehension(_, _) => None,
141
4354976
        Expr::UnsafeIndex(_, subject, indices) | Expr::SafeIndex(_, subject, indices) => {
142
6199740
            let subject: Lit = eval_constant(subject.as_ref())?;
143
90080
            let indices: Vec<Lit> = indices
144
90080
                .iter()
145
90080
                .map(eval_constant)
146
90080
                .collect::<Option<Vec<Lit>>>()?;
147

            
148
17136
            match subject {
149
16576
                Lit::AbstractLiteral(subject @ AbstractLiteral::Matrix(_, _)) => {
150
16576
                    matrix::flatten_enumerate(subject)
151
43974
                        .find(|(i, _)| i == &indices)
152
16576
                        .map(|(_, x)| x)
153
                }
154
400
                Lit::AbstractLiteral(subject @ AbstractLiteral::Tuple(_)) => {
155
400
                    let AbstractLiteral::Tuple(elems) = subject else {
156
                        return None;
157
                    };
158

            
159
400
                    assert!(indices.len() == 1, "nested tuples not supported yet");
160

            
161
400
                    let Lit::Int(index) = indices[0].clone() else {
162
                        return None;
163
                    };
164

            
165
400
                    if elems.len() < index as usize || index < 1 {
166
                        return None;
167
400
                    }
168

            
169
                    // -1 for 0-indexing vs 1-indexing
170
400
                    let item = elems[index as usize - 1].clone();
171

            
172
400
                    Some(item)
173
                }
174
160
                Lit::AbstractLiteral(subject @ AbstractLiteral::Record(_)) => {
175
160
                    let AbstractLiteral::Record(elems) = subject else {
176
                        return None;
177
                    };
178

            
179
160
                    assert!(indices.len() == 1, "nested record not supported yet");
180

            
181
160
                    let Lit::Int(index) = indices[0].clone() else {
182
                        return None;
183
                    };
184

            
185
160
                    if elems.len() < index as usize || index < 1 {
186
                        return None;
187
160
                    }
188

            
189
                    // -1 for 0-indexing vs 1-indexing
190
160
                    let item = elems[index as usize - 1].clone();
191
160
                    Some(item.value)
192
                }
193
                _ => None,
194
            }
195
        }
196
54080
        Expr::UnsafeSlice(_, subject, indices) | Expr::SafeSlice(_, subject, indices) => {
197
102960
            let subject: Lit = eval_constant(subject.as_ref())?;
198
80
            let Lit::AbstractLiteral(subject @ AbstractLiteral::Matrix(_, _)) = subject else {
199
                return None;
200
            };
201

            
202
80
            let hole_dim = indices
203
80
                .iter()
204
80
                .cloned()
205
160
                .position(|x| x.is_none())
206
80
                .expect("slice expression should have a hole dimension");
207

            
208
80
            let missing_domain = matrix::index_domains(subject.clone())[hole_dim].clone();
209

            
210
80
            let indices: Vec<Option<Lit>> = indices
211
80
                .iter()
212
80
                .cloned()
213
160
                .map(|x| {
214
                    // the outer option represents success of this iterator, the inner the index
215
                    // slice.
216
160
                    match x {
217
80
                        Some(x) => eval_constant(&x).map(Some),
218
80
                        None => Some(None),
219
                    }
220
160
                })
221
80
                .collect::<Option<Vec<Option<Lit>>>>()?;
222

            
223
80
            let indices_in_slice: Vec<Vec<Lit>> = missing_domain
224
80
                .values()
225
80
                .ok()?
226
240
                .map(|i| {
227
240
                    let mut indices = indices.clone();
228
240
                    indices[hole_dim] = Some(i);
229
                    // These unwraps will only fail if we have multiple holes.
230
                    // As this is invalid, panicking is fine.
231
480
                    indices.into_iter().map(|x| x.unwrap()).collect_vec()
232
240
                })
233
80
                .collect_vec();
234

            
235
            // Note: indices_in_slice is not necessarily sorted, so this is the best way.
236
80
            let elems = matrix::flatten_enumerate(subject)
237
720
                .filter(|(i, _)| indices_in_slice.contains(i))
238
80
                .map(|(_, elem)| elem)
239
80
                .collect();
240

            
241
80
            Some(Lit::AbstractLiteral(into_matrix![elems]))
242
        }
243
23560
        Expr::Abs(_, e) => un_op::<i32, i32>(|a| a.abs(), e).map(Lit::Int),
244
1530298
        Expr::Eq(_, a, b) => bin_op::<i32, bool>(|a, b| a == b, a, b)
245
1530298
            .or_else(|| bin_op::<bool, bool>(|a, b| a == b, a, b))
246
1530298
            .map(Lit::Bool),
247
872698
        Expr::Neq(_, a, b) => bin_op::<i32, bool>(|a, b| a != b, a, b).map(Lit::Bool),
248
310422
        Expr::Lt(_, a, b) => bin_op::<i32, bool>(|a, b| a < b, a, b).map(Lit::Bool),
249
15808
        Expr::Gt(_, a, b) => bin_op::<i32, bool>(|a, b| a > b, a, b).map(Lit::Bool),
250
1110448
        Expr::Leq(_, a, b) => bin_op::<i32, bool>(|a, b| a <= b, a, b).map(Lit::Bool),
251
350886
        Expr::Geq(_, a, b) => bin_op::<i32, bool>(|a, b| a >= b, a, b).map(Lit::Bool),
252
382336
        Expr::Not(_, expr) => un_op::<bool, bool>(|e| !e, expr).map(Lit::Bool),
253
1488100
        Expr::And(_, e) => {
254
1488100
            vec_lit_op::<bool, bool>(|e| e.iter().all(|&e| e), e.as_ref()).map(Lit::Bool)
255
        }
256
1440
        Expr::Table(_, _, _) => None,
257
240
        Expr::NegativeTable(_, _, _) => None,
258
499682
        Expr::Root(_, _) => None,
259
974874
        Expr::Or(_, es) => {
260
            // possibly cheating; definitely should be in partial eval instead
261
1577636
            for e in (**es).clone().unwrap_list()? {
262
4518
                if let Expr::Atomic(_, Atom::Literal(Lit::Bool(true))) = e {
263
2832
                    return Some(Lit::Bool(true));
264
1574804
                };
265
            }
266

            
267
759346
            vec_lit_op::<bool, bool>(|e| e.iter().any(|&e| e), es.as_ref()).map(Lit::Bool)
268
        }
269
484836
        Expr::Imply(_, box1, box2) => {
270
484836
            let a: &Atom = (&**box1).try_into().ok()?;
271
320120
            let b: &Atom = (&**box2).try_into().ok()?;
272

            
273
148788
            let a: bool = a.try_into().ok()?;
274
137560
            let b: bool = b.try_into().ok()?;
275

            
276
137560
            if a {
277
                // true -> b ~> b
278
68300
                Some(Lit::Bool(b))
279
            } else {
280
                // false -> b ~> true
281
69260
                Some(Lit::Bool(true))
282
            }
283
        }
284
15452
        Expr::Iff(_, box1, box2) => {
285
15452
            let a: &Atom = (&**box1).try_into().ok()?;
286
4248
            let b: &Atom = (&**box2).try_into().ok()?;
287

            
288
328
            let a: bool = a.try_into().ok()?;
289
88
            let b: bool = b.try_into().ok()?;
290

            
291
8
            Some(Lit::Bool(a == b))
292
        }
293
2699954
        Expr::Sum(_, exprs) => vec_lit_op::<i32, i32>(|e| e.iter().sum(), exprs).map(Lit::Int),
294
958564
        Expr::Product(_, exprs) => {
295
958564
            vec_lit_op::<i32, i32>(|e| e.iter().product(), exprs).map(Lit::Int)
296
        }
297
341128
        Expr::FlatIneq(_, a, b, c) => {
298
341128
            let a: i32 = a.try_into().ok()?;
299
216134
            let b: i32 = b.try_into().ok()?;
300
162800
            let c: i32 = c.try_into().ok()?;
301

            
302
162800
            Some(Lit::Bool(a <= b + c))
303
        }
304
409114
        Expr::FlatSumGeq(_, exprs, a) => {
305
768884
            let sum = exprs.iter().try_fold(0, |acc, atom: &Atom| {
306
768884
                let n: i32 = atom.try_into().ok()?;
307
359770
                let acc = acc + n;
308
359770
                Some(acc)
309
768884
            })?;
310

            
311
            Some(Lit::Bool(sum >= a.try_into().ok()?))
312
        }
313
457670
        Expr::FlatSumLeq(_, exprs, a) => {
314
829552
            let sum = exprs.iter().try_fold(0, |acc, atom: &Atom| {
315
829552
                let n: i32 = atom.try_into().ok()?;
316
372042
                let acc = acc + n;
317
372042
                Some(acc)
318
829552
            })?;
319

            
320
160
            Some(Lit::Bool(sum >= a.try_into().ok()?))
321
        }
322
66120
        Expr::Min(_, e) => {
323
66120
            opt_vec_lit_op::<i32, i32>(|e| e.iter().min().copied(), e.as_ref()).map(Lit::Int)
324
        }
325
68022
        Expr::Max(_, e) => {
326
68022
            opt_vec_lit_op::<i32, i32>(|e| e.iter().max().copied(), e.as_ref()).map(Lit::Int)
327
        }
328
150120
        Expr::UnsafeDiv(_, a, b) | Expr::SafeDiv(_, a, b) => {
329
251520
            if unwrap_expr::<i32>(b)? == 0 {
330
                return None;
331
21680
            }
332
21680
            bin_op::<i32, i32>(|a, b| ((a as f32) / (b as f32)).floor() as i32, a, b).map(Lit::Int)
333
        }
334
40800
        Expr::UnsafeMod(_, a, b) | Expr::SafeMod(_, a, b) => {
335
60080
            if unwrap_expr::<i32>(b)? == 0 {
336
                return None;
337
4480
            }
338
4480
            bin_op::<i32, i32>(|a, b| a - b * (a as f32 / b as f32).floor() as i32, a, b)
339
4480
                .map(Lit::Int)
340
        }
341
12040
        Expr::MinionDivEqUndefZero(_, a, b, c) => {
342
            // div always rounds down
343
12040
            let a: i32 = a.try_into().ok()?;
344
240
            let b: i32 = b.try_into().ok()?;
345
            let c: i32 = c.try_into().ok()?;
346

            
347
            if b == 0 {
348
                return None;
349
            }
350

            
351
            let a = a as f32;
352
            let b = b as f32;
353
            let div: i32 = (a / b).floor() as i32;
354
            Some(Lit::Bool(div == c))
355
        }
356
53756
        Expr::Bubble(_, a, b) => bin_op::<bool, bool>(|a, b| a && b, a, b).map(Lit::Bool),
357
157680
        Expr::MinionReify(_, a, b) => {
358
157680
            let result = eval_constant(a)?;
359

            
360
15400
            let result: bool = result.try_into().ok()?;
361
15400
            let b: bool = b.try_into().ok()?;
362

            
363
            Some(Lit::Bool(b == result))
364
        }
365
73956
        Expr::MinionReifyImply(_, a, b) => {
366
73956
            let result = eval_constant(a)?;
367

            
368
            let result: bool = result.try_into().ok()?;
369
            let b: bool = b.try_into().ok()?;
370

            
371
            if b {
372
                Some(Lit::Bool(result))
373
            } else {
374
                Some(Lit::Bool(true))
375
            }
376
        }
377
2960
        Expr::MinionModuloEqUndefZero(_, a, b, c) => {
378
            // From Savile Row. Same semantics as division.
379
            //
380
            //   a - (b * floor(a/b))
381
            //
382
            // We don't use % as it has the same semantics as /. We don't use / as we want to round
383
            // down instead, not towards zero.
384

            
385
2960
            let a: i32 = a.try_into().ok()?;
386
240
            let b: i32 = b.try_into().ok()?;
387
            let c: i32 = c.try_into().ok()?;
388

            
389
            if b == 0 {
390
                return None;
391
            }
392

            
393
            let modulo = a - b * (a as f32 / b as f32).floor() as i32;
394
            Some(Lit::Bool(modulo == c))
395
        }
396
4950
        Expr::MinionPow(_, a, b, c) => {
397
            // only available for positive a b c
398

            
399
4950
            let a: i32 = a.try_into().ok()?;
400
            let b: i32 = b.try_into().ok()?;
401
            let c: i32 = c.try_into().ok()?;
402

            
403
            if a <= 0 {
404
                return None;
405
            }
406

            
407
            if b <= 0 {
408
                return None;
409
            }
410

            
411
            if c <= 0 {
412
                return None;
413
            }
414

            
415
            Some(Lit::Bool(a ^ b == c))
416
        }
417
640
        Expr::MinionWInSet(_, _, _) => None,
418
2040
        Expr::MinionWInIntervalSet(_, x, intervals) => {
419
2040
            let x_lit: &Lit = x.try_into().ok()?;
420

            
421
            let x_lit = match x_lit.clone() {
422
                Lit::Int(i) => Some(i),
423
                Lit::Bool(true) => Some(1),
424
                Lit::Bool(false) => Some(0),
425
                _ => None,
426
            }?;
427

            
428
            let mut intervals = intervals.iter();
429
            loop {
430
                let Some(lower) = intervals.next() else {
431
                    break;
432
                };
433

            
434
                let Some(upper) = intervals.next() else {
435
                    break;
436
                };
437
                if &x_lit >= lower && &x_lit <= upper {
438
                    return Some(Lit::Bool(true));
439
                }
440
            }
441

            
442
            Some(Lit::Bool(false))
443
        }
444
        Expr::Flatten(_, _, _) => {
445
            // TODO
446
22196
            None
447
        }
448
106766
        Expr::AllDiff(_, e) => {
449
106766
            let es = (**e).clone().unwrap_list()?;
450
12294
            let mut lits: HashSet<Lit> = HashSet::new();
451
13414
            for expr in es {
452
6014
                let Expr::Atomic(_, Atom::Literal(x)) = expr else {
453
11974
                    return None;
454
                };
455
1440
                match x {
456
                    Lit::Int(_) | Lit::Bool(_) => {
457
1440
                        if lits.contains(&x) {
458
                            return Some(Lit::Bool(false));
459
1440
                        } else {
460
1440
                            lits.insert(x.clone());
461
1440
                        }
462
                    }
463
                    Lit::AbstractLiteral(_) => return None, // Reject AbstractLiteral cases
464
                }
465
            }
466
320
            Some(Lit::Bool(true))
467
        }
468
61554
        Expr::FlatAllDiff(_, es) => {
469
61554
            let mut lits: HashSet<Lit> = HashSet::new();
470
61554
            for atom in es {
471
61554
                let Atom::Literal(x) = atom else {
472
61554
                    return None;
473
                };
474

            
475
                match x {
476
                    Lit::Int(_) | Lit::Bool(_) => {
477
                        if lits.contains(x) {
478
                            return Some(Lit::Bool(false));
479
                        } else {
480
                            lits.insert(x.clone());
481
                        }
482
                    }
483
                    Lit::AbstractLiteral(_) => return None, // Reject AbstractLiteral cases
484
                }
485
            }
486
            Some(Lit::Bool(true))
487
        }
488
50216
        Expr::FlatWatchedLiteral(_, _, _) => None,
489
198446
        Expr::AuxDeclaration(_, _, _) => None,
490
122490
        Expr::Neg(_, a) => match eval_constant(a.as_ref())? {
491
17840
            Lit::Int(a) => Some(Lit::Int(-a)),
492
            _ => None,
493
        },
494
        Expr::Factorial(_, _) => None,
495
684746
        Expr::Minus(_, a, b) => bin_op::<i32, i32>(|a, b| a - b, a, b).map(Lit::Int),
496
1680
        Expr::FlatMinusEq(_, a, b) => {
497
1680
            let a: i32 = a.try_into().ok()?;
498
            let b: i32 = b.try_into().ok()?;
499
            Some(Lit::Bool(a == -b))
500
        }
501
6560
        Expr::FlatProductEq(_, a, b, c) => {
502
6560
            let a: i32 = a.try_into().ok()?;
503
            let b: i32 = b.try_into().ok()?;
504
            let c: i32 = c.try_into().ok()?;
505
            Some(Lit::Bool(a * b == c))
506
        }
507
155800
        Expr::FlatWeightedSumLeq(_, cs, vs, total) => {
508
155800
            let cs: Vec<i32> = cs
509
155800
                .iter()
510
321160
                .map(|x| TryInto::<i32>::try_into(x).ok())
511
155800
                .collect::<Option<Vec<i32>>>()?;
512
155800
            let vs: Vec<i32> = vs
513
155800
                .iter()
514
259560
                .map(|x| TryInto::<i32>::try_into(x).ok())
515
155800
                .collect::<Option<Vec<i32>>>()?;
516
2640
            let total: i32 = total.try_into().ok()?;
517

            
518
            let sum: i32 = izip!(cs, vs).fold(0, |acc, (c, v)| acc + (c * v));
519

            
520
            Some(Lit::Bool(sum <= total))
521
        }
522
152280
        Expr::FlatWeightedSumGeq(_, cs, vs, total) => {
523
152280
            let cs: Vec<i32> = cs
524
152280
                .iter()
525
310160
                .map(|x| TryInto::<i32>::try_into(x).ok())
526
152280
                .collect::<Option<Vec<i32>>>()?;
527
152280
            let vs: Vec<i32> = vs
528
152280
                .iter()
529
253560
                .map(|x| TryInto::<i32>::try_into(x).ok())
530
152280
                .collect::<Option<Vec<i32>>>()?;
531
2120
            let total: i32 = total.try_into().ok()?;
532

            
533
            let sum: i32 = izip!(cs, vs).fold(0, |acc, (c, v)| acc + (c * v));
534

            
535
            Some(Lit::Bool(sum >= total))
536
        }
537
3440
        Expr::FlatAbsEq(_, x, y) => {
538
3440
            let x: i32 = x.try_into().ok()?;
539
160
            let y: i32 = y.try_into().ok()?;
540

            
541
            Some(Lit::Bool(x == y.abs()))
542
        }
543
287068
        Expr::UnsafePow(_, a, b) | Expr::SafePow(_, a, b) => {
544
554700
            let a: &Atom = a.try_into().ok()?;
545
550300
            let a: i32 = a.try_into().ok()?;
546

            
547
481626
            let b: &Atom = b.try_into().ok()?;
548
481626
            let b: i32 = b.try_into().ok()?;
549

            
550
481626
            if (a != 0 || b != 0) && b >= 0 {
551
481626
                Some(Lit::Int(a.pow(b as u32)))
552
            } else {
553
                None
554
            }
555
        }
556
        Expr::Metavar(_, _) => None,
557
215804
        Expr::MinionElementOne(_, _, _, _) => None,
558
50034
        Expr::ToInt(_, expression) => {
559
50034
            let lit = eval_constant(expression.as_ref())?;
560
1632
            match lit {
561
                Lit::Int(_) => Some(lit),
562
1610
                Lit::Bool(true) => Some(Lit::Int(1)),
563
22
                Lit::Bool(false) => Some(Lit::Int(0)),
564
                _ => None,
565
            }
566
        }
567
        Expr::SATInt(_, _, _, _) => {
568
            // TODO: If this SATInt is composed of literals, we should evaluate it back to an
569
            // integer literal.
570
            //
571
            // This is important because `is_all_constant` currently returns true for SATInts
572
            // containing no references. If we don't evaluate them here, bubble rules will skip
573
            // them (thinking they'll be constant-folded later), but they'll actually reach
574
            // the solver adaptors as un-encoded unsafe operations, causing panics.
575
2193820
            None
576
        }
577
        Expr::PairwiseSum(_, a, b) => {
578
            match (eval_constant(a.as_ref())?, eval_constant(b.as_ref())?) {
579
                (Lit::Int(a_int), Lit::Int(b_int)) => Some(Lit::Int(a_int + b_int)),
580
                _ => None,
581
            }
582
        }
583
        Expr::PairwiseProduct(_, a, b) => {
584
            match (eval_constant(a.as_ref())?, eval_constant(b.as_ref())?) {
585
                (Lit::Int(a_int), Lit::Int(b_int)) => Some(Lit::Int(a_int * b_int)),
586
                _ => None,
587
            }
588
        }
589
        Expr::Defined(_, _) => todo!(),
590
        Expr::Range(_, _) => todo!(),
591
        Expr::Image(_, _, _) => todo!(),
592
        Expr::ImageSet(_, _, _) => todo!(),
593
        Expr::PreImage(_, _, _) => todo!(),
594
        Expr::Inverse(_, _, _) => todo!(),
595
        Expr::Restrict(_, _, _) => todo!(),
596
2002
        Expr::LexLt(_, a, b) => {
597
2002
            let lt = vec_expr_pairs_op::<i32, _>(a, b, |pairs, (a_len, b_len)| {
598
320
                pairs
599
320
                    .iter()
600
640
                    .find_map(|(a, b)| match a.cmp(b) {
601
160
                        CmpOrdering::Less => Some(true),     // First difference is <
602
                        CmpOrdering::Greater => Some(false), // First difference is >
603
480
                        CmpOrdering::Equal => None,          // No difference
604
640
                    })
605
320
                    .unwrap_or(a_len < b_len) // [1,1] <lex [1,1,x]
606
1682
            })?;
607
320
            Some(lt.into())
608
        }
609
82760
        Expr::LexLeq(_, a, b) => {
610
82760
            let lt = vec_expr_pairs_op::<i32, _>(a, b, |pairs, (a_len, b_len)| {
611
160
                pairs
612
160
                    .iter()
613
320
                    .find_map(|(a, b)| match a.cmp(b) {
614
160
                        CmpOrdering::Less => Some(true),
615
                        CmpOrdering::Greater => Some(false),
616
160
                        CmpOrdering::Equal => None,
617
320
                    })
618
160
                    .unwrap_or(a_len <= b_len) // [1,1] <=lex [1,1,x]
619
82600
            })?;
620
160
            Some(lt.into())
621
        }
622
120
        Expr::LexGt(_, a, b) => eval_constant(&Expr::LexLt(Metadata::new(), b.clone(), a.clone())),
623
240
        Expr::LexGeq(_, a, b) => {
624
240
            eval_constant(&Expr::LexLeq(Metadata::new(), b.clone(), a.clone()))
625
        }
626
320
        Expr::FlatLexLt(_, a, b) => {
627
320
            let lt = atoms_pairs_op::<i32, _>(a, b, |pairs, (a_len, b_len)| {
628
                pairs
629
                    .iter()
630
                    .find_map(|(a, b)| match a.cmp(b) {
631
                        CmpOrdering::Less => Some(true),
632
                        CmpOrdering::Greater => Some(false),
633
                        CmpOrdering::Equal => None,
634
                    })
635
                    .unwrap_or(a_len < b_len)
636
320
            })?;
637
            Some(lt.into())
638
        }
639
480
        Expr::FlatLexLeq(_, a, b) => {
640
480
            let lt = atoms_pairs_op::<i32, _>(a, b, |pairs, (a_len, b_len)| {
641
                pairs
642
                    .iter()
643
                    .find_map(|(a, b)| match a.cmp(b) {
644
                        CmpOrdering::Less => Some(true),
645
                        CmpOrdering::Greater => Some(false),
646
                        CmpOrdering::Equal => None,
647
                    })
648
                    .unwrap_or(a_len <= b_len)
649
480
            })?;
650
            Some(lt.into())
651
        }
652
    }
653
56579166
}
654

            
655
405896
pub fn un_op<T, A>(f: fn(T) -> A, a: &Expr) -> Option<A>
656
405896
where
657
405896
    T: TryFrom<Lit>,
658
{
659
405896
    let a = unwrap_expr::<T>(a)?;
660
138188
    Some(f(a))
661
405896
}
662

            
663
6317236
pub fn bin_op<T, A>(f: fn(T, T) -> A, a: &Expr, b: &Expr) -> Option<A>
664
6317236
where
665
6317236
    T: TryFrom<Lit>,
666
{
667
6317236
    let a = unwrap_expr::<T>(a)?;
668
804308
    let b = unwrap_expr::<T>(b)?;
669
678680
    Some(f(a, b))
670
6317236
}
671

            
672
#[allow(dead_code)]
673
pub fn tern_op<T, A>(f: fn(T, T, T) -> A, a: &Expr, b: &Expr, c: &Expr) -> Option<A>
674
where
675
    T: TryFrom<Lit>,
676
{
677
    let a = unwrap_expr::<T>(a)?;
678
    let b = unwrap_expr::<T>(b)?;
679
    let c = unwrap_expr::<T>(c)?;
680
    Some(f(a, b, c))
681
}
682

            
683
52444
pub fn vec_op<T, A>(f: fn(Vec<T>) -> A, a: &[Expr]) -> Option<A>
684
52444
where
685
52444
    T: TryFrom<Lit>,
686
{
687
52444
    let a = a.iter().map(unwrap_expr).collect::<Option<Vec<T>>>()?;
688
24
    Some(f(a))
689
52444
}
690

            
691
5905964
pub fn vec_lit_op<T, A>(f: fn(Vec<T>) -> A, a: &Expr) -> Option<A>
692
5905964
where
693
5905964
    T: TryFrom<Lit>,
694
{
695
5905964
    Some(f(eval_list_items(a)?))
696
5905964
}
697

            
698
type PairsCallback<T, A> = fn(Vec<(T, T)>, (usize, usize)) -> A;
699

            
700
/// Calls the given function on each consecutive pair of elements in the list expressions.
701
/// Also passes the length of the two lists.
702
84762
fn vec_expr_pairs_op<T, A>(a: &Expr, b: &Expr, f: PairsCallback<T, A>) -> Option<A>
703
84762
where
704
84762
    T: TryFrom<Lit>,
705
{
706
84762
    let a_exprs = a.clone().unwrap_matrix_unchecked()?.0;
707
2080
    let b_exprs = b.clone().unwrap_matrix_unchecked()?.0;
708
1280
    let lens = (a_exprs.len(), b_exprs.len());
709

            
710
1280
    let lit_pairs = std::iter::zip(a_exprs, b_exprs)
711
1760
        .map(|(a, b)| Some((unwrap_expr(&a)?, unwrap_expr(&b)?)))
712
1280
        .collect::<Option<Vec<(T, T)>>>()?;
713
480
    Some(f(lit_pairs, lens))
714
84762
}
715

            
716
/// Same as [`vec_expr_pairs_op`], but over slices of atoms.
717
800
fn atoms_pairs_op<T, A>(a: &[Atom], b: &[Atom], f: PairsCallback<T, A>) -> Option<A>
718
800
where
719
800
    T: TryFrom<Atom>,
720
{
721
800
    let lit_pairs = Iterator::zip(a.iter(), b.iter())
722
800
        .map(|(a, b)| Some((a.clone().try_into().ok()?, b.clone().try_into().ok()?)))
723
800
        .collect::<Option<Vec<(T, T)>>>()?;
724
    Some(f(lit_pairs, (a.len(), b.len())))
725
800
}
726

            
727
pub fn opt_vec_op<T, A>(f: fn(Vec<T>) -> Option<A>, a: &[Expr]) -> Option<A>
728
where
729
    T: TryFrom<Lit>,
730
{
731
    let a = a.iter().map(unwrap_expr).collect::<Option<Vec<T>>>()?;
732
    f(a)
733
}
734

            
735
134142
pub fn opt_vec_lit_op<T, A>(f: fn(Vec<T>) -> Option<A>, a: &Expr) -> Option<A>
736
134142
where
737
134142
    T: TryFrom<Lit>,
738
{
739
134142
    f(eval_list_items(a)?)
740
134142
}
741

            
742
#[allow(dead_code)]
743
pub fn flat_op<T, A>(f: fn(Vec<T>, T) -> A, a: &[Expr], b: &Expr) -> Option<A>
744
where
745
    T: TryFrom<Lit>,
746
{
747
    let a = a.iter().map(unwrap_expr).collect::<Option<Vec<T>>>()?;
748
    let b = unwrap_expr::<T>(b)?;
749
    Some(f(a, b))
750
}
751

            
752
15851472
pub fn unwrap_expr<T: TryFrom<Lit>>(expr: &Expr) -> Option<T> {
753
15851472
    let c = eval_constant(expr)?;
754
4771918
    TryInto::<T>::try_into(c).ok()
755
15851472
}
756

            
757
6040106
fn eval_list_items<T>(expr: &Expr) -> Option<Vec<T>>
758
6040106
where
759
6040106
    T: TryFrom<Lit>,
760
{
761
6040106
    if let Some(items) = expr
762
6040106
        .clone()
763
6040106
        .unwrap_matrix_unchecked()
764
6040106
        .map(|(items, _)| items)
765
    {
766
5783396
        return items.iter().map(unwrap_expr).collect();
767
256710
    }
768

            
769
256710
    let Lit::AbstractLiteral(list) = eval_constant(expr)? else {
770
        return None;
771
    };
772

            
773
4830
    let items = list.unwrap_list()?;
774
4830
    items
775
4830
        .iter()
776
4830
        .cloned()
777
4830
        .map(TryInto::try_into)
778
4830
        .collect::<Result<Vec<_>, _>>()
779
4830
        .ok()
780
6040106
}
781

            
782
414128
fn eval_constant_comprehension(comprehension: &Comprehension) -> Option<Lit> {
783
414128
    let mut values = Vec::new();
784
414128
    eval_comprehension_qualifiers(comprehension, 0, &mut values)?;
785
5910
    Some(Lit::AbstractLiteral(
786
5910
        AbstractLiteral::matrix_implied_indices(values),
787
5910
    ))
788
414128
}
789

            
790
962082
fn eval_comprehension_qualifiers(
791
962082
    comprehension: &Comprehension,
792
962082
    qualifier_index: usize,
793
962082
    values: &mut Vec<Lit>,
794
962082
) -> Option<()> {
795
962082
    if qualifier_index == comprehension.qualifiers.len() {
796
405656
        values.push(eval_constant(&comprehension.return_expression)?);
797
12402
        return Some(());
798
556426
    }
799

            
800
556426
    match &comprehension.qualifiers[qualifier_index] {
801
484410
        ComprehensionQualifier::Generator { ptr } => {
802
484410
            let domain = ptr.domain()?;
803
484410
            let generator_values = domain.resolve()?.values().ok()?.collect_vec();
804

            
805
494882
            for value in generator_values {
806
494882
                with_temporary_quantified_binding(ptr, &value, || {
807
494882
                    eval_comprehension_qualifiers(comprehension, qualifier_index + 1, values)
808
494882
                })?;
809
            }
810
        }
811
578
        ComprehensionQualifier::ExpressionGenerator { ptr } => {
812
            // clone immediately so the read lock guard is dropped
813
578
            let expr = ptr.as_quantified_expr()?.clone();
814
578
            let generator_values = generator_values_from_expr(&expr)?;
815

            
816
8
            for value in generator_values {
817
8
                with_temporary_quantified_binding(ptr, &value, || {
818
8
                    eval_comprehension_qualifiers(comprehension, qualifier_index + 1, values)
819
8
                })?;
820
            }
821
        }
822
71438
        ComprehensionQualifier::Condition(condition) => match eval_constant(condition)? {
823
            Lit::Bool(true) => {
824
53064
                eval_comprehension_qualifiers(comprehension, qualifier_index + 1, values)?
825
            }
826
5420
            Lit::Bool(false) => {}
827
            _ => return None,
828
        },
829
    }
830

            
831
14048
    Some(())
832
962082
}
833

            
834
578
fn generator_values_from_expr(expr: &Expr) -> Option<Vec<Lit>> {
835
578
    match eval_constant(expr)? {
836
8
        Lit::AbstractLiteral(AbstractLiteral::Set(values))
837
        | Lit::AbstractLiteral(AbstractLiteral::MSet(values))
838
8
        | Lit::AbstractLiteral(AbstractLiteral::Tuple(values)) => Some(values),
839
        Lit::AbstractLiteral(list) => list.unwrap_list().cloned(),
840
        _ => None,
841
    }
842
578
}
843

            
844
494890
fn with_temporary_quantified_binding<T>(
845
494890
    quantified: &crate::ast::DeclarationPtr,
846
494890
    value: &Lit,
847
494890
    f: impl FnOnce() -> Option<T>,
848
494890
) -> Option<T> {
849
494890
    let mut targets = vec![quantified.clone()];
850
494890
    if let DeclarationKind::Quantified(inner) = &*quantified.kind()
851
494882
        && let Some(generator) = inner.generator()
852
    {
853
        targets.push(generator.clone());
854
494890
    }
855

            
856
494890
    let mut originals = Vec::with_capacity(targets.len());
857
494890
    for mut target in targets {
858
494890
        let old_kind = target.replace_kind(DeclarationKind::TemporaryValueLetting(Expr::Atomic(
859
494890
            Metadata::new(),
860
494890
            Atom::Literal(value.clone()),
861
494890
        )));
862
494890
        originals.push((target, old_kind));
863
494890
    }
864

            
865
494890
    let result = f();
866

            
867
494890
    for (mut target, old_kind) in originals.into_iter().rev() {
868
494890
        let _ = target.replace_kind(old_kind);
869
494890
    }
870

            
871
494890
    result
872
494890
}