1
#![allow(dead_code)]
2
use crate::ast::{AbstractLiteral, Atom, Expression as Expr, Literal as Lit, Metadata, matrix};
3
use crate::into_matrix;
4
use itertools::{Itertools as _, izip};
5
use std::cmp::Ordering as CmpOrdering;
6
use std::collections::HashSet;
7

            
8
/// Simplify an expression to a constant if possible
9
/// Returns:
10
/// `None` if the expression cannot be simplified to a constant (e.g. if it contains a variable)
11
/// `Some(Const)` if the expression can be simplified to a constant
12
30020568
pub fn eval_constant(expr: &Expr) -> Option<Lit> {
13
15988630
    match expr {
14
464
        Expr::Supset(_, a, b) => match (a.as_ref(), b.as_ref()) {
15
            (
16
464
                Expr::Atomic(_, Atom::Literal(Lit::AbstractLiteral(AbstractLiteral::Set(a)))),
17
464
                Expr::Atomic(_, Atom::Literal(Lit::AbstractLiteral(AbstractLiteral::Set(b)))),
18
            ) => {
19
464
                let a_set: HashSet<Lit> = a.iter().cloned().collect();
20
464
                let b_set: HashSet<Lit> = b.iter().cloned().collect();
21

            
22
464
                if a_set.difference(&b_set).count() > 0 {
23
348
                    Some(Lit::Bool(a_set.is_superset(&b_set)))
24
                } else {
25
116
                    Some(Lit::Bool(false))
26
                }
27
            }
28
            _ => None,
29
        },
30
464
        Expr::SupsetEq(_, a, b) => match (a.as_ref(), b.as_ref()) {
31
            (
32
464
                Expr::Atomic(_, Atom::Literal(Lit::AbstractLiteral(AbstractLiteral::Set(a)))),
33
464
                Expr::Atomic(_, Atom::Literal(Lit::AbstractLiteral(AbstractLiteral::Set(b)))),
34
464
            ) => Some(Lit::Bool(
35
464
                a.iter()
36
464
                    .cloned()
37
464
                    .collect::<HashSet<Lit>>()
38
464
                    .is_superset(&b.iter().cloned().collect::<HashSet<Lit>>()),
39
464
            )),
40
            _ => None,
41
        },
42
580
        Expr::Subset(_, a, b) => match (a.as_ref(), b.as_ref()) {
43
            (
44
580
                Expr::Atomic(_, Atom::Literal(Lit::AbstractLiteral(AbstractLiteral::Set(a)))),
45
580
                Expr::Atomic(_, Atom::Literal(Lit::AbstractLiteral(AbstractLiteral::Set(b)))),
46
            ) => {
47
580
                let a_set: HashSet<Lit> = a.iter().cloned().collect();
48
580
                let b_set: HashSet<Lit> = b.iter().cloned().collect();
49

            
50
580
                if b_set.difference(&a_set).count() > 0 {
51
464
                    Some(Lit::Bool(a_set.is_subset(&b_set)))
52
                } else {
53
116
                    Some(Lit::Bool(false))
54
                }
55
            }
56
            _ => None,
57
        },
58
1160
        Expr::SubsetEq(_, a, b) => match (a.as_ref(), b.as_ref()) {
59
            (
60
1160
                Expr::Atomic(_, Atom::Literal(Lit::AbstractLiteral(AbstractLiteral::Set(a)))),
61
1160
                Expr::Atomic(_, Atom::Literal(Lit::AbstractLiteral(AbstractLiteral::Set(b)))),
62
1160
            ) => Some(Lit::Bool(
63
1160
                a.iter()
64
1160
                    .cloned()
65
1160
                    .collect::<HashSet<Lit>>()
66
1160
                    .is_subset(&b.iter().cloned().collect::<HashSet<Lit>>()),
67
1160
            )),
68
            _ => None,
69
        },
70
348
        Expr::Intersect(_, a, b) => match (a.as_ref(), b.as_ref()) {
71
            (
72
348
                Expr::Atomic(_, Atom::Literal(Lit::AbstractLiteral(AbstractLiteral::Set(a)))),
73
348
                Expr::Atomic(_, Atom::Literal(Lit::AbstractLiteral(AbstractLiteral::Set(b)))),
74
            ) => {
75
348
                let mut res: Vec<Lit> = Vec::new();
76
812
                for lit in a.iter() {
77
812
                    if b.contains(lit) && !res.contains(lit) {
78
580
                        res.push(lit.clone());
79
580
                    }
80
                }
81
348
                Some(Lit::AbstractLiteral(AbstractLiteral::Set(res)))
82
            }
83
            _ => None,
84
        },
85
348
        Expr::Union(_, a, b) => match (a.as_ref(), b.as_ref()) {
86
            (
87
348
                Expr::Atomic(_, Atom::Literal(Lit::AbstractLiteral(AbstractLiteral::Set(a)))),
88
348
                Expr::Atomic(_, Atom::Literal(Lit::AbstractLiteral(AbstractLiteral::Set(b)))),
89
            ) => {
90
348
                let mut res: Vec<Lit> = Vec::new();
91
928
                for lit in a.iter() {
92
928
                    res.push(lit.clone());
93
928
                }
94
928
                for lit in b.iter() {
95
928
                    if !res.contains(lit) {
96
696
                        res.push(lit.clone());
97
696
                    }
98
                }
99
348
                Some(Lit::AbstractLiteral(AbstractLiteral::Set(res)))
100
            }
101
            _ => None,
102
        },
103
3306
        Expr::In(_, a, b) => {
104
            if let (
105
116
                Expr::Atomic(_, Atom::Literal(Lit::Int(c))),
106
116
                Expr::Atomic(_, Atom::Literal(Lit::AbstractLiteral(AbstractLiteral::Set(d)))),
107
3306
            ) = (a.as_ref(), b.as_ref())
108
            {
109
348
                for lit in d.iter() {
110
348
                    if let Lit::Int(x) = lit
111
348
                        && c == x
112
                    {
113
116
                        return Some(Lit::Bool(true));
114
232
                    }
115
                }
116
                Some(Lit::Bool(false))
117
            } else {
118
3190
                None
119
            }
120
        }
121
        Expr::FromSolution(_, _) => None,
122
        Expr::DominanceRelation(_, _) => None,
123
65428
        Expr::InDomain(_, e, domain) => {
124
65428
            let Expr::Atomic(_, Atom::Literal(lit)) = e.as_ref() else {
125
37516
                return None;
126
            };
127

            
128
27912
            domain.contains(lit).ok().map(Into::into)
129
        }
130
5728276
        Expr::Atomic(_, Atom::Literal(c)) => Some(c.clone()),
131
10260354
        Expr::Atomic(_, Atom::Reference(reference)) => reference.resolve_constant(),
132
2396068
        Expr::AbstractLiteral(_, a) => Some(Lit::AbstractLiteral(a.clone().into_literals()?)),
133
66658
        Expr::Comprehension(_, _) => None,
134
116
        Expr::AbstractComprehension(_, _) => None,
135
1835506
        Expr::UnsafeIndex(_, subject, indices) | Expr::SafeIndex(_, subject, indices) => {
136
3668242
            let subject: Lit = eval_constant(subject.as_ref())?;
137
6786
            let indices: Vec<Lit> = indices
138
6786
                .iter()
139
6786
                .map(eval_constant)
140
6786
                .collect::<Option<Vec<Lit>>>()?;
141

            
142
1102
            match subject {
143
696
                Lit::AbstractLiteral(subject @ AbstractLiteral::Matrix(_, _)) => {
144
696
                    matrix::flatten_enumerate(subject)
145
2436
                        .find(|(i, _)| i == &indices)
146
696
                        .map(|(_, x)| x)
147
                }
148
290
                Lit::AbstractLiteral(subject @ AbstractLiteral::Tuple(_)) => {
149
290
                    let AbstractLiteral::Tuple(elems) = subject else {
150
                        return None;
151
                    };
152

            
153
290
                    assert!(indices.len() == 1, "nested tuples not supported yet");
154

            
155
290
                    let Lit::Int(index) = indices[0].clone() else {
156
                        return None;
157
                    };
158

            
159
290
                    if elems.len() < index as usize || index < 1 {
160
                        return None;
161
290
                    }
162

            
163
                    // -1 for 0-indexing vs 1-indexing
164
290
                    let item = elems[index as usize - 1].clone();
165

            
166
290
                    Some(item)
167
                }
168
116
                Lit::AbstractLiteral(subject @ AbstractLiteral::Record(_)) => {
169
116
                    let AbstractLiteral::Record(elems) = subject else {
170
                        return None;
171
                    };
172

            
173
116
                    assert!(indices.len() == 1, "nested record not supported yet");
174

            
175
116
                    let Lit::Int(index) = indices[0].clone() else {
176
                        return None;
177
                    };
178

            
179
116
                    if elems.len() < index as usize || index < 1 {
180
                        return None;
181
116
                    }
182

            
183
                    // -1 for 0-indexing vs 1-indexing
184
116
                    let item = elems[index as usize - 1].clone();
185
116
                    Some(item.value)
186
                }
187
                _ => None,
188
            }
189
        }
190
66120
        Expr::UnsafeSlice(_, subject, indices) | Expr::SafeSlice(_, subject, indices) => {
191
128644
            let subject: Lit = eval_constant(subject.as_ref())?;
192
58
            let Lit::AbstractLiteral(subject @ AbstractLiteral::Matrix(_, _)) = subject else {
193
                return None;
194
            };
195

            
196
58
            let hole_dim = indices
197
58
                .iter()
198
58
                .cloned()
199
116
                .position(|x| x.is_none())
200
58
                .expect("slice expression should have a hole dimension");
201

            
202
58
            let missing_domain = matrix::index_domains(subject.clone())[hole_dim].clone();
203

            
204
58
            let indices: Vec<Option<Lit>> = indices
205
58
                .iter()
206
58
                .cloned()
207
116
                .map(|x| {
208
                    // the outer option represents success of this iterator, the inner the index
209
                    // slice.
210
116
                    match x {
211
58
                        Some(x) => eval_constant(&x).map(Some),
212
58
                        None => Some(None),
213
                    }
214
116
                })
215
58
                .collect::<Option<Vec<Option<Lit>>>>()?;
216

            
217
58
            let indices_in_slice: Vec<Vec<Lit>> = missing_domain
218
58
                .values()
219
58
                .ok()?
220
174
                .map(|i| {
221
174
                    let mut indices = indices.clone();
222
174
                    indices[hole_dim] = Some(i);
223
                    // These unwraps will only fail if we have multiple holes.
224
                    // As this is invalid, panicking is fine.
225
348
                    indices.into_iter().map(|x| x.unwrap()).collect_vec()
226
174
                })
227
58
                .collect_vec();
228

            
229
            // Note: indices_in_slice is not necessarily sorted, so this is the best way.
230
58
            let elems = matrix::flatten_enumerate(subject)
231
522
                .filter(|(i, _)| indices_in_slice.contains(i))
232
58
                .map(|(_, elem)| elem)
233
58
                .collect();
234

            
235
58
            Some(Lit::AbstractLiteral(into_matrix![elems]))
236
        }
237
12064
        Expr::Abs(_, e) => un_op::<i32, i32>(|a| a.abs(), e).map(Lit::Int),
238
966808
        Expr::Eq(_, a, b) => bin_op::<i32, bool>(|a, b| a == b, a, b)
239
966808
            .or_else(|| bin_op::<bool, bool>(|a, b| a == b, a, b))
240
966808
            .map(Lit::Bool),
241
590708
        Expr::Neq(_, a, b) => bin_op::<i32, bool>(|a, b| a != b, a, b).map(Lit::Bool),
242
95716
        Expr::Lt(_, a, b) => bin_op::<i32, bool>(|a, b| a < b, a, b).map(Lit::Bool),
243
6960
        Expr::Gt(_, a, b) => bin_op::<i32, bool>(|a, b| a > b, a, b).map(Lit::Bool),
244
447312
        Expr::Leq(_, a, b) => bin_op::<i32, bool>(|a, b| a <= b, a, b).map(Lit::Bool),
245
130926
        Expr::Geq(_, a, b) => bin_op::<i32, bool>(|a, b| a >= b, a, b).map(Lit::Bool),
246
28804
        Expr::Not(_, expr) => un_op::<bool, bool>(|e| !e, expr).map(Lit::Bool),
247
572318
        Expr::And(_, e) => {
248
572318
            vec_lit_op::<bool, bool>(|e| e.iter().all(|&e| e), e.as_ref()).map(Lit::Bool)
249
        }
250
1392
        Expr::Table(_, _, _) => None,
251
232
        Expr::NegativeTable(_, _, _) => None,
252
395690
        Expr::Root(_, _) => None,
253
251938
        Expr::Or(_, es) => {
254
            // possibly cheating; definitely should be in partial eval instead
255
251938
            for e in (**es).clone().unwrap_list()? {
256
1474
                if let Expr::Atomic(_, Atom::Literal(Lit::Bool(true))) = e {
257
894
                    return Some(Lit::Bool(true));
258
226196
                };
259
            }
260

            
261
106640
            vec_lit_op::<bool, bool>(|e| e.iter().any(|&e| e), es.as_ref()).map(Lit::Bool)
262
        }
263
169542
        Expr::Imply(_, box1, box2) => {
264
169542
            let a: &Atom = (&**box1).try_into().ok()?;
265
128224
            let b: &Atom = (&**box2).try_into().ok()?;
266

            
267
6280
            let a: bool = a.try_into().ok()?;
268
3380
            let b: bool = b.try_into().ok()?;
269

            
270
3380
            if a {
271
                // true -> b ~> b
272
1820
                Some(Lit::Bool(b))
273
            } else {
274
                // false -> b ~> true
275
1560
                Some(Lit::Bool(true))
276
            }
277
        }
278
2842
        Expr::Iff(_, box1, box2) => {
279
2842
            let a: &Atom = (&**box1).try_into().ok()?;
280
2378
            let b: &Atom = (&**box2).try_into().ok()?;
281

            
282
174
            let a: bool = a.try_into().ok()?;
283
58
            let b: bool = b.try_into().ok()?;
284

            
285
            Some(Lit::Bool(a == b))
286
        }
287
1182378
        Expr::Sum(_, exprs) => vec_lit_op::<i32, i32>(|e| e.iter().sum(), exprs).map(Lit::Int),
288
343418
        Expr::Product(_, exprs) => {
289
343418
            vec_lit_op::<i32, i32>(|e| e.iter().product(), exprs).map(Lit::Int)
290
        }
291
100592
        Expr::FlatIneq(_, a, b, c) => {
292
100592
            let a: i32 = a.try_into().ok()?;
293
33524
            let b: i32 = b.try_into().ok()?;
294
            let c: i32 = c.try_into().ok()?;
295

            
296
            Some(Lit::Bool(a <= b + c))
297
        }
298
170712
        Expr::FlatSumGeq(_, exprs, a) => {
299
323238
            let sum = exprs.iter().try_fold(0, |acc, atom: &Atom| {
300
323238
                let n: i32 = atom.try_into().ok()?;
301
152526
                let acc = acc + n;
302
152526
                Some(acc)
303
323238
            })?;
304

            
305
            Some(Lit::Bool(sum >= a.try_into().ok()?))
306
        }
307
201276
        Expr::FlatSumLeq(_, exprs, a) => {
308
371192
            let sum = exprs.iter().try_fold(0, |acc, atom: &Atom| {
309
371192
                let n: i32 = atom.try_into().ok()?;
310
169916
                let acc = acc + n;
311
169916
                Some(acc)
312
371192
            })?;
313

            
314
            Some(Lit::Bool(sum >= a.try_into().ok()?))
315
        }
316
27492
        Expr::Min(_, e) => {
317
27492
            opt_vec_lit_op::<i32, i32>(|e| e.iter().min().copied(), e.as_ref()).map(Lit::Int)
318
        }
319
25868
        Expr::Max(_, e) => {
320
25868
            opt_vec_lit_op::<i32, i32>(|e| e.iter().max().copied(), e.as_ref()).map(Lit::Int)
321
        }
322
66410
        Expr::UnsafeDiv(_, a, b) | Expr::SafeDiv(_, a, b) => {
323
114550
            if unwrap_expr::<i32>(b)? == 0 {
324
116
                return None;
325
10150
            }
326
10150
            bin_op::<i32, i32>(|a, b| ((a as f32) / (b as f32)).floor() as i32, a, b).map(Lit::Int)
327
        }
328
26332
        Expr::UnsafeMod(_, a, b) | Expr::SafeMod(_, a, b) => {
329
34974
            if unwrap_expr::<i32>(b)? == 0 {
330
                return None;
331
3770
            }
332
3770
            bin_op::<i32, i32>(|a, b| a - b * (a as f32 / b as f32).floor() as i32, a, b)
333
3770
                .map(Lit::Int)
334
        }
335
6786
        Expr::MinionDivEqUndefZero(_, a, b, c) => {
336
            // div always rounds down
337
6786
            let a: i32 = a.try_into().ok()?;
338
116
            let b: i32 = b.try_into().ok()?;
339
            let c: i32 = c.try_into().ok()?;
340

            
341
            if b == 0 {
342
                return None;
343
            }
344

            
345
            let a = a as f32;
346
            let b = b as f32;
347
            let div: i32 = (a / b).floor() as i32;
348
            Some(Lit::Bool(div == c))
349
        }
350
57448
        Expr::Bubble(_, a, b) => bin_op::<bool, bool>(|a, b| a && b, a, b).map(Lit::Bool),
351
71050
        Expr::MinionReify(_, a, b) => {
352
71050
            let result = eval_constant(a)?;
353

            
354
9512
            let result: bool = result.try_into().ok()?;
355
9512
            let b: bool = b.try_into().ok()?;
356

            
357
            Some(Lit::Bool(b == result))
358
        }
359
31196
        Expr::MinionReifyImply(_, a, b) => {
360
31196
            let result = eval_constant(a)?;
361

            
362
            let result: bool = result.try_into().ok()?;
363
            let b: bool = b.try_into().ok()?;
364

            
365
            if b {
366
                Some(Lit::Bool(result))
367
            } else {
368
                Some(Lit::Bool(true))
369
            }
370
        }
371
2146
        Expr::MinionModuloEqUndefZero(_, a, b, c) => {
372
            // From Savile Row. Same semantics as division.
373
            //
374
            //   a - (b * floor(a/b))
375
            //
376
            // We don't use % as it has the same semantics as /. We don't use / as we want to round
377
            // down instead, not towards zero.
378

            
379
2146
            let a: i32 = a.try_into().ok()?;
380
116
            let b: i32 = b.try_into().ok()?;
381
            let c: i32 = c.try_into().ok()?;
382

            
383
            if b == 0 {
384
                return None;
385
            }
386

            
387
            let modulo = a - b * (a as f32 / b as f32).floor() as i32;
388
            Some(Lit::Bool(modulo == c))
389
        }
390
4048
        Expr::MinionPow(_, a, b, c) => {
391
            // only available for positive a b c
392

            
393
4048
            let a: i32 = a.try_into().ok()?;
394
            let b: i32 = b.try_into().ok()?;
395
            let c: i32 = c.try_into().ok()?;
396

            
397
            if a <= 0 {
398
                return None;
399
            }
400

            
401
            if b <= 0 {
402
                return None;
403
            }
404

            
405
            if c <= 0 {
406
                return None;
407
            }
408

            
409
            Some(Lit::Bool(a ^ b == c))
410
        }
411
464
        Expr::MinionWInSet(_, _, _) => None,
412
1102
        Expr::MinionWInIntervalSet(_, x, intervals) => {
413
1102
            let x_lit: &Lit = x.try_into().ok()?;
414

            
415
            let x_lit = match x_lit.clone() {
416
                Lit::Int(i) => Some(i),
417
                Lit::Bool(true) => Some(1),
418
                Lit::Bool(false) => Some(0),
419
                _ => None,
420
            }?;
421

            
422
            let mut intervals = intervals.iter();
423
            loop {
424
                let Some(lower) = intervals.next() else {
425
                    break;
426
                };
427

            
428
                let Some(upper) = intervals.next() else {
429
                    break;
430
                };
431
                if &x_lit >= lower && &x_lit <= upper {
432
                    return Some(Lit::Bool(true));
433
                }
434
            }
435

            
436
            Some(Lit::Bool(false))
437
        }
438
        Expr::Flatten(_, _, _) => {
439
            // TODO
440
11600
            None
441
        }
442
70760
        Expr::AllDiff(_, e) => {
443
70760
            let es = (**e).clone().unwrap_list()?;
444
6670
            let mut lits: HashSet<Lit> = HashSet::new();
445
7482
            for expr in es {
446
2958
                let Expr::Atomic(_, Atom::Literal(x)) = expr else {
447
6438
                    return None;
448
                };
449
1044
                match x {
450
                    Lit::Int(_) | Lit::Bool(_) => {
451
1044
                        if lits.contains(&x) {
452
                            return Some(Lit::Bool(false));
453
1044
                        } else {
454
1044
                            lits.insert(x.clone());
455
1044
                        }
456
                    }
457
                    Lit::AbstractLiteral(_) => return None, // Reject AbstractLiteral cases
458
                }
459
            }
460
232
            Some(Lit::Bool(true))
461
        }
462
44022
        Expr::FlatAllDiff(_, es) => {
463
44022
            let mut lits: HashSet<Lit> = HashSet::new();
464
44022
            for atom in es {
465
44022
                let Atom::Literal(x) = atom else {
466
44022
                    return None;
467
                };
468

            
469
                match x {
470
                    Lit::Int(_) | Lit::Bool(_) => {
471
                        if lits.contains(x) {
472
                            return Some(Lit::Bool(false));
473
                        } else {
474
                            lits.insert(x.clone());
475
                        }
476
                    }
477
                    Lit::AbstractLiteral(_) => return None, // Reject AbstractLiteral cases
478
                }
479
            }
480
            Some(Lit::Bool(true))
481
        }
482
5162
        Expr::FlatWatchedLiteral(_, _, _) => None,
483
77944
        Expr::AuxDeclaration(_, _, _) => None,
484
79392
        Expr::Neg(_, a) => {
485
79392
            let a: &Atom = a.try_into().ok()?;
486
38212
            let a: i32 = a.try_into().ok()?;
487
13852
            Some(Lit::Int(-a))
488
        }
489
165588
        Expr::Minus(_, a, b) => {
490
165588
            let a: &Atom = a.try_into().ok()?;
491
156308
            let a: i32 = a.try_into().ok()?;
492

            
493
14964
            let b: &Atom = b.try_into().ok()?;
494
13920
            let b: i32 = b.try_into().ok()?;
495

            
496
10324
            Some(Lit::Int(a - b))
497
        }
498
3480
        Expr::FlatMinusEq(_, a, b) => {
499
3480
            let a: i32 = a.try_into().ok()?;
500
2784
            let b: i32 = b.try_into().ok()?;
501
            Some(Lit::Bool(a == -b))
502
        }
503
1856
        Expr::FlatProductEq(_, a, b, c) => {
504
1856
            let a: i32 = a.try_into().ok()?;
505
            let b: i32 = b.try_into().ok()?;
506
            let c: i32 = c.try_into().ok()?;
507
            Some(Lit::Bool(a * b == c))
508
        }
509
39150
        Expr::FlatWeightedSumLeq(_, cs, vs, total) => {
510
39150
            let cs: Vec<i32> = cs
511
39150
                .iter()
512
82940
                .map(|x| TryInto::<i32>::try_into(x).ok())
513
39150
                .collect::<Option<Vec<i32>>>()?;
514
39150
            let vs: Vec<i32> = vs
515
39150
                .iter()
516
50286
                .map(|x| TryInto::<i32>::try_into(x).ok())
517
39150
                .collect::<Option<Vec<i32>>>()?;
518
            let total: i32 = total.try_into().ok()?;
519

            
520
            let sum: i32 = izip!(cs, vs).fold(0, |acc, (c, v)| acc + (c * v));
521

            
522
            Some(Lit::Bool(sum <= total))
523
        }
524
36714
        Expr::FlatWeightedSumGeq(_, cs, vs, total) => {
525
36714
            let cs: Vec<i32> = cs
526
36714
                .iter()
527
76734
                .map(|x| TryInto::<i32>::try_into(x).ok())
528
36714
                .collect::<Option<Vec<i32>>>()?;
529
36714
            let vs: Vec<i32> = vs
530
36714
                .iter()
531
46922
                .map(|x| TryInto::<i32>::try_into(x).ok())
532
36714
                .collect::<Option<Vec<i32>>>()?;
533
            let total: i32 = total.try_into().ok()?;
534

            
535
            let sum: i32 = izip!(cs, vs).fold(0, |acc, (c, v)| acc + (c * v));
536

            
537
            Some(Lit::Bool(sum >= total))
538
        }
539
1218
        Expr::FlatAbsEq(_, x, y) => {
540
1218
            let x: i32 = x.try_into().ok()?;
541
58
            let y: i32 = y.try_into().ok()?;
542

            
543
            Some(Lit::Bool(x == y.abs()))
544
        }
545
360858
        Expr::UnsafePow(_, a, b) | Expr::SafePow(_, a, b) => {
546
377318
            let a: &Atom = a.try_into().ok()?;
547
375404
            let a: i32 = a.try_into().ok()?;
548

            
549
348732
            let b: &Atom = b.try_into().ok()?;
550
348732
            let b: i32 = b.try_into().ok()?;
551

            
552
348732
            if (a != 0 || b != 0) && b >= 0 {
553
348732
                Some(Lit::Int(a.pow(b as u32)))
554
            } else {
555
                None
556
            }
557
        }
558
        Expr::Metavar(_, _) => None,
559
38180
        Expr::MinionElementOne(_, _, _, _) => None,
560
1508
        Expr::ToInt(_, expression) => {
561
1508
            let lit = eval_constant(expression.as_ref())?;
562
            match lit {
563
                Lit::Int(_) => Some(lit),
564
                Lit::Bool(true) => Some(Lit::Int(1)),
565
                Lit::Bool(false) => Some(Lit::Int(0)),
566
                _ => None,
567
            }
568
        }
569
        Expr::SATInt(_, _, _, _) => {
570
            // TODO: If this SATInt is composed of literals, we should evaluate it back to an
571
            // integer literal.
572
            //
573
            // This is important because `is_all_constant` currently returns true for SATInts
574
            // containing no references. If we don't evaluate them here, bubble rules will skip
575
            // them (thinking they'll be constant-folded later), but they'll actually reach
576
            // the solver adaptors as un-encoded unsafe operations, causing panics.
577
625838
            None
578
        }
579
        Expr::PairwiseSum(_, a, b) => {
580
            match (eval_constant(a.as_ref())?, eval_constant(b.as_ref())?) {
581
                (Lit::Int(a_int), Lit::Int(b_int)) => Some(Lit::Int(a_int + b_int)),
582
                _ => None,
583
            }
584
        }
585
        Expr::PairwiseProduct(_, a, b) => {
586
            match (eval_constant(a.as_ref())?, eval_constant(b.as_ref())?) {
587
                (Lit::Int(a_int), Lit::Int(b_int)) => Some(Lit::Int(a_int * b_int)),
588
                _ => None,
589
            }
590
        }
591
        Expr::Defined(_, _) => todo!(),
592
        Expr::Range(_, _) => todo!(),
593
        Expr::Image(_, _, _) => todo!(),
594
        Expr::ImageSet(_, _, _) => todo!(),
595
        Expr::PreImage(_, _, _) => todo!(),
596
        Expr::Inverse(_, _, _) => todo!(),
597
        Expr::Restrict(_, _, _) => todo!(),
598
812
        Expr::LexLt(_, a, b) => {
599
812
            let lt = vec_expr_pairs_op::<i32, _>(a, b, |pairs, (a_len, b_len)| {
600
232
                pairs
601
232
                    .iter()
602
464
                    .find_map(|(a, b)| match a.cmp(b) {
603
116
                        CmpOrdering::Less => Some(true),     // First difference is <
604
                        CmpOrdering::Greater => Some(false), // First difference is >
605
348
                        CmpOrdering::Equal => None,          // No difference
606
464
                    })
607
232
                    .unwrap_or(a_len < b_len) // [1,1] <lex [1,1,x]
608
580
            })?;
609
232
            Some(lt.into())
610
        }
611
70470
        Expr::LexLeq(_, a, b) => {
612
70470
            let lt = vec_expr_pairs_op::<i32, _>(a, b, |pairs, (a_len, b_len)| {
613
116
                pairs
614
116
                    .iter()
615
232
                    .find_map(|(a, b)| match a.cmp(b) {
616
116
                        CmpOrdering::Less => Some(true),
617
                        CmpOrdering::Greater => Some(false),
618
116
                        CmpOrdering::Equal => None,
619
232
                    })
620
116
                    .unwrap_or(a_len <= b_len) // [1,1] <=lex [1,1,x]
621
70354
            })?;
622
116
            Some(lt.into())
623
        }
624
        Expr::LexGt(_, a, b) => eval_constant(&Expr::LexLt(Metadata::new(), b.clone(), a.clone())),
625
        Expr::LexGeq(_, a, b) => {
626
            eval_constant(&Expr::LexLeq(Metadata::new(), b.clone(), a.clone()))
627
        }
628
116
        Expr::FlatLexLt(_, a, b) => {
629
116
            let lt = atoms_pairs_op::<i32, _>(a, b, |pairs, (a_len, b_len)| {
630
                pairs
631
                    .iter()
632
                    .find_map(|(a, b)| match a.cmp(b) {
633
                        CmpOrdering::Less => Some(true),
634
                        CmpOrdering::Greater => Some(false),
635
                        CmpOrdering::Equal => None,
636
                    })
637
                    .unwrap_or(a_len < b_len)
638
116
            })?;
639
            Some(lt.into())
640
        }
641
232
        Expr::FlatLexLeq(_, a, b) => {
642
232
            let lt = atoms_pairs_op::<i32, _>(a, b, |pairs, (a_len, b_len)| {
643
                pairs
644
                    .iter()
645
                    .find_map(|(a, b)| match a.cmp(b) {
646
                        CmpOrdering::Less => Some(true),
647
                        CmpOrdering::Greater => Some(false),
648
                        CmpOrdering::Equal => None,
649
                    })
650
                    .unwrap_or(a_len <= b_len)
651
232
            })?;
652
            Some(lt.into())
653
        }
654
    }
655
30020568
}
656

            
657
40868
pub fn un_op<T, A>(f: fn(T) -> A, a: &Expr) -> Option<A>
658
40868
where
659
40868
    T: TryFrom<Lit>,
