1
#![allow(dead_code)]
2
use crate::ast::{AbstractLiteral, Atom, Expression as Expr, Literal as Lit, Metadata, matrix};
3
use crate::into_matrix;
4
use itertools::{Itertools as _, izip};
5
use std::cmp::Ordering as CmpOrdering;
6
use std::collections::HashSet;
7

            
8
/// Simplify an expression to a constant if possible
9
/// Returns:
10
/// `None` if the expression cannot be simplified to a constant (e.g. if it contains a variable)
11
/// `Some(Const)` if the expression can be simplified to a constant
12
20076620
pub fn eval_constant(expr: &Expr) -> Option<Lit> {
13
10630520
    match expr {
14
320
        Expr::Supset(_, a, b) => match (a.as_ref(), b.as_ref()) {
15
            (
16
320
                Expr::Atomic(_, Atom::Literal(Lit::AbstractLiteral(AbstractLiteral::Set(a)))),
17
320
                Expr::Atomic(_, Atom::Literal(Lit::AbstractLiteral(AbstractLiteral::Set(b)))),
18
            ) => {
19
320
                let a_set: HashSet<Lit> = a.iter().cloned().collect();
20
320
                let b_set: HashSet<Lit> = b.iter().cloned().collect();
21

            
22
320
                if a_set.difference(&b_set).count() > 0 {
23
240
                    Some(Lit::Bool(a_set.is_superset(&b_set)))
24
                } else {
25
80
                    Some(Lit::Bool(false))
26
                }
27
            }
28
            _ => None,
29
        },
30
320
        Expr::SupsetEq(_, a, b) => match (a.as_ref(), b.as_ref()) {
31
            (
32
320
                Expr::Atomic(_, Atom::Literal(Lit::AbstractLiteral(AbstractLiteral::Set(a)))),
33
320
                Expr::Atomic(_, Atom::Literal(Lit::AbstractLiteral(AbstractLiteral::Set(b)))),
34
320
            ) => Some(Lit::Bool(
35
320
                a.iter()
36
320
                    .cloned()
37
320
                    .collect::<HashSet<Lit>>()
38
320
                    .is_superset(&b.iter().cloned().collect::<HashSet<Lit>>()),
39
320
            )),
40
            _ => None,
41
        },
42
400
        Expr::Subset(_, a, b) => match (a.as_ref(), b.as_ref()) {
43
            (
44
400
                Expr::Atomic(_, Atom::Literal(Lit::AbstractLiteral(AbstractLiteral::Set(a)))),
45
400
                Expr::Atomic(_, Atom::Literal(Lit::AbstractLiteral(AbstractLiteral::Set(b)))),
46
            ) => {
47
400
                let a_set: HashSet<Lit> = a.iter().cloned().collect();
48
400
                let b_set: HashSet<Lit> = b.iter().cloned().collect();
49

            
50
400
                if b_set.difference(&a_set).count() > 0 {
51
320
                    Some(Lit::Bool(a_set.is_subset(&b_set)))
52
                } else {
53
80
                    Some(Lit::Bool(false))
54
                }
55
            }
56
            _ => None,
57
        },
58
800
        Expr::SubsetEq(_, a, b) => match (a.as_ref(), b.as_ref()) {
59
            (
60
800
                Expr::Atomic(_, Atom::Literal(Lit::AbstractLiteral(AbstractLiteral::Set(a)))),
61
800
                Expr::Atomic(_, Atom::Literal(Lit::AbstractLiteral(AbstractLiteral::Set(b)))),
62
800
            ) => Some(Lit::Bool(
63
800
                a.iter()
64
800
                    .cloned()
65
800
                    .collect::<HashSet<Lit>>()
66
800
                    .is_subset(&b.iter().cloned().collect::<HashSet<Lit>>()),
67
800
            )),
68
            _ => None,
69
        },
70
240
        Expr::Intersect(_, a, b) => match (a.as_ref(), b.as_ref()) {
71
            (
72
240
                Expr::Atomic(_, Atom::Literal(Lit::AbstractLiteral(AbstractLiteral::Set(a)))),
73
240
                Expr::Atomic(_, Atom::Literal(Lit::AbstractLiteral(AbstractLiteral::Set(b)))),
74
            ) => {
75
240
                let mut res: Vec<Lit> = Vec::new();
76
560
                for lit in a.iter() {
77
560
                    if b.contains(lit) && !res.contains(lit) {
78
400
                        res.push(lit.clone());
79
400
                    }
80
                }
81
240
                Some(Lit::AbstractLiteral(AbstractLiteral::Set(res)))
82
            }
83
            _ => None,
84
        },
85
240
        Expr::Union(_, a, b) => match (a.as_ref(), b.as_ref()) {
86
            (
87
240
                Expr::Atomic(_, Atom::Literal(Lit::AbstractLiteral(AbstractLiteral::Set(a)))),
88
240
                Expr::Atomic(_, Atom::Literal(Lit::AbstractLiteral(AbstractLiteral::Set(b)))),
89
            ) => {
90
240
                let mut res: Vec<Lit> = Vec::new();
91
640
                for lit in a.iter() {
92
640
                    res.push(lit.clone());
93
640
                }
94
640
                for lit in b.iter() {
95
640
                    if !res.contains(lit) {
96
480
                        res.push(lit.clone());
97
480
                    }
98
                }
99
240
                Some(Lit::AbstractLiteral(AbstractLiteral::Set(res)))
100
            }
101
            _ => None,
102
        },
103
2280
        Expr::In(_, a, b) => {
104
            if let (
105
80
                Expr::Atomic(_, Atom::Literal(Lit::Int(c))),
106
80
                Expr::Atomic(_, Atom::Literal(Lit::AbstractLiteral(AbstractLiteral::Set(d)))),
107
2280
            ) = (a.as_ref(), b.as_ref())
108
            {
109
240
                for lit in d.iter() {
110
240
                    if let Lit::Int(x) = lit
111
240
                        && c == x
112
                    {
113
80
                        return Some(Lit::Bool(true));
114
160
                    }
115
                }
116
                Some(Lit::Bool(false))
117
            } else {
118
2200
                None
119
            }
120
        }
121
        Expr::FromSolution(_, _) => None,
122
        Expr::DominanceRelation(_, _) => None,
123
45040
        Expr::InDomain(_, e, domain) => {
124
45040
            let Expr::Atomic(_, Atom::Literal(lit)) = e.as_ref() else {
125
25840
                return None;
126
            };
127

            
128
19200
            domain.contains(lit).ok().map(Into::into)
129
        }
130
3928820
        Expr::Atomic(_, Atom::Literal(c)) => Some(c.clone()),
131
6701700
        Expr::Atomic(_, Atom::Reference(reference)) => reference.resolve_constant(),
132
1614160
        Expr::AbstractLiteral(_, a) => Some(Lit::AbstractLiteral(a.clone().into_literals()?)),
133
45960
        Expr::Comprehension(_, _) => None,
134
80
        Expr::AbstractComprehension(_, _) => None,
135
1262920
        Expr::UnsafeIndex(_, subject, indices) | Expr::SafeIndex(_, subject, indices) => {
136
2525080
            let subject: Lit = eval_constant(subject.as_ref())?;
137
4680
            let indices: Vec<Lit> = indices
138
4680
                .iter()
139
4680
                .map(eval_constant)
140
4680
                .collect::<Option<Vec<Lit>>>()?;
141

            
142
760
            match subject {
143
480
                Lit::AbstractLiteral(subject @ AbstractLiteral::Matrix(_, _)) => {
144
480
                    matrix::flatten_enumerate(subject)
145
1680
                        .find(|(i, _)| i == &indices)
146
480
                        .map(|(_, x)| x)
147
                }
148
200
                Lit::AbstractLiteral(subject @ AbstractLiteral::Tuple(_)) => {
149
200
                    let AbstractLiteral::Tuple(elems) = subject else {
150
                        return None;
151
                    };
152

            
153
200
                    assert!(indices.len() == 1, "nested tuples not supported yet");
154

            
155
200
                    let Lit::Int(index) = indices[0].clone() else {
156
                        return None;
157
                    };
158

            
159
200
                    if elems.len() < index as usize || index < 1 {
160
                        return None;
161
200
                    }
162

            
163
                    // -1 for 0-indexing vs 1-indexing
164
200
                    let item = elems[index as usize - 1].clone();
165

            
166
200
                    Some(item)
167
                }
168
80
                Lit::AbstractLiteral(subject @ AbstractLiteral::Record(_)) => {
169
80
                    let AbstractLiteral::Record(elems) = subject else {
170
                        return None;
171
                    };
172

            
173
80
                    assert!(indices.len() == 1, "nested record not supported yet");
174

            
175
80
                    let Lit::Int(index) = indices[0].clone() else {
176
                        return None;
177
                    };
178

            
179
80
                    if elems.len() < index as usize || index < 1 {
180
                        return None;
181
80
                    }
182

            
183
                    // -1 for 0-indexing vs 1-indexing
184
80
                    let item = elems[index as usize - 1].clone();
185
80
                    Some(item.value)
186
                }
187
                _ => None,
188
            }
189
        }
190
45600
        Expr::UnsafeSlice(_, subject, indices) | Expr::SafeSlice(_, subject, indices) => {
191
88720
            let subject: Lit = eval_constant(subject.as_ref())?;
192
40
            let Lit::AbstractLiteral(subject @ AbstractLiteral::Matrix(_, _)) = subject else {
193
                return None;
194
            };
195

            
196
40
            let hole_dim = indices
197
40
                .iter()
198
40
                .cloned()
199
80
                .position(|x| x.is_none())
200
40
                .expect("slice expression should have a hole dimension");
201

            
202
40
            let missing_domain = matrix::index_domains(subject.clone())[hole_dim].clone();
203

            
204
40
            let indices: Vec<Option<Lit>> = indices
205
40
                .iter()
206
40
                .cloned()
207
80
                .map(|x| {
208
                    // the outer option represents success of this iterator, the inner the index
209
                    // slice.
210
80
                    match x {
211
40
                        Some(x) => eval_constant(&x).map(Some),
212
40
                        None => Some(None),
213
                    }
214
80
                })
215
40
                .collect::<Option<Vec<Option<Lit>>>>()?;
216

            
217
40
            let indices_in_slice: Vec<Vec<Lit>> = missing_domain
218
40
                .values()
219
40
                .ok()?
220
120
                .map(|i| {
221
120
                    let mut indices = indices.clone();
222
120
                    indices[hole_dim] = Some(i);
223
                    // These unwraps will only fail if we have multiple holes.
224
                    // As this is invalid, panicking is fine.
225
240
                    indices.into_iter().map(|x| x.unwrap()).collect_vec()
226
120
                })
227
40
                .collect_vec();
228

            
229
            // Note: indices_in_slice is not necessarily sorted, so this is the best way.
230
40
            let elems = matrix::flatten_enumerate(subject)
231
360
                .filter(|(i, _)| indices_in_slice.contains(i))
232
40
                .map(|(_, elem)| elem)
233
40
                .collect();
234

            
235
40
            Some(Lit::AbstractLiteral(into_matrix![elems]))
236
        }
237
8320
        Expr::Abs(_, e) => un_op::<i32, i32>(|a| a.abs(), e).map(Lit::Int),
238
654000
        Expr::Eq(_, a, b) => bin_op::<i32, bool>(|a, b| a == b, a, b)
239
654000
            .or_else(|| bin_op::<bool, bool>(|a, b| a == b, a, b))
240
654000
            .map(Lit::Bool),
241
402040
        Expr::Neq(_, a, b) => bin_op::<i32, bool>(|a, b| a != b, a, b).map(Lit::Bool),
242
63600
        Expr::Lt(_, a, b) => bin_op::<i32, bool>(|a, b| a < b, a, b).map(Lit::Bool),
243
4560
        Expr::Gt(_, a, b) => bin_op::<i32, bool>(|a, b| a > b, a, b).map(Lit::Bool),
244
305240
        Expr::Leq(_, a, b) => bin_op::<i32, bool>(|a, b| a <= b, a, b).map(Lit::Bool),
245
89440
        Expr::Geq(_, a, b) => bin_op::<i32, bool>(|a, b| a >= b, a, b).map(Lit::Bool),
246
14920
        Expr::Not(_, expr) => un_op::<bool, bool>(|e| !e, expr).map(Lit::Bool),
247
373960
        Expr::And(_, e) => {
248
373960
            vec_lit_op::<bool, bool>(|e| e.iter().all(|&e| e), e.as_ref()).map(Lit::Bool)
249
        }
250
480
        Expr::Table(_, _, _) => None,
251
259660
        Expr::Root(_, _) => None,
252
164800
        Expr::Or(_, es) => {
253
            // possibly cheating; definitely should be in partial eval instead
254
164800
            for e in (**es).clone().unwrap_list()? {
255
1000
                if let Expr::Atomic(_, Atom::Literal(Lit::Bool(true))) = e {
256
600
                    return Some(Lit::Bool(true));
257
145040
                };
258
            }
259

            
260
65720
            vec_lit_op::<bool, bool>(|e| e.iter().any(|&e| e), es.as_ref()).map(Lit::Bool)
261
        }
262
114360
        Expr::Imply(_, box1, box2) => {
263
114360
            let a: &Atom = (&**box1).try_into().ok()?;
264
86080
            let b: &Atom = (&**box2).try_into().ok()?;
265

            
266
2000
            let a: bool = a.try_into().ok()?;
267
            let b: bool = b.try_into().ok()?;
268

            
269
            if a {
270
                // true -> b ~> b
271
                Some(Lit::Bool(b))
272
            } else {
273
                // false -> b ~> true
274
                Some(Lit::Bool(true))
275
            }
276
        }
277
1960
        Expr::Iff(_, box1, box2) => {
278
1960
            let a: &Atom = (&**box1).try_into().ok()?;
279
1640
            let b: &Atom = (&**box2).try_into().ok()?;
280

            
281
120
            let a: bool = a.try_into().ok()?;
282
40
            let b: bool = b.try_into().ok()?;
283

            
284
            Some(Lit::Bool(a == b))
285
        }
286
811940
        Expr::Sum(_, exprs) => vec_lit_op::<i32, i32>(|e| e.iter().sum(), exprs).map(Lit::Int),
287
236840
        Expr::Product(_, exprs) => {
288
236840
            vec_lit_op::<i32, i32>(|e| e.iter().product(), exprs).map(Lit::Int)
289
        }
290
69280
        Expr::FlatIneq(_, a, b, c) => {
291
69280
            let a: i32 = a.try_into().ok()?;
292
23120
            let b: i32 = b.try_into().ok()?;
293
            let c: i32 = c.try_into().ok()?;
294

            
295
            Some(Lit::Bool(a <= b + c))
296
        }
297
116800
        Expr::FlatSumGeq(_, exprs, a) => {
298
221080
            let sum = exprs.iter().try_fold(0, |acc, atom: &Atom| {
299
221080
                let n: i32 = atom.try_into().ok()?;
300
104280
                let acc = acc + n;
301
104280
                Some(acc)
302
221080
            })?;
303

            
304
            Some(Lit::Bool(sum >= a.try_into().ok()?))
305
        }
306
137840
        Expr::FlatSumLeq(_, exprs, a) => {
307
254080
            let sum = exprs.iter().try_fold(0, |acc, atom: &Atom| {
308
254080
                let n: i32 = atom.try_into().ok()?;
309
116240
                let acc = acc + n;
310
116240
                Some(acc)
311
254080
            })?;
312

            
313
            Some(Lit::Bool(sum >= a.try_into().ok()?))
314
        }
315
18960
        Expr::Min(_, e) => {
316
18960
            opt_vec_lit_op::<i32, i32>(|e| e.iter().min().copied(), e.as_ref()).map(Lit::Int)
317
        }
318
17840
        Expr::Max(_, e) => {
319
17840
            opt_vec_lit_op::<i32, i32>(|e| e.iter().max().copied(), e.as_ref()).map(Lit::Int)
320
        }
321
31240
        Expr::UnsafeDiv(_, a, b) | Expr::SafeDiv(_, a, b) => {
322
39720
            if unwrap_expr::<i32>(b)? == 0 {
323
80
                return None;
324
7000
            }
325
7000
            bin_op::<i32, i32>(|a, b| ((a as f32) / (b as f32)).floor() as i32, a, b).map(Lit::Int)
326
        }
327
18160
        Expr::UnsafeMod(_, a, b) | Expr::SafeMod(_, a, b) => {
328
24120
            if unwrap_expr::<i32>(b)? == 0 {
329
                return None;
330
2600
            }
331
2600
            bin_op::<i32, i32>(|a, b| a - b * (a as f32 / b as f32).floor() as i32, a, b)
332
2600
                .map(Lit::Int)
333
        }
334
4480
        Expr::MinionDivEqUndefZero(_, a, b, c) => {
335
            // div always rounds down
336
4480
            let a: i32 = a.try_into().ok()?;
337
80
            let b: i32 = b.try_into().ok()?;
338
            let c: i32 = c.try_into().ok()?;
339

            
340
            if b == 0 {
341
                return None;
342
            }
343

            
344
            let a = a as f32;
345
            let b = b as f32;
346
            let div: i32 = (a / b).floor() as i32;
347
            Some(Lit::Bool(div == c))
348
        }
349
36080
        Expr::Bubble(_, a, b) => bin_op::<bool, bool>(|a, b| a && b, a, b).map(Lit::Bool),
350
48520
        Expr::MinionReify(_, a, b) => {
351
48520
            let result = eval_constant(a)?;
352

            
353
6560
            let result: bool = result.try_into().ok()?;
354
6560
            let b: bool = b.try_into().ok()?;
355

            
356
            Some(Lit::Bool(b == result))
357
        }
358
21040
        Expr::MinionReifyImply(_, a, b) => {
359
21040
            let result = eval_constant(a)?;
360

            
361
            let result: bool = result.try_into().ok()?;
362
            let b: bool = b.try_into().ok()?;
363

            
364
            if b {
365
                Some(Lit::Bool(result))
366
            } else {
367
                Some(Lit::Bool(true))
368
            }
369
        }
370
1480
        Expr::MinionModuloEqUndefZero(_, a, b, c) => {
371
            // From Savile Row. Same semantics as division.
372
            //
373
            //   a - (b * floor(a/b))
374
            //
375
            // We don't use % as it has the same semantics as /. We don't use / as we want to round
376
            // down instead, not towards zero.
377

            
378
1480
            let a: i32 = a.try_into().ok()?;
379
80
            let b: i32 = b.try_into().ok()?;
380
            let c: i32 = c.try_into().ok()?;
381

            
382
            if b == 0 {
383
                return None;
384
            }
385

            
386
            let modulo = a - b * (a as f32 / b as f32).floor() as i32;
387
            Some(Lit::Bool(modulo == c))
388
        }
389
2720
        Expr::MinionPow(_, a, b, c) => {
390
            // only available for positive a b c
391

            
392
2720
            let a: i32 = a.try_into().ok()?;
393
            let b: i32 = b.try_into().ok()?;
394
            let c: i32 = c.try_into().ok()?;
395

            
396
            if a <= 0 {
397
                return None;
398
            }
399

            
400
            if b <= 0 {
401
                return None;
402
            }
403

            
404
            if c <= 0 {
405
                return None;
406
            }
407

            
408
            Some(Lit::Bool(a ^ b == c))
409
        }
410
320
        Expr::MinionWInSet(_, _, _) => None,
411
760
        Expr::MinionWInIntervalSet(_, x, intervals) => {
412
760
            let x_lit: &Lit = x.try_into().ok()?;
413

            
414
            let x_lit = match x_lit.clone() {
415
                Lit::Int(i) => Some(i),
416
                Lit::Bool(true) => Some(1),
417
                Lit::Bool(false) => Some(0),
418
                _ => None,
419
            }?;
420

            
421
            let mut intervals = intervals.iter();
422
            loop {
423
                let Some(lower) = intervals.next() else {
424
                    break;
425
                };
426

            
427
                let Some(upper) = intervals.next() else {
428
                    break;
429
                };
430
                if &x_lit >= lower && &x_lit <= upper {
431
                    return Some(Lit::Bool(true));
432
                }
433
            }
434

            
435
            Some(Lit::Bool(false))
436
        }
437
        Expr::Flatten(_, _, _) => {
438
            // TODO
439
8000
            None
440
        }
441
48800
        Expr::AllDiff(_, e) => {
442
48800
            let es = (**e).clone().unwrap_list()?;
443
4600
            let mut lits: HashSet<Lit> = HashSet::new();
444
5160
            for expr in es {
445
2040
                let Expr::Atomic(_, Atom::Literal(x)) = expr else {
446
4440
                    return None;
447
                };
448
720
                match x {
449
                    Lit::Int(_) | Lit::Bool(_) => {
450
720
                        if lits.contains(&x) {
451
                            return Some(Lit::Bool(false));
452
720
                        } else {
453
720
                            lits.insert(x.clone());
454
720
                        }
455
                    }
456
                    Lit::AbstractLiteral(_) => return None, // Reject AbstractLiteral cases
457
                }
458
            }
459
160
            Some(Lit::Bool(true))
460
        }
461
30360
        Expr::FlatAllDiff(_, es) => {
462
30360
            let mut lits: HashSet<Lit> = HashSet::new();
463
30360
            for atom in es {
464
30360
                let Atom::Literal(x) = atom else {
465
30360
                    return None;
466
                };
467

            
468
                match x {
469
                    Lit::Int(_) | Lit::Bool(_) => {
470
                        if lits.contains(x) {
471
                            return Some(Lit::Bool(false));
472
                        } else {
473
                            lits.insert(x.clone());
474
                        }
475
                    }
476
                    Lit::AbstractLiteral(_) => return None, // Reject AbstractLiteral cases
477
                }
478
            }
479
            Some(Lit::Bool(true))
480
        }
481
3560
        Expr::FlatWatchedLiteral(_, _, _) => None,
482
53640
        Expr::AuxDeclaration(_, _, _) => None,
483
56000
        Expr::Neg(_, a) => {
484
56000
            let a: &Atom = a.try_into().ok()?;
485
26520
            let a: i32 = a.try_into().ok()?;
486
9620
            Some(Lit::Int(-a))
487
        }
488
113980
        Expr::Minus(_, a, b) => {
489
113980
            let a: &Atom = a.try_into().ok()?;
490
107580
            let a: i32 = a.try_into().ok()?;
491

            
492
10300
            let b: &Atom = b.try_into().ok()?;
493
9580
            let b: i32 = b.try_into().ok()?;
494

            
495
7100
            Some(Lit::Int(a - b))
496
        }
497
2440
        Expr::FlatMinusEq(_, a, b) => {
498
2440
            let a: i32 = a.try_into().ok()?;
499
1920
            let b: i32 = b.try_into().ok()?;
500
            Some(Lit::Bool(a == -b))
501
        }
502
1280
        Expr::FlatProductEq(_, a, b, c) => {
503
1280
            let a: i32 = a.try_into().ok()?;
504
            let b: i32 = b.try_into().ok()?;
505
            let c: i32 = c.try_into().ok()?;
506
            Some(Lit::Bool(a * b == c))
507
        }
508
27000
        Expr::FlatWeightedSumLeq(_, cs, vs, total) => {
509
27000
            let cs: Vec<i32> = cs
510
27000
                .iter()
511
57200
                .map(|x| TryInto::<i32>::try_into(x).ok())
512
27000
                .collect::<Option<Vec<i32>>>()?;
513
27000
            let vs: Vec<i32> = vs
514
27000
                .iter()
515
34680
                .map(|x| TryInto::<i32>::try_into(x).ok())
516
27000
                .collect::<Option<Vec<i32>>>()?;
517
            let total: i32 = total.try_into().ok()?;
518

            
519
            let sum: i32 = izip!(cs, vs).fold(0, |acc, (c, v)| acc + (c * v));
520

            
521
            Some(Lit::Bool(sum <= total))
522
        }
523
25320
        Expr::FlatWeightedSumGeq(_, cs, vs, total) => {
524
25320
            let cs: Vec<i32> = cs
525
25320
                .iter()
526
52920
                .map(|x| TryInto::<i32>::try_into(x).ok())
527
25320
                .collect::<Option<Vec<i32>>>()?;
528
25320
            let vs: Vec<i32> = vs
529
25320
                .iter()
530
32360
                .map(|x| TryInto::<i32>::try_into(x).ok())
531
25320
                .collect::<Option<Vec<i32>>>()?;
532
            let total: i32 = total.try_into().ok()?;
533

            
534
            let sum: i32 = izip!(cs, vs).fold(0, |acc, (c, v)| acc + (c * v));
535

            
536
            Some(Lit::Bool(sum >= total))
537
        }
538
840
        Expr::FlatAbsEq(_, x, y) => {
539
840
            let x: i32 = x.try_into().ok()?;
540
40
            let y: i32 = y.try_into().ok()?;
541

            
542
            Some(Lit::Bool(x == y.abs()))
543
        }
544
245480
        Expr::UnsafePow(_, a, b) | Expr::SafePow(_, a, b) => {
545
256520
            let a: &Atom = a.try_into().ok()?;
546
255200
            let a: i32 = a.try_into().ok()?;
547

            
548
240480
            let b: &Atom = b.try_into().ok()?;
549
240480
            let b: i32 = b.try_into().ok()?;
550

            
551
240480
            if (a != 0 || b != 0) && b >= 0 {
552
240480
                Some(Lit::Int(a.pow(b as u32)))
553
            } else {
554
                None
555
            }
556
        }
557
        Expr::Metavar(_, _) => None,
558
25520
        Expr::MinionElementOne(_, _, _, _) => None,
559
1040
        Expr::ToInt(_, expression) => {
560
1040
            let lit = eval_constant(expression.as_ref())?;
561
            match lit {
562
                Lit::Int(_) => Some(lit),
563
                Lit::Bool(true) => Some(Lit::Int(1)),
564
                Lit::Bool(false) => Some(Lit::Int(0)),
565
                _ => None,
566
            }
567
        }
568
372680
        Expr::SATInt(..) => None,
569
        Expr::PairwiseSum(_, a, b) => {
570
            match (eval_constant(a.as_ref())?, eval_constant(b.as_ref())?) {
571
                (Lit::Int(a_int), Lit::Int(b_int)) => Some(Lit::Int(a_int + b_int)),
572
                _ => None,
573
            }
574
        }
575
        Expr::PairwiseProduct(_, a, b) => {
576
            match (eval_constant(a.as_ref())?, eval_constant(b.as_ref())?) {
577
                (Lit::Int(a_int), Lit::Int(b_int)) => Some(Lit::Int(a_int * b_int)),
578
                _ => None,
579
            }
580
        }
581
        Expr::Defined(_, _) => todo!(),
582
        Expr::Range(_, _) => todo!(),
583
        Expr::Image(_, _, _) => todo!(),
584
        Expr::ImageSet(_, _, _) => todo!(),
585
        Expr::PreImage(_, _, _) => todo!(),
586
        Expr::Inverse(_, _, _) => todo!(),
587
        Expr::Restrict(_, _, _) => todo!(),
588
560
        Expr::LexLt(_, a, b) => {
589
560
            let lt = vec_expr_pairs_op::<i32, _>(a, b, |pairs, (a_len, b_len)| {
590
160
                pairs
591
160
                    .iter()
592
320
                    .find_map(|(a, b)| match a.cmp(b) {
593
80
                        CmpOrdering::Less => Some(true),     // First difference is <
594
                        CmpOrdering::Greater => Some(false), // First difference is >
595
240
                        CmpOrdering::Equal => None,          // No difference
596
320
                    })
597
160
                    .unwrap_or(a_len < b_len) // [1,1] <lex [1,1,x]
598
400
            })?;
599
160
            Some(lt.into())
600
        }
601
48600
        Expr::LexLeq(_, a, b) => {
602
48600
            let lt = vec_expr_pairs_op::<i32, _>(a, b, |pairs, (a_len, b_len)| {
603
80
                pairs
604
80
                    .iter()
605
160
                    .find_map(|(a, b)| match a.cmp(b) {
606
80
                        CmpOrdering::Less => Some(true),
607
                        CmpOrdering::Greater => Some(false),
608
80
                        CmpOrdering::Equal => None,
609
160
                    })
610
80
                    .unwrap_or(a_len <= b_len) // [1,1] <=lex [1,1,x]
611
48520
            })?;
612
80
            Some(lt.into())
613
        }
614
        Expr::LexGt(_, a, b) => eval_constant(&Expr::LexLt(Metadata::new(), b.clone(), a.clone())),
615
        Expr::LexGeq(_, a, b) => {
616
            eval_constant(&Expr::LexLeq(Metadata::new(), b.clone(), a.clone()))
617
        }
618
80
        Expr::FlatLexLt(_, a, b) => {
619
80
            let lt = atoms_pairs_op::<i32, _>(a, b, |pairs, (a_len, b_len)| {
620
                pairs
621
                    .iter()
622
                    .find_map(|(a, b)| match a.cmp(b) {
623
                        CmpOrdering::Less => Some(true),
624
                        CmpOrdering::Greater => Some(false),
625
                        CmpOrdering::Equal => None,
626
                    })
627
                    .unwrap_or(a_len < b_len)
628
80
            })?;
629
            Some(lt.into())
630
        }
631
160
        Expr::FlatLexLeq(_, a, b) => {
632
160
            let lt = atoms_pairs_op::<i32, _>(a, b, |pairs, (a_len, b_len)| {
633
                pairs
634
                    .iter()
635
                    .find_map(|(a, b)| match a.cmp(b) {
636
                        CmpOrdering::Less => Some(true),
637
                        CmpOrdering::Greater => Some(false),
638
                        CmpOrdering::Equal => None,
639
                    })
640
                    .unwrap_or(a_len <= b_len)
641
160
            })?;
642
            Some(lt.into())
643
        }
644
    }
645
20076620
}
646

            
647
23240
pub fn un_op<T, A>(f: fn(T) -> A, a: &Expr) -> Option<A>
648
23240
where
649
23240
    T: TryFrom<Lit>,
