1
use conjure_core::ast::{Constant as Const, Expression as Expr};
2
use conjure_core::metadata::Metadata;
3
use conjure_core::rule_engine::{
4
    register_rule, register_rule_set, ApplicationError, ApplicationResult, Reduction,
5
};
6
use conjure_core::Model;
7

            
8
register_rule_set!("Constant", 255, ());
9

            
10
#[register_rule(("Constant", 255))]
11
fn apply_eval_constant(expr: &Expr, _: &Model) -> ApplicationResult {
12
    if let Expr::Constant(_, _) = expr {
13
        return Err(ApplicationError::RuleNotApplicable);
14
    }
15
    eval_constant(expr)
16
        .map(|c| Reduction::pure(Expr::Constant(Metadata::new(), c)))
17
        .ok_or(ApplicationError::RuleNotApplicable)
18
}
19

            
20
/// Simplify an expression to a constant if possible
21
/// Returns:
22
/// `None` if the expression cannot be simplified to a constant (e.g. if it contains a variable)
23
/// `Some(Const)` if the expression can be simplified to a constant
24
pub fn eval_constant(expr: &Expr) -> Option<Const> {
25
    match expr {
26
        Expr::Constant(_, c) => Some(c.clone()),
27
        Expr::Reference(_, _) => None,
28
        Expr::Eq(_, a, b) => bin_op::<i32, bool>(|a, b| a == b, a, b)
29
            .or_else(|| bin_op::<bool, bool>(|a, b| a == b, a, b))
30
            .map(Const::Bool),
31
        Expr::Neq(_, a, b) => bin_op::<i32, bool>(|a, b| a != b, a, b).map(Const::Bool),
32
        Expr::Lt(_, a, b) => bin_op::<i32, bool>(|a, b| a < b, a, b).map(Const::Bool),
33
        Expr::Gt(_, a, b) => bin_op::<i32, bool>(|a, b| a > b, a, b).map(Const::Bool),
34
        Expr::Leq(_, a, b) => bin_op::<i32, bool>(|a, b| a <= b, a, b).map(Const::Bool),
35
        Expr::Geq(_, a, b) => bin_op::<i32, bool>(|a, b| a >= b, a, b).map(Const::Bool),
36

            
37
        Expr::Not(_, expr) => un_op::<bool, bool>(|e| !e, expr).map(Const::Bool),
38

            
39
        Expr::And(_, exprs) => {
40
            vec_op::<bool, bool>(|e| e.iter().all(|&e| e), exprs).map(Const::Bool)
41
        }
42
        Expr::Or(_, exprs) => {
43
            vec_op::<bool, bool>(|e| e.iter().any(|&e| e), exprs).map(Const::Bool)
44
        }
45

            
46
        Expr::Sum(_, exprs) => vec_op::<i32, i32>(|e| e.iter().sum(), exprs).map(Const::Int),
47

            
48
        Expr::Ineq(_, a, b, c) => {
49
            tern_op::<i32, bool>(|a, b, c| a <= (b + c), a, b, c).map(Const::Bool)
50
        }
51

            
52
        Expr::SumGeq(_, exprs, a) => {
53
            flat_op::<i32, bool>(|e, a| e.iter().sum::<i32>() >= a, exprs, a).map(Const::Bool)
54
        }
55
        Expr::SumLeq(_, exprs, a) => {
56
            flat_op::<i32, bool>(|e, a| e.iter().sum::<i32>() <= a, exprs, a).map(Const::Bool)
57
        }
58
        // Expr::Div(_, a, b) => bin_op::<i32, i32>(|a, b| a / b, a, b).map(Const::Int),
59
        // Expr::SafeDiv(_, a, b) => bin_op::<i32, i32>(|a, b| a / b, a, b).map(Const::Int),
60
        Expr::Min(_, exprs) => {
61
            opt_vec_op::<i32, i32>(|e| e.iter().min().copied(), exprs).map(Const::Int)
62
        }
63
        Expr::UnsafeDiv(_, a, b) | Expr::SafeDiv(_, a, b) => {
64
            if unwrap_expr::<i32>(b)? == 0 {
65
                return None;
66
            }
67
            bin_op::<i32, i32>(|a, b| a / b, a, b).map(Const::Int)
68
        }
69
        Expr::DivEq(_, a, b, c) => {
70
            tern_op::<i32, bool>(|a, b, c| a == b * c, a, b, c).map(Const::Bool)
71
        }
72
        Expr::Bubble(_, a, b) => bin_op::<bool, bool>(|a, b| a && b, a, b).map(Const::Bool),
73
        _ => {
74
            println!("WARNING: Unimplemented constant eval: {:?}", expr);
75
            None
76
        }
77
    }
78
}
79

            
80
fn un_op<T, A>(f: fn(T) -> A, a: &Expr) -> Option<A>
81
where
82
    T: TryFrom<Const>,