660
{
661
40868
    let a = unwrap_expr::<T>(a)?;
662
3630
    Some(f(a))
663
40868
}
664

            
665
3156360
pub fn bin_op<T, A>(f: fn(T, T) -> A, a: &Expr, b: &Expr) -> Option<A>
666
3156360
where
667
3156360
    T: TryFrom<Lit>,
668
{
669
3156360
    let a = unwrap_expr::<T>(a)?;
670
440452
    let b = unwrap_expr::<T>(b)?;
671
364124
    Some(f(a, b))
672
3156360
}
673

            
674
#[allow(dead_code)]
675
pub fn tern_op<T, A>(f: fn(T, T, T) -> A, a: &Expr, b: &Expr, c: &Expr) -> Option<A>
676
where
677
    T: TryFrom<Lit>,
678
{
679
    let a = unwrap_expr::<T>(a)?;
680
    let b = unwrap_expr::<T>(b)?;
681
    let c = unwrap_expr::<T>(c)?;
682
    Some(f(a, b, c))
683
}
684

            
685
40993
pub fn vec_op<T, A>(f: fn(Vec<T>) -> A, a: &[Expr]) -> Option<A>
686
40993
where
687
40993
    T: TryFrom<Lit>,
688
{
689
40993
    let a = a.iter().map(unwrap_expr).collect::<Option<Vec<T>>>()?;
690
18
    Some(f(a))
691
40993
}
692

            
693
2204754
pub fn vec_lit_op<T, A>(f: fn(Vec<T>) -> A, a: &Expr) -> Option<A>
694
2204754
where
695
2204754
    T: TryFrom<Lit>,
