1
use std::collections::HashSet;
2

            
3
use conjure_core::ast::{Atom, Expression as Expr, Literal as Lit};
4
use conjure_core::metadata::Metadata;
5
use conjure_core::rule_engine::{
6
    register_rule, register_rule_set, ApplicationError, ApplicationResult, Reduction,
7
};
8
use conjure_core::Model;
9
use itertools::izip;
10

            
11
register_rule_set!("Constant", 100, ());
12

            
13
#[register_rule(("Constant", 9001))]
14
487458
fn apply_eval_constant(expr: &Expr, _: &Model) -> ApplicationResult {
15
252824
    if let Expr::Atomic(_, Atom::Literal(_)) = expr {
16
107491
        return Err(ApplicationError::RuleNotApplicable);
17
379967
    }
18
379967
    eval_constant(expr)
19
379967
        .map(|c| Reduction::pure(Expr::Atomic(Metadata::new(), Atom::Literal(c))))
20
379967
        .ok_or(ApplicationError::RuleNotApplicable)
21
487458
}
22

            
23
/// Simplify an expression to a constant if possible
24
/// Returns:
25
/// `None` if the expression cannot be simplified to a constant (e.g. if it contains a variable)
26
/// `Some(Const)` if the expression can be simplified to a constant
27
842779
pub fn eval_constant(expr: &Expr) -> Option<Lit> {
28
455483
    match expr {
29
90017
        Expr::Atomic(_, Atom::Literal(c)) => Some(c.clone()),
30
365466
        Expr::Atomic(_, Atom::Reference(_c)) => None,
31
3587
        Expr::Abs(_, e) => un_op::<i32, i32>(|a| a.abs(), e).map(Lit::Int),
32
59636
        Expr::Eq(_, a, b) => bin_op::<i32, bool>(|a, b| a == b, a, b)
33
59636
            .or_else(|| bin_op::<bool, bool>(|a, b| a == b, a, b))
34
59636
            .map(Lit::Bool),
35
13107
        Expr::Neq(_, a, b) => bin_op::<i32, bool>(|a, b| a != b, a, b).map(Lit::Bool),
36
3043
        Expr::Lt(_, a, b) => bin_op::<i32, bool>(|a, b| a < b, a, b).map(Lit::Bool),
37
34
        Expr::Gt(_, a, b) => bin_op::<i32, bool>(|a, b| a > b, a, b).map(Lit::Bool),
38
39644
        Expr::Leq(_, a, b) => bin_op::<i32, bool>(|a, b| a <= b, a, b).map(Lit::Bool),
39
1768
        Expr::Geq(_, a, b) => bin_op::<i32, bool>(|a, b| a >= b, a, b).map(Lit::Bool),
40

            
41
7565
        Expr::Not(_, expr) => un_op::<bool, bool>(|e| !e, expr).map(Lit::Bool),
42

            
43
10387
        Expr::And(_, exprs) => vec_op::<bool, bool>(|e| e.iter().all(|&e| e), exprs).map(Lit::Bool),
44
21029
        Expr::Or(_, exprs) => vec_op::<bool, bool>(|e| e.iter().any(|&e| e), exprs).map(Lit::Bool),
45
38879
        Expr::Imply(_, box1, box2) => {
46
38879
            let a: &Atom = (&**box1).try_into().ok()?;
47
1411
            let b: &Atom = (&**box2).try_into().ok()?;
48

            
49
850
            let a: bool = a.try_into().ok()?;
50
            let b: bool = b.try_into().ok()?;
51

            
52
            if a {
53
                // true -> b ~> b
54
                Some(Lit::Bool(b))
55
            } else {
56
                // false -> b ~> true
57
                Some(Lit::Bool(true))
58
            }
59
        }
60

            
61
11169
        Expr::Sum(_, exprs) => vec_op::<i32, i32>(|e| e.iter().sum(), exprs).map(Lit::Int),
62
5695
        Expr::Product(_, exprs) => vec_op::<i32, i32>(|e| e.iter().product(), exprs).map(Lit::Int),
63

            
64
5865
        Expr::FlatIneq(_, a, b, c) => {
65
5865
            let a: i32 = a.try_into().ok()?;
66
408
            let b: i32 = b.try_into().ok()?;
67
            let c: i32 = c.try_into().ok()?;
68

            
69
            Some(Lit::Bool(a <= b + c))
70
        }
71

            
72
1309
        Expr::FlatSumGeq(_, exprs, a) => {
73
1326
            let sum = exprs.iter().try_fold(0, |acc, atom: &Atom| {
74
1326
                let n: i32 = atom.try_into().ok()?;
75
17
                let acc = acc + n;
76
17
                Some(acc)
77
1326
            })?;
78

            
79
            Some(Lit::Bool(sum >= a.try_into().ok()?))
80
        }
81
2567
        Expr::FlatSumLeq(_, exprs, a) => {
82
2584
            let sum = exprs.iter().try_fold(0, |acc, atom: &Atom| {
83
2584
                let n: i32 = atom.try_into().ok()?;
84
17
                let acc = acc + n;
85
17
                Some(acc)
86
2584
            })?;
87

            
88
            Some(Lit::Bool(sum >= a.try_into().ok()?))
89
        }
90
272
        Expr::Min(_, exprs) => {
91
272
            opt_vec_op::<i32, i32>(|e| e.iter().min().copied(), exprs).map(Lit::Int)
92
        }
93
136
        Expr::Max(_, exprs) => {
94
136
            opt_vec_op::<i32, i32>(|e| e.iter().max().copied(), exprs).map(Lit::Int)
95
        }
96
31383
        Expr::UnsafeDiv(_, a, b) | Expr::SafeDiv(_, a, b) => {
97
47857
            if unwrap_expr::<i32>(b)? == 0 {
98
2
                return None;
99
34629
            }
100
34629
            bin_op::<i32, i32>(|a, b| ((a as f32) / (b as f32)).floor() as i32, a, b).map(Lit::Int)
101
        }
102
42823
        Expr::UnsafeMod(_, a, b) | Expr::SafeMod(_, a, b) => {
103
57902
            if unwrap_expr::<i32>(b)? == 0 {
104
                return None;
105
47804
            }
106
47804
            bin_op::<i32, i32>(|a, b| a - b * (a as f32 / b as f32).floor() as i32, a, b)
107
47804
                .map(Lit::Int)
108
        }
109
3247
        Expr::MinionDivEqUndefZero(_, a, b, c) => {
110
            // div always rounds down
111
3247
            let a: i32 = a.try_into().ok()?;
112
34
            let b: i32 = b.try_into().ok()?;
113
            let c: i32 = c.try_into().ok()?;
114

            
115
            if b == 0 {
116
                return None;
117
            }
118

            
119
            let a = a as f32;
120
            let b = b as f32;
121
            let div: i32 = (a / b).floor() as i32;
122
            Some(Lit::Bool(div == c))
123
        }
124
5797
        Expr::Bubble(_, a, b) => bin_op::<bool, bool>(|a, b| a && b, a, b).map(Lit::Bool),
125

            
126
12172
        Expr::MinionReify(_, a, b) => {
127
12172
            let result = eval_constant(a)?;
128

            
129
17
            let result: bool = result.try_into().ok()?;
130
17
            let b: bool = b.try_into().ok()?;
131

            
132
            Some(Lit::Bool(b == result))
133
        }
134

            
135
13464
        Expr::MinionReifyImply(_, a, b) => {
136
13464
            let result = eval_constant(a)?;
137

            
138
            let result: bool = result.try_into().ok()?;
139
            let b: bool = b.try_into().ok()?;
140

            
141
            if b {
142
                Some(Lit::Bool(result))
143
            } else {
144
                Some(Lit::Bool(true))
145
            }
146
        }
147
9962
        Expr::MinionModuloEqUndefZero(_, a, b, c) => {
148
            // From Savile Row. Same semantics as division.
149
            //
150
            //   a - (b * floor(a/b))
151
            //
152
            // We don't use % as it has the same semantics as /. We don't use / as we want to round
153
            // down instead, not towards zero.
154

            
155
9962
            let a: i32 = a.try_into().ok()?;
156
34
            let b: i32 = b.try_into().ok()?;
157
            let c: i32 = c.try_into().ok()?;
158

            
159
            if b == 0 {
160
                return None;
161
            }
162

            
163
            let modulo = a - b * (a as f32 / b as f32).floor() as i32;
164
            Some(Lit::Bool(modulo == c))
165
        }
166

            
167
578
        Expr::MinionPow(_, a, b, c) => {
168
            // only available for positive a b c
169

            
170
578
            let a: i32 = a.try_into().ok()?;
171
            let b: i32 = b.try_into().ok()?;
172
            let c: i32 = c.try_into().ok()?;
173

            
174
            if a <= 0 {
175
                return None;
176
            }
177

            
178
            if b <= 0 {
179
                return None;
180
            }
181

            
182
            if c <= 0 {
183
                return None;
184
            }
185

            
186
            Some(Lit::Bool(a ^ b == c))
187
        }
188

            
189
34
        Expr::AllDiff(_, es) => {
190
34
            let mut lits: HashSet<Lit> = HashSet::new();
191
136
            for expr in es {
192
102
                let Expr::Atomic(_, Atom::Literal(x)) = expr else {
193
                    return None;
194
                };
195
102
                if lits.contains(x) {
196
                    return Some(Lit::Bool(false));
197
102
                } else {
198
102
                    lits.insert(x.clone());
199
102
                }
200
            }
201
34
            Some(Lit::Bool(true))
202
        }
203
714
        Expr::FlatWatchedLiteral(_, _, _) => None,
204
2346
        Expr::AuxDeclaration(_, _, _) => None,
205
3366
        Expr::Neg(_, a) => {
206
3366
            let a: &Atom = a.try_into().ok()?;
207
2312
            let a: i32 = a.try_into().ok()?;
208
731
            Some(Lit::Int(-a))
209
        }
210
221
        Expr::Minus(_, a, b) => {
211
221
            let a: &Atom = a.try_into().ok()?;
212
102
            let a: i32 = a.try_into().ok()?;
213

            
214
            let b: &Atom = b.try_into().ok()?;
215
            let b: i32 = b.try_into().ok()?;
216

            
217
            Some(Lit::Int(a - b))
218
        }
219
153
        Expr::FlatMinusEq(_, a, b) => {
220
153
            let a: i32 = a.try_into().ok()?;
221
            let b: i32 = b.try_into().ok()?;
222
            Some(Lit::Bool(a == -b))
223
        }
224
153
        Expr::FlatProductEq(_, a, b, c) => {
225
153
            let a: i32 = a.try_into().ok()?;
226
            let b: i32 = b.try_into().ok()?;
227
            let c: i32 = c.try_into().ok()?;
228
            Some(Lit::Bool(a * b == c))
229
        }
230
510
        Expr::FlatWeightedSumLeq(_, cs, vs, total) => {
231
510
            let cs: Vec<i32> = cs
232
510
                .iter()
233
2244
                .map(|x| TryInto::<i32>::try_into(x).ok())
234
510
                .collect::<Option<Vec<i32>>>()?;
235
510
            let vs: Vec<i32> = vs
236
510
                .iter()
237
510
                .map(|x| TryInto::<i32>::try_into(x).ok())
238
510
                .collect::<Option<Vec<i32>>>()?;
239
            let total: i32 = total.try_into().ok()?;
240

            
241
            let sum: i32 = izip!(cs, vs).fold(0, |acc, (c, v)| acc + (c * v));
242

            
243
            Some(Lit::Bool(sum <= total))
244
        }
245

            
246
170
        Expr::FlatWeightedSumGeq(_, cs, vs, total) => {
247
170
            let cs: Vec<i32> = cs
248
170
                .iter()
249
510
                .map(|x| TryInto::<i32>::try_into(x).ok())
250
170
                .collect::<Option<Vec<i32>>>()?;
251
170
            let vs: Vec<i32> = vs
252
170
                .iter()
253
170
                .map(|x| TryInto::<i32>::try_into(x).ok())
254
170
                .collect::<Option<Vec<i32>>>()?;
255
            let total: i32 = total.try_into().ok()?;
256

            
257
            let sum: i32 = izip!(cs, vs).fold(0, |acc, (c, v)| acc + (c * v));
258

            
259
            Some(Lit::Bool(sum >= total))
260
        }
261
289
        Expr::FlatAbsEq(_, x, y) => {
262
289
            let x: i32 = x.try_into().ok()?;
263
17
            let y: i32 = y.try_into().ok()?;
264

            
265
            Some(Lit::Bool(x == y.abs()))
266
        }
267

            
268
1785
        Expr::UnsafePow(_, a, b) | Expr::SafePow(_, a, b) => {
269
2669
            let a: &Atom = a.try_into().ok()?;
270
1666
            let a: i32 = a.try_into().ok()?;
271

            
272
            let b: &Atom = b.try_into().ok()?;
273
            let b: i32 = b.try_into().ok()?;
274

            
275
            if (a != 0 || b != 0) && b >= 0 {
276
                Some(Lit::Int(a ^ b))
277
            } else {
278
                None
279
            }
280
        }
281
    }
282
842779
}
283

            
284
11152
fn un_op<T, A>(f: fn(T) -> A, a: &Expr) -> Option<A>
285
11152
where
286
11152
    T: TryFrom<Lit>,
