1
use std::collections::HashSet;
2

            
3
use conjure_macros::register_rule;
4

            
5
use crate::ast::{Atom, Expression as Expr, Literal::*};
6
use crate::rule_engine::{ApplicationResult, Reduction};
7
use crate::Model;
8

            
9
use super::utils::ToAuxVarOutput;
10

            
11
#[register_rule(("Base",9000))]
12
8229411
fn partial_evaluator(expr: &Expr, _: &Model) -> ApplicationResult {
13
    use conjure_core::rule_engine::ApplicationError::RuleNotApplicable;
14
    use Expr::*;
15

            
16
    // NOTE: If nothing changes, we must return RuleNotApplicable, or the rewriter will try this
17
    // rule infinitely!
18
    // This is why we always check whether we found a constant or not.
19
8229411
    match expr.clone() {
20
442
        Bubble(_, _, _) => Err(RuleNotApplicable),
21
6137153
        Atomic(_, _) => Err(RuleNotApplicable),
22
204
        Sum(m, vec) => {
23
204
            let mut acc = 0;
24
204
            let mut n_consts = 0;
25
204
            let mut new_vec: Vec<Expr> = Vec::new();
26
782
            for expr in vec {
27
272
                if let Expr::Atomic(_, Atom::Literal(Int(x))) = expr {
28
272
                    acc += x;
29
272
                    n_consts += 1;
30
306
                } else {
31
306
                    new_vec.push(expr);
32
306
                }
33
            }
34
204
            if acc != 0 {
35
170
                new_vec.push(Expr::Atomic(Default::default(), Atom::Literal(Int(acc))));
36
170
            }
37

            
38
204
            if n_consts <= 1 {
39
136
                Err(RuleNotApplicable)
40
            } else {
41
68
                Ok(Reduction::pure(Sum(m, new_vec)))
42
            }
43
        }
44

            
45
119
        Min(m, vec) => {
46
119
            let mut acc: Option<i32> = None;
47
119
            let mut n_consts = 0;
48
119
            let mut new_vec: Vec<Expr> = Vec::new();
49
357
            for expr in vec {
50
34
                if let Expr::Atomic(_, Atom::Literal(Int(x))) = expr {
51
34
                    n_consts += 1;
52
34
                    acc = match acc {
53
17
                        Some(i) => {
54
17
                            if i > x {
55
                                Some(x)
56
                            } else {
57
17
                                Some(i)
58
                            }
59
                        }
60
17
                        None => Some(x),
61
                    };
62
204
                } else {
63
204
                    new_vec.push(expr);
64
204
                }
65
            }
66

            
67
119
            if let Some(i) = acc {
68
17
                new_vec.push(Expr::Atomic(Default::default(), Atom::Literal(Int(i))));
69
102
            }
70

            
71
119
            if n_consts <= 1 {
72
102
                Err(RuleNotApplicable)
73
            } else {
74
17
                Ok(Reduction::pure(Min(m, new_vec)))
75
            }
76
        }
77
68
        Max(m, vec) => {
78
68
            let mut acc: Option<i32> = None;
79
68
            let mut n_consts = 0;
80
68
            let mut new_vec: Vec<Expr> = Vec::new();
81
204
            for expr in vec {
82
17
                if let Expr::Atomic(_, Atom::Literal(Int(x))) = expr {
83
17
                    n_consts += 1;
84
17
                    acc = match acc {
85
                        Some(i) => {
86
                            if i < x {
87
                                Some(x)
88
                            } else {
89
                                Some(i)
90
                            }
91
                        }
92
17
                        None => Some(x),
93
                    };
94
119
                } else {
95
119
                    new_vec.push(expr);
96
119
                }
97
            }
98

            
99
68
            if let Some(i) = acc {
100
17
                new_vec.push(Expr::Atomic(Default::default(), Atom::Literal(Int(i))));
101
51
            }
102

            
103
68
            if n_consts <= 1 {
104
68
                Err(RuleNotApplicable)
105
            } else {
106
                Ok(Reduction::pure(Max(m, new_vec)))
107
            }
108
        }
109
1020
        Not(_, _) => Err(RuleNotApplicable),
110
429114
        Or(m, vec) => {
111
429114
            let mut new_vec: Vec<Expr> = Vec::new();
112
429114
            let mut has_const: bool = false;
113
2095726
            for expr in vec {
114
68
                if let Expr::Atomic(_, Atom::Literal(Bool(x))) = expr {
115
68
                    has_const = true;
116
68
                    if x {
117
34
                        return Ok(Reduction::pure(Atomic(
118
34
                            Default::default(),
119
34
                            Atom::Literal(Bool(true)),
120
34
                        )));
121
34
                    }
122
1666578
                } else {
123
1666578
                    new_vec.push(expr);
124
1666578
                }
125
            }
126

            
127
429080
            if !has_const {
128
429046
                Err(RuleNotApplicable)
129
            } else {
130
34
                Ok(Reduction::pure(Or(m, new_vec)))
131
            }
132
        }
133
20842
        And(m, vec) => {
134
20842
            let mut new_vec: Vec<Expr> = Vec::new();
135
20842
            let mut has_const: bool = false;
136
620891
            for expr in vec {
137
68
                if let Expr::Atomic(_, Atom::Literal(Bool(x))) = expr {
138
68
                    has_const = true;
139
68
                    if !x {
140
                        return Ok(Reduction::pure(Atomic(
141
                            Default::default(),
142
                            Atom::Literal(Bool(false)),
143
                        )));
144
68
                    }
145
599981
                } else {
146
599981
                    new_vec.push(expr);
147
599981
                }
148
            }
149

            
150
20842
            if !has_const {
151
20774
                Err(RuleNotApplicable)
152
            } else {
153
68
                Ok(Reduction::pure(And(m, new_vec)))
154
            }
155
        }
156
5168
        Eq(_, _, _) => Err(RuleNotApplicable),
157
2108
        Neq(_, _, _) => Err(RuleNotApplicable),
158
272
        Geq(_, _, _) => Err(RuleNotApplicable),
159
357
        Leq(_, _, _) => Err(RuleNotApplicable),
160
        Gt(_, _, _) => Err(RuleNotApplicable),
161
3094
        Lt(_, _, _) => Err(RuleNotApplicable),
162
51
        SafeDiv(_, _, _) => Err(RuleNotApplicable),
163
340
        UnsafeDiv(_, _, _) => Err(RuleNotApplicable),
164
3196
        SumEq(m, vec, eq) => {
165
3196
            let mut acc = 0;
166
3196
            let mut new_vec: Vec<Expr> = Vec::new();
167
3196
            let mut n_consts = 0;
168
12767
            for expr in vec {
169
85
                if let Expr::Atomic(_, Atom::Literal(Int(x))) = expr {
170
85
                    n_consts += 1;
171
85
                    acc += x;
172
9486
                } else {
173
9486
                    new_vec.push(expr);
174
9486
                }
175
            }
176

            
177
3179
            if let Expr::Atomic(_, Atom::Literal(Int(x))) = *eq {
178
3179
                if acc != 0 {
179
                    // when rhs is a constant, move lhs constants to rhs
180
34
                    return Ok(Reduction::pure(SumEq(
181
34
                        m,
182
34
                        new_vec,
183
34
                        Box::new(Expr::Atomic(
184
34
                            Default::default(),
185
34
                            Atom::Literal(Int(x - acc)),
186
34
                        )),
187
34
                    )));
188
3145
                }
189
17
            } else if acc != 0 {
190
17
                new_vec.push(Expr::Atomic(Default::default(), Atom::Literal(Int(acc))));
191
17
            }
192

            
193
3162
            if n_consts <= 1 {
194
3162
                Err(RuleNotApplicable)
195
            } else {
196
                Ok(Reduction::pure(SumEq(m, new_vec, eq)))
197
            }
198
        }
199
642583
        SumGeq(m, vec, geq) => {
200
642583
            let mut acc = 0;
201
642583
            let mut new_vec: Vec<Expr> = Vec::new();
202
642583
            let mut n_consts = 0;
203
2570128
            for expr in vec {
204
153
                if let Expr::Atomic(_, Atom::Literal(Int(x))) = expr {
205
153
                    n_consts += 1;
206
153
                    acc += x;
207
1927392
                } else {
208
1927392
                    new_vec.push(expr);
209
1927392
                }
210
            }
211

            
212
642413
            if let Expr::Atomic(_, Atom::Literal(Int(x))) = *geq {
213
642413
                if acc != 0 {
214
                    // when rhs is a constant, move lhs constants to rhs
215
                    return Ok(Reduction::pure(SumGeq(
216
                        m,
217
                        new_vec,
218
                        Box::new(Expr::Atomic(
219
                            Default::default(),
220
                            Atom::Literal(Int(x - acc)),
221
                        )),
222
                    )));
223
642413
                }
224
170
            } else if acc != 0 {
225
153
                new_vec.push(Expr::Atomic(Default::default(), Atom::Literal(Int(acc))));
226
153
            }
227

            
228
642583
            if n_consts <= 1 {
229
642583
                Err(RuleNotApplicable)
230
            } else {
231
                Ok(Reduction::pure(SumGeq(m, new_vec, geq)))
232
            }
233
        }
234
620330
        SumLeq(m, vec, leq) => {
235
620330
            let mut acc = 0;
236
620330
            let mut new_vec: Vec<Expr> = Vec::new();
237
620330
            let mut n_consts = 0;
238
2480912
            for expr in vec {
239
170
                if let Expr::Atomic(_, Atom::Literal(Int(x))) = expr {
240
170
                    n_consts += 1;
241
170
                    acc += x;
242
1860412
                } else {
243
1860412
                    new_vec.push(expr);
244
1860412
                }
245
            }
246

            
247
620177
            if let Expr::Atomic(_, Atom::Literal(Int(x))) = *leq {
248
                // when rhs is a constant, move lhs constants to rhs
249
620177
                if acc != 0 {
250
34
                    return Ok(Reduction::pure(SumLeq(
251
34
                        m,
252
34
                        new_vec,
253
34
                        Box::new(Expr::Atomic(
254
34
                            Default::default(),
255
34
                            Atom::Literal(Int(x - acc)),
256
34
                        )),
257
34
                    )));
258
620143
                }
259
153
            } else if acc != 0 {
260
136
                new_vec.push(Expr::Atomic(Default::default(), Atom::Literal(Int(acc))));
261
136
            }
262

            
263
620296
            if n_consts <= 1 {
264
620296
                Err(RuleNotApplicable)
265
            } else {
266
                Ok(Reduction::pure(SumLeq(m, new_vec, leq)))
267
            }
268
        }
269
476
        DivEqUndefZero(_, _, _, _) => Err(RuleNotApplicable),
270
360519
        Ineq(_, _, _, _) => Err(RuleNotApplicable),
271
        AllDiff(m, vec) => {
272
            let mut consts: HashSet<i32> = HashSet::new();
273

            
274
            // check for duplicate constant values which would fail the constraint
275
            for expr in &vec {
276
                if let Expr::Atomic(_, Atom::Literal(Int(x))) = expr {
277
                    if !consts.insert(*x) {
278
                        return Ok(Reduction::pure(Expr::Atomic(m, Atom::Literal(Bool(false)))));
279
                    }
280
                }
281
            }
282

            
283
            // nothing has changed
284
            Err(RuleNotApplicable)
285
        }
286
680
        WatchedLiteral(_, _, _) => Err(RuleNotApplicable),
287
        Reify(_, _, _) => Err(RuleNotApplicable),
288
408
        AuxDeclaration(_, _, _) => Err(RuleNotApplicable),
289
340
        UnsafeMod(_, a, b) => Err(RuleNotApplicable),
290
51
        SafeMod(_, a, b) => Err(RuleNotApplicable),
291
476
        ModuloEqUndefZero(_, _, _, _) => Err(RuleNotApplicable),
292
    }
293
8229411
}