1
use conjure_core::ast::{Expression as Expr, Factor, Literal as Lit};
2
use conjure_core::metadata::Metadata;
3
use conjure_core::rule_engine::{
4
    register_rule, register_rule_set, ApplicationError, ApplicationResult, Reduction,
5
};
6
use conjure_core::Model;
7
use tracing::warn;
8

            
9
register_rule_set!("Constant", 100, ());
10

            
11
#[register_rule(("Constant", 9001))]
12
7250565
fn apply_eval_constant(expr: &Expr, _: &Model) -> ApplicationResult {
13
5412165
    if let Expr::FactorE(_, Factor::Literal(_)) = expr {
14
1432980
        return Err(ApplicationError::RuleNotApplicable);
15
5817585
    }
16
5817585
    eval_constant(expr)
17
5817585
        .map(|c| Reduction::pure(Expr::FactorE(Metadata::new(), Factor::Literal(c))))
18
5817585
        .ok_or(ApplicationError::RuleNotApplicable)
19
7250565
}
20

            
21
/// Simplify an expression to a constant if possible
22
/// Returns:
23
/// `None` if the expression cannot be simplified to a constant (e.g. if it contains a variable)
24
/// `Some(Const)` if the expression can be simplified to a constant
25
8086579
pub fn eval_constant(expr: &Expr) -> Option<Lit> {
26
5820572
    match expr {
27
1427
        Expr::FactorE(_, Factor::Literal(c)) => Some(c.clone()),
28
5819145
        Expr::FactorE(_, Factor::Reference(c)) => None,
29
7440
        Expr::Eq(_, a, b) => bin_op::<i32, bool>(|a, b| a == b, a, b)
30
7440
            .or_else(|| bin_op::<bool, bool>(|a, b| a == b, a, b))
31
7440
            .map(Lit::Bool),
32
240
        Expr::Neq(_, a, b) => bin_op::<i32, bool>(|a, b| a != b, a, b).map(Lit::Bool),
33
4170
        Expr::Lt(_, a, b) => bin_op::<i32, bool>(|a, b| a < b, a, b).map(Lit::Bool),
34
        Expr::Gt(_, a, b) => bin_op::<i32, bool>(|a, b| a > b, a, b).map(Lit::Bool),
35
315
        Expr::Leq(_, a, b) => bin_op::<i32, bool>(|a, b| a <= b, a, b).map(Lit::Bool),
36
240
        Expr::Geq(_, a, b) => bin_op::<i32, bool>(|a, b| a >= b, a, b).map(Lit::Bool),
37

            
38
150
        Expr::Not(_, expr) => un_op::<bool, bool>(|e| !e, expr).map(Lit::Bool),
39

            
40
20130
        Expr::And(_, exprs) => vec_op::<bool, bool>(|e| e.iter().all(|&e| e), exprs).map(Lit::Bool),
41
393405
        Expr::Or(_, exprs) => vec_op::<bool, bool>(|e| e.iter().any(|&e| e), exprs).map(Lit::Bool),
42

            
43
13170
        Expr::Sum(_, exprs) => vec_op::<i32, i32>(|e| e.iter().sum(), exprs).map(Lit::Int),
44

            
45
553995
        Expr::Ineq(_, a, b, c) => {
46
553995
            tern_op::<i32, bool>(|a, b, c| a <= (b + c), a, b, c).map(Lit::Bool)
47
        }
48

            
49
653295
        Expr::SumGeq(_, exprs, a) => {
50
653295
            flat_op::<i32, bool>(|e, a| e.iter().sum::<i32>() >= a, exprs, a).map(Lit::Bool)
51
        }
52
611910
        Expr::SumLeq(_, exprs, a) => {
53
611910
            flat_op::<i32, bool>(|e, a| e.iter().sum::<i32>() <= a, exprs, a).map(Lit::Bool)
54
        }
55
        // Expr::Div(_, a, b) => bin_op::<i32, i32>(|a, b| a / b, a, b).map(Lit::Int),
56
        // Expr::SafeDiv(_, a, b) => bin_op::<i32, i32>(|a, b| a / b, a, b).map(Lit::Int),
57
330
        Expr::Min(_, exprs) => {
58
330
            opt_vec_op::<i32, i32>(|e| e.iter().min().copied(), exprs).map(Lit::Int)
59
        }
60
180
        Expr::Max(_, exprs) => {
61
180
            opt_vec_op::<i32, i32>(|e| e.iter().max().copied(), exprs).map(Lit::Int)
62
        }
63
526
        Expr::UnsafeDiv(_, a, b) | Expr::SafeDiv(_, a, b) => {
64
752
            if unwrap_expr::<i32>(b)? == 0 {
65
2
                return None;
66
30
            }
67
30
            bin_op::<i32, i32>(|a, b| a / b, a, b).map(Lit::Int)
68
        }
69
240
        Expr::DivEq(_, a, b, c) => {
70
240
            tern_op::<i32, bool>(|a, b, c| a == b * c, a, b, c).map(Lit::Bool)
71
        }
72
210
        Expr::Bubble(_, a, b) => bin_op::<bool, bool>(|a, b| a && b, a, b).map(Lit::Bool),
73

            
74
120
        Expr::Reify(_, a, b) => bin_op::<bool, bool>(|a, b| a == b, a, b).map(Lit::Bool),
75
        _ => {
76
5715
            warn!(%expr,"Unimplemented constant eval");
77
5715
            None
78
        }
79
    }
80
8086579
}
81

            
82
150
fn un_op<T, A>(f: fn(T) -> A, a: &Expr) -> Option<A>
83
150
where
84
150
    T: TryFrom<Lit>,
