1
use std::collections::{HashSet, VecDeque};
2

            
3
use conjure_macros::register_rule;
4
use itertools::iproduct;
5
use uniplate::Biplate;
6

            
7
use crate::ast::SymbolTable;
8
use crate::into_matrix_expr;
9
use crate::rule_engine::{ApplicationResult, Reduction};
10
use crate::{
11
    ast::{Atom, Expression as Expr, Literal as Lit, Literal::*},
12
    metadata::Metadata,
13
};
14

            
15
#[register_rule(("Base",9000))]
16
629532
fn partial_evaluator(expr: &Expr, _: &SymbolTable) -> ApplicationResult {
17
    use conjure_core::rule_engine::ApplicationError::RuleNotApplicable;
18
    use Expr::*;
19

            
20
    // NOTE: If nothing changes, we must return RuleNotApplicable, or the rewriter will try this
21
    // rule infinitely!
22
    // This is why we always check whether we found a constant or not.
23
629532
    match expr.clone() {
24
35460
        AbstractLiteral(_, _) => Err(RuleNotApplicable),
25
        DominanceRelation(_, _) => Err(RuleNotApplicable),
26
        FromSolution(_, _) => Err(RuleNotApplicable),
27
4086
        UnsafeIndex(_, _, _) => Err(RuleNotApplicable),
28
2916
        UnsafeSlice(_, _, _) => Err(RuleNotApplicable),
29
2880
        SafeIndex(_, _, _) => Err(RuleNotApplicable),
30
6552
        SafeSlice(_, _, _) => Err(RuleNotApplicable),
31
        InDomain(_, _, _) => Err(RuleNotApplicable),
32
3294
        Bubble(_, _, _) => Err(RuleNotApplicable),
33
287280
        Atomic(_, _) => Err(RuleNotApplicable),
34
        Scope(_, _) => Err(RuleNotApplicable),
35
972
        Abs(m, e) => match *e {
36
36
            Neg(_, inner) => Ok(Reduction::pure(Abs(m, inner))),
37
936
            _ => Err(RuleNotApplicable),
38
        },
39
3690
        Sum(m, vec) => {
40
3690
            let mut acc = 0;
41
3690
            let mut n_consts = 0;
42
3690
            let mut new_vec: Vec<Expr> = Vec::new();
43
14400
            for expr in vec {
44
1710
                if let Expr::Atomic(_, Atom::Literal(Int(x))) = expr {
45
1710
                    acc += x;
46
1710
                    n_consts += 1;
47
9000
                } else {
48
9000
                    new_vec.push(expr);
49
9000
                }
50
            }
51
3690
            if acc != 0 {
52
1494
                new_vec.push(Expr::Atomic(Default::default(), Atom::Literal(Int(acc))));
53
2196
            }
54

            
55
3690
            if n_consts <= 1 {
56
3582
                Err(RuleNotApplicable)
57
            } else {
58
108
                Ok(Reduction::pure(Sum(m, new_vec)))
59
            }
60
        }
61

            
62
2844
        Product(m, vec) => {
63
2844
            let mut acc = 1;
64
2844
            let mut n_consts = 0;
65
2844
            let mut new_vec: Vec<Expr> = Vec::new();
66
8982
            for expr in vec {
67
2700
                if let Expr::Atomic(_, Atom::Literal(Int(x))) = expr {
68
2700
                    acc *= x;
69
2700
                    n_consts += 1;
70
3438
                } else {
71
3438
                    new_vec.push(expr);
72
3438
                }
73
            }
74

            
75
2844
            if n_consts == 0 {
76
198
                return Err(RuleNotApplicable);
77
2646
            }
78
2646

            
79
2646
            new_vec.push(Expr::Atomic(Default::default(), Atom::Literal(Int(acc))));
80
2646
            let new_product = Product(m, new_vec);
81
2646

            
82
2646
            if acc == 0 {
83
                // if safe, 0 * exprs ~> 0
84
                // otherwise, just return 0* exprs
85
                if new_product.is_safe() {
86
                    Ok(Reduction::pure(Expr::Atomic(
87
                        Default::default(),
88
                        Atom::Literal(Int(0)),
89
                    )))
90
                } else {
91
                    Ok(Reduction::pure(new_product))
92
                }
93
2646
            } else if n_consts == 1 {
94
                // acc !=0, only one constant
95
2610
                Err(RuleNotApplicable)
96
            } else {
97
                // acc !=0, multiple constants found
98
36
                Ok(Reduction::pure(new_product))
99
            }
100
        }
101

            
102
234
        Min(m, e) => {
103
234
            let Some(vec) = e.unwrap_list() else {
104
126
                return Err(RuleNotApplicable);
105
            };
106
108
            let mut acc: Option<i32> = None;
107
108
            let mut n_consts = 0;
108
108
            let mut new_vec: Vec<Expr> = Vec::new();
109
324
            for expr in vec {
110
                if let Expr::Atomic(_, Atom::Literal(Int(x))) = expr {
111
                    n_consts += 1;
112
                    acc = match acc {
113
                        Some(i) => {
114
                            if i > x {
115
                                Some(x)
116
                            } else {
117
                                Some(i)
118
                            }
119
                        }
120
                        None => Some(x),
121
                    };
122
216
                } else {
123
216
                    new_vec.push(expr);
124
216
                }
125
            }
126

            
127
108
            if let Some(i) = acc {
128
                new_vec.push(Expr::Atomic(Default::default(), Atom::Literal(Int(i))));
129
108
            }
130

            
131
108
            if n_consts <= 1 {
132
108
                Err(RuleNotApplicable)
133
            } else {
134
                Ok(Reduction::pure(Min(
135
                    m,
136
                    Box::new(into_matrix_expr![new_vec]),
137
                )))
138
            }
139
        }
140

            
141
198
        Max(m, e) => {
142
198
            let Some(vec) = e.unwrap_list() else {
143
144
                return Err(RuleNotApplicable);
144
            };
145

            
146
54
            let mut acc: Option<i32> = None;
147
54
            let mut n_consts = 0;
148
54
            let mut new_vec: Vec<Expr> = Vec::new();
149
162
            for expr in vec {
150
18
                if let Expr::Atomic(_, Atom::Literal(Int(x))) = expr {
151
18
                    n_consts += 1;
152
18
                    acc = match acc {
153
                        Some(i) => {
154
                            if i < x {
155
                                Some(x)
156
                            } else {
157
                                Some(i)
158
                            }
159
                        }
160
18
                        None => Some(x),
161
                    };
162
90
                } else {
163
90
                    new_vec.push(expr);
164
90
                }
165
            }
166

            
167
54
            if let Some(i) = acc {
168
18
                new_vec.push(Expr::Atomic(Default::default(), Atom::Literal(Int(i))));
169
36
            }
170

            
171
54
            if n_consts <= 1 {
172
54
                Err(RuleNotApplicable)
173
            } else {
174
                Ok(Reduction::pure(Max(
175
                    m,
176
                    Box::new(into_matrix_expr![new_vec]),
177
                )))
178
            }
179
        }
180
9162
        Not(_, _) => Err(RuleNotApplicable),
181
24786
        Or(m, e) => {
182
24786
            let Some(terms) = e.unwrap_list() else {
183
21456
                return Err(RuleNotApplicable);
184
            };
185

            
186
3330
            let mut has_changed = false;
187
3330

            
188
3330
            // 2. boolean literals
189
3330
            let mut new_terms = vec![];
190
10638
            for expr in terms {
191
72
                if let Expr::Atomic(_, Atom::Literal(Bool(x))) = expr {
192
72
                    has_changed = true;
193
72

            
194
72
                    // true ~~> entire or is true
195
72
                    // false ~~> remove false from the or
196
72
                    if x {
197
36
                        return Ok(Reduction::pure(true.into()));
198
36
                    }
199
7272
                } else {
200
7272
                    new_terms.push(expr);
201
7272
                }
202
            }
203

            
204
            // 2. check pairwise tautologies.
205
3294
            if check_pairwise_or_tautologies(&new_terms) {
206
                return Ok(Reduction::pure(true.into()));
207
3294
            }
208
3294

            
209
3294
            // 3. empty or ~~> false
210
3294
            if new_terms.is_empty() {
211
                return Ok(Reduction::pure(false.into()));
212
3294
            }
213
3294

            
214
3294
            if !has_changed {
215
3258
                return Err(RuleNotApplicable);
216
36
            }
217
36

            
218
36
            Ok(Reduction::pure(Or(
219
36
                m,
220
36
                Box::new(into_matrix_expr![new_terms]),
221
36
            )))
222
        }
223
10476
        And(_, e) => {
224
10476
            let Some(vec) = e.unwrap_list() else {
225
216
                return Err(RuleNotApplicable);
226
            };
227
10260
            let mut new_vec: Vec<Expr> = Vec::new();
228
10260
            let mut has_const: bool = false;
229
29034
            for expr in vec {
230
810
                if let Expr::Atomic(_, Atom::Literal(Bool(x))) = expr {
231
810
                    has_const = true;
232
810
                    if !x {
233
36
                        return Ok(Reduction::pure(Atomic(
234
36
                            Default::default(),
235
36
                            Atom::Literal(Bool(false)),
236
36
                        )));
237
774
                    }
238
18000
                } else {
239
18000
                    new_vec.push(expr);
240
18000
                }
241
            }
242

            
243
10224
            if !has_const {
244
9450
                Err(RuleNotApplicable)
245
            } else {
246
774
                Ok(Reduction::pure(Expr::And(
247
774
                    Metadata::new(),
248
774
                    Box::new(into_matrix_expr![new_vec]),
249
774
                )))
250
            }
251
        }
252

            
253
        // similar to And, but booleans are returned wrapped in Root.
254
17316
        Root(_, vec) => {
255
17316
            // root([true]) / root([false]) are already evaluated
256
17316
            if vec.len() < 2 {
257
7488
                return Err(RuleNotApplicable);
258
9828
            }
259
9828

            
260
9828
            let mut new_vec: Vec<Expr> = Vec::new();
261
9828
            let mut has_const: bool = false;
262
89514
            for expr in vec {
263
108
                if let Expr::Atomic(_, Atom::Literal(Bool(x))) = expr {
264
108
                    has_const = true;
265
108
                    if !x {
266
36
                        return Ok(Reduction::pure(Root(
267
36
                            Metadata::new(),
268
36
                            vec![Atomic(Default::default(), Atom::Literal(Bool(false)))],
269
36
                        )));
270
72
                    }
271
79614
                } else {
272
79614
                    new_vec.push(expr);
273
79614
                }
274
            }
275

            
276
9792
            if !has_const {
277
9720
                Err(RuleNotApplicable)
278
            } else {
279
72
                if new_vec.is_empty() {
280
                    new_vec.push(true.into());
281
72
                }
282
72
                Ok(Reduction::pure(
283
72
                    expr.with_children_bi(VecDeque::from([new_vec])),
284
72
                ))
285
            }
286
        }
287
28278
        Imply(_m, x, y) => {
288
36
            if let Expr::Atomic(_, Atom::Literal(Lit::Bool(x))) = *x {
289
36
                if x {
290
                    // (true) -> y ~~> y
291
36
                    return Ok(Reduction::pure(*y));
292
                } else {
293
                    // (false) -> y ~~> true
294
                    return Ok(Reduction::pure(Expr::Atomic(Metadata::new(), true.into())));
295
                }
296
28242
            };
297
28242

            
298
28242
            // reflexivity: p -> p ~> true
299
28242

            
300
28242
            // instead of checking syntactic equivalence of a possibly deep expression,
301
28242
            // let identical-CSE turn them into identical variables first. Then, check if they are
302
28242
            // identical variables.
303
28242

            
304
28242
            if x.identical_atom_to(y.as_ref()) {
305
36
                return Ok(Reduction::pure(true.into()));
306
28206
            }
307
28206

            
308
28206
            Err(RuleNotApplicable)
309
        }
310
48816
        Eq(_, _, _) => Err(RuleNotApplicable),
311
8586
        Neq(_, _, _) => Err(RuleNotApplicable),
312
1800
        Geq(_, _, _) => Err(RuleNotApplicable),
313
26424
        Leq(_, _, _) => Err(RuleNotApplicable),
314
54
        Gt(_, _, _) => Err(RuleNotApplicable),
315
2142
        Lt(_, _, _) => Err(RuleNotApplicable),
316
10080
        SafeDiv(_, _, _) => Err(RuleNotApplicable),
317
5076
        UnsafeDiv(_, _, _) => Err(RuleNotApplicable),
318
9756
        AllDiff(m, e) => {
319
9756
            let Some(vec) = e.unwrap_list() else {
320
9702
                return Err(RuleNotApplicable);
321
            };
322

            
323
54
            let mut consts: HashSet<i32> = HashSet::new();
324

            
325
            // check for duplicate constant values which would fail the constraint
326
216
            for expr in vec {
327
                if let Expr::Atomic(_, Atom::Literal(Int(x))) = expr {
328
                    if !consts.insert(x) {
329
                        return Ok(Reduction::pure(Expr::Atomic(m, Atom::Literal(Bool(false)))));
330
                    }
331
162
                }
332
            }
333

            
334
            // nothing has changed
335
54
            Err(RuleNotApplicable)
336
        }
337
2088
        Neg(_, _) => Err(RuleNotApplicable),
338
2700
        AuxDeclaration(_, _, _) => Err(RuleNotApplicable),
339
3960
        UnsafeMod(_, _, _) => Err(RuleNotApplicable),
340
11970
        SafeMod(_, _, _) => Err(RuleNotApplicable),
341
180
        UnsafePow(_, _, _) => Err(RuleNotApplicable),
342
396
        SafePow(_, _, _) => Err(RuleNotApplicable),
343
90
        Minus(_, _, _) => Err(RuleNotApplicable),
344

            
345
        // As these are in a low level solver form, I'm assuming that these have already been
346
        // simplified and partially evaluated.
347
54
        FlatAllDiff(_, _) => Err(RuleNotApplicable),
348
306
        FlatAbsEq(_, _, _) => Err(RuleNotApplicable),
349
11574
        FlatIneq(_, _, _, _) => Err(RuleNotApplicable),
350
126
        FlatMinusEq(_, _, _) => Err(RuleNotApplicable),
351
162
        FlatProductEq(_, _, _, _) => Err(RuleNotApplicable),
352
1728
        FlatSumLeq(_, _, _) => Err(RuleNotApplicable),
353
1818
        FlatSumGeq(_, _, _) => Err(RuleNotApplicable),
354
1206
        FlatWatchedLiteral(_, _, _) => Err(RuleNotApplicable),
355
288
        FlatWeightedSumLeq(_, _, _, _) => Err(RuleNotApplicable),
356
180
        FlatWeightedSumGeq(_, _, _, _) => Err(RuleNotApplicable),
357
2718
        MinionDivEqUndefZero(_, _, _, _) => Err(RuleNotApplicable),
358
5994
        MinionModuloEqUndefZero(_, _, _, _) => Err(RuleNotApplicable),
359
306
        MinionPow(_, _, _, _) => Err(RuleNotApplicable),
360
15696
        MinionReify(_, _, _) => Err(RuleNotApplicable),
361
12834
        MinionReifyImply(_, _, _) => Err(RuleNotApplicable),
362
    }
363
629532
}
364

            
365
/// Checks for tautologies involving pairs of terms inside an or, returning true if one is found.
366
///
367
/// This applies the following rules:
368
///
369
/// ```text
370
/// (p->q) \/ (q->p) ~> true    [totality of implication]
371
/// (p->q) \/ (p-> !q) ~> true  [conditional excluded middle]
372
/// ```
373
///
374
3294
fn check_pairwise_or_tautologies(or_terms: &[Expr]) -> bool {
375
3294
    // Collect terms that are structurally identical to the rule input.
376
3294
    // Then, try the rules on these terms, also checking the other conditions of the rules.
377
3294

            
378
3294
    // stores (p,q) in p -> q
379
3294
    let mut p_implies_q: Vec<(&Expr, &Expr)> = vec![];
380
3294

            
381
3294
    // stores (p,q) in p -> !q
382
3294
    let mut p_implies_not_q: Vec<(&Expr, &Expr)> = vec![];
383

            
384
7236
    for term in or_terms.iter() {
385
7236
        if let Expr::Imply(_, p, q) = term {
386
            // we use identical_atom_to for equality later on, so these sets are mutually exclusive.
387
            //
388
            // in general however, p -> !q would be in p_implies_q as (p,!q)
389
            if let Expr::Not(_, q_1) = q.as_ref() {
390
                p_implies_not_q.push((p.as_ref(), q_1.as_ref()));
391
            } else {
392
                p_implies_q.push((p.as_ref(), q.as_ref()));
393
            }
394
7236
        }
395
    }
396

            
397
    // `(p->q) \/ (q->p) ~> true    [totality of implication]`
398
3294
    for ((p1, q1), (q2, p2)) in iproduct!(p_implies_q.iter(), p_implies_q.iter()) {
399
        if p1.identical_atom_to(p2) && q1.identical_atom_to(q2) {
400
            return true;
401
        }
402
    }
403

            
404
    // `(p->q) \/ (p-> !q) ~> true`    [conditional excluded middle]
405
3294
    for ((p1, q1), (p2, q2)) in iproduct!(p_implies_q.iter(), p_implies_not_q.iter()) {
406
        if p1.identical_atom_to(p2) && q1.identical_atom_to(q2) {
407
            return true;
408
        }
409
    }
410

            
411
3294
    false
412
3294
}