1
use conjure_core::ast::{Atom, Expression as Expr, Literal as Lit};
2
use conjure_core::metadata::Metadata;
3
use conjure_core::rule_engine::{
4
    register_rule, register_rule_set, ApplicationError, ApplicationResult, Reduction,
5
};
6
use conjure_core::Model;
7
use tracing::warn;
8

            
9
register_rule_set!("Constant", 100, ());
10

            
11
#[register_rule(("Constant", 9001))]
12
18
fn apply_eval_constant(expr: &Expr, _: &Model) -> ApplicationResult {
13
9
    if let Expr::Atomic(_, Atom::Literal(_)) = expr {
14
9
        return Err(ApplicationError::RuleNotApplicable);
15
9
    }
16
9
    eval_constant(expr)
17
9
        .map(|c| Reduction::pure(Expr::Atomic(Metadata::new(), Atom::Literal(c))))
18
9
        .ok_or(ApplicationError::RuleNotApplicable)
19
18
}
20

            
21
/// Simplify an expression to a constant if possible
22
/// Returns:
23
/// `None` if the expression cannot be simplified to a constant (e.g. if it contains a variable)
24
/// `Some(Const)` if the expression can be simplified to a constant
25
13
pub fn eval_constant(expr: &Expr) -> Option<Lit> {
26
2
    match expr {
27
2
        Expr::Atomic(_, Atom::Literal(c)) => Some(c.clone()),
28
        Expr::Atomic(_, Atom::Reference(_c)) => None,
29
        Expr::Eq(_, a, b) => bin_op::<i32, bool>(|a, b| a == b, a, b)
30
            .or_else(|| bin_op::<bool, bool>(|a, b| a == b, a, b))
31
            .map(Lit::Bool),
32
        Expr::Neq(_, a, b) => bin_op::<i32, bool>(|a, b| a != b, a, b).map(Lit::Bool),
33
        Expr::Lt(_, a, b) => bin_op::<i32, bool>(|a, b| a < b, a, b).map(Lit::Bool),
34
        Expr::Gt(_, a, b) => bin_op::<i32, bool>(|a, b| a > b, a, b).map(Lit::Bool),
35
        Expr::Leq(_, a, b) => bin_op::<i32, bool>(|a, b| a <= b, a, b).map(Lit::Bool),
36
        Expr::Geq(_, a, b) => bin_op::<i32, bool>(|a, b| a >= b, a, b).map(Lit::Bool),
37

            
38
        Expr::Not(_, expr) => un_op::<bool, bool>(|e| !e, expr).map(Lit::Bool),
39

            
40
        Expr::And(_, exprs) => vec_op::<bool, bool>(|e| e.iter().all(|&e| e), exprs).map(Lit::Bool),
41
9
        Expr::Or(_, exprs) => vec_op::<bool, bool>(|e| e.iter().any(|&e| e), exprs).map(Lit::Bool),
42

            
43
        Expr::Sum(_, exprs) => vec_op::<i32, i32>(|e| e.iter().sum(), exprs).map(Lit::Int),
44

            
45
        Expr::Ineq(_, a, b, c) => {
46
            tern_op::<i32, bool>(|a, b, c| a <= (b + c), a, b, c).map(Lit::Bool)
47
        }
48

            
49
        Expr::SumGeq(_, exprs, a) => {
50
            flat_op::<i32, bool>(|e, a| e.iter().sum::<i32>() >= a, exprs, a).map(Lit::Bool)
51
        }
52
        Expr::SumLeq(_, exprs, a) => {
53
            flat_op::<i32, bool>(|e, a| e.iter().sum::<i32>() <= a, exprs, a).map(Lit::Bool)
54
        }
55
        // Expr::Div(_, a, b) => bin_op::<i32, i32>(|a, b| a / b, a, b).map(Lit::Int),
56
        // Expr::SafeDiv(_, a, b) => bin_op::<i32, i32>(|a, b| a / b, a, b).map(Lit::Int),
57
        Expr::Min(_, exprs) => {
58
            opt_vec_op::<i32, i32>(|e| e.iter().min().copied(), exprs).map(Lit::Int)
59
        }
60
        Expr::Max(_, exprs) => {
61
            opt_vec_op::<i32, i32>(|e| e.iter().max().copied(), exprs).map(Lit::Int)
62
        }
63
1
        Expr::UnsafeDiv(_, a, b) | Expr::SafeDiv(_, a, b) => {
64
2
            if unwrap_expr::<i32>(b)? == 0 {
65
2
                return None;
66
            }
67
            bin_op::<i32, i32>(|a, b| a / b, a, b).map(Lit::Int)
68
        }
69
        Expr::UnsafeMod(_, a, b) | Expr::SafeMod(_, a, b) => {
70
            if unwrap_expr::<i32>(b)? == 0 {
71
                return None;
72
            }
73
            bin_op::<i32, i32>(|a, b| a % b, a, b).map(Lit::Int)
74
        }
75
        Expr::DivEqUndefZero(_, a, b, c) => {
76
            let a = unwrap_atom::<i32>(a)?;
77
            let b = unwrap_atom::<i32>(b)?;
78
            let c = unwrap_atom::<i32>(c)?;
79

            
80
            if b == 0 {
81
                return None;
82
            }
83

            
84
            Some(Lit::Bool(a / b == c))
85
        }
86
        Expr::Bubble(_, a, b) => bin_op::<bool, bool>(|a, b| a && b, a, b).map(Lit::Bool),
87

            
88
        Expr::Reify(_, a, b) => bin_op::<bool, bool>(|a, b| a == b, a, b).map(Lit::Bool),
89
        _ => {
90
            warn!(%expr,"Unimplemented constant eval");
91
            None
92
        }
93
    }
94
13
}
95

            
96
fn un_op<T, A>(f: fn(T) -> A, a: &Expr) -> Option<A>
97
where
98
    T: TryFrom<Lit>,