83
{
84
    let a = unwrap_expr::<T>(a)?;
85
    Some(f(a))
86
}
87

            
88
fn bin_op<T, A>(f: fn(T, T) -> A, a: &Expr, b: &Expr) -> Option<A>
89
where
90
    T: TryFrom<Const>,
91
{
92
    let a = unwrap_expr::<T>(a)?;
93
    let b = unwrap_expr::<T>(b)?;
94
    Some(f(a, b))
95
}
96

            
97
fn tern_op<T, A>(f: fn(T, T, T) -> A, a: &Expr, b: &Expr, c: &Expr) -> Option<A>
98
where
99
    T: TryFrom<Const>,
100
{
101
    let a = unwrap_expr::<T>(a)?;
102
    let b = unwrap_expr::<T>(b)?;
103
    let c = unwrap_expr::<T>(c)?;
104
    Some(f(a, b, c))
105
}
106

            
107
fn vec_op<T, A>(f: fn(Vec<T>) -> A, a: &[Expr]) -> Option<A>
108
where
109
    T: TryFrom<Const>,
110
{
111
    let a = a.iter().map(unwrap_expr).collect::<Option<Vec<T>>>()?;
112
    Some(f(a))
113
}
114

            
115
fn opt_vec_op<T, A>(f: fn(Vec<T>) -> Option<A>, a: &[Expr]) -> Option<A>
116
where
117
    T: TryFrom<Const>,
118
{
119
    let a = a.iter().map(unwrap_expr).collect::<Option<Vec<T>>>()?;
120
    f(a)
121
}
122

            
123
fn flat_op<T, A>(f: fn(Vec<T>, T) -> A, a: &[Expr], b: &Expr) -> Option<A>
124
where
125
    T: TryFrom<Const>,
126
{
127
    let a = a.iter().map(unwrap_expr).collect::<Option<Vec<T>>>()?;
128
    let b = unwrap_expr::<T>(b)?;
129
    Some(f(a, b))
130
}
131

            
132
fn unwrap_expr<T: TryFrom<Const>>(expr: &Expr) -> Option<T> {
133
    let c = eval_constant(expr)?;
134
    TryInto::<T>::try_into(c).ok()
135
}
136

            
137
#[cfg(test)]
138
mod tests {
139
    use conjure_core::ast::{Constant, Expression};
140

            
141
    #[test]
142
    fn div_by_zero() {
143
        let expr = Expression::UnsafeDiv(
144
            Default::default(),
145
            Box::new(Expression::Constant(Default::default(), Constant::Int(1))),
146
            Box::new(Expression::Constant(Default::default(), Constant::Int(0))),
147
        );
148
        assert_eq!(super::eval_constant(&expr), None);
149
    }
150

            
151
    #[test]
152
    fn safediv_by_zero() {
153
        let expr = Expression::SafeDiv(
154
            Default::default(),
155
            Box::new(Expression::Constant(Default::default(), Constant::Int(1))),
156
            Box::new(Expression::Constant(Default::default(), Constant::Int(0))),
157
        );
158
        assert_eq!(super::eval_constant(&expr), None);
159
    }
160
}