696
{
697
    // we don't care about preserving indices here, as we will be getting rid of the vector
698
    // anyways!
699
2204754
    let a = a.clone().unwrap_matrix_unchecked()?.0;
700
2094074
    let a = a.iter().map(unwrap_expr).collect::<Option<Vec<T>>>()?;
701
399972
    Some(f(a))
702
2204754
}
703

            
704
type PairsCallback<T, A> = fn(Vec<(T, T)>, (usize, usize)) -> A;
705

            
706
/// Calls the given function on each consecutive pair of elements in the list expressions.
707
/// Also passes the length of the two lists.
708
71282
fn vec_expr_pairs_op<T, A>(a: &Expr, b: &Expr, f: PairsCallback<T, A>) -> Option<A>
709
71282
where
710
71282
    T: TryFrom<Lit>,
711
{
712
71282
    let a_exprs = a.clone().unwrap_matrix_unchecked()?.0;
713
1044
    let b_exprs = b.clone().unwrap_matrix_unchecked()?.0;
714
696
    let lens = (a_exprs.len(), b_exprs.len());
715

            
716
696
    let lit_pairs = std::iter::zip(a_exprs, b_exprs)
717
1044
        .map(|(a, b)| Some((unwrap_expr(&a)?, unwrap_expr(&b)?)))
718
696
        .collect::<Option<Vec<(T, T)>>>()?;
719
348
    Some(f(lit_pairs, lens))
720
71282
}
721

            
722
/// Same as [`vec_expr_pairs_op`], but over slices of atoms.
723
348
fn atoms_pairs_op<T, A>(a: &[Atom], b: &[Atom], f: PairsCallback<T, A>) -> Option<A>
724
348
where
725
348
    T: TryFrom<Atom>,