99
{
100
    let a = unwrap_expr::<T>(a)?;
101
    Some(f(a))
102
}
103

            
104
fn bin_op<T, A>(f: fn(T, T) -> A, a: &Expr, b: &Expr) -> Option<A>
105
where
106
    T: TryFrom<Lit>,
107
{
108
    let a = unwrap_expr::<T>(a)?;
109
    let b = unwrap_expr::<T>(b)?;
110
    Some(f(a, b))
111
}
112

            
113
fn tern_op<T, A>(f: fn(T, T, T) -> A, a: &Expr, b: &Expr, c: &Expr) -> Option<A>
114
where
115
    T: TryFrom<Lit>,
116
{
117
    let a = unwrap_expr::<T>(a)?;
118
    let b = unwrap_expr::<T>(b)?;
119
    let c = unwrap_expr::<T>(c)?;
120
    Some(f(a, b, c))
121
}
122

            
123
9
fn vec_op<T, A>(f: fn(Vec<T>) -> A, a: &[Expr]) -> Option<A>
124
9
where
125
9
    T: TryFrom<Lit>,
126
9
{
127
9
    let a = a.iter().map(unwrap_expr).collect::<Option<Vec<T>>>()?;
128
9
    Some(f(a))
129
9
}
130

            
131
fn opt_vec_op<T, A>(f: fn(Vec<T>) -> Option<A>, a: &[Expr]) -> Option<A>
132
where
133
    T: TryFrom<Lit>,
134
{
135
    let a = a.iter().map(unwrap_expr).collect::<Option<Vec<T>>>()?;
136
    f(a)
137
}
138

            
139
fn flat_op<T, A>(f: fn(Vec<T>, T) -> A, a: &[Expr], b: &Expr) -> Option<A>
140
where
141
    T: TryFrom<Lit>,
142
{
143
    let a = a.iter().map(unwrap_expr).collect::<Option<Vec<T>>>()?;
144
    let b = unwrap_expr::<T>(b)?;
145
    Some(f(a, b))
146
}
147

            
148
2
fn unwrap_expr<T: TryFrom<Lit>>(expr: &Expr) -> Option<T> {
149
2
    let c = eval_constant(expr)?;
150
2
    TryInto::<T>::try_into(c).ok()
151
2
}
152

            
153
fn unwrap_atom<T: TryFrom<Lit>>(atom: &Atom) -> Option<T> {
154
    let Atom::Literal(c) = atom else {
155
        return None;
156
    };
157
    TryInto::<T>::try_into(c.clone()).ok()
158
}
159

            
160
#[cfg(test)]
161
mod tests {
162
    use conjure_core::ast::{Atom, Expression, Literal};
163

            
164
    #[test]
165
1
    fn div_by_zero() {
166
1
        let expr = Expression::UnsafeDiv(
167
1
            Default::default(),
168
1
            Box::new(Expression::Atomic(
169
1
                Default::default(),
170
1
                Atom::Literal(Literal::Int(1)),
171
1
            )),
172
1
            Box::new(Expression::Atomic(
173
1
                Default::default(),
174
1
                Atom::Literal(Literal::Int(0)),
175
1
            )),
176
1
        );
177
1
        assert_eq!(super::eval_constant(&expr), None);
178
1
    }
179

            
180
    #[test]
181
1
    fn safediv_by_zero() {
182
1
        let expr = Expression::SafeDiv(
183
1
            Default::default(),
184
1
            Box::new(Expression::Atomic(
185
1
                Default::default(),
186
1
                Atom::Literal(Literal::Int(1)),
187
1
            )),
188
1
            Box::new(Expression::Atomic(
189
1
                Default::default(),
190
1
                Atom::Literal(Literal::Int(0)),
191
1
            )),
192
1
        );
193
1
        assert_eq!(super::eval_constant(&expr), None);
194
1
    }
195
}