1
#![allow(dead_code)]
2
use std::collections::HashSet;
3

            
4
use conjure_core::ast::{Atom, Expression as Expr, Literal as Lit};
5
use conjure_core::metadata::Metadata;
6
use conjure_core::rule_engine::{
7
    register_rule, register_rule_set, ApplicationError, ApplicationError::RuleNotApplicable,
8
    ApplicationResult, Reduction,
9
};
10
use itertools::izip;
11

            
12
use crate::ast::SymbolTable;
13

            
14
register_rule_set!("Constant", ());
15

            
16
#[register_rule(("Constant", 9001))]
17
747954
fn apply_eval_constant(expr: &Expr, _: &SymbolTable) -> ApplicationResult {
18
346230
    if let Expr::Atomic(_, Atom::Literal(_)) = expr {
19
145422
        return Err(ApplicationError::RuleNotApplicable);
20
602532
    }
21
602532
    eval_constant(expr)
22
602532
        .map(|c| Reduction::pure(Expr::Atomic(Metadata::new(), Atom::Literal(c))))
23
602532
        .ok_or(ApplicationError::RuleNotApplicable)
24
747954
}
25

            
26
/// Simplify an expression to a constant if possible
27
/// Returns:
28
/// `None` if the expression cannot be simplified to a constant (e.g. if it contains a variable)
29
/// `Some(Const)` if the expression can be simplified to a constant
30
1132564
pub fn eval_constant(expr: &Expr) -> Option<Lit> {
31
542486
    match expr {
32
40878
        Expr::AbstractLiteral(_, _) => None,
33
        // `fromSolution()` pulls a literal value from last found solution
34
        Expr::FromSolution(_, _) => None,
35
        // Same as Expr::Root, we should not replace the dominance relation with a constant
36
        Expr::DominanceRelation(_, _) => None,
37
14490
        Expr::UnsafeIndex(_, _, _) => None,
38
        // handled elsewhere
39
15930
        Expr::SafeIndex(_, _, _) => None,
40
3636
        Expr::UnsafeSlice(_, _, _) => None,
41
        // handled elsewhere
42
10386
        Expr::SafeSlice(_, _, _) => None,
43
630
        Expr::InDomain(_, e, domain) => {
44
630
            let Expr::Atomic(_, Atom::Literal(lit)) = e.as_ref() else {
45
                return None;
46
            };
47

            
48
630
            domain.contains(lit).map(Into::into)
49
        }
50
97940
        Expr::Atomic(_, Atom::Literal(c)) => Some(c.clone()),
51
444546
        Expr::Atomic(_, Atom::Reference(_c)) => None,
52
3798
        Expr::Abs(_, e) => un_op::<i32, i32>(|a| a.abs(), e).map(Lit::Int),
53
74556
        Expr::Eq(_, a, b) => bin_op::<i32, bool>(|a, b| a == b, a, b)
54
74556
            .or_else(|| bin_op::<bool, bool>(|a, b| a == b, a, b))
55
74556
            .map(Lit::Bool),
56
14490
        Expr::Neq(_, a, b) => bin_op::<i32, bool>(|a, b| a != b, a, b).map(Lit::Bool),
57
3114
        Expr::Lt(_, a, b) => bin_op::<i32, bool>(|a, b| a < b, a, b).map(Lit::Bool),
58
54
        Expr::Gt(_, a, b) => bin_op::<i32, bool>(|a, b| a > b, a, b).map(Lit::Bool),
59
43632
        Expr::Leq(_, a, b) => bin_op::<i32, bool>(|a, b| a <= b, a, b).map(Lit::Bool),
60
2466
        Expr::Geq(_, a, b) => bin_op::<i32, bool>(|a, b| a >= b, a, b).map(Lit::Bool),
61

            
62
13212
        Expr::Not(_, expr) => un_op::<bool, bool>(|e| !e, expr).map(Lit::Bool),
63

            
64
15714
        Expr::And(_, e) => {
65
15714
            vec_lit_op::<bool, bool>(|e| e.iter().all(|&e| e), e.as_ref()).map(Lit::Bool)
66
        }
67
        // this is done elsewhere instead - root should return a new root with a literal inside it,
68
        // not a literal
69
19602
        Expr::Root(_, _) => None,
70
33984
        Expr::Or(_, e) => {
71
33984
            vec_lit_op::<bool, bool>(|e| e.iter().any(|&e| e), e.as_ref()).map(Lit::Bool)
72
        }
73
37260
        Expr::Imply(_, box1, box2) => {
74
37260
            let a: &Atom = (&**box1).try_into().ok()?;
75
9162
            let b: &Atom = (&**box2).try_into().ok()?;
76

            
77
6444
            let a: bool = a.try_into().ok()?;
78
            let b: bool = b.try_into().ok()?;
79

            
80
            if a {
81
                // true -> b ~> b
82
                Some(Lit::Bool(b))
83
            } else {
84
                // false -> b ~> true
85
                Some(Lit::Bool(true))
86
            }
87
        }
88

            
89
13230
        Expr::Sum(_, exprs) => vec_op::<i32, i32>(|e| e.iter().sum(), exprs).map(Lit::Int),
90
6102
        Expr::Product(_, exprs) => vec_op::<i32, i32>(|e| e.iter().product(), exprs).map(Lit::Int),
91

            
92
15786
        Expr::FlatIneq(_, a, b, c) => {
93
15786
            let a: i32 = a.try_into().ok()?;
94
882
            let b: i32 = b.try_into().ok()?;
95
            let c: i32 = c.try_into().ok()?;
96

            
97
            Some(Lit::Bool(a <= b + c))
98
        }
99

            
100
1854
        Expr::FlatSumGeq(_, exprs, a) => {
101
1872
            let sum = exprs.iter().try_fold(0, |acc, atom: &Atom| {
102
1872
                let n: i32 = atom.try_into().ok()?;
103
18
                let acc = acc + n;
104
18
                Some(acc)
105
1872
            })?;
106

            
107
            Some(Lit::Bool(sum >= a.try_into().ok()?))
108
        }
109
4032
        Expr::FlatSumLeq(_, exprs, a) => {
110
4050
            let sum = exprs.iter().try_fold(0, |acc, atom: &Atom| {
111
4050
                let n: i32 = atom.try_into().ok()?;
112
18
                let acc = acc + n;
113
18
                Some(acc)
114
4050
            })?;
115

            
116
            Some(Lit::Bool(sum >= a.try_into().ok()?))
117
        }
118
576
        Expr::Min(_, e) => {
119
576
            opt_vec_lit_op::<i32, i32>(|e| e.iter().min().copied(), e.as_ref()).map(Lit::Int)
120
        }
121
432
        Expr::Max(_, e) => {
122
432
            opt_vec_lit_op::<i32, i32>(|e| e.iter().max().copied(), e.as_ref()).map(Lit::Int)
123
        }
124
34201
        Expr::UnsafeDiv(_, a, b) | Expr::SafeDiv(_, a, b) => {
125
51428
            if unwrap_expr::<i32>(b)? == 0 {
126
2
                return None;
127
37314
            }
128
37314
            bin_op::<i32, i32>(|a, b| ((a as f32) / (b as f32)).floor() as i32, a, b).map(Lit::Int)
129
        }
130
46602
        Expr::UnsafeMod(_, a, b) | Expr::SafeMod(_, a, b) => {
131
62226
            if unwrap_expr::<i32>(b)? == 0 {
132
                return None;
133
51498
            }
134
51498
            bin_op::<i32, i32>(|a, b| a - b * (a as f32 / b as f32).floor() as i32, a, b)
135
51498
                .map(Lit::Int)
136
        }
137
4914
        Expr::MinionDivEqUndefZero(_, a, b, c) => {
138
            // div always rounds down
139
4914
            let a: i32 = a.try_into().ok()?;
140
36
            let b: i32 = b.try_into().ok()?;
141
            let c: i32 = c.try_into().ok()?;
142

            
143
            if b == 0 {
144
                return None;
145
            }
146

            
147
            let a = a as f32;
148
            let b = b as f32;
149
            let div: i32 = (a / b).floor() as i32;
150
            Some(Lit::Bool(div == c))
151
        }
152
9306
        Expr::Bubble(_, a, b) => bin_op::<bool, bool>(|a, b| a && b, a, b).map(Lit::Bool),
153

            
154
15714
        Expr::MinionReify(_, a, b) => {
155
15714
            let result = eval_constant(a)?;
156

            
157
36
            let result: bool = result.try_into().ok()?;
158
36
            let b: bool = b.try_into().ok()?;
159

            
160
            Some(Lit::Bool(b == result))
161
        }
162

            
163
13464
        Expr::MinionReifyImply(_, a, b) => {
164
13464
            let result = eval_constant(a)?;
165

            
166
            let result: bool = result.try_into().ok()?;
167
            let b: bool = b.try_into().ok()?;
168

            
169
            if b {
170
                Some(Lit::Bool(result))
171
            } else {
172
                Some(Lit::Bool(true))
173
            }
174
        }
175
10854
        Expr::MinionModuloEqUndefZero(_, a, b, c) => {
176
            // From Savile Row. Same semantics as division.
177
            //
178
            //   a - (b * floor(a/b))
179
            //
180
            // We don't use % as it has the same semantics as /. We don't use / as we want to round
181
            // down instead, not towards zero.
182

            
183
10854
            let a: i32 = a.try_into().ok()?;
184
36
            let b: i32 = b.try_into().ok()?;
185
            let c: i32 = c.try_into().ok()?;
186

            
187
            if b == 0 {
188
                return None;
189
            }
190

            
191
            let modulo = a - b * (a as f32 / b as f32).floor() as i32;
192
            Some(Lit::Bool(modulo == c))
193
        }
194

            
195
774
        Expr::MinionPow(_, a, b, c) => {
196
            // only available for positive a b c
197

            
198
774
            let a: i32 = a.try_into().ok()?;
199
            let b: i32 = b.try_into().ok()?;
200
            let c: i32 = c.try_into().ok()?;
201

            
202
            if a <= 0 {
203
                return None;
204
            }
205

            
206
            if b <= 0 {
207
                return None;
208
            }
209

            
210
            if c <= 0 {
211
                return None;
212
            }
213

            
214
            Some(Lit::Bool(a ^ b == c))
215
        }
216

            
217
18108
        Expr::AllDiff(_, e) => {
218
18108
            let es = e.clone().unwrap_list()?;
219
90
            let mut lits: HashSet<Lit> = HashSet::new();
220
90
            for expr in es {
221
90
                let Expr::Atomic(_, Atom::Literal(x)) = expr else {
222
90
                    return None;
223
                };
224
                match x {
225
                    Lit::Int(_) | Lit::Bool(_) => {
226
                        if lits.contains(&x) {
227
                            return Some(Lit::Bool(false));
228
                        } else {
229
                            lits.insert(x.clone());
230
                        }
231
                    }
232
                    Lit::AbstractLiteral(_) => return None, // Reject AbstractLiteral cases
233
                }
234
            }
235
            Some(Lit::Bool(true))
236
        }
237
90
        Expr::FlatAllDiff(_, es) => {
238
90
            let mut lits: HashSet<Lit> = HashSet::new();
239
90
            for atom in es {
240
90
                let Atom::Literal(x) = atom else {
241
90
                    return None;
242
                };
243

            
244
                match x {
245
                    Lit::Int(_) | Lit::Bool(_) => {
246
                        if lits.contains(x) {
247
                            return Some(Lit::Bool(false));
248
                        } else {
249
                            lits.insert(x.clone());
250
                        }
251
                    }
252
                    Lit::AbstractLiteral(_) => return None, // Reject AbstractLiteral cases
253
                }
254
            }
255
            Some(Lit::Bool(true))
256
        }
257
2430
        Expr::FlatWatchedLiteral(_, _, _) => None,
258
2700
        Expr::AuxDeclaration(_, _, _) => None,
259
3564
        Expr::Neg(_, a) => {
260
3564
            let a: &Atom = a.try_into().ok()?;
261
2448
            let a: i32 = a.try_into().ok()?;
262
774
            Some(Lit::Int(-a))
263
        }
264
270
        Expr::Minus(_, a, b) => {
265
270
            let a: &Atom = a.try_into().ok()?;
266
144
            let a: i32 = a.try_into().ok()?;
267

            
268
            let b: &Atom = b.try_into().ok()?;
269
            let b: i32 = b.try_into().ok()?;
270

            
271
            Some(Lit::Int(a - b))
272
        }
273
198
        Expr::FlatMinusEq(_, a, b) => {
274
198
            let a: i32 = a.try_into().ok()?;
275
            let b: i32 = b.try_into().ok()?;
276
            Some(Lit::Bool(a == -b))
277
        }
278
180
        Expr::FlatProductEq(_, a, b, c) => {
279
180
            let a: i32 = a.try_into().ok()?;
280
            let b: i32 = b.try_into().ok()?;
281
            let c: i32 = c.try_into().ok()?;
282
            Some(Lit::Bool(a * b == c))
283
        }
284
630
        Expr::FlatWeightedSumLeq(_, cs, vs, total) => {
285
630
            let cs: Vec<i32> = cs
286
630
                .iter()
287
3006
                .map(|x| TryInto::<i32>::try_into(x).ok())
288
630
                .collect::<Option<Vec<i32>>>()?;
289
630
            let vs: Vec<i32> = vs
290
630
                .iter()
291
630
                .map(|x| TryInto::<i32>::try_into(x).ok())
292
630
                .collect::<Option<Vec<i32>>>()?;
293
            let total: i32 = total.try_into().ok()?;
294

            
295
            let sum: i32 = izip!(cs, vs).fold(0, |acc, (c, v)| acc + (c * v));
296

            
297
            Some(Lit::Bool(sum <= total))
298
        }
299

            
300
180
        Expr::FlatWeightedSumGeq(_, cs, vs, total) => {
301
180
            let cs: Vec<i32> = cs
302
180
                .iter()
303
540
                .map(|x| TryInto::<i32>::try_into(x).ok())
304
180
                .collect::<Option<Vec<i32>>>()?;
305
180
            let vs: Vec<i32> = vs
306
180
                .iter()
307
180
                .map(|x| TryInto::<i32>::try_into(x).ok())
308
180
                .collect::<Option<Vec<i32>>>()?;
309
            let total: i32 = total.try_into().ok()?;
310

            
311
            let sum: i32 = izip!(cs, vs).fold(0, |acc, (c, v)| acc + (c * v));
312

            
313
            Some(Lit::Bool(sum >= total))
314
        }
315
306
        Expr::FlatAbsEq(_, x, y) => {
316
306
            let x: i32 = x.try_into().ok()?;
317
18
            let y: i32 = y.try_into().ok()?;
318

            
319
            Some(Lit::Bool(x == y.abs()))
320
        }
321

            
322
1962
        Expr::UnsafePow(_, a, b) | Expr::SafePow(_, a, b) => {
323
2898
            let a: &Atom = a.try_into().ok()?;
324
1836
            let a: i32 = a.try_into().ok()?;
325

            
326
            let b: &Atom = b.try_into().ok()?;
327
            let b: i32 = b.try_into().ok()?;
328

            
329
            if (a != 0 || b != 0) && b >= 0 {
330
                Some(Lit::Int(a ^ b))
331
            } else {
332
                None
333
            }
334
        }
335
        Expr::Scope(_, _) => None,
336
    }
337
1132564
}
338

            
339
/// Evaluate the root expression.
340
///
341
/// This returns either Expr::Root([true]) or Expr::Root([false]).
342
#[register_rule(("Constant", 9001))]
343
747954
fn eval_root(expr: &Expr, _: &SymbolTable) -> ApplicationResult {
344
    // this is its own rule not part of apply_eval_constant, because root should return a new root
345
    // with a literal inside it, not just a literal
346

            
347
747954
    let Expr::Root(_, exprs) = expr else {
348
728352
        return Err(RuleNotApplicable);
349
    };
350

            
351
19602
    match exprs.len() {
352
36
        0 => Ok(Reduction::pure(Expr::Root(
353
36
            Metadata::new(),
354
36
            vec![true.into()],
355
36
        ))),
356
7956
        1 => Err(RuleNotApplicable),
357
        _ => {
358
36
            let lit =
359
11610
                vec_op::<bool, bool>(|e| e.iter().all(|&e| e), exprs).ok_or(RuleNotApplicable)?;
360

            
361
36
            Ok(Reduction::pure(Expr::Root(
362
36
                Metadata::new(),
363
36
                vec![lit.into()],
364
36
            )))
365
        }
366
    }
367
747954
}
368

            
369
17010
fn un_op<T, A>(f: fn(T) -> A, a: &Expr) -> Option<A>
370
17010
where
371
17010
    T: TryFrom<Lit>,
