1
#![allow(dead_code)]
2
use crate::ast::{AbstractLiteral, Atom, Expression as Expr, Literal as Lit, Metadata, matrix};
3
use crate::into_matrix;
4
use itertools::{Itertools as _, izip};
5
use std::cmp::Ordering as CmpOrdering;
6
use std::collections::HashSet;
7

            
8
/// Simplify an expression to a constant if possible
9
/// Returns:
10
/// `None` if the expression cannot be simplified to a constant (e.g. if it contains a variable)
11
/// `Some(Const)` if the expression can be simplified to a constant
12
pub fn eval_constant(expr: &Expr) -> Option<Lit> {
13
    match expr {
14
        Expr::Supset(_, a, b) => match (a.as_ref(), b.as_ref()) {
15
            (
16
                Expr::Atomic(_, Atom::Literal(Lit::AbstractLiteral(AbstractLiteral::Set(a)))),
17
                Expr::Atomic(_, Atom::Literal(Lit::AbstractLiteral(AbstractLiteral::Set(b)))),
18
            ) => {
19
                let a_set: HashSet<Lit> = a.iter().cloned().collect();
20
                let b_set: HashSet<Lit> = b.iter().cloned().collect();
21

            
22
                if a_set.difference(&b_set).count() > 0 {
23
                    Some(Lit::Bool(a_set.is_superset(&b_set)))
24
                } else {
25
                    Some(Lit::Bool(false))
26
                }
27
            }
28
            _ => None,
29
        },
30
        Expr::SupsetEq(_, a, b) => match (a.as_ref(), b.as_ref()) {
31
            (
32
                Expr::Atomic(_, Atom::Literal(Lit::AbstractLiteral(AbstractLiteral::Set(a)))),
33
                Expr::Atomic(_, Atom::Literal(Lit::AbstractLiteral(AbstractLiteral::Set(b)))),
34
            ) => Some(Lit::Bool(
35
                a.iter()
36
                    .cloned()
37
                    .collect::<HashSet<Lit>>()
38
                    .is_superset(&b.iter().cloned().collect::<HashSet<Lit>>()),
39
            )),
40
            _ => None,
41
        },
42
        Expr::Subset(_, a, b) => match (a.as_ref(), b.as_ref()) {
43
            (
44
                Expr::Atomic(_, Atom::Literal(Lit::AbstractLiteral(AbstractLiteral::Set(a)))),
45
                Expr::Atomic(_, Atom::Literal(Lit::AbstractLiteral(AbstractLiteral::Set(b)))),
46
            ) => {
47
                let a_set: HashSet<Lit> = a.iter().cloned().collect();
48
                let b_set: HashSet<Lit> = b.iter().cloned().collect();
49

            
50
                if b_set.difference(&a_set).count() > 0 {
51
                    Some(Lit::Bool(a_set.is_subset(&b_set)))
52
                } else {
53
                    Some(Lit::Bool(false))
54
                }
55
            }
56
            _ => None,
57
        },
58
        Expr::SubsetEq(_, a, b) => match (a.as_ref(), b.as_ref()) {
59
            (
60
                Expr::Atomic(_, Atom::Literal(Lit::AbstractLiteral(AbstractLiteral::Set(a)))),
61
                Expr::Atomic(_, Atom::Literal(Lit::AbstractLiteral(AbstractLiteral::Set(b)))),
62
            ) => Some(Lit::Bool(
63
                a.iter()
64
                    .cloned()
65
                    .collect::<HashSet<Lit>>()
66
                    .is_subset(&b.iter().cloned().collect::<HashSet<Lit>>()),
67
            )),
68
            _ => None,
69
        },
70
        Expr::Intersect(_, a, b) => match (a.as_ref(), b.as_ref()) {
71
            (
72
                Expr::Atomic(_, Atom::Literal(Lit::AbstractLiteral(AbstractLiteral::Set(a)))),
73
                Expr::Atomic(_, Atom::Literal(Lit::AbstractLiteral(AbstractLiteral::Set(b)))),
74
            ) => {
75
                let mut res: Vec<Lit> = Vec::new();
76
                for lit in a.iter() {
77
                    if b.contains(lit) && !res.contains(lit) {
78
                        res.push(lit.clone());
79
                    }
80
                }
81
                Some(Lit::AbstractLiteral(AbstractLiteral::Set(res)))
82
            }
83
            _ => None,
84
        },
85
        Expr::Union(_, a, b) => match (a.as_ref(), b.as_ref()) {
86
            (
87
                Expr::Atomic(_, Atom::Literal(Lit::AbstractLiteral(AbstractLiteral::Set(a)))),
88
                Expr::Atomic(_, Atom::Literal(Lit::AbstractLiteral(AbstractLiteral::Set(b)))),
89
            ) => {
90
                let mut res: Vec<Lit> = Vec::new();
91
                for lit in a.iter() {
92
                    res.push(lit.clone());
93
                }
94
                for lit in b.iter() {
95
                    if !res.contains(lit) {
96
                        res.push(lit.clone());
97
                    }
98
                }
99
                Some(Lit::AbstractLiteral(AbstractLiteral::Set(res)))
100
            }
101
            _ => None,
102
        },
103
        Expr::In(_, a, b) => {
104
            if let (
105
                Expr::Atomic(_, Atom::Literal(Lit::Int(c))),
106
                Expr::Atomic(_, Atom::Literal(Lit::AbstractLiteral(AbstractLiteral::Set(d)))),
107
            ) = (a.as_ref(), b.as_ref())
108
            {
109
                for lit in d.iter() {
110
                    if let Lit::Int(x) = lit
111
                        && c == x
112
                    {
113
                        return Some(Lit::Bool(true));
114
                    }
115
                }
116
                Some(Lit::Bool(false))
117
            } else {
118
                None
119
            }
120
        }
121
        Expr::FromSolution(_, _) => None,
122
        Expr::DominanceRelation(_, _) => None,
123
        Expr::InDomain(_, e, domain) => {
124
            let Expr::Atomic(_, Atom::Literal(lit)) = e.as_ref() else {
125
                return None;
126
            };
127

            
128
            domain.contains(lit).ok().map(Into::into)
129
        }
130
        Expr::Atomic(_, Atom::Literal(c)) => Some(c.clone()),
131
        Expr::Atomic(_, Atom::Reference(_)) => None,
132
        Expr::AbstractLiteral(_, a) => {
133
            if let AbstractLiteral::Set(s) = a {
134
                let mut copy = Vec::new();
135
                for expr in s.iter() {
136
                    if let Expr::Atomic(_, Atom::Literal(lit)) = expr {
137
                        copy.push(lit.clone());
138
                    } else {
139
                        return None;
140
                    }
141
                }
142
                Some(Lit::AbstractLiteral(AbstractLiteral::Set(copy)))
143
            } else {
144
                None
145
            }
146
        }
147
        Expr::Comprehension(_, _) => None,
148
        Expr::UnsafeIndex(_, subject, indices) | Expr::SafeIndex(_, subject, indices) => {
149
            let subject: Lit = subject.as_ref().clone().into_literal()?;
150
            let indices: Vec<Lit> = indices
151
                .iter()
152
                .cloned()
153
                .map(|x| x.into_literal())
154
                .collect::<Option<Vec<Lit>>>()?;
155

            
156
            match subject {
157
                Lit::AbstractLiteral(subject @ AbstractLiteral::Matrix(_, _)) => {
158
                    matrix::flatten_enumerate(subject)
159
                        .find(|(i, _)| i == &indices)
160
                        .map(|(_, x)| x)
161
                }
162
                Lit::AbstractLiteral(subject @ AbstractLiteral::Tuple(_)) => {
163
                    let AbstractLiteral::Tuple(elems) = subject else {
164
                        return None;
165
                    };
166

            
167
                    assert!(indices.len() == 1, "nested tuples not supported yet");
168

            
169
                    let Lit::Int(index) = indices[0].clone() else {
170
                        return None;
171
                    };
172

            
173
                    if elems.len() < index as usize || index < 1 {
174
                        return None;
175
                    }
176

            
177
                    // -1 for 0-indexing vs 1-indexing
178
                    let item = elems[index as usize - 1].clone();
179

            
180
                    Some(item)
181
                }
182
                Lit::AbstractLiteral(subject @ AbstractLiteral::Record(_)) => {
183
                    let AbstractLiteral::Record(elems) = subject else {
184
                        return None;
185
                    };
186

            
187
                    assert!(indices.len() == 1, "nested record not supported yet");
188

            
189
                    let Lit::Int(index) = indices[0].clone() else {
190
                        return None;
191
                    };
192

            
193
                    if elems.len() < index as usize || index < 1 {
194
                        return None;
195
                    }
196

            
197
                    // -1 for 0-indexing vs 1-indexing
198
                    let item = elems[index as usize - 1].clone();
199
                    Some(item.value)
200
                }
201
                _ => None,
202
            }
203
        }
204
        Expr::UnsafeSlice(_, subject, indices) | Expr::SafeSlice(_, subject, indices) => {
205
            let subject: Lit = subject.as_ref().clone().into_literal()?;
206
            let Lit::AbstractLiteral(subject @ AbstractLiteral::Matrix(_, _)) = subject else {
207
                return None;
208
            };
209

            
210
            let hole_dim = indices
211
                .iter()
212
                .cloned()
213
                .position(|x| x.is_none())
214
                .expect("slice expression should have a hole dimension");
215

            
216
            let missing_domain = matrix::index_domains(subject.clone())[hole_dim].clone();
217

            
218
            let indices: Vec<Option<Lit>> = indices
219
                .iter()
220
                .cloned()
221
                .map(|x| {
222
                    // the outer option represents success of this iterator, the inner the index
223
                    // slice.
224
                    match x {
225
                        Some(x) => x.into_literal().map(Some),
226
                        None => Some(None),
227
                    }
228
                })
229
                .collect::<Option<Vec<Option<Lit>>>>()?;
230

            
231
            let indices_in_slice: Vec<Vec<Lit>> = missing_domain
232
                .values()
233
                .ok()?
234
                .map(|i| {
235
                    let mut indices = indices.clone();
236
                    indices[hole_dim] = Some(i);
237
                    // These unwraps will only fail if we have multiple holes.
238
                    // As this is invalid, panicking is fine.
239
                    indices.into_iter().map(|x| x.unwrap()).collect_vec()
240
                })
241
                .collect_vec();
242

            
243
            // Note: indices_in_slice is not necessarily sorted, so this is the best way.
244
            let elems = matrix::flatten_enumerate(subject)
245
                .filter(|(i, _)| indices_in_slice.contains(i))
246
                .map(|(_, elem)| elem)
247
                .collect();
248

            
249
            Some(Lit::AbstractLiteral(into_matrix![elems]))
250
        }
251
        Expr::Abs(_, e) => un_op::<i32, i32>(|a| a.abs(), e).map(Lit::Int),
252
        Expr::Eq(_, a, b) => bin_op::<i32, bool>(|a, b| a == b, a, b)
253
            .or_else(|| bin_op::<bool, bool>(|a, b| a == b, a, b))
254
            .map(Lit::Bool),
255
        Expr::Neq(_, a, b) => bin_op::<i32, bool>(|a, b| a != b, a, b).map(Lit::Bool),
256
        Expr::Lt(_, a, b) => bin_op::<i32, bool>(|a, b| a < b, a, b).map(Lit::Bool),
257
        Expr::Gt(_, a, b) => bin_op::<i32, bool>(|a, b| a > b, a, b).map(Lit::Bool),
258
        Expr::Leq(_, a, b) => bin_op::<i32, bool>(|a, b| a <= b, a, b).map(Lit::Bool),
259
        Expr::Geq(_, a, b) => bin_op::<i32, bool>(|a, b| a >= b, a, b).map(Lit::Bool),
260
        Expr::Not(_, expr) => un_op::<bool, bool>(|e| !e, expr).map(Lit::Bool),
261
        Expr::And(_, e) => {
262
            vec_lit_op::<bool, bool>(|e| e.iter().all(|&e| e), e.as_ref()).map(Lit::Bool)
263
        }
264
        Expr::Root(_, _) => None,
265
        Expr::Or(_, es) => {
266
            // possibly cheating; definitely should be in partial eval instead
267
            for e in (**es).clone().unwrap_list()? {
268
                if let Expr::Atomic(_, Atom::Literal(Lit::Bool(true))) = e {
269
                    return Some(Lit::Bool(true));
270
                };
271
            }
272

            
273
            vec_lit_op::<bool, bool>(|e| e.iter().any(|&e| e), es.as_ref()).map(Lit::Bool)
274
        }
275
        Expr::Imply(_, box1, box2) => {
276
            let a: &Atom = (&**box1).try_into().ok()?;
277
            let b: &Atom = (&**box2).try_into().ok()?;
278

            
279
            let a: bool = a.try_into().ok()?;
280
            let b: bool = b.try_into().ok()?;
281

            
282
            if a {
283
                // true -> b ~> b
284
                Some(Lit::Bool(b))
285
            } else {
286
                // false -> b ~> true
287
                Some(Lit::Bool(true))
288
            }
289
        }
290
        Expr::Iff(_, box1, box2) => {
291
            let a: &Atom = (&**box1).try_into().ok()?;
292
            let b: &Atom = (&**box2).try_into().ok()?;
293

            
294
            let a: bool = a.try_into().ok()?;
295
            let b: bool = b.try_into().ok()?;
296

            
297
            Some(Lit::Bool(a == b))
298
        }
299
        Expr::Sum(_, exprs) => vec_lit_op::<i32, i32>(|e| e.iter().sum(), exprs).map(Lit::Int),
300
        Expr::Product(_, exprs) => {
301
            vec_lit_op::<i32, i32>(|e| e.iter().product(), exprs).map(Lit::Int)
302
        }
303
        Expr::FlatIneq(_, a, b, c) => {
304
            let a: i32 = a.try_into().ok()?;
305
            let b: i32 = b.try_into().ok()?;
306
            let c: i32 = c.try_into().ok()?;
307

            
308
            Some(Lit::Bool(a <= b + c))
309
        }
310
        Expr::FlatSumGeq(_, exprs, a) => {
311
            let sum = exprs.iter().try_fold(0, |acc, atom: &Atom| {
312
                let n: i32 = atom.try_into().ok()?;
313
                let acc = acc + n;
314
                Some(acc)
315
            })?;
316

            
317
            Some(Lit::Bool(sum >= a.try_into().ok()?))
318
        }
319
        Expr::FlatSumLeq(_, exprs, a) => {
320
            let sum = exprs.iter().try_fold(0, |acc, atom: &Atom| {
321
                let n: i32 = atom.try_into().ok()?;
322
                let acc = acc + n;
323
                Some(acc)
324
            })?;
325

            
326
            Some(Lit::Bool(sum >= a.try_into().ok()?))
327
        }
328
        Expr::Min(_, e) => {
329
            opt_vec_lit_op::<i32, i32>(|e| e.iter().min().copied(), e.as_ref()).map(Lit::Int)
330
        }
331
        Expr::Max(_, e) => {
332
            opt_vec_lit_op::<i32, i32>(|e| e.iter().max().copied(), e.as_ref()).map(Lit::Int)
333
        }
334
        Expr::UnsafeDiv(_, a, b) | Expr::SafeDiv(_, a, b) => {
335
            if unwrap_expr::<i32>(b)? == 0 {
336
                return None;
337
            }
338
            bin_op::<i32, i32>(|a, b| ((a as f32) / (b as f32)).floor() as i32, a, b).map(Lit::Int)
339
        }
340
        Expr::UnsafeMod(_, a, b) | Expr::SafeMod(_, a, b) => {
341
            if unwrap_expr::<i32>(b)? == 0 {
342
                return None;
343
            }
344
            bin_op::<i32, i32>(|a, b| a - b * (a as f32 / b as f32).floor() as i32, a, b)
345
                .map(Lit::Int)
346
        }
347
        Expr::MinionDivEqUndefZero(_, a, b, c) => {
348
            // div always rounds down
349
            let a: i32 = a.try_into().ok()?;
350
            let b: i32 = b.try_into().ok()?;
351
            let c: i32 = c.try_into().ok()?;
352

            
353
            if b == 0 {
354
                return None;
355
            }
356

            
357
            let a = a as f32;
358
            let b = b as f32;
359
            let div: i32 = (a / b).floor() as i32;
360
            Some(Lit::Bool(div == c))
361
        }
362
        Expr::Bubble(_, a, b) => bin_op::<bool, bool>(|a, b| a && b, a, b).map(Lit::Bool),
363
        Expr::MinionReify(_, a, b) => {
364
            let result = eval_constant(a)?;
365

            
366
            let result: bool = result.try_into().ok()?;
367
            let b: bool = b.try_into().ok()?;
368

            
369
            Some(Lit::Bool(b == result))
370
        }
371
        Expr::MinionReifyImply(_, a, b) => {
372
            let result = eval_constant(a)?;
373

            
374
            let result: bool = result.try_into().ok()?;
375
            let b: bool = b.try_into().ok()?;
376

            
377
            if b {
378
                Some(Lit::Bool(result))
379
            } else {
380
                Some(Lit::Bool(true))
381
            }
382
        }
383
        Expr::MinionModuloEqUndefZero(_, a, b, c) => {
384
            // From Savile Row. Same semantics as division.
385
            //
386
            //   a - (b * floor(a/b))
387
            //
388
            // We don't use % as it has the same semantics as /. We don't use / as we want to round
389
            // down instead, not towards zero.
390

            
391
            let a: i32 = a.try_into().ok()?;
392
            let b: i32 = b.try_into().ok()?;
393
            let c: i32 = c.try_into().ok()?;
394

            
395
            if b == 0 {
396
                return None;
397
            }
398

            
399
            let modulo = a - b * (a as f32 / b as f32).floor() as i32;
400
            Some(Lit::Bool(modulo == c))
401
        }
402
        Expr::MinionPow(_, a, b, c) => {
403
            // only available for positive a b c
404

            
405
            let a: i32 = a.try_into().ok()?;
406
            let b: i32 = b.try_into().ok()?;
407
            let c: i32 = c.try_into().ok()?;
408

            
409
            if a <= 0 {
410
                return None;
411
            }
412

            
413
            if b <= 0 {
414
                return None;
415
            }
416

            
417
            if c <= 0 {
418
                return None;
419
            }
420

            
421
            Some(Lit::Bool(a ^ b == c))
422
        }
423
        Expr::MinionWInSet(_, _, _) => None,
424
        Expr::MinionWInIntervalSet(_, x, intervals) => {
425
            let x_lit: &Lit = x.try_into().ok()?;
426

            
427
            let x_lit = match x_lit.clone() {
428
                Lit::Int(i) => Some(i),
429
                Lit::Bool(true) => Some(1),
430
                Lit::Bool(false) => Some(0),
431
                _ => None,
432
            }?;
433

            
434
            let mut intervals = intervals.iter();
435
            loop {
436
                let Some(lower) = intervals.next() else {
437
                    break;
438
                };
439

            
440
                let Some(upper) = intervals.next() else {
441
                    break;
442
                };
443
                if &x_lit >= lower && &x_lit <= upper {
444
                    return Some(Lit::Bool(true));
445
                }
446
            }
447

            
448
            Some(Lit::Bool(false))
449
        }
450
        Expr::Flatten(_, _, _) => {
451
            // TODO
452
            None
453
        }
454
        Expr::AllDiff(_, e) => {
455
            let es = (**e).clone().unwrap_list()?;
456
            let mut lits: HashSet<Lit> = HashSet::new();
457
            for expr in es {
458
                let Expr::Atomic(_, Atom::Literal(x)) = expr else {
459
                    return None;
460
                };
461
                match x {
462
                    Lit::Int(_) | Lit::Bool(_) => {
463
                        if lits.contains(&x) {
464
                            return Some(Lit::Bool(false));
465
                        } else {
466
                            lits.insert(x.clone());
467
                        }
468
                    }
469
                    Lit::AbstractLiteral(_) => return None, // Reject AbstractLiteral cases
470
                }
471
            }
472
            Some(Lit::Bool(true))
473
        }
474
        Expr::FlatAllDiff(_, es) => {
475
            let mut lits: HashSet<Lit> = HashSet::new();
476
            for atom in es {
477
                let Atom::Literal(x) = atom else {
478
                    return None;
479
                };
480

            
481
                match x {
482
                    Lit::Int(_) | Lit::Bool(_) => {
483
                        if lits.contains(x) {
484
                            return Some(Lit::Bool(false));
485
                        } else {
486
                            lits.insert(x.clone());
487
                        }
488
                    }
489
                    Lit::AbstractLiteral(_) => return None, // Reject AbstractLiteral cases
490
                }
491
            }
492
            Some(Lit::Bool(true))
493
        }
494
        Expr::FlatWatchedLiteral(_, _, _) => None,
495
        Expr::AuxDeclaration(_, _, _) => None,
496
        Expr::Neg(_, a) => {
497
            let a: &Atom = a.try_into().ok()?;
498
            let a: i32 = a.try_into().ok()?;
499
            Some(Lit::Int(-a))
500
        }
501
        Expr::Minus(_, a, b) => {
502
            let a: &Atom = a.try_into().ok()?;
503
            let a: i32 = a.try_into().ok()?;
504

            
505
            let b: &Atom = b.try_into().ok()?;
506
            let b: i32 = b.try_into().ok()?;
507

            
508
            Some(Lit::Int(a - b))
509
        }
510
        Expr::FlatMinusEq(_, a, b) => {
511
            let a: i32 = a.try_into().ok()?;
512
            let b: i32 = b.try_into().ok()?;
513
            Some(Lit::Bool(a == -b))
514
        }
515
        Expr::FlatProductEq(_, a, b, c) => {
516
            let a: i32 = a.try_into().ok()?;
517
            let b: i32 = b.try_into().ok()?;
518
            let c: i32 = c.try_into().ok()?;
519
            Some(Lit::Bool(a * b == c))
520
        }
521
        Expr::FlatWeightedSumLeq(_, cs, vs, total) => {
522
            let cs: Vec<i32> = cs
523
                .iter()
524
                .map(|x| TryInto::<i32>::try_into(x).ok())
525
                .collect::<Option<Vec<i32>>>()?;
526
            let vs: Vec<i32> = vs
527
                .iter()
528
                .map(|x| TryInto::<i32>::try_into(x).ok())
529
                .collect::<Option<Vec<i32>>>()?;
530
            let total: i32 = total.try_into().ok()?;
531

            
532
            let sum: i32 = izip!(cs, vs).fold(0, |acc, (c, v)| acc + (c * v));
533

            
534
            Some(Lit::Bool(sum <= total))
535
        }
536
        Expr::FlatWeightedSumGeq(_, cs, vs, total) => {
537
            let cs: Vec<i32> = cs
538
                .iter()
539
                .map(|x| TryInto::<i32>::try_into(x).ok())
540
                .collect::<Option<Vec<i32>>>()?;
541
            let vs: Vec<i32> = vs
542
                .iter()
543
                .map(|x| TryInto::<i32>::try_into(x).ok())
544
                .collect::<Option<Vec<i32>>>()?;
545
            let total: i32 = total.try_into().ok()?;
546

            
547
            let sum: i32 = izip!(cs, vs).fold(0, |acc, (c, v)| acc + (c * v));
548

            
549
            Some(Lit::Bool(sum >= total))
550
        }
551
        Expr::FlatAbsEq(_, x, y) => {
552
            let x: i32 = x.try_into().ok()?;
553
            let y: i32 = y.try_into().ok()?;
554

            
555
            Some(Lit::Bool(x == y.abs()))
556
        }
557
        Expr::UnsafePow(_, a, b) | Expr::SafePow(_, a, b) => {
558
            let a: &Atom = a.try_into().ok()?;
559
            let a: i32 = a.try_into().ok()?;
560

            
561
            let b: &Atom = b.try_into().ok()?;
562
            let b: i32 = b.try_into().ok()?;
563

            
564
            if (a != 0 || b != 0) && b >= 0 {
565
                Some(Lit::Int(a.pow(b as u32)))
566
            } else {
567
                None
568
            }
569
        }
570
        Expr::Scope(_, _) => None,
571
        Expr::Metavar(_, _) => None,
572
        Expr::MinionElementOne(_, _, _, _) => None,
573
        Expr::ToInt(_, expression) => {
574
            let lit = (**expression).clone().into_literal()?;
575
            match lit {
576
                Lit::Int(_) => Some(lit),
577
                Lit::Bool(true) => Some(Lit::Int(1)),
578
                Lit::Bool(false) => Some(Lit::Int(0)),
579
                _ => None,
580
            }
581
        }
582
        Expr::SATInt(_, _) => None,
583
        Expr::PairwiseSum(_, a, b) => {
584
            match ((**a).clone().into_literal()?, (**b).clone().into_literal()?) {
585
                (Lit::Int(a_int), Lit::Int(b_int)) => Some(Lit::Int(a_int + b_int)),
586
                _ => None,
587
            }
588
        }
589
        Expr::PairwiseProduct(_, a, b) => {
590
            match ((**a).clone().into_literal()?, (**b).clone().into_literal()?) {
591
                (Lit::Int(a_int), Lit::Int(b_int)) => Some(Lit::Int(a_int * b_int)),
592
                _ => None,
593
            }
594
        }
595
        Expr::Defined(_, _) => todo!(),
596
        Expr::Range(_, _) => todo!(),
597
        Expr::Image(_, _, _) => todo!(),
598
        Expr::ImageSet(_, _, _) => todo!(),
599
        Expr::PreImage(_, _, _) => todo!(),
600
        Expr::Inverse(_, _, _) => todo!(),
601
        Expr::Restrict(_, _, _) => todo!(),
602
        Expr::LexLt(_, a, b) => {
603
            let lt = vec_expr_pairs_op::<i32, _>(a, b, |pairs, (a_len, b_len)| {
604
                pairs
605
                    .iter()
606
                    .find_map(|(a, b)| match a.cmp(b) {
607
                        CmpOrdering::Less => Some(true),     // First difference is <
608
                        CmpOrdering::Greater => Some(false), // First difference is >
609
                        CmpOrdering::Equal => None,          // No difference
610
                    })
611
                    .unwrap_or(a_len < b_len) // [1,1] <lex [1,1,x]
612
            })?;
613
            Some(lt.into())
614
        }
615
        Expr::LexLeq(_, a, b) => {
616
            let lt = vec_expr_pairs_op::<i32, _>(a, b, |pairs, (a_len, b_len)| {
617
                pairs
618
                    .iter()
619
                    .find_map(|(a, b)| match a.cmp(b) {
620
                        CmpOrdering::Less => Some(true),
621
                        CmpOrdering::Greater => Some(false),
622
                        CmpOrdering::Equal => None,
623
                    })
624
                    .unwrap_or(a_len <= b_len) // [1,1] <=lex [1,1,x]
625
            })?;
626
            Some(lt.into())
627
        }
628
        Expr::LexGt(_, a, b) => eval_constant(&Expr::LexLt(Metadata::new(), b.clone(), a.clone())),
629
        Expr::LexGeq(_, a, b) => {
630
            eval_constant(&Expr::LexLeq(Metadata::new(), b.clone(), a.clone()))
631
        }
632
        Expr::FlatLexLt(_, a, b) => {
633
            let lt = atoms_pairs_op::<i32, _>(a, b, |pairs, (a_len, b_len)| {
634
                pairs
635
                    .iter()
636
                    .find_map(|(a, b)| match a.cmp(b) {
637
                        CmpOrdering::Less => Some(true),
638
                        CmpOrdering::Greater => Some(false),
639
                        CmpOrdering::Equal => None,
640
                    })
641
                    .unwrap_or(a_len < b_len)
642
            })?;
643
            Some(lt.into())
644
        }
645
        Expr::FlatLexLeq(_, a, b) => {
646
            let lt = atoms_pairs_op::<i32, _>(a, b, |pairs, (a_len, b_len)| {
647
                pairs
648
                    .iter()
649
                    .find_map(|(a, b)| match a.cmp(b) {
650
                        CmpOrdering::Less => Some(true),
651
                        CmpOrdering::Greater => Some(false),
652
                        CmpOrdering::Equal => None,
653
                    })
654
                    .unwrap_or(a_len <= b_len)
655
            })?;
656
            Some(lt.into())
657
        }
658
    }
659
}
660

            
661
pub fn un_op<T, A>(f: fn(T) -> A, a: &Expr) -> Option<A>
662
where
663
    T: TryFrom<Lit>,