85
150
{
86
150
    let a = unwrap_expr::<T>(a)?;
87
    Some(f(a))
88
150
}
89

            
90
20190
fn bin_op<T, A>(f: fn(T, T) -> A, a: &Expr, b: &Expr) -> Option<A>
91
20190
where
92
20190
    T: TryFrom<Lit>,
93
20190
{
94
20190
    let a = unwrap_expr::<T>(a)?;
95
165
    let b = unwrap_expr::<T>(b)?;
96
60
    Some(f(a, b))
97
20190
}
98

            
99
554235
fn tern_op<T, A>(f: fn(T, T, T) -> A, a: &Expr, b: &Expr, c: &Expr) -> Option<A>
100
554235
where
101
554235
    T: TryFrom<Lit>,
102
554235
{
103
554235
    let a = unwrap_expr::<T>(a)?;
104
750
    let b = unwrap_expr::<T>(b)?;
105
    let c = unwrap_expr::<T>(c)?;
106
    Some(f(a, b, c))
107
554235
}
108

            
109
426705
fn vec_op<T, A>(f: fn(Vec<T>) -> A, a: &[Expr]) -> Option<A>
110
426705
where
111
426705
    T: TryFrom<Lit>,
112
426705
{
113
426705
    let a = a.iter().map(unwrap_expr).collect::<Option<Vec<T>>>()?;
114
75
    Some(f(a))
115
426705
}
116

            
117
510
fn opt_vec_op<T, A>(f: fn(Vec<T>) -> Option<A>, a: &[Expr]) -> Option<A>
118
510
where
119
510
    T: TryFrom<Lit>,
120
510
{
121
510
    let a = a.iter().map(unwrap_expr).collect::<Option<Vec<T>>>()?;
122
45
    f(a)
123
510
}
124

            
125
1265205
fn flat_op<T, A>(f: fn(Vec<T>, T) -> A, a: &[Expr], b: &Expr) -> Option<A>
126
1265205
where
127
1265205
    T: TryFrom<Lit>,
128
1265205
{
129
1265205
    let a = a.iter().map(unwrap_expr).collect::<Option<Vec<T>>>()?;
130
    let b = unwrap_expr::<T>(b)?;
131
    Some(f(a, b))
132
1265205
}
133

            
134
2268827
fn unwrap_expr<T: TryFrom<Lit>>(expr: &Expr) -> Option<T> {
135
2268827
    let c = eval_constant(expr)?;
136
1442
    TryInto::<T>::try_into(c).ok()
137
2268827
}
138

            
139
#[cfg(test)]
140
mod tests {
141
    use conjure_core::ast::{Expression, Factor, Literal};
142

            
143
    #[test]
144
1
    fn div_by_zero() {
145
1
        let expr = Expression::UnsafeDiv(
146
1
            Default::default(),
147
1
            Box::new(Expression::FactorE(
148
1
                Default::default(),
149
1
                Factor::Literal(Literal::Int(1)),
150
1
            )),
151
1
            Box::new(Expression::FactorE(
152
1
                Default::default(),
153
1
                Factor::Literal(Literal::Int(0)),
154
1
            )),
155
1
        );
156
1
        assert_eq!(super::eval_constant(&expr), None);
157
1
    }
158

            
159
    #[test]
160
1
    fn safediv_by_zero() {
161
1
        let expr = Expression::SafeDiv(
162
1
            Default::default(),
163
1
            Box::new(Expression::FactorE(
164
1
                Default::default(),
165
1
                Factor::Literal(Literal::Int(1)),
166
1
            )),
167
1
            Box::new(Expression::FactorE(
168
1
                Default::default(),
169
1
                Factor::Literal(Literal::Int(0)),
170
1
            )),
171
1
        );
172
1
        assert_eq!(super::eval_constant(&expr), None);
173
1
    }
174
}