372
17010
{
373
17010
    let a = unwrap_expr::<T>(a)?;
374
18
    Some(f(a))
375
17010
}
376

            
377
310662
fn bin_op<T, A>(f: fn(T, T) -> A, a: &Expr, b: &Expr) -> Option<A>
378
310662
where
379
310662
    T: TryFrom<Lit>,
380
310662
{
381
310662
    let a = unwrap_expr::<T>(a)?;
382
1260
    let b = unwrap_expr::<T>(b)?;
383
936
    Some(f(a, b))
384
310662
}
385

            
386
#[allow(dead_code)]
387
fn tern_op<T, A>(f: fn(T, T, T) -> A, a: &Expr, b: &Expr, c: &Expr) -> Option<A>
388
where
389
    T: TryFrom<Lit>,
390
{
391
    let a = unwrap_expr::<T>(a)?;
392
    let b = unwrap_expr::<T>(b)?;
393
    let c = unwrap_expr::<T>(c)?;
394
    Some(f(a, b, c))
395
}
396

            
397
30942
fn vec_op<T, A>(f: fn(Vec<T>) -> A, a: &[Expr]) -> Option<A>
398
30942
where
399
30942
    T: TryFrom<Lit>,
400
30942
{
401
30942
    let a = a.iter().map(unwrap_expr).collect::<Option<Vec<T>>>()?;
402
396
    Some(f(a))
403
30942
}
404

            
405
49698
fn vec_lit_op<T, A>(f: fn(Vec<T>) -> A, a: &Expr) -> Option<A>
406
49698
where
407
49698
    T: TryFrom<Lit>,
