1
use std::collections::HashSet;
2

            
3
use conjure_core::ast::{Atom, Expression as Expr, Literal as Lit};
4
use conjure_core::metadata::Metadata;
5
use conjure_core::rule_engine::{
6
    register_rule, register_rule_set, ApplicationError, ApplicationError::RuleNotApplicable,
7
    ApplicationResult, Reduction,
8
};
9
use itertools::izip;
10

            
11
use crate::ast::SymbolTable;
12

            
13
register_rule_set!("Constant", ());
14

            
15
#[register_rule(("Constant", 9001))]
16
504628
fn apply_eval_constant(expr: &Expr, _: &SymbolTable) -> ApplicationResult {
17
249322
    if let Expr::Atomic(_, Atom::Literal(_)) = expr {
18
99977
        return Err(ApplicationError::RuleNotApplicable);
19
404651
    }
20
404651
    eval_constant(expr)
21
404651
        .map(|c| Reduction::pure(Expr::Atomic(Metadata::new(), Atom::Literal(c))))
22
404651
        .ok_or(ApplicationError::RuleNotApplicable)
23
504628
}
24

            
25
/// Simplify an expression to a constant if possible
26
/// Returns:
27
/// `None` if the expression cannot be simplified to a constant (e.g. if it contains a variable)
28
/// `Some(Const)` if the expression can be simplified to a constant
29
892742
pub fn eval_constant(expr: &Expr) -> Option<Lit> {
30
470324
    match expr {
31
91921
        Expr::Atomic(_, Atom::Literal(c)) => Some(c.clone()),
32
378403
        Expr::Atomic(_, Atom::Reference(_c)) => None,
33
3587
        Expr::Abs(_, e) => un_op::<i32, i32>(|a| a.abs(), e).map(Lit::Int),
34
62033
        Expr::Eq(_, a, b) => bin_op::<i32, bool>(|a, b| a == b, a, b)
35
62033
            .or_else(|| bin_op::<bool, bool>(|a, b| a == b, a, b))
36
62033
            .map(Lit::Bool),
37
13685
        Expr::Neq(_, a, b) => bin_op::<i32, bool>(|a, b| a != b, a, b).map(Lit::Bool),
38
3791
        Expr::Lt(_, a, b) => bin_op::<i32, bool>(|a, b| a < b, a, b).map(Lit::Bool),
39
34
        Expr::Gt(_, a, b) => bin_op::<i32, bool>(|a, b| a > b, a, b).map(Lit::Bool),
40
40443
        Expr::Leq(_, a, b) => bin_op::<i32, bool>(|a, b| a <= b, a, b).map(Lit::Bool),
41
2363
        Expr::Geq(_, a, b) => bin_op::<i32, bool>(|a, b| a >= b, a, b).map(Lit::Bool),
42

            
43
8109
        Expr::Not(_, expr) => un_op::<bool, bool>(|e| !e, expr).map(Lit::Bool),
44

            
45
12376
        Expr::And(_, exprs) => vec_op::<bool, bool>(|e| e.iter().all(|&e| e), exprs).map(Lit::Bool),
46
        // this is done elsewhere instead - root should return a new root with a literal inside it,
47
        // not a literal
48
13838
        Expr::Root(_, _) => None,
49
23205
        Expr::Or(_, exprs) => vec_op::<bool, bool>(|e| e.iter().any(|&e| e), exprs).map(Lit::Bool),
50
43129
        Expr::Imply(_, box1, box2) => {
51
43129
            let a: &Atom = (&**box1).try_into().ok()?;
52
2159
            let b: &Atom = (&**box2).try_into().ok()?;
53

            
54
1241
            let a: bool = a.try_into().ok()?;
55
            let b: bool = b.try_into().ok()?;
56

            
57
            if a {
58
                // true -> b ~> b
59
                Some(Lit::Bool(b))
60
            } else {
61
                // false -> b ~> true
62
                Some(Lit::Bool(true))
63
            }
64
        }
65

            
66
11934
        Expr::Sum(_, exprs) => vec_op::<i32, i32>(|e| e.iter().sum(), exprs).map(Lit::Int),
67
5763
        Expr::Product(_, exprs) => vec_op::<i32, i32>(|e| e.iter().product(), exprs).map(Lit::Int),
68

            
69
6562
        Expr::FlatIneq(_, a, b, c) => {
70
6562
            let a: i32 = a.try_into().ok()?;
71
731
            let b: i32 = b.try_into().ok()?;
72
            let c: i32 = c.try_into().ok()?;
73

            
74
            Some(Lit::Bool(a <= b + c))
75
        }
76

            
77
1666
        Expr::FlatSumGeq(_, exprs, a) => {
78
1683
            let sum = exprs.iter().try_fold(0, |acc, atom: &Atom| {
79
1683
                let n: i32 = atom.try_into().ok()?;
80
17
                let acc = acc + n;
81
17
                Some(acc)
82
1683
            })?;
83

            
84
            Some(Lit::Bool(sum >= a.try_into().ok()?))
85
        }
86
3689
        Expr::FlatSumLeq(_, exprs, a) => {
87
3706
            let sum = exprs.iter().try_fold(0, |acc, atom: &Atom| {
88
3706
                let n: i32 = atom.try_into().ok()?;
89
17
                let acc = acc + n;
90
17
                Some(acc)
91
3706
            })?;
92

            
93
            Some(Lit::Bool(sum >= a.try_into().ok()?))
94
        }
95
272
        Expr::Min(_, exprs) => {
96
272
            opt_vec_op::<i32, i32>(|e| e.iter().min().copied(), exprs).map(Lit::Int)
97
        }
98
153
        Expr::Max(_, exprs) => {
99
153
            opt_vec_op::<i32, i32>(|e| e.iter().max().copied(), exprs).map(Lit::Int)
100
        }
101
31485
        Expr::UnsafeDiv(_, a, b) | Expr::SafeDiv(_, a, b) => {
102
48571
            if unwrap_expr::<i32>(b)? == 0 {
103
2
                return None;
104
35241
            }
105
35241
            bin_op::<i32, i32>(|a, b| ((a as f32) / (b as f32)).floor() as i32, a, b).map(Lit::Int)
106
        }
107
42857
        Expr::UnsafeMod(_, a, b) | Expr::SafeMod(_, a, b) => {
108
58905
            if unwrap_expr::<i32>(b)? == 0 {
109
                return None;
110
48773
            }
111
48773
            bin_op::<i32, i32>(|a, b| a - b * (a as f32 / b as f32).floor() as i32, a, b)
112
48773
                .map(Lit::Int)
113
        }
114
3417
        Expr::MinionDivEqUndefZero(_, a, b, c) => {
115
            // div always rounds down
116
3417
            let a: i32 = a.try_into().ok()?;
117
34
            let b: i32 = b.try_into().ok()?;
118
            let c: i32 = c.try_into().ok()?;
119

            
120
            if b == 0 {
121
                return None;
122
            }
123

            
124
            let a = a as f32;
125
            let b = b as f32;
126
            let div: i32 = (a / b).floor() as i32;
127
            Some(Lit::Bool(div == c))
128
        }
129
5797
        Expr::Bubble(_, a, b) => bin_op::<bool, bool>(|a, b| a && b, a, b).map(Lit::Bool),
130

            
131
12342
        Expr::MinionReify(_, a, b) => {
132
12342
            let result = eval_constant(a)?;
133

            
134
34
            let result: bool = result.try_into().ok()?;
135
34
            let b: bool = b.try_into().ok()?;
136

            
137
            Some(Lit::Bool(b == result))
138
        }
139

            
140
15045
        Expr::MinionReifyImply(_, a, b) => {
141
15045
            let result = eval_constant(a)?;
142

            
143
            let result: bool = result.try_into().ok()?;
144
            let b: bool = b.try_into().ok()?;
145

            
146
            if b {
147
                Some(Lit::Bool(result))
148
            } else {
149
                Some(Lit::Bool(true))
150
            }
151
        }
152
10030
        Expr::MinionModuloEqUndefZero(_, a, b, c) => {
153
            // From Savile Row. Same semantics as division.
154
            //
155
            //   a - (b * floor(a/b))
156
            //
157
            // We don't use % as it has the same semantics as /. We don't use / as we want to round
158
            // down instead, not towards zero.
159

            
160
10030
            let a: i32 = a.try_into().ok()?;
161
34
            let b: i32 = b.try_into().ok()?;
162
            let c: i32 = c.try_into().ok()?;
163

            
164
            if b == 0 {
165
                return None;
166
            }
167

            
168
            let modulo = a - b * (a as f32 / b as f32).floor() as i32;
169
            Some(Lit::Bool(modulo == c))
170
        }
171

            
172
731
        Expr::MinionPow(_, a, b, c) => {
173
            // only available for positive a b c
174

            
175
731
            let a: i32 = a.try_into().ok()?;
176
            let b: i32 = b.try_into().ok()?;
177
            let c: i32 = c.try_into().ok()?;
178

            
179
            if a <= 0 {
180
                return None;
181
            }
182

            
183
            if b <= 0 {
184
                return None;
185
            }
186

            
187
            if c <= 0 {
188
                return None;
189
            }
190

            
191
            Some(Lit::Bool(a ^ b == c))
192
        }
193

            
194
34
        Expr::AllDiff(_, es) => {
195
34
            let mut lits: HashSet<Lit> = HashSet::new();
196
136
            for expr in es {
197
102
                let Expr::Atomic(_, Atom::Literal(x)) = expr else {
198
                    return None;
199
                };
200
102
                if lits.contains(x) {
201
                    return Some(Lit::Bool(false));
202
102
                } else {
203
102
                    lits.insert(x.clone());
204
102
                }
205
            }
206
34
            Some(Lit::Bool(true))
207
        }
208
731
        Expr::FlatWatchedLiteral(_, _, _) => None,
209
2414
        Expr::AuxDeclaration(_, _, _) => None,
210
3366
        Expr::Neg(_, a) => {
211
3366
            let a: &Atom = a.try_into().ok()?;
212
2312
            let a: i32 = a.try_into().ok()?;
213
731
            Some(Lit::Int(-a))
214
        }
215
255
        Expr::Minus(_, a, b) => {
216
255
            let a: &Atom = a.try_into().ok()?;
217
136
            let a: i32 = a.try_into().ok()?;
218

            
219
            let b: &Atom = b.try_into().ok()?;
220
            let b: i32 = b.try_into().ok()?;
221

            
222
            Some(Lit::Int(a - b))
223
        }
224
187
        Expr::FlatMinusEq(_, a, b) => {
225
187
            let a: i32 = a.try_into().ok()?;
226
            let b: i32 = b.try_into().ok()?;
227
            Some(Lit::Bool(a == -b))
228
        }
229
170
        Expr::FlatProductEq(_, a, b, c) => {
230
170
            let a: i32 = a.try_into().ok()?;
231
            let b: i32 = b.try_into().ok()?;
232
            let c: i32 = c.try_into().ok()?;
233
            Some(Lit::Bool(a * b == c))
234
        }
235
595
        Expr::FlatWeightedSumLeq(_, cs, vs, total) => {
236
595
            let cs: Vec<i32> = cs
237
595
                .iter()
238
2839
                .map(|x| TryInto::<i32>::try_into(x).ok())
239
595
                .collect::<Option<Vec<i32>>>()?;
240
595
            let vs: Vec<i32> = vs
241
595
                .iter()
242
595
                .map(|x| TryInto::<i32>::try_into(x).ok())
243
595
                .collect::<Option<Vec<i32>>>()?;
244
            let total: i32 = total.try_into().ok()?;
245

            
246
            let sum: i32 = izip!(cs, vs).fold(0, |acc, (c, v)| acc + (c * v));
247

            
248
            Some(Lit::Bool(sum <= total))
249
        }
250

            
251
170
        Expr::FlatWeightedSumGeq(_, cs, vs, total) => {
252
170
            let cs: Vec<i32> = cs
253
170
                .iter()
254
510
                .map(|x| TryInto::<i32>::try_into(x).ok())
255
170
                .collect::<Option<Vec<i32>>>()?;
256
170
            let vs: Vec<i32> = vs
257
170
                .iter()
258
170
                .map(|x| TryInto::<i32>::try_into(x).ok())
259
170
                .collect::<Option<Vec<i32>>>()?;
260
            let total: i32 = total.try_into().ok()?;
261

            
262
            let sum: i32 = izip!(cs, vs).fold(0, |acc, (c, v)| acc + (c * v));
263

            
264
            Some(Lit::Bool(sum >= total))
265
        }
266
289
        Expr::FlatAbsEq(_, x, y) => {
267
289
            let x: i32 = x.try_into().ok()?;
268
17
            let y: i32 = y.try_into().ok()?;
269

            
270
            Some(Lit::Bool(x == y.abs()))
271
        }
272

            
273
1853
        Expr::UnsafePow(_, a, b) | Expr::SafePow(_, a, b) => {
274
2737
            let a: &Atom = a.try_into().ok()?;
275
1734
            let a: i32 = a.try_into().ok()?;
276

            
277
            let b: &Atom = b.try_into().ok()?;
278
            let b: i32 = b.try_into().ok()?;
279

            
280
            if (a != 0 || b != 0) && b >= 0 {
281
                Some(Lit::Int(a ^ b))
282
            } else {
283
                None
284
            }
285
        }
286
    }
287
892742
}
288

            
289
/// Evaluate the root expression.
290
///
291
/// This returns either Expr::Root([true]) or Expr::Root([false]).
292
#[register_rule(("Constant", 9001))]
293
504628
fn eval_root(expr: &Expr, _: &SymbolTable) -> ApplicationResult {
294
    // this is its own rule not part of apply_eval_constant, because root should return a new root
295
    // with a literal inside it, not just a literal
296

            
297
504628
    let Expr::Root(_, exprs) = expr else {
298
490790
        return Err(RuleNotApplicable);
299
    };
300

            
301
13838
    match exprs.len() {
302
34
        0 => Ok(Reduction::pure(Expr::Root(
303
34
            Metadata::new(),
304
34
            vec![true.into()],
305
34
        ))),
306
7089
        1 => Err(RuleNotApplicable),
307
        _ => {
308
34
            let lit =
309
6715
                vec_op::<bool, bool>(|e| e.iter().all(|&e| e), exprs).ok_or(RuleNotApplicable)?;
310

            
311
34
            Ok(Reduction::pure(Expr::Root(
312
34
                Metadata::new(),
313
34
                vec![lit.into()],
314
34
            )))
315
        }
316
    }
317
504628
}
318

            
319
11696
fn un_op<T, A>(f: fn(T) -> A, a: &Expr) -> Option<A>
320
11696
where
321
11696
    T: TryFrom<Lit>,