650
{
651
23240
    let a = unwrap_expr::<T>(a)?;
652
200
    Some(f(a))
653
23240
}
654

            
655
2135640
pub fn bin_op<T, A>(f: fn(T, T) -> A, a: &Expr, b: &Expr) -> Option<A>
656
2135640
where
657
2135640
    T: TryFrom<Lit>,
658
{
659
2135640
    let a = unwrap_expr::<T>(a)?;
660
299080
    let b = unwrap_expr::<T>(b)?;
661
247200
    Some(f(a, b))
662
2135640
}
663

            
664
#[allow(dead_code)]
665
pub fn tern_op<T, A>(f: fn(T, T, T) -> A, a: &Expr, b: &Expr, c: &Expr) -> Option<A>
666
where
667
    T: TryFrom<Lit>,
668
{
669
    let a = unwrap_expr::<T>(a)?;
670
    let b = unwrap_expr::<T>(b)?;
671
    let c = unwrap_expr::<T>(c)?;
672
    Some(f(a, b, c))
673
}
674

            
675
25815
pub fn vec_op<T, A>(f: fn(Vec<T>) -> A, a: &[Expr]) -> Option<A>
676
25815
where
677
25815
    T: TryFrom<Lit>,
678
{
679
25815
    let a = a.iter().map(unwrap_expr).collect::<Option<Vec<T>>>()?;
680
12
    Some(f(a))
681
25815
}
682

            
683
1488460
pub fn vec_lit_op<T, A>(f: fn(Vec<T>) -> A, a: &Expr) -> Option<A>
684
1488460
where
685
1488460
    T: TryFrom<Lit>,