408
49698
{
409
49698
    let a = a.clone().unwrap_list()?;
410
20178
    let a = a.iter().map(unwrap_expr).collect::<Option<Vec<T>>>()?;
411
1476
    Some(f(a))
412
49698
}
413

            
414
fn opt_vec_op<T, A>(f: fn(Vec<T>) -> Option<A>, a: &[Expr]) -> Option<A>
415
where
416
    T: TryFrom<Lit>,
417
{
418
    let a = a.iter().map(unwrap_expr).collect::<Option<Vec<T>>>()?;
419
    f(a)
420
}
421

            
422
1008
fn opt_vec_lit_op<T, A>(f: fn(Vec<T>) -> Option<A>, a: &Expr) -> Option<A>
423
1008
where
424
1008
    T: TryFrom<Lit>,
425
1008
{
426
1008
    let a = a.clone().unwrap_list()?;
427
    // FIXME: deal with explicit matrix domains
428
414
    let a = a.iter().map(unwrap_expr).collect::<Option<Vec<T>>>()?;
429
54
    f(a)
430
1008
}
431

            
432
#[allow(dead_code)]
433
fn flat_op<T, A>(f: fn(Vec<T>, T) -> A, a: &[Expr], b: &Expr) -> Option<A>
434
where
435
    T: TryFrom<Lit>,
