1
#![allow(dead_code)]
2
use crate::ast::{AbstractLiteral, Atom, Expression as Expr, Literal as Lit, Metadata, matrix};
3
use crate::into_matrix;
4
use itertools::{Itertools as _, izip};
5
use std::cmp::Ordering as CmpOrdering;
6
use std::collections::HashSet;
7

            
8
/// Simplify an expression to a constant if possible
9
/// Returns:
10
/// `None` if the expression cannot be simplified to a constant (e.g. if it contains a variable)
11
/// `Some(Const)` if the expression can be simplified to a constant
12
26973898
pub fn eval_constant(expr: &Expr) -> Option<Lit> {
13
13683107
    match expr {
14
312
        Expr::Supset(_, a, b) => match (a.as_ref(), b.as_ref()) {
15
            (
16
312
                Expr::Atomic(_, Atom::Literal(Lit::AbstractLiteral(AbstractLiteral::Set(a)))),
17
312
                Expr::Atomic(_, Atom::Literal(Lit::AbstractLiteral(AbstractLiteral::Set(b)))),
18
            ) => {
19
312
                let a_set: HashSet<Lit> = a.iter().cloned().collect();
20
312
                let b_set: HashSet<Lit> = b.iter().cloned().collect();
21

            
22
312
                if a_set.difference(&b_set).count() > 0 {
23
234
                    Some(Lit::Bool(a_set.is_superset(&b_set)))
24
                } else {
25
78
                    Some(Lit::Bool(false))
26
                }
27
            }
28
            _ => None,
29
        },
30
312
        Expr::SupsetEq(_, a, b) => match (a.as_ref(), b.as_ref()) {
31
            (
32
312
                Expr::Atomic(_, Atom::Literal(Lit::AbstractLiteral(AbstractLiteral::Set(a)))),
33
312
                Expr::Atomic(_, Atom::Literal(Lit::AbstractLiteral(AbstractLiteral::Set(b)))),
34
312
            ) => Some(Lit::Bool(
35
312
                a.iter()
36
312
                    .cloned()
37
312
                    .collect::<HashSet<Lit>>()
38
312
                    .is_superset(&b.iter().cloned().collect::<HashSet<Lit>>()),
39
312
            )),
40
            _ => None,
41
        },
42
390
        Expr::Subset(_, a, b) => match (a.as_ref(), b.as_ref()) {
43
            (
44
390
                Expr::Atomic(_, Atom::Literal(Lit::AbstractLiteral(AbstractLiteral::Set(a)))),
45
390
                Expr::Atomic(_, Atom::Literal(Lit::AbstractLiteral(AbstractLiteral::Set(b)))),
46
            ) => {
47
390
                let a_set: HashSet<Lit> = a.iter().cloned().collect();
48
390
                let b_set: HashSet<Lit> = b.iter().cloned().collect();
49

            
50
390
                if b_set.difference(&a_set).count() > 0 {
51
312
                    Some(Lit::Bool(a_set.is_subset(&b_set)))
52
                } else {
53
78
                    Some(Lit::Bool(false))
54
                }
55
            }
56
            _ => None,
57
        },
58
780
        Expr::SubsetEq(_, a, b) => match (a.as_ref(), b.as_ref()) {
59
            (
60
780
                Expr::Atomic(_, Atom::Literal(Lit::AbstractLiteral(AbstractLiteral::Set(a)))),
61
780
                Expr::Atomic(_, Atom::Literal(Lit::AbstractLiteral(AbstractLiteral::Set(b)))),
62
780
            ) => Some(Lit::Bool(
63
780
                a.iter()
64
780
                    .cloned()
65
780
                    .collect::<HashSet<Lit>>()
66
780
                    .is_subset(&b.iter().cloned().collect::<HashSet<Lit>>()),
67
780
            )),
68
            _ => None,
69
        },
70
234
        Expr::Intersect(_, a, b) => match (a.as_ref(), b.as_ref()) {
71
            (
72
234
                Expr::Atomic(_, Atom::Literal(Lit::AbstractLiteral(AbstractLiteral::Set(a)))),
73
234
                Expr::Atomic(_, Atom::Literal(Lit::AbstractLiteral(AbstractLiteral::Set(b)))),
74
            ) => {
75
234
                let mut res: Vec<Lit> = Vec::new();
76
546
                for lit in a.iter() {
77
546
                    if b.contains(lit) && !res.contains(lit) {
78
390
                        res.push(lit.clone());
79
390
                    }
80
                }
81
234
                Some(Lit::AbstractLiteral(AbstractLiteral::Set(res)))
82
            }
83
            _ => None,
84
        },
85
234
        Expr::Union(_, a, b) => match (a.as_ref(), b.as_ref()) {
86
            (
87
234
                Expr::Atomic(_, Atom::Literal(Lit::AbstractLiteral(AbstractLiteral::Set(a)))),
88
234
                Expr::Atomic(_, Atom::Literal(Lit::AbstractLiteral(AbstractLiteral::Set(b)))),
89
            ) => {
90
234
                let mut res: Vec<Lit> = Vec::new();
91
624
                for lit in a.iter() {
92
624
                    res.push(lit.clone());
93
624
                }
94
624
                for lit in b.iter() {
95
624
                    if !res.contains(lit) {
96
468
                        res.push(lit.clone());
97
468
                    }
98
                }
99
234
                Some(Lit::AbstractLiteral(AbstractLiteral::Set(res)))
100
            }
101
            _ => None,
102
        },
103
2223
        Expr::In(_, a, b) => {
104
            if let (
105
78
                Expr::Atomic(_, Atom::Literal(Lit::Int(c))),
106
78
                Expr::Atomic(_, Atom::Literal(Lit::AbstractLiteral(AbstractLiteral::Set(d)))),
107
2223
            ) = (a.as_ref(), b.as_ref())
108
            {
109
234
                for lit in d.iter() {
110
234
                    if let Lit::Int(x) = lit
111
234
                        && c == x
112
                    {
113
78
                        return Some(Lit::Bool(true));
114
156
                    }
115
                }
116
                Some(Lit::Bool(false))
117
            } else {
118
2145
                None
119
            }
120
        }
121
        Expr::FromSolution(_, _) => None,
122
        Expr::DominanceRelation(_, _) => None,
123
74607
        Expr::InDomain(_, e, domain) => {
124
74607
            let Expr::Atomic(_, Atom::Literal(lit)) = e.as_ref() else {
125
53079
                return None;
126
            };
127

            
128
21528
            domain.contains(lit).ok().map(Into::into)
129
        }
130
4390558
        Expr::Atomic(_, Atom::Literal(c)) => Some(c.clone()),
131
9292549
        Expr::Atomic(_, Atom::Reference(reference)) => reference.resolve_constant(),
132
2919503
        Expr::AbstractLiteral(_, a) => Some(Lit::AbstractLiteral(a.clone().into_literals()?)),
133
47610
        Expr::Comprehension(_, _) => None,
134
78
        Expr::AbstractComprehension(_, _) => None,
135
1890798
        Expr::UnsafeIndex(_, subject, indices) | Expr::SafeIndex(_, subject, indices) => {
136
3198351
            let subject: Lit = eval_constant(subject.as_ref())?;
137
4563
            let indices: Vec<Lit> = indices
138
4563
                .iter()
139
4563
                .map(eval_constant)
140
4563
                .collect::<Option<Vec<Lit>>>()?;
141

            
142
741
            match subject {
143
468
                Lit::AbstractLiteral(subject @ AbstractLiteral::Matrix(_, _)) => {
144
468
                    matrix::flatten_enumerate(subject)
145
1638
                        .find(|(i, _)| i == &indices)
146
468
                        .map(|(_, x)| x)
147
                }
148
195
                Lit::AbstractLiteral(subject @ AbstractLiteral::Tuple(_)) => {
149
195
                    let AbstractLiteral::Tuple(elems) = subject else {
150
                        return None;
151
                    };
152

            
153
195
                    assert!(indices.len() == 1, "nested tuples not supported yet");
154

            
155
195
                    let Lit::Int(index) = indices[0].clone() else {
156
                        return None;
157
                    };
158

            
159
195
                    if elems.len() < index as usize || index < 1 {
160
                        return None;
161
195
                    }
162

            
163
                    // -1 for 0-indexing vs 1-indexing
164
195
                    let item = elems[index as usize - 1].clone();
165

            
166
195
                    Some(item)
167
                }
168
78
                Lit::AbstractLiteral(subject @ AbstractLiteral::Record(_)) => {
169
78
                    let AbstractLiteral::Record(elems) = subject else {
170
                        return None;
171
                    };
172

            
173
78
                    assert!(indices.len() == 1, "nested record not supported yet");
174

            
175
78
                    let Lit::Int(index) = indices[0].clone() else {
176
                        return None;
177
                    };
178

            
179
78
                    if elems.len() < index as usize || index < 1 {
180
                        return None;
181
78
                    }
182

            
183
                    // -1 for 0-indexing vs 1-indexing
184
78
                    let item = elems[index as usize - 1].clone();
185
78
                    Some(item.value)
186
                }
187
                _ => None,
188
            }
189
        }
190
72540
        Expr::UnsafeSlice(_, subject, indices) | Expr::SafeSlice(_, subject, indices) => {
191
117390
            let subject: Lit = eval_constant(subject.as_ref())?;
192
39
            let Lit::AbstractLiteral(subject @ AbstractLiteral::Matrix(_, _)) = subject else {
193
                return None;
194
            };
195

            
196
39
            let hole_dim = indices
197
39
                .iter()
198
39
                .cloned()
199
78
                .position(|x| x.is_none())
200
39
                .expect("slice expression should have a hole dimension");
201

            
202
39
            let missing_domain = matrix::index_domains(subject.clone())[hole_dim].clone();
203

            
204
39
            let indices: Vec<Option<Lit>> = indices
205
39
                .iter()
206
39
                .cloned()
207
78
                .map(|x| {
208
                    // the outer option represents success of this iterator, the inner the index
209
                    // slice.
210
78
                    match x {
211
39
                        Some(x) => eval_constant(&x).map(Some),
212
39
                        None => Some(None),
213
                    }
214
78
                })
215
39
                .collect::<Option<Vec<Option<Lit>>>>()?;
216

            
217
39
            let indices_in_slice: Vec<Vec<Lit>> = missing_domain
218
39
                .values()
219
39
                .ok()?
220
117
                .map(|i| {
221
117
                    let mut indices = indices.clone();
222
117
                    indices[hole_dim] = Some(i);
223
                    // These unwraps will only fail if we have multiple holes.
224
                    // As this is invalid, panicking is fine.
225
234
                    indices.into_iter().map(|x| x.unwrap()).collect_vec()
226
117
                })
227
39
                .collect_vec();
228

            
229
            // Note: indices_in_slice is not necessarily sorted, so this is the best way.
230
39
            let elems = matrix::flatten_enumerate(subject)
231
351
                .filter(|(i, _)| indices_in_slice.contains(i))
232
39
                .map(|(_, elem)| elem)
233
39
                .collect();
234

            
235
39
            Some(Lit::AbstractLiteral(into_matrix![elems]))
236
        }
237
6630
        Expr::Abs(_, e) => un_op::<i32, i32>(|a| a.abs(), e).map(Lit::Int),
238
717656
        Expr::Eq(_, a, b) => bin_op::<i32, bool>(|a, b| a == b, a, b)
239
717656
            .or_else(|| bin_op::<bool, bool>(|a, b| a == b, a, b))
240
717656
            .map(Lit::Bool),
241
549114
        Expr::Neq(_, a, b) => bin_op::<i32, bool>(|a, b| a != b, a, b).map(Lit::Bool),
242
60989
        Expr::Lt(_, a, b) => bin_op::<i32, bool>(|a, b| a < b, a, b).map(Lit::Bool),
243
4714
        Expr::Gt(_, a, b) => bin_op::<i32, bool>(|a, b| a > b, a, b).map(Lit::Bool),
244
237860
        Expr::Leq(_, a, b) => bin_op::<i32, bool>(|a, b| a <= b, a, b).map(Lit::Bool),
245
21593
        Expr::Geq(_, a, b) => bin_op::<i32, bool>(|a, b| a >= b, a, b).map(Lit::Bool),
246
14118
        Expr::Not(_, expr) => un_op::<bool, bool>(|e| !e, expr).map(Lit::Bool),
247
348400
        Expr::And(_, e) => {
248
348400
            vec_lit_op::<bool, bool>(|e| e.iter().all(|&e| e), e.as_ref()).map(Lit::Bool)
249
        }
250
313568
        Expr::Root(_, _) => None,
251
150900
        Expr::Or(_, es) => {
252
            // possibly cheating; definitely should be in partial eval instead
253
150900
            for e in (**es).clone().unwrap_list()? {
254
1326
                if let Expr::Atomic(_, Atom::Literal(Lit::Bool(true))) = e {
255
936
                    return Some(Lit::Bool(true));
256
147819
                };
257
            }
258

            
259
54219
            vec_lit_op::<bool, bool>(|e| e.iter().any(|&e| e), es.as_ref()).map(Lit::Bool)
260
        }
261
113997
        Expr::Imply(_, box1, box2) => {
262
113997
            let a: &Atom = (&**box1).try_into().ok()?;
263
84630
            let b: &Atom = (&**box2).try_into().ok()?;
264

            
265
1599
            let a: bool = a.try_into().ok()?;
266
            let b: bool = b.try_into().ok()?;
267

            
268
            if a {
269
                // true -> b ~> b
270
                Some(Lit::Bool(b))
271
            } else {
272
                // false -> b ~> true
273
                Some(Lit::Bool(true))
274
            }
275
        }
276
1911
        Expr::Iff(_, box1, box2) => {
277
1911
            let a: &Atom = (&**box1).try_into().ok()?;
278
1599
            let b: &Atom = (&**box2).try_into().ok()?;
279

            
280
117
            let a: bool = a.try_into().ok()?;
281
39
            let b: bool = b.try_into().ok()?;
282

            
283
            Some(Lit::Bool(a == b))
284
        }
285
1564278
        Expr::Sum(_, exprs) => vec_lit_op::<i32, i32>(|e| e.iter().sum(), exprs).map(Lit::Int),
286
915993
        Expr::Product(_, exprs) => {
287
915993
            vec_lit_op::<i32, i32>(|e| e.iter().product(), exprs).map(Lit::Int)
288
        }
289
65286
        Expr::FlatIneq(_, a, b, c) => {
290
65286
            let a: i32 = a.try_into().ok()?;
291
21762
            let b: i32 = b.try_into().ok()?;
292
            let c: i32 = c.try_into().ok()?;
293

            
294
            Some(Lit::Bool(a <= b + c))
295
        }
296
216596
        Expr::FlatSumGeq(_, exprs, a) => {
297
423406
            let sum = exprs.iter().try_fold(0, |acc, atom: &Atom| {
298
423406
                let n: i32 = atom.try_into().ok()?;
299
206810
                let acc = acc + n;
300
206810
                Some(acc)
301
423406
            })?;
302

            
303
            Some(Lit::Bool(sum >= a.try_into().ok()?))
304
        }
305
163845
        Expr::FlatSumLeq(_, exprs, a) => {
306
312292
            let sum = exprs.iter().try_fold(0, |acc, atom: &Atom| {
307
312292
                let n: i32 = atom.try_into().ok()?;
308
148447
                let acc = acc + n;
309
148447
                Some(acc)
310
312292
            })?;
311

            
312
            Some(Lit::Bool(sum >= a.try_into().ok()?))
313
        }
314
3393
        Expr::Min(_, e) => {
315
3393
            opt_vec_lit_op::<i32, i32>(|e| e.iter().min().copied(), e.as_ref()).map(Lit::Int)
316
        }
317
1716
        Expr::Max(_, e) => {
318
1716
            opt_vec_lit_op::<i32, i32>(|e| e.iter().max().copied(), e.as_ref()).map(Lit::Int)
319
        }
320
30459
        Expr::UnsafeDiv(_, a, b) | Expr::SafeDiv(_, a, b) => {
321
38727
            if unwrap_expr::<i32>(b)? == 0 {
322
78
                return None;
323
6825
            }
324
6825
            bin_op::<i32, i32>(|a, b| ((a as f32) / (b as f32)).floor() as i32, a, b).map(Lit::Int)
325
        }
326
17706
        Expr::UnsafeMod(_, a, b) | Expr::SafeMod(_, a, b) => {
327
23517
            if unwrap_expr::<i32>(b)? == 0 {
328
                return None;
329
2535
            }
330
2535
            bin_op::<i32, i32>(|a, b| a - b * (a as f32 / b as f32).floor() as i32, a, b)
331
2535
                .map(Lit::Int)
332
        }
333
4368
        Expr::MinionDivEqUndefZero(_, a, b, c) => {
334
            // div always rounds down
335
4368
            let a: i32 = a.try_into().ok()?;
336
78
            let b: i32 = b.try_into().ok()?;
337
            let c: i32 = c.try_into().ok()?;
338

            
339
            if b == 0 {
340
                return None;
341
            }
342

            
343
            let a = a as f32;
344
            let b = b as f32;
345
            let div: i32 = (a / b).floor() as i32;
346
            Some(Lit::Bool(div == c))
347
        }
348
45747
        Expr::Bubble(_, a, b) => bin_op::<bool, bool>(|a, b| a && b, a, b).map(Lit::Bool),
349
98474
        Expr::MinionReify(_, a, b) => {
350
98474
            let result = eval_constant(a)?;
351

            
352
6396
            let result: bool = result.try_into().ok()?;
353
6396
            let b: bool = b.try_into().ok()?;
354

            
355
            Some(Lit::Bool(b == result))
356
        }
357
60996
        Expr::MinionReifyImply(_, a, b) => {
358
60996
            let result = eval_constant(a)?;
359

            
360
            let result: bool = result.try_into().ok()?;
361
            let b: bool = b.try_into().ok()?;
362

            
363
            if b {
364
                Some(Lit::Bool(result))
365
            } else {
366
                Some(Lit::Bool(true))
367
            }
368
        }
369
1443
        Expr::MinionModuloEqUndefZero(_, a, b, c) => {
370
            // From Savile Row. Same semantics as division.
371
            //
372
            //   a - (b * floor(a/b))
373
            //
374
            // We don't use % as it has the same semantics as /. We don't use / as we want to round
375
            // down instead, not towards zero.
376

            
377
1443
            let a: i32 = a.try_into().ok()?;
378
78
            let b: i32 = b.try_into().ok()?;
379
            let c: i32 = c.try_into().ok()?;
380

            
381
            if b == 0 {
382
                return None;
383
            }
384

            
385
            let modulo = a - b * (a as f32 / b as f32).floor() as i32;
386
            Some(Lit::Bool(modulo == c))
387
        }
388
2652
        Expr::MinionPow(_, a, b, c) => {
389
            // only available for positive a b c
390

            
391
2652
            let a: i32 = a.try_into().ok()?;
392
            let b: i32 = b.try_into().ok()?;
393
            let c: i32 = c.try_into().ok()?;
394

            
395
            if a <= 0 {
396
                return None;
397
            }
398

            
399
            if b <= 0 {
400
                return None;
401
            }
402

            
403
            if c <= 0 {
404
                return None;
405
            }
406

            
407
            Some(Lit::Bool(a ^ b == c))
408
        }
409
312
        Expr::MinionWInSet(_, _, _) => None,
410
741
        Expr::MinionWInIntervalSet(_, x, intervals) => {
411
741
            let x_lit: &Lit = x.try_into().ok()?;
412

            
413
            let x_lit = match x_lit.clone() {
414
                Lit::Int(i) => Some(i),
415
                Lit::Bool(true) => Some(1),
416
                Lit::Bool(false) => Some(0),
417
                _ => None,
418
            }?;
419

            
420
            let mut intervals = intervals.iter();
421
            loop {
422
                let Some(lower) = intervals.next() else {
423
                    break;
424
                };
425

            
426
                let Some(upper) = intervals.next() else {
427
                    break;
428
                };
429
                if &x_lit >= lower && &x_lit <= upper {
430
                    return Some(Lit::Bool(true));
431
                }
432
            }
433

            
434
            Some(Lit::Bool(false))
435
        }
436
        Expr::Flatten(_, _, _) => {
437
            // TODO
438
7800
            None
439
        }
440
47580
        Expr::AllDiff(_, e) => {
441
47580
            let es = (**e).clone().unwrap_list()?;
442
4485
            let mut lits: HashSet<Lit> = HashSet::new();
443
5031
            for expr in es {
444
1989
                let Expr::Atomic(_, Atom::Literal(x)) = expr else {
445
4329
                    return None;
446
                };
447
702
                match x {
448
                    Lit::Int(_) | Lit::Bool(_) => {
449
702
                        if lits.contains(&x) {
450
                            return Some(Lit::Bool(false));
451
702
                        } else {
452
702
                            lits.insert(x.clone());
453
702
                        }
454
                    }
455
                    Lit::AbstractLiteral(_) => return None, // Reject AbstractLiteral cases
456
                }
457
            }
458
156
            Some(Lit::Bool(true))
459
        }
460
29601
        Expr::FlatAllDiff(_, es) => {
461
29601
            let mut lits: HashSet<Lit> = HashSet::new();
462
29601
            for atom in es {
463
29601
                let Atom::Literal(x) = atom else {
464
29601
                    return None;
465
                };
466

            
467
                match x {
468
                    Lit::Int(_) | Lit::Bool(_) => {
469
                        if lits.contains(x) {
470
                            return Some(Lit::Bool(false));
471
                        } else {
472
                            lits.insert(x.clone());
473
                        }
474
                    }
475
                    Lit::AbstractLiteral(_) => return None, // Reject AbstractLiteral cases
476
                }
477
            }
478
            Some(Lit::Bool(true))
479
        }
480
3471
        Expr::FlatWatchedLiteral(_, _, _) => None,
481
67391
        Expr::AuxDeclaration(_, _, _) => None,
482
82993
        Expr::Neg(_, a) => {
483
82993
            let a: &Atom = a.try_into().ok()?;
484
60217
            let a: i32 = a.try_into().ok()?;
485
19072
            Some(Lit::Int(-a))
486
        }
487
437788
        Expr::Minus(_, a, b) => {
488
437788
            let a: &Atom = a.try_into().ok()?;
489
429559
            let a: i32 = a.try_into().ok()?;
490

            
491
11050
            let b: &Atom = b.try_into().ok()?;
492
10348
            let b: i32 = b.try_into().ok()?;
493

            
494
7579
            Some(Lit::Int(a - b))
495
        }
496
2340
        Expr::FlatMinusEq(_, a, b) => {
497
2340
            let a: i32 = a.try_into().ok()?;
498
1872
            let b: i32 = b.try_into().ok()?;
499
            Some(Lit::Bool(a == -b))
500
        }
501
585
        Expr::FlatProductEq(_, a, b, c) => {
502
585
            let a: i32 = a.try_into().ok()?;
503
            let b: i32 = b.try_into().ok()?;
504
            let c: i32 = c.try_into().ok()?;
505
            Some(Lit::Bool(a * b == c))
506
        }
507
79248
        Expr::FlatWeightedSumLeq(_, cs, vs, total) => {
508
79248
            let cs: Vec<i32> = cs
509
79248
                .iter()
510
161616
                .map(|x| TryInto::<i32>::try_into(x).ok())
511
79248
                .collect::<Option<Vec<i32>>>()?;
512
79248
            let vs: Vec<i32> = vs
513
79248
                .iter()
514
87282
                .map(|x| TryInto::<i32>::try_into(x).ok())
515
79248
                .collect::<Option<Vec<i32>>>()?;
516
            let total: i32 = total.try_into().ok()?;
517

            
518
            let sum: i32 = izip!(cs, vs).fold(0, |acc, (c, v)| acc + (c * v));
519

            
520
            Some(Lit::Bool(sum <= total))
521
        }
522
73398
        Expr::FlatWeightedSumGeq(_, cs, vs, total) => {
523
73398
            let cs: Vec<i32> = cs
524
73398
                .iter()
525
149019
                .map(|x| TryInto::<i32>::try_into(x).ok())
526
73398
                .collect::<Option<Vec<i32>>>()?;
527
73398
            let vs: Vec<i32> = vs
528
73398
                .iter()
529
80613
                .map(|x| TryInto::<i32>::try_into(x).ok())
530
73398
                .collect::<Option<Vec<i32>>>()?;
531
            let total: i32 = total.try_into().ok()?;
532

            
533
            let sum: i32 = izip!(cs, vs).fold(0, |acc, (c, v)| acc + (c * v));
534

            
535
            Some(Lit::Bool(sum >= total))
536
        }
537
819
        Expr::FlatAbsEq(_, x, y) => {
538
819
            let x: i32 = x.try_into().ok()?;
539
39
            let y: i32 = y.try_into().ok()?;
540

            
541
            Some(Lit::Bool(x == y.abs()))
542
        }
543
239343
        Expr::UnsafePow(_, a, b) | Expr::SafePow(_, a, b) => {
544
250107
            let a: &Atom = a.try_into().ok()?;
545
248820
            let a: i32 = a.try_into().ok()?;
546

            
547
234468
            let b: &Atom = b.try_into().ok()?;
548
234468
            let b: i32 = b.try_into().ok()?;
549

            
550
234468
            if (a != 0 || b != 0) && b >= 0 {
551
234468
                Some(Lit::Int(a.pow(b as u32)))
552
            } else {
553
                None
554
            }
555
        }
556
        Expr::Metavar(_, _) => None,
557
7878
        Expr::MinionElementOne(_, _, _, _) => None,
558
1014
        Expr::ToInt(_, expression) => {
559
1014
            let lit = eval_constant(expression.as_ref())?;
560
            match lit {
561
                Lit::Int(_) => Some(lit),
562
                Lit::Bool(true) => Some(Lit::Int(1)),
563
                Lit::Bool(false) => Some(Lit::Int(0)),
564
                _ => None,
565
            }
566
        }
567
10507
        Expr::SATInt(..) => None,
568
        Expr::PairwiseSum(_, a, b) => {
569
            match (eval_constant(a.as_ref())?, eval_constant(b.as_ref())?) {
570
                (Lit::Int(a_int), Lit::Int(b_int)) => Some(Lit::Int(a_int + b_int)),
571
                _ => None,
572
            }
573
        }
574
        Expr::PairwiseProduct(_, a, b) => {
575
            match (eval_constant(a.as_ref())?, eval_constant(b.as_ref())?) {
576
                (Lit::Int(a_int), Lit::Int(b_int)) => Some(Lit::Int(a_int * b_int)),
577
                _ => None,
578
            }
579
        }
580
        Expr::Defined(_, _) => todo!(),
581
        Expr::Range(_, _) => todo!(),
582
        Expr::Image(_, _, _) => todo!(),
583
        Expr::ImageSet(_, _, _) => todo!(),
584
        Expr::PreImage(_, _, _) => todo!(),
585
        Expr::Inverse(_, _, _) => todo!(),
586
        Expr::Restrict(_, _, _) => todo!(),
587
546
        Expr::LexLt(_, a, b) => {
588
546
            let lt = vec_expr_pairs_op::<i32, _>(a, b, |pairs, (a_len, b_len)| {
589
156
                pairs
590
156
                    .iter()
591
312
                    .find_map(|(a, b)| match a.cmp(b) {
592
78
                        CmpOrdering::Less => Some(true),     // First difference is <
593
                        CmpOrdering::Greater => Some(false), // First difference is >
594
234
                        CmpOrdering::Equal => None,          // No difference
595
312
                    })
596
156
                    .unwrap_or(a_len < b_len) // [1,1] <lex [1,1,x]
597
390
            })?;
598
156
            Some(lt.into())
599
        }
600
75933
        Expr::LexLeq(_, a, b) => {
601
75933
            let lt = vec_expr_pairs_op::<i32, _>(a, b, |pairs, (a_len, b_len)| {
602
78
                pairs
603
78
                    .iter()
604
156
                    .find_map(|(a, b)| match a.cmp(b) {
605
78
                        CmpOrdering::Less => Some(true),
606
                        CmpOrdering::Greater => Some(false),
607
78
                        CmpOrdering::Equal => None,
608
156
                    })
609
78
                    .unwrap_or(a_len <= b_len) // [1,1] <=lex [1,1,x]
610
75855
            })?;
611
78
            Some(lt.into())
612
        }
613
        Expr::LexGt(_, a, b) => eval_constant(&Expr::LexLt(Metadata::new(), b.clone(), a.clone())),
614
        Expr::LexGeq(_, a, b) => {
615
            eval_constant(&Expr::LexLeq(Metadata::new(), b.clone(), a.clone()))
616
        }
617
78
        Expr::FlatLexLt(_, a, b) => {
618
78
            let lt = atoms_pairs_op::<i32, _>(a, b, |pairs, (a_len, b_len)| {
619
                pairs
620
                    .iter()
621
                    .find_map(|(a, b)| match a.cmp(b) {
622
                        CmpOrdering::Less => Some(true),
623
                        CmpOrdering::Greater => Some(false),
624
                        CmpOrdering::Equal => None,
625
                    })
626
                    .unwrap_or(a_len < b_len)
627
78
            })?;
628
            Some(lt.into())
629
        }
630
156
        Expr::FlatLexLeq(_, a, b) => {
631
156
            let lt = atoms_pairs_op::<i32, _>(a, b, |pairs, (a_len, b_len)| {
632
                pairs
633
                    .iter()
634
                    .find_map(|(a, b)| match a.cmp(b) {
635
                        CmpOrdering::Less => Some(true),
636
                        CmpOrdering::Greater => Some(false),
637
                        CmpOrdering::Equal => None,
638
                    })
639
                    .unwrap_or(a_len <= b_len)
640
156
            })?;
641
            Some(lt.into())
642
        }
643
    }
644
26973898
}
645

            
646
20748
pub fn un_op<T, A>(f: fn(T) -> A, a: &Expr) -> Option<A>
647
20748
where
648
20748
    T: TryFrom<Lit>,