322
11696
{
323
11696
    let a = unwrap_expr::<T>(a)?;
324
17
    Some(f(a))
325
11696
}
326

            
327
273904
fn bin_op<T, A>(f: fn(T, T) -> A, a: &Expr, b: &Expr) -> Option<A>
328
273904
where
329
273904
    T: TryFrom<Lit>,
330
273904
{
331
273904
    let a = unwrap_expr::<T>(a)?;
332
1139
    let b = unwrap_expr::<T>(b)?;
333
850
    Some(f(a, b))
334
273904
}
335

            
336
#[allow(dead_code)]
337
fn tern_op<T, A>(f: fn(T, T, T) -> A, a: &Expr, b: &Expr, c: &Expr) -> Option<A>
338
where
339
    T: TryFrom<Lit>,
340
{
341
    let a = unwrap_expr::<T>(a)?;
342
    let b = unwrap_expr::<T>(b)?;
343
    let c = unwrap_expr::<T>(c)?;
344
    Some(f(a, b, c))
345
}
346

            
347
59993
fn vec_op<T, A>(f: fn(Vec<T>) -> A, a: &[Expr]) -> Option<A>
348
59993
where
349
59993
    T: TryFrom<Lit>,
350
59993
{
351
59993
    let a = a.iter().map(unwrap_expr).collect::<Option<Vec<T>>>()?;
352
782
    Some(f(a))
353
59993
}
354

            
355
425
fn opt_vec_op<T, A>(f: fn(Vec<T>) -> Option<A>, a: &[Expr]) -> Option<A>
356
425
where
357
425
    T: TryFrom<Lit>,