436
{
437
    let a = a.iter().map(unwrap_expr).collect::<Option<Vec<T>>>()?;
438
    let b = unwrap_expr::<T>(b)?;
439
    Some(f(a, b))
440
}
441

            
442
500456
fn unwrap_expr<T: TryFrom<Lit>>(expr: &Expr) -> Option<T> {
443
500456
    let c = eval_constant(expr)?;
444
99722
    TryInto::<T>::try_into(c).ok()
445
500456
}
446

            
447
#[cfg(test)]
448
mod tests {
449
    use crate::rules::eval_constant;
450
    use conjure_core::ast::{Atom, Expression, Literal};
451

            
452
    #[test]
453
1
    fn div_by_zero() {
454
1
        let expr = Expression::UnsafeDiv(
455
1
            Default::default(),
456
1
            Box::new(Expression::Atomic(
457
1
                Default::default(),
458
1
                Atom::Literal(Literal::Int(1)),
459
1
            )),
460
1
            Box::new(Expression::Atomic(
461
1
                Default::default(),
462
1
                Atom::Literal(Literal::Int(0)),
463
1
            )),
464
1
        );
465
1
        assert_eq!(eval_constant(&expr), None);
466
1
    }
467

            
468
    #[test]
469
1
    fn safediv_by_zero() {
470
1
        let expr = Expression::SafeDiv(
471
1
            Default::default(),
472
1
            Box::new(Expression::Atomic(
473
1
                Default::default(),
474
1
                Atom::Literal(Literal::Int(1)),
475
1
            )),
476
1
            Box::new(Expression::Atomic(
477
1
                Default::default(),
478
1
                Atom::Literal(Literal::Int(0)),
479
1
            )),
480
1
        );
481
1
        assert_eq!(eval_constant(&expr), None);
482
1
    }
483
}