1
use std::collections::{HashSet, VecDeque};
2

            
3
use conjure_macros::register_rule;
4
use itertools::iproduct;
5
use uniplate::Biplate;
6

            
7
use crate::ast::SymbolTable;
8
use crate::rule_engine::{ApplicationResult, Reduction};
9
use crate::{
10
    ast::{Atom, Expression as Expr, Literal as Lit, Literal::*},
11
    metadata::Metadata,
12
};
13

            
14
#[register_rule(("Base",9000))]
15
425731
fn partial_evaluator(expr: &Expr, _: &SymbolTable) -> ApplicationResult {
16
    use conjure_core::rule_engine::ApplicationError::RuleNotApplicable;
17
    use Expr::*;
18

            
19
    // NOTE: If nothing changes, we must return RuleNotApplicable, or the rewriter will try this
20
    // rule infinitely!
21
    // This is why we always check whether we found a constant or not.
22
425731
    match expr.clone() {
23
2125
        Bubble(_, _, _) => Err(RuleNotApplicable),
24
207825
        Atomic(_, _) => Err(RuleNotApplicable),
25
918
        Abs(m, e) => match *e {
26
34
            Neg(_, inner) => Ok(Reduction::pure(Abs(m, inner))),
27
884
            _ => Err(RuleNotApplicable),
28
        },
29
3162
        Sum(m, vec) => {
30
3162
            let mut acc = 0;
31
3162
            let mut n_consts = 0;
32
3162
            let mut new_vec: Vec<Expr> = Vec::new();
33
12563
            for expr in vec {
34
1275
                if let Expr::Atomic(_, Atom::Literal(Int(x))) = expr {
35
1275
                    acc += x;
36
1275
                    n_consts += 1;
37
8126
                } else {
38
8126
                    new_vec.push(expr);
39
8126
                }
40
            }
41
3162
            if acc != 0 {
42
1139
                new_vec.push(Expr::Atomic(Default::default(), Atom::Literal(Int(acc))));
43
2023
            }
44

            
45
3162
            if n_consts <= 1 {
46
3094
                Err(RuleNotApplicable)
47
            } else {
48
68
                Ok(Reduction::pure(Sum(m, new_vec)))
49
            }
50
        }
51

            
52
2686
        Product(m, vec) => {
53
2686
            let mut acc = 1;
54
2686
            let mut n_consts = 0;
55
2686
            let mut new_vec: Vec<Expr> = Vec::new();
56
8483
            for expr in vec {
57
2550
                if let Expr::Atomic(_, Atom::Literal(Int(x))) = expr {
58
2550
                    acc *= x;
59
2550
                    n_consts += 1;
60
3247
                } else {
61
3247
                    new_vec.push(expr);
62
3247
                }
63
            }
64

            
65
2686
            if n_consts == 0 {
66
187
                return Err(RuleNotApplicable);
67
2499
            }
68
2499

            
69
2499
            new_vec.push(Expr::Atomic(Default::default(), Atom::Literal(Int(acc))));
70
2499
            let new_product = Product(m, new_vec);
71
2499

            
72
2499
            if acc == 0 {
73
                // if safe, 0 * exprs ~> 0
74
                // otherwise, just return 0* exprs
75
                if new_product.is_safe() {
76
                    Ok(Reduction::pure(Expr::Atomic(
77
                        Default::default(),
78
                        Atom::Literal(Int(0)),
79
                    )))
80
                } else {
81
                    Ok(Reduction::pure(new_product))
82
                }
83
2499
            } else if n_consts == 1 {
84
                // acc !=0, only one constant
85
2465
                Err(RuleNotApplicable)
86
            } else {
87
                // acc !=0, multiple constants found
88
34
                Ok(Reduction::pure(new_product))
89
            }
90
        }
91

            
92
102
        Min(m, vec) => {
93
102
            let mut acc: Option<i32> = None;
94
102
            let mut n_consts = 0;
95
102
            let mut new_vec: Vec<Expr> = Vec::new();
96
306
            for expr in vec {
97
                if let Expr::Atomic(_, Atom::Literal(Int(x))) = expr {
98
                    n_consts += 1;
99
                    acc = match acc {
100
                        Some(i) => {
101
                            if i > x {
102
                                Some(x)
103
                            } else {
104
                                Some(i)
105
                            }
106
                        }
107
                        None => Some(x),
108
                    };
109
204
                } else {
110
204
                    new_vec.push(expr);
111
204
                }
112
            }
113

            
114
102
            if let Some(i) = acc {
115
                new_vec.push(Expr::Atomic(Default::default(), Atom::Literal(Int(i))));
116
102
            }
117

            
118
102
            if n_consts <= 1 {
119
102
                Err(RuleNotApplicable)
120
            } else {
121
                Ok(Reduction::pure(Min(m, new_vec)))
122
            }
123
        }
124

            
125
68
        Max(m, vec) => {
126
68
            let mut acc: Option<i32> = None;
127
68
            let mut n_consts = 0;
128
68
            let mut new_vec: Vec<Expr> = Vec::new();
129
204
            for expr in vec {
130
17
                if let Expr::Atomic(_, Atom::Literal(Int(x))) = expr {
131
17
                    n_consts += 1;
132
17
                    acc = match acc {
133
                        Some(i) => {
134
                            if i < x {
135
                                Some(x)
136
                            } else {
137
                                Some(i)
138
                            }
139
                        }
140
17
                        None => Some(x),
141
                    };
142
119
                } else {
143
119
                    new_vec.push(expr);
144
119
                }
145
            }
146

            
147
68
            if let Some(i) = acc {
148
17
                new_vec.push(Expr::Atomic(Default::default(), Atom::Literal(Int(i))));
149
51
            }
150

            
151
68
            if n_consts <= 1 {
152
68
                Err(RuleNotApplicable)
153
            } else {
154
                Ok(Reduction::pure(Max(m, new_vec)))
155
            }
156
        }
157
5780
        Not(_, _) => Err(RuleNotApplicable),
158
16388
        Or(m, terms) => {
159
16388
            let mut has_changed = false;
160
16388

            
161
16388
            // 2. boolean literals
162
16388
            let mut new_terms = vec![];
163
62696
            for expr in terms {
164
34
                if let Expr::Atomic(_, Atom::Literal(Bool(x))) = expr {
165
34
                    has_changed = true;
166
34

            
167
34
                    // true ~~> entire or is true
168
34
                    // false ~~> remove false from the or
169
34
                    if x {
170
                        return Ok(Reduction::pure(true.into()));
171
34
                    }
172
46274
                } else {
173
46274
                    new_terms.push(expr);
174
46274
                }
175
            }
176

            
177
            // 2. check pairwise tautologies.
178
16388
            if check_pairwise_or_tautologies(&new_terms) {
179
68
                return Ok(Reduction::pure(true.into()));
180
16320
            }
181
16320

            
182
16320
            // 3. empty or ~~> false
183
16320
            if new_terms.is_empty() {
184
                return Ok(Reduction::pure(false.into()));
185
16320
            }
186
16320

            
187
16320
            if !has_changed {
188
16286
                return Err(RuleNotApplicable);
189
34
            }
190
34

            
191
34
            Ok(Reduction::pure(Or(m, new_terms)))
192
        }
193
8687
        And(_, vec) => {
194
8687
            let mut new_vec: Vec<Expr> = Vec::new();
195
8687
            let mut has_const: bool = false;
196
24820
            for expr in vec {
197
272
                if let Expr::Atomic(_, Atom::Literal(Bool(x))) = expr {
198
272
                    has_const = true;
199
272
                    if !x {
200
                        return Ok(Reduction::pure(Atomic(
201
                            Default::default(),
202
                            Atom::Literal(Bool(false)),
203
                        )));
204
272
                    }
205
15861
                } else {
206
15861
                    new_vec.push(expr);
207
15861
                }
208
            }
209

            
210
8687
            if !has_const {
211
8415
                Err(RuleNotApplicable)
212
            } else {
213
272
                Ok(Reduction::pure(
214
272
                    expr.with_children_bi(VecDeque::from([new_vec])),
215
272
                ))
216
            }
217
        }
218

            
219
        // similar to And, but booleans are returned wrapped in Root.
220
12631
        Root(_, vec) => {
221
12631
            // root([true]) / root([false]) are already evaluated
222
12631
            if vec.len() < 2 {
223
6647
                return Err(RuleNotApplicable);
224
5984
            }
225
5984

            
226
5984
            let mut new_vec: Vec<Expr> = Vec::new();
227
5984
            let mut has_const: bool = false;
228
53890
            for expr in vec {
229
119
                if let Expr::Atomic(_, Atom::Literal(Bool(x))) = expr {
230
119
                    has_const = true;
231
119
                    if !x {
232
                        return Ok(Reduction::pure(Root(
233
                            Metadata::new(),
234
                            vec![Atomic(Default::default(), Atom::Literal(Bool(false)))],
235
                        )));
236
119
                    }
237
47787
                } else {
238
47787
                    new_vec.push(expr);
239
47787
                }
240
            }
241

            
242
5984
            if !has_const {
243
5865
                Err(RuleNotApplicable)
244
            } else {
245
119
                if new_vec.is_empty() {
246
                    new_vec.push(true.into());
247
119
                }
248
119
                Ok(Reduction::pure(
249
119
                    expr.with_children_bi(VecDeque::from([new_vec])),
250
119
                ))
251
            }
252
        }
253
20791
        Imply(_m, x, y) => {
254
34
            if let Expr::Atomic(_, Atom::Literal(Lit::Bool(x))) = *x {
255
34
                if x {
256
                    // (true) -> y ~~> y
257
34
                    return Ok(Reduction::pure(*y));
258
                } else {
259
                    // (false) -> y ~~> true
260
                    return Ok(Reduction::pure(Expr::Atomic(Metadata::new(), true.into())));
261
                }
262
20757
            };
263
20757

            
264
20757
            // reflexivity: p -> p ~> true
265
20757

            
266
20757
            // instead of checking syntactic equivalence of a possibly deep expression,
267
20757
            // let identical-CSE turn them into identical variables first. Then, check if they are
268
20757
            // identical variables.
269
20757

            
270
20757
            if x.identical_atom_to(y.as_ref()) {
271
34
                return Ok(Reduction::pure(true.into()));
272
20723
            }
273
20723

            
274
20723
            Err(RuleNotApplicable)
275
        }
276
38199
        Eq(_, _, _) => Err(RuleNotApplicable),
277
8109
        Neq(_, _, _) => Err(RuleNotApplicable),
278
1768
        Geq(_, _, _) => Err(RuleNotApplicable),
279
23732
        Leq(_, _, _) => Err(RuleNotApplicable),
280
34
        Gt(_, _, _) => Err(RuleNotApplicable),
281
1989
        Lt(_, _, _) => Err(RuleNotApplicable),
282
9248
        SafeDiv(_, _, _) => Err(RuleNotApplicable),
283
4794
        UnsafeDiv(_, _, _) => Err(RuleNotApplicable),
284
        AllDiff(m, vec) => {
285
            let mut consts: HashSet<i32> = HashSet::new();
286

            
287
            // check for duplicate constant values which would fail the constraint
288
            for expr in &vec {
289
                if let Expr::Atomic(_, Atom::Literal(Int(x))) = expr {
290
                    if !consts.insert(*x) {
291
                        return Ok(Reduction::pure(Expr::Atomic(m, Atom::Literal(Bool(false)))));
292
                    }
293
                }
294
            }
295

            
296
            // nothing has changed
297
            Err(RuleNotApplicable)
298
        }
299
1972
        Neg(_, _) => Err(RuleNotApplicable),
300
2414
        AuxDeclaration(_, _, _) => Err(RuleNotApplicable),
301
3740
        UnsafeMod(_, _, _) => Err(RuleNotApplicable),
302
10829
        SafeMod(_, _, _) => Err(RuleNotApplicable),
303
170
        UnsafePow(_, _, _) => Err(RuleNotApplicable),
304
374
        SafePow(_, _, _) => Err(RuleNotApplicable),
305
85
        Minus(_, _, _) => Err(RuleNotApplicable),
306

            
307
        // As these are in a low level solver form, I'm assuming that these have already been
308
        // simplified and partially evaluated.
309
289
        FlatAbsEq(_, _, _) => Err(RuleNotApplicable),
310
4352
        FlatIneq(_, _, _, _) => Err(RuleNotApplicable),
311
119
        FlatMinusEq(_, _, _) => Err(RuleNotApplicable),
312
153
        FlatProductEq(_, _, _, _) => Err(RuleNotApplicable),
313
1581
        FlatSumLeq(_, _, _) => Err(RuleNotApplicable),
314
1632
        FlatSumGeq(_, _, _) => Err(RuleNotApplicable),
315
340
        FlatWatchedLiteral(_, _, _) => Err(RuleNotApplicable),
316
272
        FlatWeightedSumLeq(_, _, _, _) => Err(RuleNotApplicable),
317
170
        FlatWeightedSumGeq(_, _, _, _) => Err(RuleNotApplicable),
318
1955
        MinionDivEqUndefZero(_, _, _, _) => Err(RuleNotApplicable),
319
4590
        MinionModuloEqUndefZero(_, _, _, _) => Err(RuleNotApplicable),
320
289
        MinionPow(_, _, _, _) => Err(RuleNotApplicable),
321
12325
        MinionReify(_, _, _) => Err(RuleNotApplicable),
322
9044
        MinionReifyImply(_, _, _) => Err(RuleNotApplicable),
323
    }
324
425731
}
325

            
326
/// Checks for tautologies involving pairs of terms inside an or, returning true if one is found.
327
///
328
/// This applies the following rules:
329
///
330
/// ```text
331
/// (p->q) \/ (q->p) ~> true    [totality of implication]
332
/// (p->q) \/ (p-> !q) ~> true  [conditional excluded middle]
333
/// ```
334
///
335
16388
fn check_pairwise_or_tautologies(or_terms: &[Expr]) -> bool {
336
16388
    // Collect terms that are structurally identical to the rule input.
337
16388
    // Then, try the rules on these terms, also checking the other conditions of the rules.
338
16388

            
339
16388
    // stores (p,q) in p -> q
340
16388
    let mut p_implies_q: Vec<(&Expr, &Expr)> = vec![];
341
16388

            
342
16388
    // stores (p,q) in p -> !q
343
16388
    let mut p_implies_not_q: Vec<(&Expr, &Expr)> = vec![];
344

            
345
46274
    for term in or_terms.iter() {
346
46274
        if let Expr::Imply(_, p, q) = term {
347
            // we use identical_atom_to for equality later on, so these sets are mutually exclusive.
348
            //
349
            // in general however, p -> !q would be in p_implies_q as (p,!q)
350
18360
            if let Expr::Not(_, q_1) = q.as_ref() {
351
3468
                p_implies_not_q.push((p.as_ref(), q_1.as_ref()));
352
14892
            } else {
353
14892
                p_implies_q.push((p.as_ref(), q.as_ref()));
354
14892
            }
355
27914
        }
356
    }
357

            
358
    // `(p->q) \/ (q->p) ~> true    [totality of implication]`
359
26010
    for ((p1, q1), (q2, p2)) in iproduct!(p_implies_q.iter(), p_implies_q.iter()) {
360
26010
        if p1.identical_atom_to(p2) && q1.identical_atom_to(q2) {
361
34
            return true;
362
25976
        }
363
    }
364

            
365
    // `(p->q) \/ (p-> !q) ~> true`    [conditional excluded middle]
366
16354
    for ((p1, q1), (p2, q2)) in iproduct!(p_implies_q.iter(), p_implies_not_q.iter()) {
367
3315
        if p1.identical_atom_to(p2) && q1.identical_atom_to(q2) {
368
34
            return true;
369
3281
        }
370
    }
371

            
372
16320
    false
373
16388
}