287
11152
{
288
11152
    let a = unwrap_expr::<T>(a)?;
289
    Some(f(a))
290
11152
}
291

            
292
264809
fn bin_op<T, A>(f: fn(T, T) -> A, a: &Expr, b: &Expr) -> Option<A>
293
264809
where
294
264809
    T: TryFrom<Lit>,
295
264809
{
296
264809
    let a = unwrap_expr::<T>(a)?;
297
1139
    let b = unwrap_expr::<T>(b)?;
298
850
    Some(f(a, b))
299
264809
}
300

            
301
#[allow(dead_code)]
302
fn tern_op<T, A>(f: fn(T, T, T) -> A, a: &Expr, b: &Expr, c: &Expr) -> Option<A>
303
where
304
    T: TryFrom<Lit>,
305
{
306
    let a = unwrap_expr::<T>(a)?;
307
    let b = unwrap_expr::<T>(b)?;
308
    let c = unwrap_expr::<T>(c)?;
309
    Some(f(a, b, c))
310
}
311

            
312
48280
fn vec_op<T, A>(f: fn(Vec<T>) -> A, a: &[Expr]) -> Option<A>
313
48280
where
314
48280
    T: TryFrom<Lit>,
315
48280
{
316
48280
    let a = a.iter().map(unwrap_expr).collect::<Option<Vec<T>>>()?;
317
663
    Some(f(a))
318
48280
}
319

            
320
408
fn opt_vec_op<T, A>(f: fn(Vec<T>) -> Option<A>, a: &[Expr]) -> Option<A>
321
408
where
322
408
    T: TryFrom<Lit>,