649
{
650
20748
    let a = unwrap_expr::<T>(a)?;
651
195
    Some(f(a))
652
20748
}
653

            
654
2283101
pub fn bin_op<T, A>(f: fn(T, T) -> A, a: &Expr, b: &Expr) -> Option<A>
655
2283101
where
656
2283101
    T: TryFrom<Lit>,
657
{
658
2283101
    let a = unwrap_expr::<T>(a)?;
659
353127
    let b = unwrap_expr::<T>(b)?;
660
241995
    Some(f(a, b))
661
2283101
}
662

            
663
#[allow(dead_code)]
664
pub fn tern_op<T, A>(f: fn(T, T, T) -> A, a: &Expr, b: &Expr, c: &Expr) -> Option<A>
665
where
666
    T: TryFrom<Lit>,
667
{
668
    let a = unwrap_expr::<T>(a)?;
669
    let b = unwrap_expr::<T>(b)?;
670
    let c = unwrap_expr::<T>(c)?;
671
    Some(f(a, b, c))
672
}
673

            
674
28818
pub fn vec_op<T, A>(f: fn(Vec<T>) -> A, a: &[Expr]) -> Option<A>
675
28818
where
676
28818
    T: TryFrom<Lit>,
677
{
678
28818
    let a = a.iter().map(unwrap_expr).collect::<Option<Vec<T>>>()?;
679
12
    Some(f(a))
680
28818
}
681

            
682
2882890
pub fn vec_lit_op<T, A>(f: fn(Vec<T>) -> A, a: &Expr) -> Option<A>
683
2882890
where
684
2882890
    T: TryFrom<Lit>,
