1
use conjure_core::ast::{Expression as Expr, Factor, Literal as Lit};
2
use conjure_core::metadata::Metadata;
3
use conjure_core::rule_engine::{
4
    register_rule, register_rule_set, ApplicationError, ApplicationResult, Reduction,
5
};
6
use conjure_core::Model;
7
use tracing::warn;
8

            
9
register_rule_set!("Constant", 100, ());
10

            
11
#[register_rule(("Constant", 9001))]
12
8217307
fn apply_eval_constant(expr: &Expr, _: &Model) -> ApplicationResult {
13
6133787
    if let Expr::FactorE(_, Factor::Literal(_)) = expr {
14
1624044
        return Err(ApplicationError::RuleNotApplicable);
15
6593263
    }
16
6593263
    eval_constant(expr)
17
6593263
        .map(|c| Reduction::pure(Expr::FactorE(Metadata::new(), Factor::Literal(c))))
18
6593263
        .ok_or(ApplicationError::RuleNotApplicable)
19
8217307
}
20

            
21
/// Simplify an expression to a constant if possible
22
/// Returns:
23
/// `None` if the expression cannot be simplified to a constant (e.g. if it contains a variable)
24
/// `Some(Const)` if the expression can be simplified to a constant
25
9164789
pub fn eval_constant(expr: &Expr) -> Option<Lit> {
26
6596648
    match expr {
27
1617
        Expr::FactorE(_, Factor::Literal(c)) => Some(c.clone()),
28
6595031
        Expr::FactorE(_, Factor::Reference(c)) => None,
29
8432
        Expr::Eq(_, a, b) => bin_op::<i32, bool>(|a, b| a == b, a, b)
30
8432
            .or_else(|| bin_op::<bool, bool>(|a, b| a == b, a, b))
31
8432
            .map(Lit::Bool),
32
272
        Expr::Neq(_, a, b) => bin_op::<i32, bool>(|a, b| a != b, a, b).map(Lit::Bool),
33
4726
        Expr::Lt(_, a, b) => bin_op::<i32, bool>(|a, b| a < b, a, b).map(Lit::Bool),
34
        Expr::Gt(_, a, b) => bin_op::<i32, bool>(|a, b| a > b, a, b).map(Lit::Bool),
35
357
        Expr::Leq(_, a, b) => bin_op::<i32, bool>(|a, b| a <= b, a, b).map(Lit::Bool),
36
272
        Expr::Geq(_, a, b) => bin_op::<i32, bool>(|a, b| a >= b, a, b).map(Lit::Bool),
37

            
38
170
        Expr::Not(_, expr) => un_op::<bool, bool>(|e| !e, expr).map(Lit::Bool),
39

            
40
22814
        Expr::And(_, exprs) => vec_op::<bool, bool>(|e| e.iter().all(|&e| e), exprs).map(Lit::Bool),
41
445859
        Expr::Or(_, exprs) => vec_op::<bool, bool>(|e| e.iter().any(|&e| e), exprs).map(Lit::Bool),
42

            
43
14926
        Expr::Sum(_, exprs) => vec_op::<i32, i32>(|e| e.iter().sum(), exprs).map(Lit::Int),
44

            
45
627861
        Expr::Ineq(_, a, b, c) => {
46
627861
            tern_op::<i32, bool>(|a, b, c| a <= (b + c), a, b, c).map(Lit::Bool)
47
        }
48

            
49
740401
        Expr::SumGeq(_, exprs, a) => {
50
740401
            flat_op::<i32, bool>(|e, a| e.iter().sum::<i32>() >= a, exprs, a).map(Lit::Bool)
51
        }
52
693498
        Expr::SumLeq(_, exprs, a) => {
53
693498
            flat_op::<i32, bool>(|e, a| e.iter().sum::<i32>() <= a, exprs, a).map(Lit::Bool)
54
        }
55
        // Expr::Div(_, a, b) => bin_op::<i32, i32>(|a, b| a / b, a, b).map(Lit::Int),
56
        // Expr::SafeDiv(_, a, b) => bin_op::<i32, i32>(|a, b| a / b, a, b).map(Lit::Int),
57
374
        Expr::Min(_, exprs) => {
58
374
            opt_vec_op::<i32, i32>(|e| e.iter().min().copied(), exprs).map(Lit::Int)
59
        }
60
204
        Expr::Max(_, exprs) => {
61
204
            opt_vec_op::<i32, i32>(|e| e.iter().max().copied(), exprs).map(Lit::Int)
62
        }
63
596
        Expr::UnsafeDiv(_, a, b) | Expr::SafeDiv(_, a, b) => {
64
852
            if unwrap_expr::<i32>(b)? == 0 {
65
2
                return None;
66
34
            }
67
34
            bin_op::<i32, i32>(|a, b| a / b, a, b).map(Lit::Int)
68
        }
69
272
        Expr::DivEq(_, a, b, c) => {
70
272
            tern_op::<i32, bool>(|a, b, c| a == b * c, a, b, c).map(Lit::Bool)
71
        }
72
238
        Expr::Bubble(_, a, b) => bin_op::<bool, bool>(|a, b| a && b, a, b).map(Lit::Bool),
73

            
74
136
        Expr::Reify(_, a, b) => bin_op::<bool, bool>(|a, b| a == b, a, b).map(Lit::Bool),
75
        _ => {
76
6477
            warn!(%expr,"Unimplemented constant eval");
77
6477
            None
78
        }
79
    }
80
9164789
}
81

            
82
170
fn un_op<T, A>(f: fn(T) -> A, a: &Expr) -> Option<A>
83
170
where
84
170
    T: TryFrom<Lit>,