664
{
665
    let a = unwrap_expr::<T>(a)?;
666
    Some(f(a))
667
}
668

            
669
pub fn bin_op<T, A>(f: fn(T, T) -> A, a: &Expr, b: &Expr) -> Option<A>
670
where
671
    T: TryFrom<Lit>,
672
{
673
    let a = unwrap_expr::<T>(a)?;
674
    let b = unwrap_expr::<T>(b)?;
675
    Some(f(a, b))
676
}
677

            
678
#[allow(dead_code)]
679
pub fn tern_op<T, A>(f: fn(T, T, T) -> A, a: &Expr, b: &Expr, c: &Expr) -> Option<A>
680
where
681
    T: TryFrom<Lit>,
682
{
683
    let a = unwrap_expr::<T>(a)?;
684
    let b = unwrap_expr::<T>(b)?;
685
    let c = unwrap_expr::<T>(c)?;
686
    Some(f(a, b, c))
687
}
688

            
689
pub fn vec_op<T, A>(f: fn(Vec<T>) -> A, a: &[Expr]) -> Option<A>
690
where
691
    T: TryFrom<Lit>,
692
{
693
    let a = a.iter().map(unwrap_expr).collect::<Option<Vec<T>>>()?;
694
    Some(f(a))
695
}
696

            
697
pub fn vec_lit_op<T, A>(f: fn(Vec<T>) -> A, a: &Expr) -> Option<A>
698
where
699
    T: TryFrom<Lit>,
