1
use std::collections::HashSet;
2

            
3
use conjure_macros::register_rule;
4

            
5
use crate::ast::{Expression as Expr, Factor, Literal::*};
6
use crate::rule_engine::{ApplicationResult, Reduction};
7
use crate::Model;
8

            
9
use super::utils::ToAuxVarOutput;
10

            
11
#[register_rule(("Base",9000))]
12
8222424
fn partial_evaluator(expr: &Expr, _: &Model) -> ApplicationResult {
13
    use conjure_core::rule_engine::ApplicationError::RuleNotApplicable;
14
    use Expr::*;
15

            
16
    // NOTE: If nothing changes, we must return RuleNotApplicable, or the rewriter will try this
17
    // rule infinitely!
18
    // This is why we always check whether we found a constant or not.
19
8222424
    match expr.clone() {
20
221
        Bubble(_, _, _) => Err(RuleNotApplicable),
21
6135147
        FactorE(_, _) => Err(RuleNotApplicable),
22
204
        Sum(m, vec) => {
23
204
            let mut acc = 0;
24
204
            let mut n_consts = 0;
25
204
            let mut new_vec: Vec<Expr> = Vec::new();
26
782
            for expr in vec {
27
272
                if let Expr::FactorE(_, Factor::Literal(Int(x))) = expr {
28
272
                    acc += x;
29
272
                    n_consts += 1;
30
306
                } else {
31
306
                    new_vec.push(expr);
32
306
                }
33
            }
34
204
            if acc != 0 {
35
170
                new_vec.push(Expr::FactorE(Default::default(), Factor::Literal(Int(acc))));
36
170
            }
37

            
38
204
            if n_consts <= 1 {
39
136
                Err(RuleNotApplicable)
40
            } else {
41
68
                Ok(Reduction::pure(Sum(m, new_vec)))
42
            }
43
        }
44
119
        Min(m, vec) => {
45
119
            let mut acc: Option<i32> = None;
46
119
            let mut n_consts = 0;
47
119
            let mut new_vec: Vec<Expr> = Vec::new();
48
357
            for expr in vec {
49
34
                if let Expr::FactorE(_, Factor::Literal(Int(x))) = expr {
50
34
                    n_consts += 1;
51
34
                    acc = match acc {
52
17
                        Some(i) => {
53
17
                            if i > x {
54
                                Some(x)
55
                            } else {
56
17
                                Some(i)
57
                            }
58
                        }
59
17
                        None => Some(x),
60
                    };
61
204
                } else {
62
204
                    new_vec.push(expr);
63
204
                }
64
            }
65

            
66
119
            if let Some(i) = acc {
67
17
                new_vec.push(Expr::FactorE(Default::default(), Factor::Literal(Int(i))));
68
102
            }
69

            
70
119
            if n_consts <= 1 {
71
102
                Err(RuleNotApplicable)
72
            } else {
73
17
                Ok(Reduction::pure(Min(m, new_vec)))
74
            }
75
        }
76
68
        Max(m, vec) => {
77
68
            let mut acc: Option<i32> = None;
78
68
            let mut n_consts = 0;
79
68
            let mut new_vec: Vec<Expr> = Vec::new();
80
204
            for expr in vec {
81
17
                if let Expr::FactorE(_, Factor::Literal(Int(x))) = expr {
82
17
                    n_consts += 1;
83
17
                    acc = match acc {
84
                        Some(i) => {
85
                            if i < x {
86
                                Some(x)
87
                            } else {
88
                                Some(i)
89
                            }
90
                        }
91
17
                        None => Some(x),
92
                    };
93
119
                } else {
94
119
                    new_vec.push(expr);
95
119
                }
96
            }
97

            
98
68
            if let Some(i) = acc {
99
17
                new_vec.push(Expr::FactorE(Default::default(), Factor::Literal(Int(i))));
100
51
            }
101

            
102
68
            if n_consts <= 1 {
103
68
                Err(RuleNotApplicable)
104
            } else {
105
                Ok(Reduction::pure(Max(m, new_vec)))
106
            }
107
        }
108
527
        Not(_, _) => Err(RuleNotApplicable),
109
428706
        Or(m, vec) => {
110
428706
            let mut new_vec: Vec<Expr> = Vec::new();
111
428706
            let mut has_const: bool = false;
112
2094332
            for expr in vec {
113
68
                if let Expr::FactorE(_, Factor::Literal(Bool(x))) = expr {
114
68
                    has_const = true;
115
68
                    if x {
116
34
                        return Ok(Reduction::pure(FactorE(
117
34
                            Default::default(),
118
34
                            Factor::Literal(Bool(true)),
119
34
                        )));
120
34
                    }
121
1665592
                } else {
122
1665592
                    new_vec.push(expr);
123
1665592
                }
124
            }
125

            
126
428672
            if !has_const {
127
428638
                Err(RuleNotApplicable)
128
            } else {
129
34
                Ok(Reduction::pure(Or(m, new_vec)))
130
            }
131
        }
132
19754
        And(m, vec) => {
133
19754
            let mut new_vec: Vec<Expr> = Vec::new();
134
19754
            let mut has_const: bool = false;
135
616896
            for expr in vec {
136
68
                if let Expr::FactorE(_, Factor::Literal(Bool(x))) = expr {
137
68
                    has_const = true;
138
68
                    if !x {
139
                        return Ok(Reduction::pure(FactorE(
140
                            Default::default(),
141
                            Factor::Literal(Bool(false)),
142
                        )));
143
68
                    }
144
597074
                } else {
145
597074
                    new_vec.push(expr);
146
597074
                }
147
            }
148

            
149
19754
            if !has_const {
150
19686
                Err(RuleNotApplicable)
151
            } else {
152
68
                Ok(Reduction::pure(And(m, new_vec)))
153
            }
154
        }
155
4539
        Eq(_, _, _) => Err(RuleNotApplicable),
156
1088
        Neq(_, _, _) => Err(RuleNotApplicable),
157
255
        Geq(_, _, _) => Err(RuleNotApplicable),
158
357
        Leq(_, _, _) => Err(RuleNotApplicable),
159
        Gt(_, _, _) => Err(RuleNotApplicable),
160
3094
        Lt(_, _, _) => Err(RuleNotApplicable),
161
51
        SafeDiv(_, _, _) => Err(RuleNotApplicable),
162
340
        UnsafeDiv(_, _, _) => Err(RuleNotApplicable),
163
3196
        SumEq(m, vec, eq) => {
164
3196
            let mut acc = 0;
165
3196
            let mut new_vec: Vec<Expr> = Vec::new();
166
3196
            let mut n_consts = 0;
167
12767
            for expr in vec {
168
85
                if let Expr::FactorE(_, Factor::Literal(Int(x))) = expr {
169
85
                    n_consts += 1;
170
85
                    acc += x;
171
9486
                } else {
172
9486
                    new_vec.push(expr);
173
9486
                }
174
            }
175

            
176
3179
            if let Expr::FactorE(_, Factor::Literal(Int(x))) = *eq {
177
3179
                if acc != 0 {
178
                    // when rhs is a constant, move lhs constants to rhs
179
34
                    return Ok(Reduction::pure(SumEq(
180
34
                        m,
181
34
                        new_vec,
182
34
                        Box::new(Expr::FactorE(
183
34
                            Default::default(),
184
34
                            Factor::Literal(Int(x - acc)),
185
34
                        )),
186
34
                    )));
187
3145
                }
188
17
            } else if acc != 0 {
189
17
                new_vec.push(Expr::FactorE(Default::default(), Factor::Literal(Int(acc))));
190
17
            }
191

            
192
3162
            if n_consts <= 1 {
193
3162
                Err(RuleNotApplicable)
194
            } else {
195
                Ok(Reduction::pure(SumEq(m, new_vec, eq)))
196
            }
197
        }
198
642583
        SumGeq(m, vec, geq) => {
199
642583
            let mut acc = 0;
200
642583
            let mut new_vec: Vec<Expr> = Vec::new();
201
642583
            let mut n_consts = 0;
202
2570128
            for expr in vec {
203
153
                if let Expr::FactorE(_, Factor::Literal(Int(x))) = expr {
204
153
                    n_consts += 1;
205
153
                    acc += x;
206
1927392
                } else {
207
1927392
                    new_vec.push(expr);
208
1927392
                }
209
            }
210

            
211
642413
            if let Expr::FactorE(_, Factor::Literal(Int(x))) = *geq {
212
642413
                if acc != 0 {
213
                    // when rhs is a constant, move lhs constants to rhs
214
                    return Ok(Reduction::pure(SumGeq(
215
                        m,
216
                        new_vec,
217
                        Box::new(Expr::FactorE(
218
                            Default::default(),
219
                            Factor::Literal(Int(x - acc)),
220
                        )),
221
                    )));
222
642413
                }
223
170
            } else if acc != 0 {
224
153
                new_vec.push(Expr::FactorE(Default::default(), Factor::Literal(Int(acc))));
225
153
            }
226

            
227
642583
            if n_consts <= 1 {
228
642583
                Err(RuleNotApplicable)
229
            } else {
230
                Ok(Reduction::pure(SumGeq(m, new_vec, geq)))
231
            }
232
        }
233
620330
        SumLeq(m, vec, leq) => {
234
620330
            let mut acc = 0;
235
620330
            let mut new_vec: Vec<Expr> = Vec::new();
236
620330
            let mut n_consts = 0;
237
2480912
            for expr in vec {
238
170
                if let Expr::FactorE(_, Factor::Literal(Int(x))) = expr {
239
170
                    n_consts += 1;
240
170
                    acc += x;
241
1860412
                } else {
242
1860412
                    new_vec.push(expr);
243
1860412
                }
244
            }
245

            
246
620177
            if let Expr::FactorE(_, Factor::Literal(Int(x))) = *leq {
247
                // when rhs is a constant, move lhs constants to rhs
248
620177
                if acc != 0 {
249
34
                    return Ok(Reduction::pure(SumLeq(
250
34
                        m,
251
34
                        new_vec,
252
34
                        Box::new(Expr::FactorE(
253
34
                            Default::default(),
254
34
                            Factor::Literal(Int(x - acc)),
255
34
                        )),
256
34
                    )));
257
620143
                }
258
153
            } else if acc != 0 {
259
136
                new_vec.push(Expr::FactorE(Default::default(), Factor::Literal(Int(acc))));
260
136
            }
261

            
262
620296
            if n_consts <= 1 {
263
620296
                Err(RuleNotApplicable)
264
            } else {
265
                Ok(Reduction::pure(SumLeq(m, new_vec, leq)))
266
            }
267
        }
268
476
        DivEq(_, _, _, _) => Err(RuleNotApplicable),
269
360485
        Ineq(_, _, _, _) => Err(RuleNotApplicable),
270
        AllDiff(m, vec) => {
271
            let mut consts: HashSet<i32> = HashSet::new();
272

            
273
            // check for duplicate constant values which would fail the constraint
274
            for expr in &vec {
275
                if let Expr::FactorE(_, Factor::Literal(Int(x))) = expr {
276
                    if !consts.insert(*x) {
277
                        return Ok(Reduction::pure(Expr::FactorE(
278
                            m,
279
                            Factor::Literal(Bool(false)),
280
                        )));
281
                    }
282
                }
283
            }
284

            
285
            // nothing has changed
286
            Err(RuleNotApplicable)
287
        }
288

            
289
680
        WatchedLiteral(_, _, _) => Err(RuleNotApplicable),
290
        Reify(_, _, _) => Err(RuleNotApplicable),
291
204
        AuxDeclaration(_, _, _) => Err(RuleNotApplicable),
292
    }
293
8222424
}