686
{
687
    // we don't care about preserving indices here, as we will be getting rid of the vector
688
    // anyways!
689
1488460
    let a = a.clone().unwrap_matrix_unchecked()?.0;
690
1412140
    let a = a.iter().map(unwrap_expr).collect::<Option<Vec<T>>>()?;
691
266540
    Some(f(a))
692
1488460
}
693

            
694
type PairsCallback<T, A> = fn(Vec<(T, T)>, (usize, usize)) -> A;
695

            
696
/// Calls the given function on each consecutive pair of elements in the list expressions.
697
/// Also passes the length of the two lists.
698
49160
fn vec_expr_pairs_op<T, A>(a: &Expr, b: &Expr, f: PairsCallback<T, A>) -> Option<A>
699
49160
where
700
49160
    T: TryFrom<Lit>,
701
{
702
49160
    let a_exprs = a.clone().unwrap_matrix_unchecked()?.0;
703
720
    let b_exprs = b.clone().unwrap_matrix_unchecked()?.0;
704
480
    let lens = (a_exprs.len(), b_exprs.len());
705

            
706
480
    let lit_pairs = std::iter::zip(a_exprs, b_exprs)
707
720
        .map(|(a, b)| Some((unwrap_expr(&a)?, unwrap_expr(&b)?)))
708
480
        .collect::<Option<Vec<(T, T)>>>()?;
709
240
    Some(f(lit_pairs, lens))
710
49160
}
711

            
712
/// Same as [`vec_expr_pairs_op`], but over slices of atoms.
713
240
fn atoms_pairs_op<T, A>(a: &[Atom], b: &[Atom], f: PairsCallback<T, A>) -> Option<A>
714
240
where
715
240
    T: TryFrom<Atom>,
