1
use std::collections::HashSet;
2

            
3
use conjure_macros::register_rule;
4
use itertools::iproduct;
5

            
6
use crate::rule_engine::{ApplicationResult, Reduction};
7
use crate::Model;
8
use crate::{
9
    ast::{Atom, Expression as Expr, Literal as Lit, Literal::*},
10
    metadata::Metadata,
11
};
12

            
13
#[register_rule(("Base",9000))]
14
415684
fn partial_evaluator(expr: &Expr, _: &Model) -> ApplicationResult {
15
    use conjure_core::rule_engine::ApplicationError::RuleNotApplicable;
16
    use Expr::*;
17

            
18
    // NOTE: If nothing changes, we must return RuleNotApplicable, or the rewriter will try this
19
    // rule infinitely!
20
    // This is why we always check whether we found a constant or not.
21
415684
    match expr.clone() {
22
2125
        Bubble(_, _, _) => Err(RuleNotApplicable),
23
213367
        Atomic(_, _) => Err(RuleNotApplicable),
24
918
        Abs(m, e) => match *e {
25
34
            Neg(_, inner) => Ok(Reduction::pure(Abs(m, inner))),
26
884
            _ => Err(RuleNotApplicable),
27
        },
28
2975
        Sum(m, vec) => {
29
2975
            let mut acc = 0;
30
2975
            let mut n_consts = 0;
31
2975
            let mut new_vec: Vec<Expr> = Vec::new();
32
11968
            for expr in vec {
33
1156
                if let Expr::Atomic(_, Atom::Literal(Int(x))) = expr {
34
1156
                    acc += x;
35
1156
                    n_consts += 1;
36
7837
                } else {
37
7837
                    new_vec.push(expr);
38
7837
                }
39
            }
40
2975
            if acc != 0 {
41
1020
                new_vec.push(Expr::Atomic(Default::default(), Atom::Literal(Int(acc))));
42
1955
            }
43

            
44
2975
            if n_consts <= 1 {
45
2907
                Err(RuleNotApplicable)
46
            } else {
47
68
                Ok(Reduction::pure(Sum(m, new_vec)))
48
            }
49
        }
50

            
51
2686
        Product(m, vec) => {
52
2686
            let mut acc = 1;
53
2686
            let mut n_consts = 0;
54
2686
            let mut new_vec: Vec<Expr> = Vec::new();
55
8483
            for expr in vec {
56
2550
                if let Expr::Atomic(_, Atom::Literal(Int(x))) = expr {
57
2550
                    acc *= x;
58
2550
                    n_consts += 1;
59
3247
                } else {
60
3247
                    new_vec.push(expr);
61
3247
                }
62
            }
63

            
64
2686
            if n_consts == 0 {
65
187
                return Err(RuleNotApplicable);
66
2499
            }
67
2499

            
68
2499
            new_vec.push(Expr::Atomic(Default::default(), Atom::Literal(Int(acc))));
69
2499
            let new_product = Product(m, new_vec);
70
2499

            
71
2499
            if acc == 0 {
72
                // if safe, 0 * exprs ~> 0
73
                // otherwise, just return 0* exprs
74
                if new_product.is_safe() {
75
                    Ok(Reduction::pure(Expr::Atomic(
76
                        Default::default(),
77
                        Atom::Literal(Int(0)),
78
                    )))
79
                } else {
80
                    Ok(Reduction::pure(new_product))
81
                }
82
2499
            } else if n_consts == 1 {
83
                // acc !=0, only one constant
84
2465
                Err(RuleNotApplicable)
85
            } else {
86
                // acc !=0, multiple constants found
87
34
                Ok(Reduction::pure(new_product))
88
            }
89
        }
90

            
91
102
        Min(m, vec) => {
92
102
            let mut acc: Option<i32> = None;
93
102
            let mut n_consts = 0;
94
102
            let mut new_vec: Vec<Expr> = Vec::new();
95
306
            for expr in vec {
96
                if let Expr::Atomic(_, Atom::Literal(Int(x))) = expr {
97
                    n_consts += 1;
98
                    acc = match acc {
99
                        Some(i) => {
100
                            if i > x {
101
                                Some(x)
102
                            } else {
103
                                Some(i)
104
                            }
105
                        }
106
                        None => Some(x),
107
                    };
108
204
                } else {
109
204
                    new_vec.push(expr);
110
204
                }
111
            }
112

            
113
102
            if let Some(i) = acc {
114
                new_vec.push(Expr::Atomic(Default::default(), Atom::Literal(Int(i))));
115
102
            }
116

            
117
102
            if n_consts <= 1 {
118
102
                Err(RuleNotApplicable)
119
            } else {
120
                Ok(Reduction::pure(Min(m, new_vec)))
121
            }
122
        }
123

            
124
68
        Max(m, vec) => {
125
68
            let mut acc: Option<i32> = None;
126
68
            let mut n_consts = 0;
127
68
            let mut new_vec: Vec<Expr> = Vec::new();
128
204
            for expr in vec {
129
17
                if let Expr::Atomic(_, Atom::Literal(Int(x))) = expr {
130
17
                    n_consts += 1;
131
17
                    acc = match acc {
132
                        Some(i) => {
133
                            if i < x {
134
                                Some(x)
135
                            } else {
136
                                Some(i)
137
                            }
138
                        }
139
17
                        None => Some(x),
140
                    };
141
119
                } else {
142
119
                    new_vec.push(expr);
143
119
                }
144
            }
145

            
146
68
            if let Some(i) = acc {
147
17
                new_vec.push(Expr::Atomic(Default::default(), Atom::Literal(Int(i))));
148
51
            }
149

            
150
68
            if n_consts <= 1 {
151
68
                Err(RuleNotApplicable)
152
            } else {
153
                Ok(Reduction::pure(Max(m, new_vec)))
154
            }
155
        }
156
5712
        Not(_, _) => Err(RuleNotApplicable),
157
16388
        Or(m, terms) => {
158
16388
            let mut has_changed = false;
159
16388

            
160
16388
            // 2. boolean literals
161
16388
            let mut new_terms = vec![];
162
62696
            for expr in terms {
163
34
                if let Expr::Atomic(_, Atom::Literal(Bool(x))) = expr {
164
34
                    has_changed = true;
165
34

            
166
34
                    // true ~~> entire or is true
167
34
                    // false ~~> remove false from the or
168
34
                    if x {
169
                        return Ok(Reduction::pure(true.into()));
170
34
                    }
171
46274
                } else {
172
46274
                    new_terms.push(expr);
173
46274
                }
174
            }
175

            
176
            // 2. check pairwise tautologies.
177
16388
            if check_pairwise_or_tautologies(&new_terms) {
178
68
                return Ok(Reduction::pure(true.into()));
179
16320
            }
180
16320

            
181
16320
            // 3. empty or ~~> false
182
16320
            if new_terms.is_empty() {
183
                return Ok(Reduction::pure(false.into()));
184
16320
            }
185
16320

            
186
16320
            if !has_changed {
187
16286
                return Err(RuleNotApplicable);
188
34
            }
189
34

            
190
34
            Ok(Reduction::pure(Or(m, new_terms)))
191
        }
192
8347
        And(m, vec) => {
193
8347
            let mut new_vec: Vec<Expr> = Vec::new();
194
8347
            let mut has_const: bool = false;
195
23817
            for expr in vec {
196
272
                if let Expr::Atomic(_, Atom::Literal(Bool(x))) = expr {
197
272
                    has_const = true;
198
272
                    if !x {
199
                        return Ok(Reduction::pure(Atomic(
200
                            Default::default(),
201
                            Atom::Literal(Bool(false)),
202
                        )));
203
272
                    }
204
15198
                } else {
205
15198
                    new_vec.push(expr);
206
15198
                }
207
            }
208

            
209
8347
            if !has_const {
210
8075
                Err(RuleNotApplicable)
211
            } else {
212
272
                Ok(Reduction::pure(And(m, new_vec)))
213
            }
214
        }
215
20383
        Imply(_m, x, y) => {
216
            if let Expr::Atomic(_, Atom::Literal(Lit::Bool(x))) = *x {
217
                if x {
218
                    // (true) -> y ~~> y
219
                    return Ok(Reduction::pure(*y));
220
                } else {
221
                    // (false) -> y ~~> true
222
                    return Ok(Reduction::pure(Expr::Atomic(Metadata::new(), true.into())));
223
                }
224
20383
            };
225
20383

            
226
20383
            // reflexivity: p -> p ~> true
227
20383

            
228
20383
            // instead of checking syntactic equivalence of a possibly deep expression,
229
20383
            // let identical-CSE turn them into identical variables first. Then, check if they are
230
20383
            // identical variables.
231
20383

            
232
20383
            if x.identical_atom_to(y.as_ref()) {
233
34
                return Ok(Reduction::pure(true.into()));
234
20349
            }
235
20349

            
236
20349
            Err(RuleNotApplicable)
237
        }
238
38097
        Eq(_, _, _) => Err(RuleNotApplicable),
239
8109
        Neq(_, _, _) => Err(RuleNotApplicable),
240
1598
        Geq(_, _, _) => Err(RuleNotApplicable),
241
23341
        Leq(_, _, _) => Err(RuleNotApplicable),
242
34
        Gt(_, _, _) => Err(RuleNotApplicable),
243
1904
        Lt(_, _, _) => Err(RuleNotApplicable),
244
9248
        SafeDiv(_, _, _) => Err(RuleNotApplicable),
245
4794
        UnsafeDiv(_, _, _) => Err(RuleNotApplicable),
246
        AllDiff(m, vec) => {
247
            let mut consts: HashSet<i32> = HashSet::new();
248

            
249
            // check for duplicate constant values which would fail the constraint
250
            for expr in &vec {
251
                if let Expr::Atomic(_, Atom::Literal(Int(x))) = expr {
252
                    if !consts.insert(*x) {
253
                        return Ok(Reduction::pure(Expr::Atomic(m, Atom::Literal(Bool(false)))));
254
                    }
255
                }
256
            }
257

            
258
            // nothing has changed
259
            Err(RuleNotApplicable)
260
        }
261
1972
        Neg(_, _) => Err(RuleNotApplicable),
262
2346
        AuxDeclaration(_, _, _) => Err(RuleNotApplicable),
263
3740
        UnsafeMod(_, _, _) => Err(RuleNotApplicable),
264
10829
        SafeMod(_, _, _) => Err(RuleNotApplicable),
265
170
        UnsafePow(_, _, _) => Err(RuleNotApplicable),
266
374
        SafePow(_, _, _) => Err(RuleNotApplicable),
267
85
        Minus(_, _, _) => Err(RuleNotApplicable),
268

            
269
        // As these are in a low level solver form, I'm assuming that these have already been
270
        // simplified and partially evaluated.
271
289
        FlatAbsEq(_, _, _) => Err(RuleNotApplicable),
272
4233
        FlatIneq(_, _, _, _) => Err(RuleNotApplicable),
273
119
        FlatMinusEq(_, _, _) => Err(RuleNotApplicable),
274
153
        FlatProductEq(_, _, _, _) => Err(RuleNotApplicable),
275
1258
        FlatSumLeq(_, _, _) => Err(RuleNotApplicable),
276
1309
        FlatSumGeq(_, _, _) => Err(RuleNotApplicable),
277
340
        FlatWatchedLiteral(_, _, _) => Err(RuleNotApplicable),
278
272
        FlatWeightedSumLeq(_, _, _, _) => Err(RuleNotApplicable),
279
170
        FlatWeightedSumGeq(_, _, _, _) => Err(RuleNotApplicable),
280
1955
        MinionDivEqUndefZero(_, _, _, _) => Err(RuleNotApplicable),
281
4590
        MinionModuloEqUndefZero(_, _, _, _) => Err(RuleNotApplicable),
282
289
        MinionPow(_, _, _, _) => Err(RuleNotApplicable),
283
12172
        MinionReify(_, _, _) => Err(RuleNotApplicable),
284
8823
        MinionReifyImply(_, _, _) => Err(RuleNotApplicable),
285
    }
286
415684
}
287

            
288
/// Checks for tautologies involving pairs of terms inside an or, returning true if one is found.
289
///
290
/// This applies the following rules:
291
///
292
/// ```text
293
/// (p->q) \/ (q->p) ~> true    [totality of implication]
294
/// (p->q) \/ (p-> !q) ~> true  [conditional excluded middle]
295
/// ```
296
///
297
16388
fn check_pairwise_or_tautologies(or_terms: &[Expr]) -> bool {
298
16388
    // Collect terms that are structurally identical to the rule input.
299
16388
    // Then, try the rules on these terms, also checking the other conditions of the rules.
300
16388

            
301
16388
    // stores (p,q) in p -> q
302
16388
    let mut p_implies_q: Vec<(&Expr, &Expr)> = vec![];
303
16388

            
304
16388
    // stores (p,q) in p -> !q
305
16388
    let mut p_implies_not_q: Vec<(&Expr, &Expr)> = vec![];
306

            
307
46274
    for term in or_terms.iter() {
308
46274
        if let Expr::Imply(_, p, q) = term {
309
            // we use identical_atom_to for equality later on, so these sets are mutually exclusive.
310
            //
311
            // in general however, p -> !q would be in p_implies_q as (p,!q)
312
18360
            if let Expr::Not(_, q_1) = q.as_ref() {
313
3468
                p_implies_not_q.push((p.as_ref(), q_1.as_ref()));
314
14892
            } else {
315
14892
                p_implies_q.push((p.as_ref(), q.as_ref()));
316
14892
            }
317
27914
        }
318
    }
319

            
320
    // `(p->q) \/ (q->p) ~> true    [totality of implication]`
321
26010
    for ((p1, q1), (q2, p2)) in iproduct!(p_implies_q.iter(), p_implies_q.iter()) {
322
26010
        if p1.identical_atom_to(p2) && q1.identical_atom_to(q2) {
323
34
            return true;
324
25976
        }
325
    }
326

            
327
    // `(p->q) \/ (p-> !q) ~> true`    [conditional excluded middle]
328
16354
    for ((p1, q1), (p2, q2)) in iproduct!(p_implies_q.iter(), p_implies_not_q.iter()) {
329
3315
        if p1.identical_atom_to(p2) && q1.identical_atom_to(q2) {
330
34
            return true;
331
3281
        }
332
    }
333

            
334
16320
    false
335
16388
}