1
use conjure_core::ast::{Expression as Expr, Factor, Literal as Lit};
2
use conjure_core::metadata::Metadata;
3
use conjure_core::rule_engine::{
4
    register_rule, register_rule_set, ApplicationError, ApplicationResult, Reduction,
5
};
6
use conjure_core::Model;
7

            
8
register_rule_set!("Constant", 100, ());
9

            
10
#[register_rule(("Constant", 9001))]
11
7250565
fn apply_eval_constant(expr: &Expr, _: &Model) -> ApplicationResult {
12
5412165
    if let Expr::FactorE(_, Factor::Literal(_)) = expr {
13
1432980
        return Err(ApplicationError::RuleNotApplicable);
14
5817585
    }
15
5817585
    eval_constant(expr)
16
5817585
        .map(|c| Reduction::pure(Expr::FactorE(Metadata::new(), Factor::Literal(c))))
17
5817585
        .ok_or(ApplicationError::RuleNotApplicable)
18
7250565
}
19

            
20
/// Simplify an expression to a constant if possible
21
/// Returns:
22
/// `None` if the expression cannot be simplified to a constant (e.g. if it contains a variable)
23
/// `Some(Const)` if the expression can be simplified to a constant
24
8086579
pub fn eval_constant(expr: &Expr) -> Option<Lit> {
25
5820572
    match expr {
26
1427
        Expr::FactorE(_, Factor::Literal(c)) => Some(c.clone()),
27
5819145
        Expr::FactorE(_, Factor::Reference(c)) => None,
28
7440
        Expr::Eq(_, a, b) => bin_op::<i32, bool>(|a, b| a == b, a, b)
29
7440
            .or_else(|| bin_op::<bool, bool>(|a, b| a == b, a, b))
30
7440
            .map(Lit::Bool),
31
240
        Expr::Neq(_, a, b) => bin_op::<i32, bool>(|a, b| a != b, a, b).map(Lit::Bool),
32
4170
        Expr::Lt(_, a, b) => bin_op::<i32, bool>(|a, b| a < b, a, b).map(Lit::Bool),
33
        Expr::Gt(_, a, b) => bin_op::<i32, bool>(|a, b| a > b, a, b).map(Lit::Bool),
34
315
        Expr::Leq(_, a, b) => bin_op::<i32, bool>(|a, b| a <= b, a, b).map(Lit::Bool),
35
240
        Expr::Geq(_, a, b) => bin_op::<i32, bool>(|a, b| a >= b, a, b).map(Lit::Bool),
36

            
37
150
        Expr::Not(_, expr) => un_op::<bool, bool>(|e| !e, expr).map(Lit::Bool),
38

            
39
20130
        Expr::And(_, exprs) => vec_op::<bool, bool>(|e| e.iter().all(|&e| e), exprs).map(Lit::Bool),
40
393405
        Expr::Or(_, exprs) => vec_op::<bool, bool>(|e| e.iter().any(|&e| e), exprs).map(Lit::Bool),
41

            
42
13170
        Expr::Sum(_, exprs) => vec_op::<i32, i32>(|e| e.iter().sum(), exprs).map(Lit::Int),
43

            
44
553995
        Expr::Ineq(_, a, b, c) => {
45
553995
            tern_op::<i32, bool>(|a, b, c| a <= (b + c), a, b, c).map(Lit::Bool)
46
        }
47

            
48
653295
        Expr::SumGeq(_, exprs, a) => {
49
653295
            flat_op::<i32, bool>(|e, a| e.iter().sum::<i32>() >= a, exprs, a).map(Lit::Bool)
50
        }
51
611910
        Expr::SumLeq(_, exprs, a) => {
52
611910
            flat_op::<i32, bool>(|e, a| e.iter().sum::<i32>() <= a, exprs, a).map(Lit::Bool)
53
        }
54
        // Expr::Div(_, a, b) => bin_op::<i32, i32>(|a, b| a / b, a, b).map(Lit::Int),
55
        // Expr::SafeDiv(_, a, b) => bin_op::<i32, i32>(|a, b| a / b, a, b).map(Lit::Int),
56
330
        Expr::Min(_, exprs) => {
57
330
            opt_vec_op::<i32, i32>(|e| e.iter().min().copied(), exprs).map(Lit::Int)
58
        }
59
180
        Expr::Max(_, exprs) => {
60
180
            opt_vec_op::<i32, i32>(|e| e.iter().max().copied(), exprs).map(Lit::Int)
61
        }
62
526
        Expr::UnsafeDiv(_, a, b) | Expr::SafeDiv(_, a, b) => {
63
752
            if unwrap_expr::<i32>(b)? == 0 {
64
2
                return None;
65
30
            }
66
30
            bin_op::<i32, i32>(|a, b| a / b, a, b).map(Lit::Int)
67
        }
68
240
        Expr::DivEq(_, a, b, c) => {
69
240
            tern_op::<i32, bool>(|a, b, c| a == b * c, a, b, c).map(Lit::Bool)
70
        }
71
210
        Expr::Bubble(_, a, b) => bin_op::<bool, bool>(|a, b| a && b, a, b).map(Lit::Bool),
72

            
73
120
        Expr::Reify(_, a, b) => bin_op::<bool, bool>(|a, b| a == b, a, b).map(Lit::Bool),
74
        _ => {
75
5715
            println!("WARNING: Unimplemented constant eval: {:?}", expr);
76
5715
            None
77
        }
78
    }
79
8086579
}
80

            
81
150
fn un_op<T, A>(f: fn(T) -> A, a: &Expr) -> Option<A>
82
150
where
83
150
    T: TryFrom<Lit>,