685
{
686
    // we don't care about preserving indices here, as we will be getting rid of the vector
687
    // anyways!
688
2882890
    let a = a.clone().unwrap_matrix_unchecked()?.0;
689
2798835
    let a = a.iter().map(unwrap_expr).collect::<Option<Vec<T>>>()?;
690
267865
    Some(f(a))
691
2882890
}
692

            
693
type PairsCallback<T, A> = fn(Vec<(T, T)>, (usize, usize)) -> A;
694

            
695
/// Calls the given function on each consecutive pair of elements in the list expressions.
696
/// Also passes the length of the two lists.
697
76479
fn vec_expr_pairs_op<T, A>(a: &Expr, b: &Expr, f: PairsCallback<T, A>) -> Option<A>
698
76479
where
699
76479
    T: TryFrom<Lit>,
700
{
701
76479
    let a_exprs = a.clone().unwrap_matrix_unchecked()?.0;
702
702
    let b_exprs = b.clone().unwrap_matrix_unchecked()?.0;
703
468
    let lens = (a_exprs.len(), b_exprs.len());
704

            
705
468
    let lit_pairs = std::iter::zip(a_exprs, b_exprs)
706
702
        .map(|(a, b)| Some((unwrap_expr(&a)?, unwrap_expr(&b)?)))
707
468
        .collect::<Option<Vec<(T, T)>>>()?;
708
234
    Some(f(lit_pairs, lens))
709
76479
}
710

            
711
/// Same as [`vec_expr_pairs_op`], but over slices of atoms.
712
234
fn atoms_pairs_op<T, A>(a: &[Atom], b: &[Atom], f: PairsCallback<T, A>) -> Option<A>
713
234
where
714
234
    T: TryFrom<Atom>,