700
{
701
    // we don't care about preserving indices here, as we will be getting rid of the vector
702
    // anyways!
703
    let a = a.clone().unwrap_matrix_unchecked()?.0;
704
    let a = a.iter().map(unwrap_expr).collect::<Option<Vec<T>>>()?;
705
    Some(f(a))
706
}
707

            
708
type PairsCallback<T, A> = fn(Vec<(T, T)>, (usize, usize)) -> A;
709

            
710
/// Calls the given function on each consecutive pair of elements in the list expressions.
711
/// Also passes the length of the two lists.
712
fn vec_expr_pairs_op<T, A>(a: &Expr, b: &Expr, f: PairsCallback<T, A>) -> Option<A>
713
where
714
    T: TryFrom<Lit>,
715
{
716
    let a_exprs = a.clone().unwrap_matrix_unchecked()?.0;
717
    let b_exprs = b.clone().unwrap_matrix_unchecked()?.0;
718
    let lens = (a_exprs.len(), b_exprs.len());
719

            
720
    let lit_pairs = std::iter::zip(a_exprs, b_exprs)
721
        .map(|(a, b)| Some((unwrap_expr(&a)?, unwrap_expr(&b)?)))
722
        .collect::<Option<Vec<(T, T)>>>()?;
723
    Some(f(lit_pairs, lens))
724
}
725

            
726
/// Same as [`vec_expr_pairs_op`], but over slices of atoms.
727
fn atoms_pairs_op<T, A>(a: &[Atom], b: &[Atom], f: PairsCallback<T, A>) -> Option<A>
728
where
729
    T: TryFrom<Atom>,