85
170
{
86
170
    let a = unwrap_expr::<T>(a)?;
87
    Some(f(a))
88
170
}
89

            
90
22882
fn bin_op<T, A>(f: fn(T, T) -> A, a: &Expr, b: &Expr) -> Option<A>
91
22882
where
92
22882
    T: TryFrom<Lit>,
93
22882
{
94
22882
    let a = unwrap_expr::<T>(a)?;
95
187
    let b = unwrap_expr::<T>(b)?;
96
68
    Some(f(a, b))
97
22882
}
98

            
99
628133
fn tern_op<T, A>(f: fn(T, T, T) -> A, a: &Expr, b: &Expr, c: &Expr) -> Option<A>
100
628133
where
101
628133
    T: TryFrom<Lit>,
102
628133
{
103
628133
    let a = unwrap_expr::<T>(a)?;
104
850
    let b = unwrap_expr::<T>(b)?;
105
    let c = unwrap_expr::<T>(c)?;
106
    Some(f(a, b, c))
107
628133
}
108

            
109
483599
fn vec_op<T, A>(f: fn(Vec<T>) -> A, a: &[Expr]) -> Option<A>
110
483599
where
111
483599
    T: TryFrom<Lit>,
112
483599
{
113
483599
    let a = a.iter().map(unwrap_expr).collect::<Option<Vec<T>>>()?;
114
85
    Some(f(a))
115
483599
}
116

            
117
578
fn opt_vec_op<T, A>(f: fn(Vec<T>) -> Option<A>, a: &[Expr]) -> Option<A>
118
578
where
119
578
    T: TryFrom<Lit>,
120
578
{
121
578
    let a = a.iter().map(unwrap_expr).collect::<Option<Vec<T>>>()?;
122
51
    f(a)
123
578
}
124

            
125
1433899
fn flat_op<T, A>(f: fn(Vec<T>, T) -> A, a: &[Expr], b: &Expr) -> Option<A>
126
1433899
where
127
1433899
    T: TryFrom<Lit>,
128
1433899
{
129
1433899
    let a = a.iter().map(unwrap_expr).collect::<Option<Vec<T>>>()?;
130
    let b = unwrap_expr::<T>(b)?;
131
    Some(f(a, b))
132
1433899
}
133

            
134
2571337
fn unwrap_expr<T: TryFrom<Lit>>(expr: &Expr) -> Option<T> {
135
2571337
    let c = eval_constant(expr)?;
136
1634
    TryInto::<T>::try_into(c).ok()
137
2571337
}
138

            
139
#[cfg(test)]
140
mod tests {
141
    use conjure_core::ast::{Expression, Factor, Literal};
142

            
143
    #[test]
144
1
    fn div_by_zero() {
145
1
        let expr = Expression::UnsafeDiv(
146
1
            Default::default(),
147
1
            Box::new(Expression::FactorE(
148
1
                Default::default(),
149
1
                Factor::Literal(Literal::Int(1)),
150
1
            )),
151
1
            Box::new(Expression::FactorE(
152
1
                Default::default(),
153
1
                Factor::Literal(Literal::Int(0)),
154
1
            )),
155
1
        );
156
1
        assert_eq!(super::eval_constant(&expr), None);
157
1
    }
158

            
159
    #[test]
160
1
    fn safediv_by_zero() {
161
1
        let expr = Expression::SafeDiv(
162
1
            Default::default(),
163
1
            Box::new(Expression::FactorE(
164
1
                Default::default(),
165
1
                Factor::Literal(Literal::Int(1)),
166
1
            )),
167
1
            Box::new(Expression::FactorE(
168
1
                Default::default(),
169
1
                Factor::Literal(Literal::Int(0)),
170
1
            )),
171
1
        );
172
1
        assert_eq!(super::eval_constant(&expr), None);
173
1
    }
174
}