715
{
716
234
    let lit_pairs = Iterator::zip(a.iter(), b.iter())
717
234
        .map(|(a, b)| Some((a.clone().try_into().ok()?, b.clone().try_into().ok()?)))
718
234
        .collect::<Option<Vec<(T, T)>>>()?;
719
    Some(f(lit_pairs, (a.len(), b.len())))
720
234
}
721

            
722
pub fn opt_vec_op<T, A>(f: fn(Vec<T>) -> Option<A>, a: &[Expr]) -> Option<A>
723
where
724
    T: TryFrom<Lit>,
725
{
726
    let a = a.iter().map(unwrap_expr).collect::<Option<Vec<T>>>()?;
727
    f(a)
728
}
729

            
730
5109
pub fn opt_vec_lit_op<T, A>(f: fn(Vec<T>) -> Option<A>, a: &Expr) -> Option<A>
731
5109
where
732
5109
    T: TryFrom<Lit>,
733
{
734
5109
    let a = a.clone().unwrap_list()?;
735
    // FIXME: deal with explicit matrix domains
736
858
    let a = a.iter().map(unwrap_expr).collect::<Option<Vec<T>>>()?;
737
39
    f(a)
738
5109
}
739

            
740
#[allow(dead_code)]
741
pub fn flat_op<T, A>(f: fn(Vec<T>, T) -> A, a: &[Expr], b: &Expr) -> Option<A>
742
where
743
    T: TryFrom<Lit>,
744
{
745
    let a = a.iter().map(unwrap_expr).collect::<Option<Vec<T>>>()?;
746
    let b = unwrap_expr::<T>(b)?;
747
    Some(f(a, b))
748
}
749

            
750
6494475
pub fn unwrap_expr<T: TryFrom<Lit>>(expr: &Expr) -> Option<T> {
751
6494475
    let c = eval_constant(expr)?;
752
1845131
    TryInto::<T>::try_into(c).ok()
753
6494475
}