726
{
727
348
    let lit_pairs = Iterator::zip(a.iter(), b.iter())
728
348
        .map(|(a, b)| Some((a.clone().try_into().ok()?, b.clone().try_into().ok()?)))
729
348
        .collect::<Option<Vec<(T, T)>>>()?;
730
    Some(f(lit_pairs, (a.len(), b.len())))
731
348
}
732

            
733
pub fn opt_vec_op<T, A>(f: fn(Vec<T>) -> Option<A>, a: &[Expr]) -> Option<A>
734
where
735
    T: TryFrom<Lit>,
736
{
737
    let a = a.iter().map(unwrap_expr).collect::<Option<Vec<T>>>()?;
738
    f(a)
739
}
740

            
741
53360
pub fn opt_vec_lit_op<T, A>(f: fn(Vec<T>) -> Option<A>, a: &Expr) -> Option<A>
742
53360
where
743
53360
    T: TryFrom<Lit>,
744
{
745
53360
    let a = a.clone().unwrap_list()?;
746
    // FIXME: deal with explicit matrix domains
747
3886
    let a = a.iter().map(unwrap_expr).collect::<Option<Vec<T>>>()?;
748
58
    f(a)
749
53360
}
750

            
751
#[allow(dead_code)]
752
pub fn flat_op<T, A>(f: fn(Vec<T>, T) -> A, a: &[Expr], b: &Expr) -> Option<A>
753
where
754
    T: TryFrom<Lit>,
755
{
756
    let a = a.iter().map(unwrap_expr).collect::<Option<Vec<T>>>()?;
757
    let b = unwrap_expr::<T>(b)?;
758
    Some(f(a, b))
759
}
760

            
761
6749062
pub fn unwrap_expr<T: TryFrom<Lit>>(expr: &Expr) -> Option<T> {
762
6749062
    let c = eval_constant(expr)?;
763
1943060
    TryInto::<T>::try_into(c).ok()
764
6749062
}