84
150
{
85
150
    let a = unwrap_expr::<T>(a)?;
86
    Some(f(a))
87
150
}
88

            
89
20190
fn bin_op<T, A>(f: fn(T, T) -> A, a: &Expr, b: &Expr) -> Option<A>
90
20190
where
91
20190
    T: TryFrom<Lit>,
92
20190
{
93
20190
    let a = unwrap_expr::<T>(a)?;
94
165
    let b = unwrap_expr::<T>(b)?;
95
60
    Some(f(a, b))
96
20190
}
97

            
98
554235
fn tern_op<T, A>(f: fn(T, T, T) -> A, a: &Expr, b: &Expr, c: &Expr) -> Option<A>
99
554235
where
100
554235
    T: TryFrom<Lit>,
101
554235
{
102
554235
    let a = unwrap_expr::<T>(a)?;
103
750
    let b = unwrap_expr::<T>(b)?;
104
    let c = unwrap_expr::<T>(c)?;
105
    Some(f(a, b, c))
106
554235
}
107

            
108
426705
fn vec_op<T, A>(f: fn(Vec<T>) -> A, a: &[Expr]) -> Option<A>
109
426705
where
110
426705
    T: TryFrom<Lit>,
111
426705
{
112
426705
    let a = a.iter().map(unwrap_expr).collect::<Option<Vec<T>>>()?;
113
75
    Some(f(a))
114
426705
}
115

            
116
510
fn opt_vec_op<T, A>(f: fn(Vec<T>) -> Option<A>, a: &[Expr]) -> Option<A>
117
510
where
118
510
    T: TryFrom<Lit>,
119
510
{
120
510
    let a = a.iter().map(unwrap_expr).collect::<Option<Vec<T>>>()?;
121
45
    f(a)
122
510
}
123

            
124
1265205
fn flat_op<T, A>(f: fn(Vec<T>, T) -> A, a: &[Expr], b: &Expr) -> Option<A>
125
1265205
where
126
1265205
    T: TryFrom<Lit>,
127
1265205
{
128
1265205
    let a = a.iter().map(unwrap_expr).collect::<Option<Vec<T>>>()?;
129
    let b = unwrap_expr::<T>(b)?;
130
    Some(f(a, b))
131
1265205
}
132

            
133
2268827
fn unwrap_expr<T: TryFrom<Lit>>(expr: &Expr) -> Option<T> {
134
2268827
    let c = eval_constant(expr)?;
135
1442
    TryInto::<T>::try_into(c).ok()
136
2268827
}
137

            
138
#[cfg(test)]
139
mod tests {
140
    use conjure_core::ast::{Expression, Factor, Literal};
141

            
142
    #[test]
143
1
    fn div_by_zero() {
144
1
        let expr = Expression::UnsafeDiv(
145
1
            Default::default(),
146
1
            Box::new(Expression::FactorE(
147
1
                Default::default(),
148
1
                Factor::Literal(Literal::Int(1)),
149
1
            )),
150
1
            Box::new(Expression::FactorE(
151
1
                Default::default(),
152
1
                Factor::Literal(Literal::Int(0)),
153
1
            )),
154
1
        );
155
1
        assert_eq!(super::eval_constant(&expr), None);
156
1
    }
157

            
158
    #[test]
159
1
    fn safediv_by_zero() {
160
1
        let expr = Expression::SafeDiv(
161
1
            Default::default(),
162
1
            Box::new(Expression::FactorE(
163
1
                Default::default(),
164
1
                Factor::Literal(Literal::Int(1)),
165
1
            )),
166
1
            Box::new(Expression::FactorE(
167
1
                Default::default(),
168
1
                Factor::Literal(Literal::Int(0)),
169
1
            )),
170
1
        );
171
1
        assert_eq!(super::eval_constant(&expr), None);
172
1
    }
173
}