1
use std::collections::HashSet;
2

            
3
use crate::ast::Typeable;
4
use crate::{
5
    ast::{
6
        AbstractLiteral, Atom, DomainPtr, Expression as Expr, GroundDomain, Literal as Lit,
7
        Metadata, Moo, Range, ReturnType,
8
    },
9
    into_matrix_expr,
10
    rule_engine::{ApplicationError::RuleNotApplicable, ApplicationResult, Reduction},
11
};
12
use itertools::iproduct;
13
use uniplate::Uniplate;
14

            
15
/// Normalises integer ranges so equivalent domains compare structurally equal.
16
40416
fn normalise_int_domain(domain: &GroundDomain) -> GroundDomain {
17
40416
    match domain {
18
40416
        GroundDomain::Int(ranges) => GroundDomain::Int(Range::squeeze(
19
40416
            &ranges
20
40416
                .iter()
21
50886
                .map(|range| Range::new(range.low().copied(), range.high().copied()))
22
40416
                .collect::<Vec<_>>(),
23
        )),
24
        _ => domain.clone(),
25
    }
26
40416
}
27

            
28
/// Returns whether `expr` is safe after resolving any referenced expressions.
29
5127186
fn is_semantically_safe(expr: &Expr) -> bool {
30
5134760
    fn helper(expr: &Expr, resolving: &mut HashSet<crate::ast::serde::ObjId>) -> bool {
31
5134760
        if !expr.is_safe() {
32
801308
            return false;
33
4333452
        }
34

            
35
40116006
        for subexpr in expr.universe() {
36
25754066
            let Expr::Atomic(_, Atom::Reference(reference)) = subexpr else {
37
21783306
                continue;
38
            };
39

            
40
18332700
            let Some(resolved) = reference.resolve_expression() else {
41
18325126
                continue;
42
            };
43

            
44
7574
            let id = reference.id();
45
7574
            if !resolving.insert(id.clone()) {
46
                return false;
47
7574
            }
48

            
49
7574
            let is_safe = helper(&resolved, resolving);
50
7574
            resolving.remove(&id);
51

            
52
7574
            if !is_safe {
53
                return false;
54
7574
            }
55
        }
56

            
57
4333452
        true
58
5134760
    }
59

            
60
5127186
    helper(expr, &mut HashSet::new())
61
5127186
}
62

            
63
/// Tries to decide `expr in domain` from resolved domains alone.
64
22100
fn simplify_in_domain(expr: &Expr, domain: &DomainPtr) -> Option<bool> {
65
22100
    if !is_semantically_safe(expr) {
66
106
        return None;
67
21994
    }
68

            
69
21994
    let expr_domain = resolved_ground_domain_of_for_partial_eval(expr)?;
70
21554
    let domain = domain.resolve()?;
71
21554
    let intersection = expr_domain.intersect(&domain).ok()?;
72

            
73
20208
    if normalise_int_domain(&intersection) == normalise_int_domain(expr_domain.as_ref()) {
74
26
        return Some(true);
75
20182
    }
76

            
77
20182
    if let Ok(values_in_domain) = intersection.values_i32()
78
20182
        && values_in_domain.is_empty()
79
    {
80
        return Some(false);
81
20182
    }
82

            
83
20182
    None
84
22100
}
85

            
86
/// Extracts an integer when `expr` is known to be a singleton integer value.
87
956470
fn singleton_int_value(expr: &Expr) -> Option<i32> {
88
956470
    if let Ok(value) = expr.try_into() {
89
660
        return Some(value);
90
955810
    }
91

            
92
955810
    let domain = resolved_ground_domain_of_for_partial_eval(expr)?;
93
955810
    let GroundDomain::Int(ranges) = domain.as_ref() else {
94
        return None;
95
    };
96
955810
    let [range] = ranges.as_slice() else {
97
        return None;
98
    };
99
955810
    let (Some(low), Some(high)) = (range.low(), range.high()) else {
100
        return None;
101
    };
102

            
103
955810
    if low == high { Some(*low) } else { None }
104
956470
}
105

            
106
/// Resolves a matrix literal subject, including constant references to matrix literals.
107
2732746
fn resolve_matrix_subject(subject: &Expr) -> Option<(Vec<Expr>, DomainPtr)> {
108
2732746
    subject.clone().unwrap_matrix_unchecked().or_else(|| {
109
1788306
        let Expr::Atomic(_, Atom::Reference(reference)) = subject else {
110
            return None;
111
        };
112

            
113
12030
        let Lit::AbstractLiteral(AbstractLiteral::Matrix(elems, index_domain)) =
114
1788306
            reference.resolve_constant()?
115
        else {
116
            return None;
117
        };
118

            
119
        Some((
120
12030
            elems
121
12030
                .into_iter()
122
129810
                .map(|elem| Expr::Atomic(Metadata::new(), Atom::Literal(elem)))
123
12030
                .collect(),
124
12030
            index_domain.into(),
125
        ))
126
1788306
    })
127
2732746
}
128

            
129
/// Resolves domains for partial evaluation while avoiding malformed indexing panics.
130
1865326
fn resolved_ground_domain_of_for_partial_eval(expr: &Expr) -> Option<Moo<GroundDomain>> {
131
1865326
    match expr {
132
116892
        Expr::SafeIndex(_, subject, _) => {
133
116892
            let subject_domain = resolved_ground_domain_of_for_partial_eval(subject)?;
134
116892
            let GroundDomain::Matrix(elem_domain, _) = subject_domain.as_ref() else {
135
7680
                return None;
136
            };
137

            
138
109212
            Some(elem_domain.clone())
139
        }
140
        Expr::SafeSlice(_, subject, indices) => {
141
            let subject_domain = resolved_ground_domain_of_for_partial_eval(subject)?;
142
            let GroundDomain::Matrix(elem_domain, index_domains) = subject_domain.as_ref() else {
143
                return None;
144
            };
145
            let sliced_dimension = indices.iter().position(Option::is_none);
146

            
147
            match sliced_dimension {
148
                Some(dimension) => Some(Moo::new(GroundDomain::Matrix(
149
                    elem_domain.clone(),
150
                    vec![index_domains[dimension].clone()],
151
                ))),
152
                None => Some(elem_domain.clone()),
153
            }
154
        }
155
        Expr::UnsafeIndex(_, _, _) | Expr::UnsafeSlice(_, _, _) => None,
156
1748434
        _ => expr.domain_of()?.resolve(),
157
    }
158
1865326
}
159

            
160
/// Tries to decide `expr = lit` and `expr != lit` from the resolved domain of `expr`.
161
932846
fn simplify_comparison_with_literal(expr: &Expr, lit: &Lit) -> Option<(bool, bool)> {
162
932846
    if !is_semantically_safe(expr) {
163
162216
        return None;
164
770630
    }
165

            
166
770630
    let expr_domain = resolved_ground_domain_of_for_partial_eval(expr)?;
167

            
168
752504
    if !expr_domain.contains(lit).ok()? {
169
3272
        return Some((false, true));
170
749232
    }
171

            
172
749232
    match (expr_domain.as_ref(), lit) {
173
571148
        (GroundDomain::Int(ranges), Lit::Int(value)) => {
174
571148
            let [range] = ranges.as_slice() else {
175
4104
                return None;
176
            };
177
567044
            let (Some(low), Some(high)) = (range.low(), range.high()) else {
178
                return None;
179
            };
180

            
181
567044
            if low == high && low == value {
182
490
                Some((true, false))
183
            } else {
184
566554
                None
185
            }
186
        }
187
173924
        (GroundDomain::Bool, Lit::Bool(_)) => None,
188
4160
        _ => None,
189
    }
190
932846
}
191

            
192
/// Tries to decide reflexive equality and inequality when both sides are semantically safe.
193
2377760
fn simplify_reflexive_comparison(x: &Expr, y: &Expr) -> Option<(bool, bool)> {
194
2377760
    if x.identical_atom_to(y) && is_semantically_safe(x) && is_semantically_safe(y) {
195
80
        return Some((true, false));
196
2377680
    }
197

            
198
2377680
    if is_semantically_safe(x) && is_semantically_safe(y) && x == y {
199
24120
        return Some((true, false));
200
2353560
    }
201

            
202
2353560
    None
203
2377760
}
204

            
205
39246952
pub fn run_partial_evaluator(expr: &Expr) -> ApplicationResult {
206
    // NOTE: If nothing changes, we must return RuleNotApplicable, or the rewriter will try this
207
    // rule infinitely!
208
    // This is why we always check whether we found a constant or not.
209
39246952
    match expr {
210
188
        Expr::Union(_, _, _) => Err(RuleNotApplicable),
211
4634
        Expr::In(_, _, _) => Err(RuleNotApplicable),
212
560
        Expr::Intersect(_, _, _) => Err(RuleNotApplicable),
213
320
        Expr::Supset(_, _, _) => Err(RuleNotApplicable),
214
320
        Expr::SupsetEq(_, _, _) => Err(RuleNotApplicable),
215
800
        Expr::Subset(_, _, _) => Err(RuleNotApplicable),
216
362
        Expr::SubsetEq(_, _, _) => Err(RuleNotApplicable),
217
4542138
        Expr::AbstractLiteral(_, _) => Err(RuleNotApplicable),
218
209780
        Expr::Comprehension(_, _) => Err(RuleNotApplicable),
219
        Expr::AbstractComprehension(_, _) => Err(RuleNotApplicable),
220
        Expr::DominanceRelation(_, _) => Err(RuleNotApplicable),
221
        Expr::FromSolution(_, _) => Err(RuleNotApplicable),
222
        Expr::Metavar(_, _) => Err(RuleNotApplicable),
223
1039636
        Expr::UnsafeIndex(_, _, _) => Err(RuleNotApplicable),
224
61640
        Expr::UnsafeSlice(_, _, _) => Err(RuleNotApplicable),
225
1920
        Expr::Table(_, _, _) => Err(RuleNotApplicable),
226
320
        Expr::NegativeTable(_, _, _) => Err(RuleNotApplicable),
227
2732746
        Expr::SafeIndex(_, subject, indices) => {
228
            // partially evaluate matrix literals indexed by a constant.
229

            
230
            // subject must be a matrix literal
231
2732746
            let (es, index_domain) = resolve_matrix_subject(subject).ok_or(RuleNotApplicable)?;
232

            
233
956470
            if indices.is_empty() {
234
                return Err(RuleNotApplicable);
235
956470
            }
236

            
237
            // the leading index must be fixed to a single value
238
956470
            let index = singleton_int_value(&indices[0]).ok_or(RuleNotApplicable)?;
239

            
240
            // index domain must be a single integer range with a lower bound
241
660
            if let Some(ranges) = index_domain.as_int_ground()
242
660
                && ranges.len() == 1
243
660
                && let Some(from) = ranges[0].low()
244
            {
245
660
                let zero_indexed_index = index - from;
246
660
                let selected = es
247
660
                    .get(zero_indexed_index as usize)
248
660
                    .ok_or(RuleNotApplicable)?
249
660
                    .clone();
250

            
251
660
                if indices.len() == 1 {
252
                    Ok(Reduction::pure(selected))
253
                } else {
254
660
                    Ok(Reduction::pure(Expr::SafeIndex(
255
660
                        Metadata::new(),
256
660
                        Moo::new(selected),
257
660
                        indices[1..].to_vec(),
258
660
                    )))
259
                }
260
            } else {
261
                Err(RuleNotApplicable)
262
            }
263
        }
264
79920
        Expr::SafeSlice(_, _, _) => Err(RuleNotApplicable),
265
22100
        Expr::InDomain(_, x, domain) => {
266
22100
            if let Some(result) = simplify_in_domain(x, domain) {
267
26
                Ok(Reduction::pure(Expr::Atomic(
268
26
                    Metadata::new(),
269
26
                    result.into(),
270
26
                )))
271
22074
            } else if let Expr::Atomic(_, Atom::Reference(decl)) = x.as_ref() {
272
4240
                let decl_domain = decl
273
4240
                    .domain()
274
4240
                    .ok_or(RuleNotApplicable)?
275
4240
                    .resolve()
276
4240
                    .ok_or(RuleNotApplicable)?;
277
4240
                let domain = domain.resolve().ok_or(RuleNotApplicable)?;
278

            
279
4240
                let intersection = decl_domain
280
4240
                    .intersect(&domain)
281
4240
                    .map_err(|_| RuleNotApplicable)?;
282

            
283
                // if the declaration's domain is a subset of domain, expr is always true.
284
4240
                if &intersection == decl_domain.as_ref() {
285
                    Ok(Reduction::pure(Expr::Atomic(Metadata::new(), true.into())))
286
                }
287
                // if no elements of declaration's domain are in the domain (i.e. they have no
288
                // intersection), expr is always false.
289
                //
290
                // Only check this when the intersection is a finite integer domain, as we
291
                // currently don't have a way to check whether other domain kinds are empty or not.
292
                //
293
                // we should expand this to cover more domain types in the future.
294
4240
                else if let Ok(values_in_domain) = intersection.values_i32()
295
4240
                    && values_in_domain.is_empty()
296
                {
297
                    Ok(Reduction::pure(Expr::Atomic(Metadata::new(), false.into())))
298
                } else {
299
4240
                    Err(RuleNotApplicable)
300
                }
301
17834
            } else if let Expr::Atomic(_, Atom::Literal(lit)) = x.as_ref() {
302
                if domain
303
                    .resolve()
304
                    .ok_or(RuleNotApplicable)?
305
                    .contains(lit)
306
                    .ok()
307
                    .ok_or(RuleNotApplicable)?
308
                {
309
                    Ok(Reduction::pure(Expr::Atomic(Metadata::new(), true.into())))
310
                } else {
311
                    Ok(Reduction::pure(Expr::Atomic(Metadata::new(), false.into())))
312
                }
313
            } else {
314
17834
                Err(RuleNotApplicable)
315
            }
316
        }
317
35694
        Expr::Bubble(_, expr, cond) => {
318
            // definition of bubble is "expr is valid as long as cond is true"
319
            //
320
            // check if cond is true and pop the bubble!
321
35694
            if let Expr::Atomic(_, Atom::Literal(Lit::Bool(true))) = cond.as_ref() {
322
2418
                Ok(Reduction::pure(Moo::unwrap_or_clone(expr.clone())))
323
            } else {
324
33276
                Err(RuleNotApplicable)
325
            }
326
        }
327
15312648
        Expr::Atomic(_, _) => Err(RuleNotApplicable),
328
17906
        Expr::ToInt(_, expression) => {
329
17906
            if expression.return_type() == ReturnType::Int {
330
                Ok(Reduction::pure(Moo::unwrap_or_clone(expression.clone())))
331
            } else {
332
17906
                Err(RuleNotApplicable)
333
            }
334
        }
335
19240
        Expr::Abs(m, e) => match e.as_ref() {
336
160
            Expr::Neg(_, inner) => Ok(Reduction::pure(Expr::Abs(m.clone(), inner.clone()))),
337
19080
            _ => Err(RuleNotApplicable),
338
        },
339
2045920
        Expr::Sum(m, vec) => {
340
2045920
            let vec = Moo::unwrap_or_clone(vec.clone())
341
2045920
                .unwrap_list()
342
2045920
                .ok_or(RuleNotApplicable)?;
343
1683460
            let mut acc = 0;
344
1683460
            let mut n_consts = 0;
345
1683460
            let mut new_vec: Vec<Expr> = Vec::new();
346
3660968
            for expr in vec {
347
1332394
                if let Expr::Atomic(_, Atom::Literal(Lit::Int(x))) = expr {
348
1332394
                    acc += x;
349
1332394
                    n_consts += 1;
350
2328574
                } else {
351
2328574
                    new_vec.push(expr);
352
2328574
                }
353
            }
354
1683460
            if acc != 0 {
355
1276878
                new_vec.push(Expr::Atomic(
356
1276878
                    Default::default(),
357
1276878
                    Atom::Literal(Lit::Int(acc)),
358
1276878
                ));
359
1277422
            }
360

            
361
1683460
            if n_consts <= 1 {
362
1675816
                Err(RuleNotApplicable)
363
            } else {
364
7644
                Ok(Reduction::pure(Expr::Sum(
365
7644
                    m.clone(),
366
7644
                    Moo::new(into_matrix_expr![new_vec]),
367
7644
                )))
368
            }
369
        }
370

            
371
666254
        Expr::Product(m, vec) => {
372
666254
            let mut acc = 1;
373
666254
            let mut n_consts = 0;
374
666254
            let mut new_vec: Vec<Expr> = Vec::new();
375
666254
            let vec = Moo::unwrap_or_clone(vec.clone())
376
666254
                .unwrap_list()
377
666254
                .ok_or(RuleNotApplicable)?;
378
1140782
            for expr in vec {
379
449728
                if let Expr::Atomic(_, Atom::Literal(Lit::Int(x))) = expr {
380
449728
                    acc *= x;
381
449728
                    n_consts += 1;
382
691054
                } else {
383
691054
                    new_vec.push(expr);
384
691054
                }
385
            }
386

            
387
575294
            if n_consts == 0 {
388
125686
                return Err(RuleNotApplicable);
389
449608
            }
390

            
391
449608
            new_vec.push(Expr::Atomic(
392
449608
                Default::default(),
393
449608
                Atom::Literal(Lit::Int(acc)),
394
449608
            ));
395
449608
            let new_product = Expr::Product(m.clone(), Moo::new(into_matrix_expr![new_vec]));
396

            
397
449608
            if acc == 0 {
398
                // if safe, 0 * exprs ~> 0
399
                // otherwise, just return 0* exprs
400
                if is_semantically_safe(&new_product) {
401
                    Ok(Reduction::pure(Expr::Atomic(
402
                        Default::default(),
403
                        Atom::Literal(Lit::Int(0)),
404
                    )))
405
                } else {
406
                    Ok(Reduction::pure(new_product))
407
                }
408
449608
            } else if n_consts == 1 {
409
                // acc !=0, only one constant
410
449528
                Err(RuleNotApplicable)
411
            } else {
412
                // acc !=0, multiple constants found
413
80
                Ok(Reduction::pure(new_product))
414
            }
415
        }
416

            
417
41880
        Expr::Min(m, e) => {
418
41880
            let Some(vec) = Moo::unwrap_or_clone(e.clone()).unwrap_list() else {
419
36360
                return Err(RuleNotApplicable);
420
            };
421
5520
            let mut acc: Option<i32> = None;
422
5520
            let mut n_consts = 0;
423
5520
            let mut new_vec: Vec<Expr> = Vec::new();
424
11760
            for expr in vec {
425
480
                if let Expr::Atomic(_, Atom::Literal(Lit::Int(x))) = expr {
426
480
                    n_consts += 1;
427
480
                    acc = match acc {
428
                        Some(i) => {
429
                            if i > x {
430
                                Some(x)
431
                            } else {
432
                                Some(i)
433
                            }
434
                        }
435
480
                        None => Some(x),
436
                    };
437
11280
                } else {
438
11280
                    new_vec.push(expr);
439
11280
                }
440
            }
441

            
442
5520
            if let Some(i) = acc {
443
480
                new_vec.push(Expr::Atomic(Default::default(), Atom::Literal(Lit::Int(i))));
444
5040
            }
445

            
446
5520
            if n_consts <= 1 {
447
5520
                Err(RuleNotApplicable)
448
            } else {
449
                Ok(Reduction::pure(Expr::Min(
450
                    m.clone(),
451
                    Moo::new(into_matrix_expr![new_vec]),
452
                )))
453
            }
454
        }
455

            
456
44000
        Expr::Max(m, e) => {
457
44000
            let Some(vec) = Moo::unwrap_or_clone(e.clone()).unwrap_list() else {
458
38960
                return Err(RuleNotApplicable);
459
            };
460

            
461
5040
            let mut acc: Option<i32> = None;
462
5040
            let mut n_consts = 0;
463
5040
            let mut new_vec: Vec<Expr> = Vec::new();
464
10320
            for expr in vec {
465
240
                if let Expr::Atomic(_, Atom::Literal(Lit::Int(x))) = expr {
466
240
                    n_consts += 1;
467
240
                    acc = match acc {
468
                        Some(i) => {
469
                            if i < x {
470
                                Some(x)
471
                            } else {
472
                                Some(i)
473
                            }
474
                        }
475
240
                        None => Some(x),
476
                    };
477
10080
                } else {
478
10080
                    new_vec.push(expr);
479
10080
                }
480
            }
481

            
482
5040
            if let Some(i) = acc {
483
240
                new_vec.push(Expr::Atomic(Default::default(), Atom::Literal(Lit::Int(i))));
484
4800
            }
485

            
486
5040
            if n_consts <= 1 {
487
5040
                Err(RuleNotApplicable)
488
            } else {
489
                Ok(Reduction::pure(Expr::Max(
490
                    m.clone(),
491
                    Moo::new(into_matrix_expr![new_vec]),
492
                )))
493
            }
494
        }
495
244264
        Expr::Not(_, e1) => {
496
244264
            let Expr::Imply(_, p, q) = e1.as_ref() else {
497
244024
                return Err(RuleNotApplicable);
498
            };
499

            
500
240
            if !is_semantically_safe(e1) {
501
                return Err(RuleNotApplicable);
502
240
            }
503

            
504
240
            match (p.as_ref(), q.as_ref()) {
505
                (_, Expr::Atomic(_, Atom::Literal(Lit::Bool(true)))) => {
506
                    Ok(Reduction::pure(Expr::from(false)))
507
                }
508
                (_, Expr::Atomic(_, Atom::Literal(Lit::Bool(false)))) => {
509
                    Ok(Reduction::pure(Moo::unwrap_or_clone(p.clone())))
510
                }
511
                (Expr::Atomic(_, Atom::Literal(Lit::Bool(true))), _) => {
512
                    Ok(Reduction::pure(Expr::Not(Metadata::new(), q.clone())))
513
                }
514
                (Expr::Atomic(_, Atom::Literal(Lit::Bool(false))), _) => {
515
                    Ok(Reduction::pure(Expr::from(false)))
516
                }
517
240
                _ => Err(RuleNotApplicable),
518
            }
519
        }
520
932124
        Expr::Or(m, e) => {
521
932124
            let Some(terms) = Moo::unwrap_or_clone(e.clone()).unwrap_list() else {
522
211410
                return Err(RuleNotApplicable);
523
            };
524

            
525
720714
            let mut has_changed = false;
526

            
527
            // 2. boolean literals
528
720714
            let mut new_terms = vec![];
529
1507146
            for expr in terms {
530
580
                if let Expr::Atomic(_, Atom::Literal(Lit::Bool(x))) = expr {
531
580
                    has_changed = true;
532

            
533
                    // true ~~> entire or is true
534
                    // false ~~> remove false from the or
535
580
                    if x {
536
                        return Ok(Reduction::pure(true.into()));
537
580
                    }
538
1506566
                } else {
539
1506566
                    new_terms.push(expr);
540
1506566
                }
541
            }
542

            
543
            // 2. check pairwise tautologies.
544
720714
            if check_pairwise_or_tautologies(&new_terms) {
545
80
                return Ok(Reduction::pure(true.into()));
546
720634
            }
547

            
548
            // 3. empty or ~~> false
549
720634
            if new_terms.is_empty() {
550
                return Ok(Reduction::pure(false.into()));
551
720634
            }
552

            
553
720634
            if !has_changed {
554
720062
                return Err(RuleNotApplicable);
555
572
            }
556

            
557
572
            Ok(Reduction::pure(Expr::Or(
558
572
                m.clone(),
559
572
                Moo::new(into_matrix_expr![new_terms]),
560
572
            )))
561
        }
562
732974
        Expr::And(_, e) => {
563
732974
            let Some(vec) = Moo::unwrap_or_clone(e.clone()).unwrap_list() else {
564
297398
                return Err(RuleNotApplicable);
565
            };
566
435576
            let mut new_vec: Vec<Expr> = Vec::new();
567
435576
            let mut has_const: bool = false;
568
1137820
            for expr in vec {
569
164200
                if let Expr::Atomic(_, Atom::Literal(Lit::Bool(x))) = expr {
570
164200
                    has_const = true;
571
164200
                    if !x {
572
1016
                        return Ok(Reduction::pure(Expr::Atomic(
573
1016
                            Default::default(),
574
1016
                            Atom::Literal(Lit::Bool(false)),
575
1016
                        )));
576
163184
                    }
577
973620
                } else {
578
973620
                    new_vec.push(expr);
579
973620
                }
580
            }
581

            
582
434560
            if !has_const {
583
432532
                Err(RuleNotApplicable)
584
            } else {
585
2028
                Ok(Reduction::pure(Expr::And(
586
2028
                    Metadata::new(),
587
2028
                    Moo::new(into_matrix_expr![new_vec]),
588
2028
                )))
589
            }
590
        }
591

            
592
        // similar to And, but booleans are returned wrapped in Root.
593
916024
        Expr::Root(_, es) => {
594
916024
            match es.as_slice() {
595
916024
                [] => Err(RuleNotApplicable),
596
                // want to unwrap nested ands
597
911164
                [Expr::And(_, _)] => Ok(()),
598
                // root([true]) / root([false]) are already evaluated
599
276264
                [_] => Err(RuleNotApplicable),
600
613470
                [_, _, ..] => Ok(()),
601
281124
            }?;
602

            
603
634900
            let mut new_vec: Vec<Expr> = Vec::new();
604
634900
            let mut has_changed: bool = false;
605
4117918
            for expr in es {
606
3504
                match expr {
607
3504
                    Expr::Atomic(_, Atom::Literal(Lit::Bool(x))) => {
608
3504
                        has_changed = true;
609
3504
                        if !x {
610
                            // false
611
308
                            return Ok(Reduction::pure(Expr::Root(
612
308
                                Metadata::new(),
613
308
                                vec![Expr::Atomic(
614
308
                                    Default::default(),
615
308
                                    Atom::Literal(Lit::Bool(false)),
616
308
                                )],
617
308
                            )));
618
3196
                        }
619
                        // remove trues
620
                    }
621

            
622
                    // flatten ands in root
623
201320
                    Expr::And(_, vecs) => match Moo::unwrap_or_clone(vecs.clone()).unwrap_list() {
624
51120
                        Some(mut list) => {
625
51120
                            has_changed = true;
626
51120
                            new_vec.append(&mut list);
627
51120
                        }
628
150200
                        None => new_vec.push(expr.clone()),
629
                    },
630
3913094
                    _ => new_vec.push(expr.clone()),
631
                }
632
            }
633

            
634
634592
            if !has_changed {
635
583366
                Err(RuleNotApplicable)
636
            } else {
637
51226
                if new_vec.is_empty() {
638
200
                    new_vec.push(true.into());
639
51026
                }
640
51226
                Ok(Reduction::pure(Expr::Root(Metadata::new(), new_vec)))
641
            }
642
        }
643
293076
        Expr::Imply(_m, x, y) => {
644
293076
            if let Expr::Atomic(_, Atom::Literal(Lit::Bool(x))) = x.as_ref() {
645
161968
                if *x {
646
                    // (true) -> y ~~> y
647
1568
                    return Ok(Reduction::pure(Moo::unwrap_or_clone(y.clone())));
648
                } else {
649
                    // (false) -> y ~~> true
650
160400
                    return Ok(Reduction::pure(Expr::Atomic(Metadata::new(), true.into())));
651
                }
652
131108
            };
653

            
654
131108
            if let Expr::Atomic(_, Atom::Literal(Lit::Bool(y))) = y.as_ref() {
655
116
                if *y {
656
                    // x -> (true) ~~> true
657
80
                    return Ok(Reduction::pure(Expr::from(true)));
658
                } else {
659
                    // x -> (false) ~~> !x
660
36
                    return Ok(Reduction::pure(Expr::Not(Metadata::new(), x.clone())));
661
                }
662
130992
            };
663

            
664
            // reflexivity: p -> p ~> true
665

            
666
            // instead of checking syntactic equivalence of a possibly deep expression,
667
            // let identical-CSE turn them into identical variables first. Then, check if they are
668
            // identical variables.
669

            
670
130992
            if x.identical_atom_to(y.as_ref()) && is_semantically_safe(x) && is_semantically_safe(y)
671
            {
672
82
                return Ok(Reduction::pure(true.into()));
673
130910
            }
674

            
675
130910
            Err(RuleNotApplicable)
676
        }
677
12204
        Expr::Iff(_m, x, y) => {
678
12204
            if let Expr::Atomic(_, Atom::Literal(Lit::Bool(x))) = x.as_ref() {
679
80
                if *x {
680
                    // (true) <-> y ~~> y
681
                    return Ok(Reduction::pure(Moo::unwrap_or_clone(y.clone())));
682
                } else {
683
                    // (false) <-> y ~~> !y
684
80
                    return Ok(Reduction::pure(Expr::Not(Metadata::new(), y.clone())));
685
                }
686
12124
            };
687
12124
            if let Expr::Atomic(_, Atom::Literal(Lit::Bool(y))) = y.as_ref() {
688
                if *y {
689
                    // x <-> (true) ~~> x
690
                    return Ok(Reduction::pure(Moo::unwrap_or_clone(x.clone())));
691
                } else {
692
                    // x <-> (false) ~~> !x
693
                    return Ok(Reduction::pure(Expr::Not(Metadata::new(), x.clone())));
694
                }
695
12124
            };
696

            
697
            // reflexivity: p <-> p ~> true
698

            
699
            // instead of checking syntactic equivalence of a possibly deep expression,
700
            // let identical-CSE turn them into identical variables first. Then, check if they are
701
            // identical variables.
702

            
703
12124
            if x.identical_atom_to(y.as_ref()) && is_semantically_safe(x) && is_semantically_safe(y)
704
            {
705
80
                return Ok(Reduction::pure(true.into()));
706
12044
            }
707

            
708
12044
            Err(RuleNotApplicable)
709
        }
710
1474646
        Expr::Eq(_, x, y) => {
711
1474646
            if let Some((eq_result, _)) = simplify_reflexive_comparison(x, y) {
712
200
                Ok(Reduction::pure(Expr::Atomic(
713
200
                    Metadata::new(),
714
200
                    Atom::Literal(Lit::Bool(eq_result)),
715
200
                )))
716
1474446
            } else if let Expr::Atomic(_, Atom::Literal(lit)) = x.as_ref()
717
21734
                && let Some((eq_result, _)) = simplify_comparison_with_literal(y, lit)
718
            {
719
160
                Ok(Reduction::pure(Expr::Atomic(
720
160
                    Metadata::new(),
721
160
                    Atom::Literal(Lit::Bool(eq_result)),
722
160
                )))
723
1474286
            } else if let Expr::Atomic(_, Atom::Literal(lit)) = y.as_ref()
724
619546
                && let Some((eq_result, _)) = simplify_comparison_with_literal(x, lit)
725
            {
726
916
                Ok(Reduction::pure(Expr::Atomic(
727
916
                    Metadata::new(),
728
916
                    Atom::Literal(Lit::Bool(eq_result)),
729
916
                )))
730
            } else {
731
1473370
                Err(RuleNotApplicable)
732
            }
733
        }
734
903114
        Expr::Neq(_, x, y) => {
735
903114
            if let Some((_, neq_result)) = simplify_reflexive_comparison(x, y) {
736
24000
                Ok(Reduction::pure(Expr::Atomic(
737
24000
                    Metadata::new(),
738
24000
                    Atom::Literal(Lit::Bool(neq_result)),
739
24000
                )))
740
879114
            } else if let Expr::Atomic(_, Atom::Literal(lit)) = x.as_ref()
741
226662
                && let Some((_, neq_result)) = simplify_comparison_with_literal(y, lit)
742
            {
743
482
                Ok(Reduction::pure(Expr::Atomic(
744
482
                    Metadata::new(),
745
482
                    Atom::Literal(Lit::Bool(neq_result)),
746
482
                )))
747
878632
            } else if let Expr::Atomic(_, Atom::Literal(lit)) = y.as_ref()
748
64904
                && let Some((_, neq_result)) = simplify_comparison_with_literal(x, lit)
749
            {
750
2204
                Ok(Reduction::pure(Expr::Atomic(
751
2204
                    Metadata::new(),
752
2204
                    Atom::Literal(Lit::Bool(neq_result)),
753
2204
                )))
754
            } else {
755
876428
                Err(RuleNotApplicable)
756
            }
757
        }
758
357608
        Expr::Geq(_, _, _) => Err(RuleNotApplicable),
759
869562
        Expr::Leq(_, _, _) => Err(RuleNotApplicable),
760
16858
        Expr::Gt(_, _, _) => Err(RuleNotApplicable),
761
149642
        Expr::Lt(_, _, _) => Err(RuleNotApplicable),
762
56600
        Expr::SafeDiv(_, _, _) => Err(RuleNotApplicable),
763
60120
        Expr::UnsafeDiv(_, _, _) => Err(RuleNotApplicable),
764
13538
        Expr::Flatten(_, _, _) => Err(RuleNotApplicable), // TODO: check if anything can be done here
765
140670
        Expr::AllDiff(m, e) => {
766
140670
            let Some(vec) = Moo::unwrap_or_clone(e.clone()).unwrap_list() else {
767
125132
                return Err(RuleNotApplicable);
768
            };
769

            
770
15538
            let mut consts: HashSet<i32> = HashSet::new();
771

            
772
            // check for duplicate constant values which would fail the constraint
773
57264
            for expr in vec {
774
2760
                if let Expr::Atomic(_, Atom::Literal(Lit::Int(x))) = expr
775
2760
                    && !consts.insert(x)
776
                {
777
                    return Ok(Reduction::pure(Expr::Atomic(
778
                        m.clone(),
779
                        Atom::Literal(Lit::Bool(false)),
780
                    )));
781
57264
                }
782
            }
783

            
784
            // nothing has changed
785
15538
            Err(RuleNotApplicable)
786
        }
787
98690
        Expr::Neg(_, _) => Err(RuleNotApplicable),
788
        Expr::Factorial(_, _) => Err(RuleNotApplicable),
789
287770
        Expr::AuxDeclaration(_, _, _) => Err(RuleNotApplicable),
790
6400
        Expr::UnsafeMod(_, _, _) => Err(RuleNotApplicable),
791
17520
        Expr::SafeMod(_, _, _) => Err(RuleNotApplicable),
792
11870
        Expr::UnsafePow(_, _, _) => Err(RuleNotApplicable),
793
20724
        Expr::SafePow(_, _, _) => Err(RuleNotApplicable),
794
408706
        Expr::Minus(_, _, _) => Err(RuleNotApplicable),
795
        Expr::Card(_, _) => todo!(),
796

            
797
        // As these are in a low level solver form, I'm assuming that these have already been
798
        // simplified and partially evaluated.
799
68640
        Expr::FlatAllDiff(_, _) => Err(RuleNotApplicable),
800
5600
        Expr::FlatAbsEq(_, _, _) => Err(RuleNotApplicable),
801
209132
        Expr::FlatIneq(_, _, _, _) => Err(RuleNotApplicable),
802
2160
        Expr::FlatMinusEq(_, _, _) => Err(RuleNotApplicable),
803
9480
        Expr::FlatProductEq(_, _, _, _) => Err(RuleNotApplicable),
804
532566
        Expr::FlatSumLeq(_, _, _) => Err(RuleNotApplicable),
805
552516
        Expr::FlatSumGeq(_, _, _) => Err(RuleNotApplicable),
806
45470
        Expr::FlatWatchedLiteral(_, _, _) => Err(RuleNotApplicable),
807
213920
        Expr::FlatWeightedSumLeq(_, _, _, _) => Err(RuleNotApplicable),
808
214600
        Expr::FlatWeightedSumGeq(_, _, _, _) => Err(RuleNotApplicable),
809
16040
        Expr::MinionDivEqUndefZero(_, _, _, _) => Err(RuleNotApplicable),
810
3600
        Expr::MinionModuloEqUndefZero(_, _, _, _) => Err(RuleNotApplicable),
811
5622
        Expr::MinionPow(_, _, _, _) => Err(RuleNotApplicable),
812
216944
        Expr::MinionReify(_, _, _) => Err(RuleNotApplicable),
813
76560
        Expr::MinionReifyImply(_, _, _) => Err(RuleNotApplicable),
814
2920
        Expr::MinionWInIntervalSet(_, _, _) => Err(RuleNotApplicable),
815
960
        Expr::MinionWInSet(_, _, _) => Err(RuleNotApplicable),
816
295610
        Expr::MinionElementOne(_, _, _, _) => Err(RuleNotApplicable),
817
1824020
        Expr::SATInt(_, _, _, _) => Err(RuleNotApplicable),
818
        Expr::PairwiseSum(_, _, _) => Err(RuleNotApplicable),
819
        Expr::PairwiseProduct(_, _, _) => Err(RuleNotApplicable),
820
        Expr::Active(_, _, _) => todo!(),
821
        Expr::Defined(_, _) => todo!(),
822
        Expr::Range(_, _) => todo!(),
823
        Expr::Image(_, _, _) => todo!(),
824
        Expr::ImageSet(_, _, _) => todo!(),
825
        Expr::PreImage(_, _, _) => todo!(),
826
        Expr::Inverse(_, _, _) => todo!(),
827
        Expr::Restrict(_, _, _) => todo!(),
828
        Expr::ToSet(_, _) => todo!(),
829
        Expr::ToMSet(_, _) => todo!(),
830
        Expr::ToRelation(_, _) => todo!(),
831
        Expr::RelationProj(_, _, _) => todo!(),
832
        Expr::Apart(_, _, _) => todo!(),
833
        Expr::Together(_, _, _) => todo!(),
834
        Expr::Participants(_, _) => todo!(),
835
        Expr::Party(_, _, _) => todo!(),
836
        Expr::Parts(_, _) => todo!(),
837
        Expr::Subsequence(_, _, _) => todo!(),
838
        Expr::Substring(_, _, _) => todo!(),
839
2402
        Expr::LexLt(_, _, _) => Err(RuleNotApplicable),
840
41000
        Expr::LexLeq(_, _, _) => Err(RuleNotApplicable),
841
120
        Expr::LexGt(_, _, _) => Err(RuleNotApplicable),
842
240
        Expr::LexGeq(_, _, _) => Err(RuleNotApplicable),
843
480
        Expr::FlatLexLt(_, _, _) => Err(RuleNotApplicable),
844
720
        Expr::FlatLexLeq(_, _, _) => Err(RuleNotApplicable),
845
    }
846
39246952
}
847

            
848
/// Checks for tautologies involving pairs of terms inside an or, returning true if one is found.
849
///
850
/// This applies the following rules:
851
///
852
/// ```text
853
/// (p->q) \/ (q->p) ~> true    [totality of implication]
854
/// (p->q) \/ (p-> !q) ~> true  [conditional excluded middle]
855
/// ```
856
///
857
720714
fn check_pairwise_or_tautologies(or_terms: &[Expr]) -> bool {
858
    // Collect terms that are structurally identical to the rule input.
859
    // Then, try the rules on these terms, also checking the other conditions of the rules.
860

            
861
    // stores (p,q) in p -> q
862
720714
    let mut p_implies_q: Vec<(&Expr, &Expr)> = vec![];
863

            
864
    // stores (p,q) in p -> !q
865
720714
    let mut p_implies_not_q: Vec<(&Expr, &Expr)> = vec![];
866

            
867
1506566
    for term in or_terms.iter() {
868
1506566
        if let Expr::Imply(_, p, q) = term {
869
            // we use identical_atom_to for equality later on, so these sets are mutually exclusive.
870
            //
871
            // in general however, p -> !q would be in p_implies_q as (p,!q)
872
1000
            if let Expr::Not(_, q_1) = q.as_ref() {
873
40
                p_implies_not_q.push((p.as_ref(), q_1.as_ref()));
874
960
            } else {
875
960
                p_implies_q.push((p.as_ref(), q.as_ref()));
876
960
            }
877
1505566
        }
878
    }
879

            
880
    // `(p->q) \/ (q->p) ~> true    [totality of implication]`
881
720714
    for ((p1, q1), (q2, p2)) in iproduct!(p_implies_q.iter(), p_implies_q.iter()) {
882
1440
        if p1.identical_atom_to(p2) && q1.identical_atom_to(q2) {
883
40
            return true;
884
1400
        }
885
    }
886

            
887
    // `(p->q) \/ (p-> !q) ~> true`    [conditional excluded middle]
888
720674
    for ((p1, q1), (p2, q2)) in iproduct!(p_implies_q.iter(), p_implies_not_q.iter()) {
889
40
        if p1.identical_atom_to(p2) && q1.identical_atom_to(q2) {
890
40
            return true;
891
        }
892
    }
893

            
894
720634
    false
895
720714
}