1
use std::collections::HashSet;
2

            
3
use conjure_macros::register_rule;
4

            
5
use crate::ast::{Atom, Expression as Expr, Literal::*};
6
use crate::rule_engine::{ApplicationResult, Reduction};
7
use crate::Model;
8

            
9
#[register_rule(("Base",9000))]
10
8229411
fn partial_evaluator(expr: &Expr, _: &Model) -> ApplicationResult {
11
    use conjure_core::rule_engine::ApplicationError::RuleNotApplicable;
12
    use Expr::*;
13

            
14
    // NOTE: If nothing changes, we must return RuleNotApplicable, or the rewriter will try this
15
    // rule infinitely!
16
    // This is why we always check whether we found a constant or not.
17
8229411
    match expr.clone() {
18
442
        Bubble(_, _, _) => Err(RuleNotApplicable),
19
6137153
        Atomic(_, _) => Err(RuleNotApplicable),
20
204
        Sum(m, vec) => {
21
204
            let mut acc = 0;
22
204
            let mut n_consts = 0;
23
204
            let mut new_vec: Vec<Expr> = Vec::new();
24
782
            for expr in vec {
25
272
                if let Expr::Atomic(_, Atom::Literal(Int(x))) = expr {
26
272
                    acc += x;
27
272
                    n_consts += 1;
28
306
                } else {
29
306
                    new_vec.push(expr);
30
306
                }
31
            }
32
204
            if acc != 0 {
33
170
                new_vec.push(Expr::Atomic(Default::default(), Atom::Literal(Int(acc))));
34
170
            }
35

            
36
204
            if n_consts <= 1 {
37
136
                Err(RuleNotApplicable)
38
            } else {
39
68
                Ok(Reduction::pure(Sum(m, new_vec)))
40
            }
41
        }
42

            
43
119
        Min(m, vec) => {
44
119
            let mut acc: Option<i32> = None;
45
119
            let mut n_consts = 0;
46
119
            let mut new_vec: Vec<Expr> = Vec::new();
47
357
            for expr in vec {
48
34
                if let Expr::Atomic(_, Atom::Literal(Int(x))) = expr {
49
34
                    n_consts += 1;
50
34
                    acc = match acc {
51
17
                        Some(i) => {
52
17
                            if i > x {
53
                                Some(x)
54
                            } else {
55
17
                                Some(i)
56
                            }
57
                        }
58
17
                        None => Some(x),
59
                    };
60
204
                } else {
61
204
                    new_vec.push(expr);
62
204
                }
63
            }
64

            
65
119
            if let Some(i) = acc {
66
17
                new_vec.push(Expr::Atomic(Default::default(), Atom::Literal(Int(i))));
67
102
            }
68

            
69
119
            if n_consts <= 1 {
70
102
                Err(RuleNotApplicable)
71
            } else {
72
17
                Ok(Reduction::pure(Min(m, new_vec)))
73
            }
74
        }
75
68
        Max(m, vec) => {
76
68
            let mut acc: Option<i32> = None;
77
68
            let mut n_consts = 0;
78
68
            let mut new_vec: Vec<Expr> = Vec::new();
79
204
            for expr in vec {
80
17
                if let Expr::Atomic(_, Atom::Literal(Int(x))) = expr {
81
17
                    n_consts += 1;
82
17
                    acc = match acc {
83
                        Some(i) => {
84
                            if i < x {
85
                                Some(x)
86
                            } else {
87
                                Some(i)
88
                            }
89
                        }
90
17
                        None => Some(x),
91
                    };
92
119
                } else {
93
119
                    new_vec.push(expr);
94
119
                }
95
            }
96

            
97
68
            if let Some(i) = acc {
98
17
                new_vec.push(Expr::Atomic(Default::default(), Atom::Literal(Int(i))));
99
51
            }
100

            
101
68
            if n_consts <= 1 {
102
68
                Err(RuleNotApplicable)
103
            } else {
104
                Ok(Reduction::pure(Max(m, new_vec)))
105
            }
106
        }
107
1020
        Not(_, _) => Err(RuleNotApplicable),
108
429114
        Or(m, vec) => {
109
429114
            let mut new_vec: Vec<Expr> = Vec::new();
110
429114
            let mut has_const: bool = false;
111
2095726
            for expr in vec {
112
68
                if let Expr::Atomic(_, Atom::Literal(Bool(x))) = expr {
113
68
                    has_const = true;
114
68
                    if x {
115
34
                        return Ok(Reduction::pure(Atomic(
116
34
                            Default::default(),
117
34
                            Atom::Literal(Bool(true)),
118
34
                        )));
119
34
                    }
120
1666578
                } else {
121
1666578
                    new_vec.push(expr);
122
1666578
                }
123
            }
124

            
125
429080
            if !has_const {
126
429046
                Err(RuleNotApplicable)
127
            } else {
128
34
                Ok(Reduction::pure(Or(m, new_vec)))
129
            }
130
        }
131
20842
        And(m, vec) => {
132
20842
            let mut new_vec: Vec<Expr> = Vec::new();
133
20842
            let mut has_const: bool = false;
134
620891
            for expr in vec {
135
68
                if let Expr::Atomic(_, Atom::Literal(Bool(x))) = expr {
136
68
                    has_const = true;
137
68
                    if !x {
138
                        return Ok(Reduction::pure(Atomic(
139
                            Default::default(),
140
                            Atom::Literal(Bool(false)),
141
                        )));
142
68
                    }
143
599981
                } else {
144
599981
                    new_vec.push(expr);
145
599981
                }
146
            }
147

            
148
20842
            if !has_const {
149
20774
                Err(RuleNotApplicable)
150
            } else {
151
68
                Ok(Reduction::pure(And(m, new_vec)))
152
            }
153
        }
154
5168
        Eq(_, _, _) => Err(RuleNotApplicable),
155
2108
        Neq(_, _, _) => Err(RuleNotApplicable),
156
272
        Geq(_, _, _) => Err(RuleNotApplicable),
157
357
        Leq(_, _, _) => Err(RuleNotApplicable),
158
        Gt(_, _, _) => Err(RuleNotApplicable),
159
3094
        Lt(_, _, _) => Err(RuleNotApplicable),
160
51
        SafeDiv(_, _, _) => Err(RuleNotApplicable),
161
340
        UnsafeDiv(_, _, _) => Err(RuleNotApplicable),
162
3196
        SumEq(m, vec, eq) => {
163
3196
            let mut acc = 0;
164
3196
            let mut new_vec: Vec<Expr> = Vec::new();
165
3196
            let mut n_consts = 0;
166
12767
            for expr in vec {
167
85
                if let Expr::Atomic(_, Atom::Literal(Int(x))) = expr {
168
85
                    n_consts += 1;
169
85
                    acc += x;
170
9486
                } else {
171
9486
                    new_vec.push(expr);
172
9486
                }
173
            }
174

            
175
3179
            if let Expr::Atomic(_, Atom::Literal(Int(x))) = *eq {
176
3179
                if acc != 0 {
177
                    // when rhs is a constant, move lhs constants to rhs
178
34
                    return Ok(Reduction::pure(SumEq(
179
34
                        m,
180
34
                        new_vec,
181
34
                        Box::new(Expr::Atomic(
182
34
                            Default::default(),
183
34
                            Atom::Literal(Int(x - acc)),
184
34
                        )),
185
34
                    )));
186
3145
                }
187
17
            } else if acc != 0 {
188
17
                new_vec.push(Expr::Atomic(Default::default(), Atom::Literal(Int(acc))));
189
17
            }
190

            
191
3162
            if n_consts <= 1 {
192
3162
                Err(RuleNotApplicable)
193
            } else {
194
                Ok(Reduction::pure(SumEq(m, new_vec, eq)))
195
            }
196
        }
197
642583
        SumGeq(m, vec, geq) => {
198
642583
            let mut acc = 0;
199
642583
            let mut new_vec: Vec<Expr> = Vec::new();
200
642583
            let mut n_consts = 0;
201
2570128
            for expr in vec {
202
153
                if let Expr::Atomic(_, Atom::Literal(Int(x))) = expr {
203
153
                    n_consts += 1;
204
153
                    acc += x;
205
1927392
                } else {
206
1927392
                    new_vec.push(expr);
207
1927392
                }
208
            }
209

            
210
642413
            if let Expr::Atomic(_, Atom::Literal(Int(x))) = *geq {
211
642413
                if acc != 0 {
212
                    // when rhs is a constant, move lhs constants to rhs
213
                    return Ok(Reduction::pure(SumGeq(
214
                        m,
215
                        new_vec,
216
                        Box::new(Expr::Atomic(
217
                            Default::default(),
218
                            Atom::Literal(Int(x - acc)),
219
                        )),
220
                    )));
221
642413
                }
222
170
            } else if acc != 0 {
223
153
                new_vec.push(Expr::Atomic(Default::default(), Atom::Literal(Int(acc))));
224
153
            }
225

            
226
642583
            if n_consts <= 1 {
227
642583
                Err(RuleNotApplicable)
228
            } else {
229
                Ok(Reduction::pure(SumGeq(m, new_vec, geq)))
230
            }
231
        }
232
620330
        SumLeq(m, vec, leq) => {
233
620330
            let mut acc = 0;
234
620330
            let mut new_vec: Vec<Expr> = Vec::new();
235
620330
            let mut n_consts = 0;
236
2480912
            for expr in vec {
237
170
                if let Expr::Atomic(_, Atom::Literal(Int(x))) = expr {
238
170
                    n_consts += 1;
239
170
                    acc += x;
240
1860412
                } else {
241
1860412
                    new_vec.push(expr);
242
1860412
                }
243
            }
244

            
245
620177
            if let Expr::Atomic(_, Atom::Literal(Int(x))) = *leq {
246
                // when rhs is a constant, move lhs constants to rhs
247
620177
                if acc != 0 {
248
34
                    return Ok(Reduction::pure(SumLeq(
249
34
                        m,
250
34
                        new_vec,
251
34
                        Box::new(Expr::Atomic(
252
34
                            Default::default(),
253
34
                            Atom::Literal(Int(x - acc)),
254
34
                        )),
255
34
                    )));
256
620143
                }
257
153
            } else if acc != 0 {
258
136
                new_vec.push(Expr::Atomic(Default::default(), Atom::Literal(Int(acc))));
259
136
            }
260

            
261
620296
            if n_consts <= 1 {
262
620296
                Err(RuleNotApplicable)
263
            } else {
264
                Ok(Reduction::pure(SumLeq(m, new_vec, leq)))
265
            }
266
        }
267
476
        DivEqUndefZero(_, _, _, _) => Err(RuleNotApplicable),
268
360519
        Ineq(_, _, _, _) => Err(RuleNotApplicable),
269
        AllDiff(m, vec) => {
270
            let mut consts: HashSet<i32> = HashSet::new();
271

            
272
            // check for duplicate constant values which would fail the constraint
273
            for expr in &vec {
274
                if let Expr::Atomic(_, Atom::Literal(Int(x))) = expr {
275
                    if !consts.insert(*x) {
276
                        return Ok(Reduction::pure(Expr::Atomic(m, Atom::Literal(Bool(false)))));
277
                    }
278
                }
279
            }
280

            
281
            // nothing has changed
282
            Err(RuleNotApplicable)
283
        }
284
680
        WatchedLiteral(_, _, _) => Err(RuleNotApplicable),
285
        Reify(_, _, _) => Err(RuleNotApplicable),
286
408
        AuxDeclaration(_, _, _) => Err(RuleNotApplicable),
287
340
        UnsafeMod(_, _, _) => Err(RuleNotApplicable),
288
51
        SafeMod(_, _, _) => Err(RuleNotApplicable),
289
476
        ModuloEqUndefZero(_, _, _, _) => Err(RuleNotApplicable),
290
    }
291
8229411
}