1
use std::collections::HashSet;
2

            
3
use conjure_macros::register_rule;
4

            
5
use crate::ast::{Expression as Expr, Factor, Literal::*};
6
use crate::rule_engine::{ApplicationResult, Reduction};
7
use crate::Model;
8

            
9
use super::utils::ToAuxVarOutput;
10

            
11
#[register_rule(("Base",9000))]
12
7250595
fn partial_evaluator(expr: &Expr, _: &Model) -> ApplicationResult {
13
    use conjure_core::rule_engine::ApplicationError::RuleNotApplicable;
14
    use Expr::*;
15

            
16
    // NOTE: If nothing changes, we must return RuleNotApplicable, or the rewriter will try this
17
    // rule infinitely!
18
    // This is why we always check whether we found a constant or not.
19
7250595
    match expr.clone() {
20
90
        Bubble(_, _, _) => Err(RuleNotApplicable),
21
5412165
        FactorE(_, _) => Err(RuleNotApplicable),
22
180
        Sum(m, vec) => {
23
180
            let mut acc = 0;
24
180
            let mut n_consts = 0;
25
180
            let mut new_vec: Vec<Expr> = Vec::new();
26
690
            for expr in vec {
27
240
                if let Expr::FactorE(_, Factor::Literal(Int(x))) = expr {
28
240
                    acc += x;
29
240
                    n_consts += 1;
30
270
                } else {
31
270
                    new_vec.push(expr);
32
270
                }
33
            }
34
180
            if acc != 0 {
35
150
                new_vec.push(Expr::FactorE(Default::default(), Factor::Literal(Int(acc))));
36
150
            }
37

            
38
180
            if n_consts <= 1 {
39
120
                Err(RuleNotApplicable)
40
            } else {
41
60
                Ok(Reduction::pure(Sum(m, new_vec)))
42
            }
43
        }
44
105
        Min(m, vec) => {
45
105
            let mut acc: Option<i32> = None;
46
105
            let mut n_consts = 0;
47
105
            let mut new_vec: Vec<Expr> = Vec::new();
48
315
            for expr in vec {
49
30
                if let Expr::FactorE(_, Factor::Literal(Int(x))) = expr {
50
30
                    n_consts += 1;
51
30
                    acc = match acc {
52
15
                        Some(i) => {
53
15
                            if i > x {
54
                                Some(x)
55
                            } else {
56
15
                                Some(i)
57
                            }
58
                        }
59
15
                        None => Some(x),
60
                    };
61
180
                } else {
62
180
                    new_vec.push(expr);
63
180
                }
64
            }
65

            
66
105
            if let Some(i) = acc {
67
15
                new_vec.push(Expr::FactorE(Default::default(), Factor::Literal(Int(i))));
68
90
            }
69

            
70
105
            if n_consts <= 1 {
71
90
                Err(RuleNotApplicable)
72
            } else {
73
15
                Ok(Reduction::pure(Min(m, new_vec)))
74
            }
75
        }
76
60
        Max(m, vec) => {
77
60
            let mut acc: Option<i32> = None;
78
60
            let mut n_consts = 0;
79
60
            let mut new_vec: Vec<Expr> = Vec::new();
80
180
            for expr in vec {
81
15
                if let Expr::FactorE(_, Factor::Literal(Int(x))) = expr {
82
15
                    n_consts += 1;
83
15
                    acc = match acc {
84
                        Some(i) => {
85
                            if i < x {
86
                                Some(x)
87
                            } else {
88
                                Some(i)
89
                            }
90
                        }
91
15
                        None => Some(x),
92
                    };
93
105
                } else {
94
105
                    new_vec.push(expr);
95
105
                }
96
            }
97

            
98
60
            if let Some(i) = acc {
99
15
                new_vec.push(Expr::FactorE(Default::default(), Factor::Literal(Int(i))));
100
45
            }
101

            
102
60
            if n_consts <= 1 {
103
60
                Err(RuleNotApplicable)
104
            } else {
105
                Ok(Reduction::pure(Max(m, new_vec)))
106
            }
107
        }
108
60
        Not(_, _) => Err(RuleNotApplicable),
109
377910
        Or(m, vec) => {
110
377910
            let mut new_vec: Vec<Expr> = Vec::new();
111
377910
            let mut has_const: bool = false;
112
1846710
            for expr in vec {
113
60
                if let Expr::FactorE(_, Factor::Literal(Bool(x))) = expr {
114
60
                    has_const = true;
115
60
                    if x {
116
30
                        return Ok(Reduction::pure(FactorE(
117
30
                            Default::default(),
118
30
                            Factor::Literal(Bool(true)),
119
30
                        )));
120
30
                    }
121
1468770
                } else {
122
1468770
                    new_vec.push(expr);
123
1468770
                }
124
            }
125

            
126
377880
            if !has_const {
127
377850
                Err(RuleNotApplicable)
128
            } else {
129
30
                Ok(Reduction::pure(Or(m, new_vec)))
130
            }
131
        }
132
16755
        And(m, vec) => {
133
16755
            let mut new_vec: Vec<Expr> = Vec::new();
134
16755
            let mut has_const: bool = false;
135
541650
            for expr in vec {
136
60
                if let Expr::FactorE(_, Factor::Literal(Bool(x))) = expr {
137
60
                    has_const = true;
138
60
                    if !x {
139
                        return Ok(Reduction::pure(FactorE(
140
                            Default::default(),
141
                            Factor::Literal(Bool(false)),
142
                        )));
143
60
                    }
144
524835
                } else {
145
524835
                    new_vec.push(expr);
146
524835
                }
147
            }
148

            
149
16755
            if !has_const {
150
16695
                Err(RuleNotApplicable)
151
            } else {
152
60
                Ok(Reduction::pure(And(m, new_vec)))
153
            }
154
        }
155
3675
        Eq(_, _, _) => Err(RuleNotApplicable),
156
195
        Neq(_, _, _) => Err(RuleNotApplicable),
157
225
        Geq(_, _, _) => Err(RuleNotApplicable),
158
315
        Leq(_, _, _) => Err(RuleNotApplicable),
159
        Gt(_, _, _) => Err(RuleNotApplicable),
160
2730
        Lt(_, _, _) => Err(RuleNotApplicable),
161
        SafeDiv(_, _, _) => Err(RuleNotApplicable),
162
105
        UnsafeDiv(_, _, _) => Err(RuleNotApplicable),
163
2820
        SumEq(m, vec, eq) => {
164
2820
            let mut acc = 0;
165
2820
            let mut new_vec: Vec<Expr> = Vec::new();
166
2820
            let mut n_consts = 0;
167
11265
            for expr in vec {
168
75
                if let Expr::FactorE(_, Factor::Literal(Int(x))) = expr {
169
75
                    n_consts += 1;
170
75
                    acc += x;
171
8370
                } else {
172
8370
                    new_vec.push(expr);
173
8370
                }
174
            }
175

            
176
2805
            if let Expr::FactorE(_, Factor::Literal(Int(x))) = *eq {
177
2805
                if acc != 0 {
178
                    // when rhs is a constant, move lhs constants to rhs
179
30
                    return Ok(Reduction::pure(SumEq(
180
30
                        m,
181
30
                        new_vec,
182
30
                        Box::new(Expr::FactorE(
183
30
                            Default::default(),
184
30
                            Factor::Literal(Int(x - acc)),
185
30
                        )),
186
30
                    )));
187
2775
                }
188
15
            } else if acc != 0 {
189
15
                new_vec.push(Expr::FactorE(Default::default(), Factor::Literal(Int(acc))));
190
15
            }
191

            
192
2790
            if n_consts <= 1 {
193
2790
                Err(RuleNotApplicable)
194
            } else {
195
                Ok(Reduction::pure(SumEq(m, new_vec, eq)))
196
            }
197
        }
198
566985
        SumGeq(m, vec, geq) => {
199
566985
            let mut acc = 0;
200
566985
            let mut new_vec: Vec<Expr> = Vec::new();
201
566985
            let mut n_consts = 0;
202
2267760
            for expr in vec {
203
135
                if let Expr::FactorE(_, Factor::Literal(Int(x))) = expr {
204
135
                    n_consts += 1;
205
135
                    acc += x;
206
1700640
                } else {
207
1700640
                    new_vec.push(expr);
208
1700640
                }
209
            }
210

            
211
566835
            if let Expr::FactorE(_, Factor::Literal(Int(x))) = *geq {
212
566835
                if acc != 0 {
213
                    // when rhs is a constant, move lhs constants to rhs
214
                    return Ok(Reduction::pure(SumGeq(
215
                        m,
216
                        new_vec,
217
                        Box::new(Expr::FactorE(
218
                            Default::default(),
219
                            Factor::Literal(Int(x - acc)),
220
                        )),
221
                    )));
222
566835
                }
223
150
            } else if acc != 0 {
224
135
                new_vec.push(Expr::FactorE(Default::default(), Factor::Literal(Int(acc))));
225
135
            }
226

            
227
566985
            if n_consts <= 1 {
228
566985
                Err(RuleNotApplicable)
229
            } else {
230
                Ok(Reduction::pure(SumGeq(m, new_vec, geq)))
231
            }
232
        }
233
547350
        SumLeq(m, vec, leq) => {
234
547350
            let mut acc = 0;
235
547350
            let mut new_vec: Vec<Expr> = Vec::new();
236
547350
            let mut n_consts = 0;
237
2189040
            for expr in vec {
238
150
                if let Expr::FactorE(_, Factor::Literal(Int(x))) = expr {
239
150
                    n_consts += 1;
240
150
                    acc += x;
241
1641540
                } else {
242
1641540
                    new_vec.push(expr);
243
1641540
                }
244
            }
245

            
246
547215
            if let Expr::FactorE(_, Factor::Literal(Int(x))) = *leq {
247
                // when rhs is a constant, move lhs constants to rhs
248
547215
                if acc != 0 {
249
30
                    return Ok(Reduction::pure(SumLeq(
250
30
                        m,
251
30
                        new_vec,
252
30
                        Box::new(Expr::FactorE(
253
30
                            Default::default(),
254
30
                            Factor::Literal(Int(x - acc)),
255
30
                        )),
256
30
                    )));
257
547185
                }
258
135
            } else if acc != 0 {
259
120
                new_vec.push(Expr::FactorE(Default::default(), Factor::Literal(Int(acc))));
260
120
            }
261

            
262
547320
            if n_consts <= 1 {
263
547320
                Err(RuleNotApplicable)
264
            } else {
265
                Ok(Reduction::pure(SumLeq(m, new_vec, leq)))
266
            }
267
        }
268
90
        DivEq(_, _, _, _) => Err(RuleNotApplicable),
269
318075
        Ineq(_, _, _, _) => Err(RuleNotApplicable),
270
        AllDiff(m, vec) => {
271
            let mut consts: HashSet<i32> = HashSet::new();
272

            
273
            // check for duplicate constant values which would fail the constraint
274
            for expr in &vec {
275
                if let Expr::FactorE(_, Factor::Literal(Int(x))) = expr {
276
                    if !consts.insert(*x) {
277
                        return Ok(Reduction::pure(Expr::FactorE(
278
                            m,
279
                            Factor::Literal(Bool(false)),
280
                        )));
281
                    }
282
                }
283
            }
284

            
285
            // nothing has changed
286
            Err(RuleNotApplicable)
287
        }
288

            
289
600
        WatchedLiteral(_, _, _) => Err(RuleNotApplicable),
290
105
        Reify(_, _, _) => Err(RuleNotApplicable),
291
        AuxDeclaration(_, _, _) => Err(RuleNotApplicable),
292
    }
293
7250595
}