716
{
717
240
    let lit_pairs = Iterator::zip(a.iter(), b.iter())
718
240
        .map(|(a, b)| Some((a.clone().try_into().ok()?, b.clone().try_into().ok()?)))
719
240
        .collect::<Option<Vec<(T, T)>>>()?;
720
    Some(f(lit_pairs, (a.len(), b.len())))
721
240
}
722

            
723
pub fn opt_vec_op<T, A>(f: fn(Vec<T>) -> Option<A>, a: &[Expr]) -> Option<A>
724
where
725
    T: TryFrom<Lit>,
726
{
727
    let a = a.iter().map(unwrap_expr).collect::<Option<Vec<T>>>()?;
728
    f(a)
729
}
730

            
731
36800
pub fn opt_vec_lit_op<T, A>(f: fn(Vec<T>) -> Option<A>, a: &Expr) -> Option<A>
732
36800
where
733
36800
    T: TryFrom<Lit>,
734
{
735
36800
    let a = a.clone().unwrap_list()?;
736
    // FIXME: deal with explicit matrix domains
737
2680
    let a = a.iter().map(unwrap_expr).collect::<Option<Vec<T>>>()?;
738
40
    f(a)
739
36800
}
740

            
741
#[allow(dead_code)]
742
pub fn flat_op<T, A>(f: fn(Vec<T>, T) -> A, a: &[Expr], b: &Expr) -> Option<A>
743
where
744
    T: TryFrom<Lit>,
745
{
746
    let a = a.iter().map(unwrap_expr).collect::<Option<Vec<T>>>()?;
747
    let b = unwrap_expr::<T>(b)?;
748
    Some(f(a, b))
749
}
750

            
751
4514900
pub fn unwrap_expr<T: TryFrom<Lit>>(expr: &Expr) -> Option<T> {
752
4514900
    let c = eval_constant(expr)?;
753
1310000
    TryInto::<T>::try_into(c).ok()
754
4514900
}