1
use std::collections::HashSet;
2

            
3
use crate::ast::Typeable;
4
use crate::{
5
    ast::{
6
        AbstractLiteral, Atom, DomainPtr, Expression as Expr, GroundDomain, Literal as Lit,
7
        Metadata, Moo, Range, ReturnType,
8
    },
9
    into_matrix_expr,
10
    rule_engine::{ApplicationError::RuleNotApplicable, ApplicationResult, Reduction},
11
};
12
use itertools::iproduct;
13
use uniplate::Uniplate;
14

            
15
/// Normalises integer ranges so equivalent domains compare structurally equal.
16
41100
fn normalise_int_domain(domain: &GroundDomain) -> GroundDomain {
17
41100
    match domain {
18
41100
        GroundDomain::Int(ranges) => GroundDomain::Int(Range::squeeze(
19
41100
            &ranges
20
41100
                .iter()
21
52020
                .map(|range| Range::new(range.low().copied(), range.high().copied()))
22
41100
                .collect::<Vec<_>>(),
23
        )),
24
        _ => domain.clone(),
25
    }
26
41100
}
27

            
28
/// Returns whether `expr` is safe after resolving any referenced expressions.
29
5115348
fn is_semantically_safe(expr: &Expr) -> bool {
30
5122922
    fn helper(expr: &Expr, resolving: &mut HashSet<crate::ast::serde::ObjId>) -> bool {
31
5122922
        if !expr.is_safe() {
32
801308
            return false;
33
4321614
        }
34

            
35
40060694
        for subexpr in expr.universe() {
36
25718940
            let Expr::Atomic(_, Atom::Reference(reference)) = subexpr else {
37
21757284
                continue;
38
            };
39

            
40
18303410
            let Some(resolved) = reference.resolve_expression() else {
41
18295836
                continue;
42
            };
43

            
44
7574
            let id = reference.id();
45
7574
            if !resolving.insert(id.clone()) {
46
                return false;
47
7574
            }
48

            
49
7574
            let is_safe = helper(&resolved, resolving);
50
7574
            resolving.remove(&id);
51

            
52
7574
            if !is_safe {
53
                return false;
54
7574
            }
55
        }
56

            
57
4321614
        true
58
5122922
    }
59

            
60
5115348
    helper(expr, &mut HashSet::new())
61
5115348
}
62

            
63
/// Tries to decide `expr in domain` from resolved domains alone.
64
22442
fn simplify_in_domain(expr: &Expr, domain: &DomainPtr) -> Option<bool> {
65
22442
    if !is_semantically_safe(expr) {
66
106
        return None;
67
22336
    }
68

            
69
22336
    let expr_domain = resolved_ground_domain_of_for_partial_eval(expr)?;
70
21896
    let domain = domain.resolve()?;
71
21896
    let intersection = expr_domain.intersect(&domain).ok()?;
72

            
73
20550
    if normalise_int_domain(&intersection) == normalise_int_domain(expr_domain.as_ref()) {
74
26
        return Some(true);
75
20524
    }
76

            
77
20524
    if let Ok(values_in_domain) = intersection.values_i32()
78
20524
        && values_in_domain.is_empty()
79
    {
80
        return Some(false);
81
20524
    }
82

            
83
20524
    None
84
22442
}
85

            
86
/// Extracts an integer when `expr` is known to be a singleton integer value.
87
957154
fn singleton_int_value(expr: &Expr) -> Option<i32> {
88
957154
    if let Ok(value) = expr.try_into() {
89
660
        return Some(value);
90
956494
    }
91

            
92
956494
    let domain = resolved_ground_domain_of_for_partial_eval(expr)?;
93
956494
    let GroundDomain::Int(ranges) = domain.as_ref() else {
94
        return None;
95
    };
96
956494
    let [range] = ranges.as_slice() else {
97
        return None;
98
    };
99
956494
    let (Some(low), Some(high)) = (range.low(), range.high()) else {
100
        return None;
101
    };
102

            
103
956494
    if low == high { Some(*low) } else { None }
104
957154
}
105

            
106
/// Resolves a matrix literal subject, including constant references to matrix literals.
107
2732520
fn resolve_matrix_subject(subject: &Expr) -> Option<(Vec<Expr>, DomainPtr)> {
108
2732520
    subject.clone().unwrap_matrix_unchecked().or_else(|| {
109
1787396
        let Expr::Atomic(_, Atom::Reference(reference)) = subject else {
110
            return None;
111
        };
112

            
113
12030
        let Lit::AbstractLiteral(AbstractLiteral::Matrix(elems, index_domain)) =
114
1787396
            reference.resolve_constant()?
115
        else {
116
            return None;
117
        };
118

            
119
        Some((
120
12030
            elems
121
12030
                .into_iter()
122
129810
                .map(|elem| Expr::Atomic(Metadata::new(), Atom::Literal(elem)))
123
12030
                .collect(),
124
12030
            index_domain.into(),
125
        ))
126
1787396
    })
127
2732520
}
128

            
129
/// Resolves domains for partial evaluation while avoiding malformed indexing panics.
130
1865162
fn resolved_ground_domain_of_for_partial_eval(expr: &Expr) -> Option<Moo<GroundDomain>> {
131
1865162
    match expr {
132
117234
        Expr::SafeIndex(_, subject, _) => {
133
117234
            let subject_domain = resolved_ground_domain_of_for_partial_eval(subject)?;
134
117234
            let GroundDomain::Matrix(elem_domain, _) = subject_domain.as_ref() else {
135
7680
                return None;
136
            };
137

            
138
109554
            Some(elem_domain.clone())
139
        }
140
        Expr::SafeSlice(_, subject, indices) => {
141
            let subject_domain = resolved_ground_domain_of_for_partial_eval(subject)?;
142
            let GroundDomain::Matrix(elem_domain, index_domains) = subject_domain.as_ref() else {
143
                return None;
144
            };
145
            let sliced_dimension = indices.iter().position(Option::is_none);
146

            
147
            match sliced_dimension {
148
                Some(dimension) => Some(Moo::new(GroundDomain::Matrix(
149
                    elem_domain.clone(),
150
                    vec![index_domains[dimension].clone()],
151
                ))),
152
                None => Some(elem_domain.clone()),
153
            }
154
        }
155
        Expr::UnsafeIndex(_, _, _) | Expr::UnsafeSlice(_, _, _) => None,
156
1747928
        _ => expr.domain_of()?.resolve(),
157
    }
158
1865162
}
159

            
160
/// Tries to decide `expr = lit` and `expr != lit` from the resolved domain of `expr`.
161
931314
fn simplify_comparison_with_literal(expr: &Expr, lit: &Lit) -> Option<(bool, bool)> {
162
931314
    if !is_semantically_safe(expr) {
163
162216
        return None;
164
769098
    }
165

            
166
769098
    let expr_domain = resolved_ground_domain_of_for_partial_eval(expr)?;
167

            
168
750984
    if !expr_domain.contains(lit).ok()? {
169
3272
        return Some((false, true));
170
747712
    }
171

            
172
747712
    match (expr_domain.as_ref(), lit) {
173
569628
        (GroundDomain::Int(ranges), Lit::Int(value)) => {
174
569628
            let [range] = ranges.as_slice() else {
175
4104
                return None;
176
            };
177
565524
            let (Some(low), Some(high)) = (range.low(), range.high()) else {
178
                return None;
179
            };
180

            
181
565524
            if low == high && low == value {
182
490
                Some((true, false))
183
            } else {
184
565034
                None
185
            }
186
        }
187
173924
        (GroundDomain::Bool, Lit::Bool(_)) => None,
188
4160
        _ => None,
189
    }
190
931314
}
191

            
192
/// Tries to decide reflexive equality and inequality when both sides are semantically safe.
193
2372436
fn simplify_reflexive_comparison(x: &Expr, y: &Expr) -> Option<(bool, bool)> {
194
2372436
    if x.identical_atom_to(y) && is_semantically_safe(x) && is_semantically_safe(y) {
195
80
        return Some((true, false));
196
2372356
    }
197

            
198
2372356
    if is_semantically_safe(x) && is_semantically_safe(y) && x == y {
199
24120
        return Some((true, false));
200
2348236
    }
201

            
202
2348236
    None
203
2372436
}
204

            
205
39044754
pub fn run_partial_evaluator(expr: &Expr) -> ApplicationResult {
206
    // NOTE: If nothing changes, we must return RuleNotApplicable, or the rewriter will try this
207
    // rule infinitely!
208
    // This is why we always check whether we found a constant or not.
209
39044754
    match expr {
210
188
        Expr::Union(_, _, _) => Err(RuleNotApplicable),
211
4632
        Expr::In(_, _, _) => Err(RuleNotApplicable),
212
560
        Expr::Intersect(_, _, _) => Err(RuleNotApplicable),
213
320
        Expr::Supset(_, _, _) => Err(RuleNotApplicable),
214
320
        Expr::SupsetEq(_, _, _) => Err(RuleNotApplicable),
215
800
        Expr::Subset(_, _, _) => Err(RuleNotApplicable),
216
362
        Expr::SubsetEq(_, _, _) => Err(RuleNotApplicable),
217
4517296
        Expr::AbstractLiteral(_, _) => Err(RuleNotApplicable),
218
210366
        Expr::Comprehension(_, _) => Err(RuleNotApplicable),
219
        Expr::AbstractComprehension(_, _) => Err(RuleNotApplicable),
220
        Expr::DominanceRelation(_, _) => Err(RuleNotApplicable),
221
        Expr::FromSolution(_, _) => Err(RuleNotApplicable),
222
        Expr::Metavar(_, _) => Err(RuleNotApplicable),
223
1039636
        Expr::UnsafeIndex(_, _, _) => Err(RuleNotApplicable),
224
61640
        Expr::UnsafeSlice(_, _, _) => Err(RuleNotApplicable),
225
1920
        Expr::Table(_, _, _) => Err(RuleNotApplicable),
226
320
        Expr::NegativeTable(_, _, _) => Err(RuleNotApplicable),
227
2732520
        Expr::SafeIndex(_, subject, indices) => {
228
            // partially evaluate matrix literals indexed by a constant.
229

            
230
            // subject must be a matrix literal
231
2732520
            let (es, index_domain) = resolve_matrix_subject(subject).ok_or(RuleNotApplicable)?;
232

            
233
957154
            if indices.is_empty() {
234
                return Err(RuleNotApplicable);
235
957154
            }
236

            
237
            // the leading index must be fixed to a single value
238
957154
            let index = singleton_int_value(&indices[0]).ok_or(RuleNotApplicable)?;
239

            
240
            // index domain must be a single integer range with a lower bound
241
660
            if let Some(ranges) = index_domain.as_int_ground()
242
660
                && ranges.len() == 1
243
660
                && let Some(from) = ranges[0].low()
244
            {
245
660
                let zero_indexed_index = index - from;
246
660
                let selected = es
247
660
                    .get(zero_indexed_index as usize)
248
660
                    .ok_or(RuleNotApplicable)?
249
660
                    .clone();
250

            
251
660
                if indices.len() == 1 {
252
                    Ok(Reduction::pure(selected))
253
                } else {
254
660
                    Ok(Reduction::pure(Expr::SafeIndex(
255
660
                        Metadata::new(),
256
660
                        Moo::new(selected),
257
660
                        indices[1..].to_vec(),
258
660
                    )))
259
                }
260
            } else {
261
                Err(RuleNotApplicable)
262
            }
263
        }
264
79920
        Expr::SafeSlice(_, _, _) => Err(RuleNotApplicable),
265
22442
        Expr::InDomain(_, x, domain) => {
266
22442
            if let Some(result) = simplify_in_domain(x, domain) {
267
26
                Ok(Reduction::pure(Expr::Atomic(
268
26
                    Metadata::new(),
269
26
                    result.into(),
270
26
                )))
271
22416
            } else if let Expr::Atomic(_, Atom::Reference(decl)) = x.as_ref() {
272
4240
                let decl_domain = decl
273
4240
                    .domain()
274
4240
                    .ok_or(RuleNotApplicable)?
275
4240
                    .resolve()
276
4240
                    .ok_or(RuleNotApplicable)?;
277
4240
                let domain = domain.resolve().ok_or(RuleNotApplicable)?;
278

            
279
4240
                let intersection = decl_domain
280
4240
                    .intersect(&domain)
281
4240
                    .map_err(|_| RuleNotApplicable)?;
282

            
283
                // if the declaration's domain is a subset of domain, expr is always true.
284
4240
                if &intersection == decl_domain.as_ref() {
285
                    Ok(Reduction::pure(Expr::Atomic(Metadata::new(), true.into())))
286
                }
287
                // if no elements of declaration's domain are in the domain (i.e. they have no
288
                // intersection), expr is always false.
289
                //
290
                // Only check this when the intersection is a finite integer domain, as we
291
                // currently don't have a way to check whether other domain kinds are empty or not.
292
                //
293
                // we should expand this to cover more domain types in the future.
294
4240
                else if let Ok(values_in_domain) = intersection.values_i32()
295
4240
                    && values_in_domain.is_empty()
296
                {
297
                    Ok(Reduction::pure(Expr::Atomic(Metadata::new(), false.into())))
298
                } else {
299
4240
                    Err(RuleNotApplicable)
300
                }
301
18176
            } else if let Expr::Atomic(_, Atom::Literal(lit)) = x.as_ref() {
302
                if domain
303
                    .resolve()
304
                    .ok_or(RuleNotApplicable)?
305
                    .contains(lit)
306
                    .ok()
307
                    .ok_or(RuleNotApplicable)?
308
                {
309
                    Ok(Reduction::pure(Expr::Atomic(Metadata::new(), true.into())))
310
                } else {
311
                    Ok(Reduction::pure(Expr::Atomic(Metadata::new(), false.into())))
312
                }
313
            } else {
314
18176
                Err(RuleNotApplicable)
315
            }
316
        }
317
35694
        Expr::Bubble(_, expr, cond) => {
318
            // definition of bubble is "expr is valid as long as cond is true"
319
            //
320
            // check if cond is true and pop the bubble!
321
35694
            if let Expr::Atomic(_, Atom::Literal(Lit::Bool(true))) = cond.as_ref() {
322
2418
                Ok(Reduction::pure(Moo::unwrap_or_clone(expr.clone())))
323
            } else {
324
33276
                Err(RuleNotApplicable)
325
            }
326
        }
327
15229614
        Expr::Atomic(_, _) => Err(RuleNotApplicable),
328
17524
        Expr::ToInt(_, expression) => {
329
17524
            if expression.return_type() == ReturnType::Int {
330
                Ok(Reduction::pure(Moo::unwrap_or_clone(expression.clone())))
331
            } else {
332
17524
                Err(RuleNotApplicable)
333
            }
334
        }
335
15960
        Expr::Abs(m, e) => match e.as_ref() {
336
160
            Expr::Neg(_, inner) => Ok(Reduction::pure(Expr::Abs(m.clone(), inner.clone()))),
337
15800
            _ => Err(RuleNotApplicable),
338
        },
339
2045934
        Expr::Sum(m, vec) => {
340
2045934
            let vec = Moo::unwrap_or_clone(vec.clone())
341
2045934
                .unwrap_list()
342
2045934
                .ok_or(RuleNotApplicable)?;
343
1683486
            let mut acc = 0;
344
1683486
            let mut n_consts = 0;
345
1683486
            let mut new_vec: Vec<Expr> = Vec::new();
346
3660316
            for expr in vec {
347
1332540
                if let Expr::Atomic(_, Atom::Literal(Lit::Int(x))) = expr {
348
1332540
                    acc += x;
349
1332540
                    n_consts += 1;
350
2327776
                } else {
351
2327776
                    new_vec.push(expr);
352
2327776
                }
353
            }
354
1683486
            if acc != 0 {
355
1276940
                new_vec.push(Expr::Atomic(
356
1276940
                    Default::default(),
357
1276940
                    Atom::Literal(Lit::Int(acc)),
358
1276940
                ));
359
1277386
            }
360

            
361
1683486
            if n_consts <= 1 {
362
1675832
                Err(RuleNotApplicable)
363
            } else {
364
7654
                Ok(Reduction::pure(Expr::Sum(
365
7654
                    m.clone(),
366
7654
                    Moo::new(into_matrix_expr![new_vec]),
367
7654
                )))
368
            }
369
        }
370

            
371
666254
        Expr::Product(m, vec) => {
372
666254
            let mut acc = 1;
373
666254
            let mut n_consts = 0;
374
666254
            let mut new_vec: Vec<Expr> = Vec::new();
375
666254
            let vec = Moo::unwrap_or_clone(vec.clone())
376
666254
                .unwrap_list()
377
666254
                .ok_or(RuleNotApplicable)?;
378
1140782
            for expr in vec {
379
449728
                if let Expr::Atomic(_, Atom::Literal(Lit::Int(x))) = expr {
380
449728
                    acc *= x;
381
449728
                    n_consts += 1;
382
691054
                } else {
383
691054
                    new_vec.push(expr);
384
691054
                }
385
            }
386

            
387
575294
            if n_consts == 0 {
388
125686
                return Err(RuleNotApplicable);
389
449608
            }
390

            
391
449608
            new_vec.push(Expr::Atomic(
392
449608
                Default::default(),
393
449608
                Atom::Literal(Lit::Int(acc)),
394
449608
            ));
395
449608
            let new_product = Expr::Product(m.clone(), Moo::new(into_matrix_expr![new_vec]));
396

            
397
449608
            if acc == 0 {
398
                // if safe, 0 * exprs ~> 0
399
                // otherwise, just return 0* exprs
400
                if is_semantically_safe(&new_product) {
401
                    Ok(Reduction::pure(Expr::Atomic(
402
                        Default::default(),
403
                        Atom::Literal(Lit::Int(0)),
404
                    )))
405
                } else {
406
                    Ok(Reduction::pure(new_product))
407
                }
408
449608
            } else if n_consts == 1 {
409
                // acc !=0, only one constant
410
449528
                Err(RuleNotApplicable)
411
            } else {
412
                // acc !=0, multiple constants found
413
80
                Ok(Reduction::pure(new_product))
414
            }
415
        }
416

            
417
41880
        Expr::Min(m, e) => {
418
41880
            let Some(vec) = Moo::unwrap_or_clone(e.clone()).unwrap_list() else {
419
36360
                return Err(RuleNotApplicable);
420
            };
421
5520
            let mut acc: Option<i32> = None;
422
5520
            let mut n_consts = 0;
423
5520
            let mut new_vec: Vec<Expr> = Vec::new();
424
11760
            for expr in vec {
425
480
                if let Expr::Atomic(_, Atom::Literal(Lit::Int(x))) = expr {
426
480
                    n_consts += 1;
427
480
                    acc = match acc {
428
                        Some(i) => {
429
                            if i > x {
430
                                Some(x)
431
                            } else {
432
                                Some(i)
433
                            }
434
                        }
435
480
                        None => Some(x),
436
                    };
437
11280
                } else {
438
11280
                    new_vec.push(expr);
439
11280
                }
440
            }
441

            
442
5520
            if let Some(i) = acc {
443
480
                new_vec.push(Expr::Atomic(Default::default(), Atom::Literal(Lit::Int(i))));
444
5040
            }
445

            
446
5520
            if n_consts <= 1 {
447
5520
                Err(RuleNotApplicable)
448
            } else {
449
                Ok(Reduction::pure(Expr::Min(
450
                    m.clone(),
451
                    Moo::new(into_matrix_expr![new_vec]),
452
                )))
453
            }
454
        }
455

            
456
44000
        Expr::Max(m, e) => {
457
44000
            let Some(vec) = Moo::unwrap_or_clone(e.clone()).unwrap_list() else {
458
38960
                return Err(RuleNotApplicable);
459
            };
460

            
461
5040
            let mut acc: Option<i32> = None;
462
5040
            let mut n_consts = 0;
463
5040
            let mut new_vec: Vec<Expr> = Vec::new();
464
10320
            for expr in vec {
465
240
                if let Expr::Atomic(_, Atom::Literal(Lit::Int(x))) = expr {
466
240
                    n_consts += 1;
467
240
                    acc = match acc {
468
                        Some(i) => {
469
                            if i < x {
470
                                Some(x)
471
                            } else {
472
                                Some(i)
473
                            }
474
                        }
475
240
                        None => Some(x),
476
                    };
477
10080
                } else {
478
10080
                    new_vec.push(expr);
479
10080
                }
480
            }
481

            
482
5040
            if let Some(i) = acc {
483
240
                new_vec.push(Expr::Atomic(Default::default(), Atom::Literal(Lit::Int(i))));
484
4800
            }
485

            
486
5040
            if n_consts <= 1 {
487
5040
                Err(RuleNotApplicable)
488
            } else {
489
                Ok(Reduction::pure(Expr::Max(
490
                    m.clone(),
491
                    Moo::new(into_matrix_expr![new_vec]),
492
                )))
493
            }
494
        }
495
244262
        Expr::Not(_, e1) => {
496
244262
            let Expr::Imply(_, p, q) = e1.as_ref() else {
497
244022
                return Err(RuleNotApplicable);
498
            };
499

            
500
240
            if !is_semantically_safe(e1) {
501
                return Err(RuleNotApplicable);
502
240
            }
503

            
504
240
            match (p.as_ref(), q.as_ref()) {
505
                (_, Expr::Atomic(_, Atom::Literal(Lit::Bool(true)))) => {
506
                    Ok(Reduction::pure(Expr::from(false)))
507
                }
508
                (_, Expr::Atomic(_, Atom::Literal(Lit::Bool(false)))) => {
509
                    Ok(Reduction::pure(Moo::unwrap_or_clone(p.clone())))
510
                }
511
                (Expr::Atomic(_, Atom::Literal(Lit::Bool(true))), _) => {
512
                    Ok(Reduction::pure(Expr::Not(Metadata::new(), q.clone())))
513
                }
514
                (Expr::Atomic(_, Atom::Literal(Lit::Bool(false))), _) => {
515
                    Ok(Reduction::pure(Expr::from(false)))
516
                }
517
240
                _ => Err(RuleNotApplicable),
518
            }
519
        }
520
928546
        Expr::Or(m, e) => {
521
928546
            let Some(terms) = Moo::unwrap_or_clone(e.clone()).unwrap_list() else {
522
211410
                return Err(RuleNotApplicable);
523
            };
524

            
525
717136
            let mut has_changed = false;
526

            
527
            // 2. boolean literals
528
717136
            let mut new_terms = vec![];
529
1503510
            for expr in terms {
530
580
                if let Expr::Atomic(_, Atom::Literal(Lit::Bool(x))) = expr {
531
580
                    has_changed = true;
532

            
533
                    // true ~~> entire or is true
534
                    // false ~~> remove false from the or
535
580
                    if x {
536
                        return Ok(Reduction::pure(true.into()));
537
580
                    }
538
1502930
                } else {
539
1502930
                    new_terms.push(expr);
540
1502930
                }
541
            }
542

            
543
            // 2. check pairwise tautologies.
544
717136
            if check_pairwise_or_tautologies(&new_terms) {
545
80
                return Ok(Reduction::pure(true.into()));
546
717056
            }
547

            
548
            // 3. empty or ~~> false
549
717056
            if new_terms.is_empty() {
550
                return Ok(Reduction::pure(false.into()));
551
717056
            }
552

            
553
717056
            if !has_changed {
554
716484
                return Err(RuleNotApplicable);
555
572
            }
556

            
557
572
            Ok(Reduction::pure(Expr::Or(
558
572
                m.clone(),
559
572
                Moo::new(into_matrix_expr![new_terms]),
560
572
            )))
561
        }
562
729266
        Expr::And(_, e) => {
563
729266
            let Some(vec) = Moo::unwrap_or_clone(e.clone()).unwrap_list() else {
564
297996
                return Err(RuleNotApplicable);
565
            };
566
431270
            let mut new_vec: Vec<Expr> = Vec::new();
567
431270
            let mut has_const: bool = false;
568
1129004
            for expr in vec {
569
164300
                if let Expr::Atomic(_, Atom::Literal(Lit::Bool(x))) = expr {
570
164300
                    has_const = true;
571
164300
                    if !x {
572
1068
                        return Ok(Reduction::pure(Expr::Atomic(
573
1068
                            Default::default(),
574
1068
                            Atom::Literal(Lit::Bool(false)),
575
1068
                        )));
576
163232
                    }
577
964704
                } else {
578
964704
                    new_vec.push(expr);
579
964704
                }
580
            }
581

            
582
430202
            if !has_const {
583
428126
                Err(RuleNotApplicable)
584
            } else {
585
2076
                Ok(Reduction::pure(Expr::And(
586
2076
                    Metadata::new(),
587
2076
                    Moo::new(into_matrix_expr![new_vec]),
588
2076
                )))
589
            }
590
        }
591

            
592
        // similar to And, but booleans are returned wrapped in Root.
593
906800
        Expr::Root(_, es) => {
594
906800
            match es.as_slice() {
595
906800
                [] => Err(RuleNotApplicable),
596
                // want to unwrap nested ands
597
901940
                [Expr::And(_, _)] => Ok(()),
598
                // root([true]) / root([false]) are already evaluated
599
274496
                [_] => Err(RuleNotApplicable),
600
606014
                [_, _, ..] => Ok(()),
601
279356
            }?;
602

            
603
627444
            let mut new_vec: Vec<Expr> = Vec::new();
604
627444
            let mut has_changed: bool = false;
605
4085758
            for expr in es {
606
3504
                match expr {
607
3504
                    Expr::Atomic(_, Atom::Literal(Lit::Bool(x))) => {
608
3504
                        has_changed = true;
609
3504
                        if !x {
610
                            // false
611
308
                            return Ok(Reduction::pure(Expr::Root(
612
308
                                Metadata::new(),
613
308
                                vec![Expr::Atomic(
614
308
                                    Default::default(),
615
308
                                    Atom::Literal(Lit::Bool(false)),
616
308
                                )],
617
308
                            )));
618
3196
                        }
619
                        // remove trues
620
                    }
621

            
622
                    // flatten ands in root
623
200792
                    Expr::And(_, vecs) => match Moo::unwrap_or_clone(vecs.clone()).unwrap_list() {
624
50480
                        Some(mut list) => {
625
50480
                            has_changed = true;
626
50480
                            new_vec.append(&mut list);
627
50480
                        }
628
150312
                        None => new_vec.push(expr.clone()),
629
                    },
630
3881462
                    _ => new_vec.push(expr.clone()),
631
                }
632
            }
633

            
634
627136
            if !has_changed {
635
576550
                Err(RuleNotApplicable)
636
            } else {
637
50586
                if new_vec.is_empty() {
638
200
                    new_vec.push(true.into());
639
50386
                }
640
50586
                Ok(Reduction::pure(Expr::Root(Metadata::new(), new_vec)))
641
            }
642
        }
643
293200
        Expr::Imply(_m, x, y) => {
644
293200
            if let Expr::Atomic(_, Atom::Literal(Lit::Bool(x))) = x.as_ref() {
645
161968
                if *x {
646
                    // (true) -> y ~~> y
647
1568
                    return Ok(Reduction::pure(Moo::unwrap_or_clone(y.clone())));
648
                } else {
649
                    // (false) -> y ~~> true
650
160400
                    return Ok(Reduction::pure(Expr::Atomic(Metadata::new(), true.into())));
651
                }
652
131232
            };
653

            
654
131232
            if let Expr::Atomic(_, Atom::Literal(Lit::Bool(y))) = y.as_ref() {
655
116
                if *y {
656
                    // x -> (true) ~~> true
657
80
                    return Ok(Reduction::pure(Expr::from(true)));
658
                } else {
659
                    // x -> (false) ~~> !x
660
36
                    return Ok(Reduction::pure(Expr::Not(Metadata::new(), x.clone())));
661
                }
662
131116
            };
663

            
664
            // reflexivity: p -> p ~> true
665

            
666
            // instead of checking syntactic equivalence of a possibly deep expression,
667
            // let identical-CSE turn them into identical variables first. Then, check if they are
668
            // identical variables.
669

            
670
131116
            if x.identical_atom_to(y.as_ref()) && is_semantically_safe(x) && is_semantically_safe(y)
671
            {
672
82
                return Ok(Reduction::pure(true.into()));
673
131034
            }
674

            
675
131034
            Err(RuleNotApplicable)
676
        }
677
12204
        Expr::Iff(_m, x, y) => {
678
12204
            if let Expr::Atomic(_, Atom::Literal(Lit::Bool(x))) = x.as_ref() {
679
80
                if *x {
680
                    // (true) <-> y ~~> y
681
                    return Ok(Reduction::pure(Moo::unwrap_or_clone(y.clone())));
682
                } else {
683
                    // (false) <-> y ~~> !y
684
80
                    return Ok(Reduction::pure(Expr::Not(Metadata::new(), y.clone())));
685
                }
686
12124
            };
687
12124
            if let Expr::Atomic(_, Atom::Literal(Lit::Bool(y))) = y.as_ref() {
688
                if *y {
689
                    // x <-> (true) ~~> x
690
                    return Ok(Reduction::pure(Moo::unwrap_or_clone(x.clone())));
691
                } else {
692
                    // x <-> (false) ~~> !x
693
                    return Ok(Reduction::pure(Expr::Not(Metadata::new(), x.clone())));
694
                }
695
12124
            };
696

            
697
            // reflexivity: p <-> p ~> true
698

            
699
            // instead of checking syntactic equivalence of a possibly deep expression,
700
            // let identical-CSE turn them into identical variables first. Then, check if they are
701
            // identical variables.
702

            
703
12124
            if x.identical_atom_to(y.as_ref()) && is_semantically_safe(x) && is_semantically_safe(y)
704
            {
705
80
                return Ok(Reduction::pure(true.into()));
706
12044
            }
707

            
708
12044
            Err(RuleNotApplicable)
709
        }
710
1469322
        Expr::Eq(_, x, y) => {
711
1469322
            if let Some((eq_result, _)) = simplify_reflexive_comparison(x, y) {
712
200
                Ok(Reduction::pure(Expr::Atomic(
713
200
                    Metadata::new(),
714
200
                    Atom::Literal(Lit::Bool(eq_result)),
715
200
                )))
716
1469122
            } else if let Expr::Atomic(_, Atom::Literal(lit)) = x.as_ref()
717
21734
                && let Some((eq_result, _)) = simplify_comparison_with_literal(y, lit)
718
            {
719
160
                Ok(Reduction::pure(Expr::Atomic(
720
160
                    Metadata::new(),
721
160
                    Atom::Literal(Lit::Bool(eq_result)),
722
160
                )))
723
1468962
            } else if let Expr::Atomic(_, Atom::Literal(lit)) = y.as_ref()
724
618014
                && let Some((eq_result, _)) = simplify_comparison_with_literal(x, lit)
725
            {
726
916
                Ok(Reduction::pure(Expr::Atomic(
727
916
                    Metadata::new(),
728
916
                    Atom::Literal(Lit::Bool(eq_result)),
729
916
                )))
730
            } else {
731
1468046
                Err(RuleNotApplicable)
732
            }
733
        }
734
903114
        Expr::Neq(_, x, y) => {
735
903114
            if let Some((_, neq_result)) = simplify_reflexive_comparison(x, y) {
736
24000
                Ok(Reduction::pure(Expr::Atomic(
737
24000
                    Metadata::new(),
738
24000
                    Atom::Literal(Lit::Bool(neq_result)),
739
24000
                )))
740
879114
            } else if let Expr::Atomic(_, Atom::Literal(lit)) = x.as_ref()
741
226662
                && let Some((_, neq_result)) = simplify_comparison_with_literal(y, lit)
742
            {
743
482
                Ok(Reduction::pure(Expr::Atomic(
744
482
                    Metadata::new(),
745
482
                    Atom::Literal(Lit::Bool(neq_result)),
746
482
                )))
747
878632
            } else if let Expr::Atomic(_, Atom::Literal(lit)) = y.as_ref()
748
64904
                && let Some((_, neq_result)) = simplify_comparison_with_literal(x, lit)
749
            {
750
2204
                Ok(Reduction::pure(Expr::Atomic(
751
2204
                    Metadata::new(),
752
2204
                    Atom::Literal(Lit::Bool(neq_result)),
753
2204
                )))
754
            } else {
755
876428
                Err(RuleNotApplicable)
756
            }
757
        }
758
348248
        Expr::Geq(_, _, _) => Err(RuleNotApplicable),
759
858282
        Expr::Leq(_, _, _) => Err(RuleNotApplicable),
760
16858
        Expr::Gt(_, _, _) => Err(RuleNotApplicable),
761
149642
        Expr::Lt(_, _, _) => Err(RuleNotApplicable),
762
56600
        Expr::SafeDiv(_, _, _) => Err(RuleNotApplicable),
763
60120
        Expr::UnsafeDiv(_, _, _) => Err(RuleNotApplicable),
764
13538
        Expr::Flatten(_, _, _) => Err(RuleNotApplicable), // TODO: check if anything can be done here
765
140670
        Expr::AllDiff(m, e) => {
766
140670
            let Some(vec) = Moo::unwrap_or_clone(e.clone()).unwrap_list() else {
767
125132
                return Err(RuleNotApplicable);
768
            };
769

            
770
15538
            let mut consts: HashSet<i32> = HashSet::new();
771

            
772
            // check for duplicate constant values which would fail the constraint
773
57264
            for expr in vec {
774
2760
                if let Expr::Atomic(_, Atom::Literal(Lit::Int(x))) = expr
775
2760
                    && !consts.insert(x)
776
                {
777
                    return Ok(Reduction::pure(Expr::Atomic(
778
                        m.clone(),
779
                        Atom::Literal(Lit::Bool(false)),
780
                    )));
781
57264
                }
782
            }
783

            
784
            // nothing has changed
785
15538
            Err(RuleNotApplicable)
786
        }
787
98152
        Expr::Neg(_, _) => Err(RuleNotApplicable),
788
        Expr::Factorial(_, _) => Err(RuleNotApplicable),
789
287770
        Expr::AuxDeclaration(_, _, _) => Err(RuleNotApplicable),
790
6400
        Expr::UnsafeMod(_, _, _) => Err(RuleNotApplicable),
791
17520
        Expr::SafeMod(_, _, _) => Err(RuleNotApplicable),
792
11870
        Expr::UnsafePow(_, _, _) => Err(RuleNotApplicable),
793
20724
        Expr::SafePow(_, _, _) => Err(RuleNotApplicable),
794
408702
        Expr::Minus(_, _, _) => Err(RuleNotApplicable),
795

            
796
        // As these are in a low level solver form, I'm assuming that these have already been
797
        // simplified and partially evaluated.
798
68640
        Expr::FlatAllDiff(_, _) => Err(RuleNotApplicable),
799
5120
        Expr::FlatAbsEq(_, _, _) => Err(RuleNotApplicable),
800
209132
        Expr::FlatIneq(_, _, _, _) => Err(RuleNotApplicable),
801
2160
        Expr::FlatMinusEq(_, _, _) => Err(RuleNotApplicable),
802
9480
        Expr::FlatProductEq(_, _, _, _) => Err(RuleNotApplicable),
803
532566
        Expr::FlatSumLeq(_, _, _) => Err(RuleNotApplicable),
804
552516
        Expr::FlatSumGeq(_, _, _) => Err(RuleNotApplicable),
805
45470
        Expr::FlatWatchedLiteral(_, _, _) => Err(RuleNotApplicable),
806
213920
        Expr::FlatWeightedSumLeq(_, _, _, _) => Err(RuleNotApplicable),
807
214600
        Expr::FlatWeightedSumGeq(_, _, _, _) => Err(RuleNotApplicable),
808
16040
        Expr::MinionDivEqUndefZero(_, _, _, _) => Err(RuleNotApplicable),
809
3600
        Expr::MinionModuloEqUndefZero(_, _, _, _) => Err(RuleNotApplicable),
810
5622
        Expr::MinionPow(_, _, _, _) => Err(RuleNotApplicable),
811
216944
        Expr::MinionReify(_, _, _) => Err(RuleNotApplicable),
812
76560
        Expr::MinionReifyImply(_, _, _) => Err(RuleNotApplicable),
813
2920
        Expr::MinionWInIntervalSet(_, _, _) => Err(RuleNotApplicable),
814
960
        Expr::MinionWInSet(_, _, _) => Err(RuleNotApplicable),
815
295610
        Expr::MinionElementOne(_, _, _, _) => Err(RuleNotApplicable),
816
1776020
        Expr::SATInt(_, _, _, _) => Err(RuleNotApplicable),
817
        Expr::PairwiseSum(_, _, _) => Err(RuleNotApplicable),
818
        Expr::PairwiseProduct(_, _, _) => Err(RuleNotApplicable),
819
        Expr::Defined(_, _) => todo!(),
820
        Expr::Range(_, _) => todo!(),
821
        Expr::Image(_, _, _) => todo!(),
822
        Expr::ImageSet(_, _, _) => todo!(),
823
        Expr::PreImage(_, _, _) => todo!(),
824
        Expr::Inverse(_, _, _) => todo!(),
825
        Expr::Restrict(_, _, _) => todo!(),
826
2402
        Expr::LexLt(_, _, _) => Err(RuleNotApplicable),
827
41000
        Expr::LexLeq(_, _, _) => Err(RuleNotApplicable),
828
120
        Expr::LexGt(_, _, _) => Err(RuleNotApplicable),
829
240
        Expr::LexGeq(_, _, _) => Err(RuleNotApplicable),
830
480
        Expr::FlatLexLt(_, _, _) => Err(RuleNotApplicable),
831
720
        Expr::FlatLexLeq(_, _, _) => Err(RuleNotApplicable),
832
    }
833
39044754
}
834

            
835
/// Checks for tautologies involving pairs of terms inside an or, returning true if one is found.
836
///
837
/// This applies the following rules:
838
///
839
/// ```text
840
/// (p->q) \/ (q->p) ~> true    [totality of implication]
841
/// (p->q) \/ (p-> !q) ~> true  [conditional excluded middle]
842
/// ```
843
///
844
717136
fn check_pairwise_or_tautologies(or_terms: &[Expr]) -> bool {
845
    // Collect terms that are structurally identical to the rule input.
846
    // Then, try the rules on these terms, also checking the other conditions of the rules.
847

            
848
    // stores (p,q) in p -> q
849
717136
    let mut p_implies_q: Vec<(&Expr, &Expr)> = vec![];
850

            
851
    // stores (p,q) in p -> !q
852
717136
    let mut p_implies_not_q: Vec<(&Expr, &Expr)> = vec![];
853

            
854
1502930
    for term in or_terms.iter() {
855
1502930
        if let Expr::Imply(_, p, q) = term {
856
            // we use identical_atom_to for equality later on, so these sets are mutually exclusive.
857
            //
858
            // in general however, p -> !q would be in p_implies_q as (p,!q)
859
1000
            if let Expr::Not(_, q_1) = q.as_ref() {
860
40
                p_implies_not_q.push((p.as_ref(), q_1.as_ref()));
861
960
            } else {
862
960
                p_implies_q.push((p.as_ref(), q.as_ref()));
863
960
            }
864
1501930
        }
865
    }
866

            
867
    // `(p->q) \/ (q->p) ~> true    [totality of implication]`
868
717136
    for ((p1, q1), (q2, p2)) in iproduct!(p_implies_q.iter(), p_implies_q.iter()) {
869
1440
        if p1.identical_atom_to(p2) && q1.identical_atom_to(q2) {
870
40
            return true;
871
1400
        }
872
    }
873

            
874
    // `(p->q) \/ (p-> !q) ~> true`    [conditional excluded middle]
875
717096
    for ((p1, q1), (p2, q2)) in iproduct!(p_implies_q.iter(), p_implies_not_q.iter()) {
876
40
        if p1.identical_atom_to(p2) && q1.identical_atom_to(q2) {
877
40
            return true;
878
        }
879
    }
880

            
881
717056
    false
882
717136
}