358
425
{
359
425
    let a = a.iter().map(unwrap_expr).collect::<Option<Vec<T>>>()?;
360
51
    f(a)
361
425
}
362

            
363
#[allow(dead_code)]
364
fn flat_op<T, A>(f: fn(Vec<T>, T) -> A, a: &[Expr], b: &Expr) -> Option<A>
365
where
366
    T: TryFrom<Lit>,
367
{
368
    let a = a.iter().map(unwrap_expr).collect::<Option<Vec<T>>>()?;
369
    let b = unwrap_expr::<T>(b)?;
370
    Some(f(a, b))
371
}
372

            
373
460515
fn unwrap_expr<T: TryFrom<Lit>>(expr: &Expr) -> Option<T> {
374
460515
    let c = eval_constant(expr)?;
375
93077
    TryInto::<T>::try_into(c).ok()
376
460515
}
377

            
378
#[cfg(test)]
379
mod tests {
380
    use crate::rules::eval_constant;
381
    use conjure_core::ast::{Atom, Expression, Literal};
382

            
383
    #[test]
384
1
    fn div_by_zero() {
385
1
        let expr = Expression::UnsafeDiv(
386
1
            Default::default(),
387
1
            Box::new(Expression::Atomic(
388
1
                Default::default(),
389
1
                Atom::Literal(Literal::Int(1)),
390
1
            )),
391
1
            Box::new(Expression::Atomic(
392
1
                Default::default(),
393
1
                Atom::Literal(Literal::Int(0)),
394
1
            )),
395
1
        );
396
1
        assert_eq!(eval_constant(&expr), None);
397
1
    }
398

            
399
    #[test]
400
1
    fn safediv_by_zero() {
401
1
        let expr = Expression::SafeDiv(
402
1
            Default::default(),
403
1
            Box::new(Expression::Atomic(
404
1
                Default::default(),
405
1
                Atom::Literal(Literal::Int(1)),
406
1
            )),
407
1
            Box::new(Expression::Atomic(
408
1
                Default::default(),
409
1
                Atom::Literal(Literal::Int(0)),
410
1
            )),
411
1
        );
412
1
        assert_eq!(eval_constant(&expr), None);
413
1
    }
414
}