730
{
731
    let lit_pairs = Iterator::zip(a.iter(), b.iter())
732
        .map(|(a, b)| Some((a.clone().try_into().ok()?, b.clone().try_into().ok()?)))
733
        .collect::<Option<Vec<(T, T)>>>()?;
734
    Some(f(lit_pairs, (a.len(), b.len())))
735
}
736

            
737
pub fn opt_vec_op<T, A>(f: fn(Vec<T>) -> Option<A>, a: &[Expr]) -> Option<A>
738
where
739
    T: TryFrom<Lit>,
740
{
741
    let a = a.iter().map(unwrap_expr).collect::<Option<Vec<T>>>()?;
742
    f(a)
743
}
744

            
745
pub fn opt_vec_lit_op<T, A>(f: fn(Vec<T>) -> Option<A>, a: &Expr) -> Option<A>
746
where
747
    T: TryFrom<Lit>,
748
{
749
    let a = a.clone().unwrap_list()?;
750
    // FIXME: deal with explicit matrix domains
751
    let a = a.iter().map(unwrap_expr).collect::<Option<Vec<T>>>()?;
752
    f(a)
753
}
754

            
755
#[allow(dead_code)]
756
pub fn flat_op<T, A>(f: fn(Vec<T>, T) -> A, a: &[Expr], b: &Expr) -> Option<A>
757
where
758
    T: TryFrom<Lit>,
759
{
760
    let a = a.iter().map(unwrap_expr).collect::<Option<Vec<T>>>()?;
761
    let b = unwrap_expr::<T>(b)?;
762
    Some(f(a, b))
763
}
764

            
765
pub fn unwrap_expr<T: TryFrom<Lit>>(expr: &Expr) -> Option<T> {
766
    let c = eval_constant(expr)?;
767
    TryInto::<T>::try_into(c).ok()
768
}