323
408
{
324
408
    let a = a.iter().map(unwrap_expr).collect::<Option<Vec<T>>>()?;
325
51
    f(a)
326
408
}
327

            
328
#[allow(dead_code)]
329
fn flat_op<T, A>(f: fn(Vec<T>, T) -> A, a: &[Expr], b: &Expr) -> Option<A>
330
where
331
    T: TryFrom<Lit>,
332
{
333
    let a = a.iter().map(unwrap_expr).collect::<Option<Vec<T>>>()?;
334
    let b = unwrap_expr::<T>(b)?;
335
    Some(f(a, b))
336
}
337

            
338
436987
fn unwrap_expr<T: TryFrom<Lit>>(expr: &Expr) -> Option<T> {
339
436987
    let c = eval_constant(expr)?;
340
90918
    TryInto::<T>::try_into(c).ok()
341
436987
}
342

            
343
#[cfg(test)]
344
mod tests {
345
    use crate::rules::eval_constant;
346
    use conjure_core::ast::{Atom, Expression, Literal};
347

            
348
    #[test]
349
1
    fn div_by_zero() {
350
1
        let expr = Expression::UnsafeDiv(
351
1
            Default::default(),
352
1
            Box::new(Expression::Atomic(
353
1
                Default::default(),
354
1
                Atom::Literal(Literal::Int(1)),
355
1
            )),
356
1
            Box::new(Expression::Atomic(
357
1
                Default::default(),
358
1
                Atom::Literal(Literal::Int(0)),
359
1
            )),
360
1
        );
361
1
        assert_eq!(eval_constant(&expr), None);
362
1
    }
363

            
364
    #[test]
365
1
    fn safediv_by_zero() {
366
1
        let expr = Expression::SafeDiv(
367
1
            Default::default(),
368
1
            Box::new(Expression::Atomic(
369
1
                Default::default(),
370
1
                Atom::Literal(Literal::Int(1)),
371
1
            )),
372
1
            Box::new(Expression::Atomic(
373
1
                Default::default(),
374
1
                Atom::Literal(Literal::Int(0)),
375
1
            )),
376
1
        );
377
1
        assert_eq!(eval_constant(&expr), None);
378
1
    }
379
}