1
use conjure_core::ast::{Expression as Expr, Factor, Literal as Lit};
2
use conjure_core::metadata::Metadata;
3
use conjure_core::rule_engine::{
4
    register_rule, register_rule_set, ApplicationError, ApplicationResult, Reduction,
5
};
6
use conjure_core::Model;
7
use tracing::warn;
8

            
9
register_rule_set!("Constant", 100, ());
10

            
11
#[register_rule(("Constant", 9001))]
12
8222390
fn apply_eval_constant(expr: &Expr, _: &Model) -> ApplicationResult {
13
6135147
    if let Expr::FactorE(_, Factor::Literal(_)) = expr {
14
1624588
        return Err(ApplicationError::RuleNotApplicable);
15
6597802
    }
16
6597802
    eval_constant(expr)
17
6597802
        .map(|c| Reduction::pure(Expr::FactorE(Metadata::new(), Factor::Literal(c))))
18
6597802
        .ok_or(ApplicationError::RuleNotApplicable)
19
8222390
}
20

            
21
/// Simplify an expression to a constant if possible
22
/// Returns:
23
/// `None` if the expression cannot be simplified to a constant (e.g. if it contains a variable)
24
/// `Some(Const)` if the expression can be simplified to a constant
25
9179919
pub fn eval_constant(expr: &Expr) -> Option<Lit> {
26
6601357
    match expr {
27
1583
        Expr::FactorE(_, Factor::Literal(c)) => Some(c.clone()),
28
6599774
        Expr::FactorE(_, Factor::Reference(c)) => None,
29
9452
        Expr::Eq(_, a, b) => bin_op::<i32, bool>(|a, b| a == b, a, b)
30
9452
            .or_else(|| bin_op::<bool, bool>(|a, b| a == b, a, b))
31
9452
            .map(Lit::Bool),
32
2159
        Expr::Neq(_, a, b) => bin_op::<i32, bool>(|a, b| a != b, a, b).map(Lit::Bool),
33
4726
        Expr::Lt(_, a, b) => bin_op::<i32, bool>(|a, b| a < b, a, b).map(Lit::Bool),
34
        Expr::Gt(_, a, b) => bin_op::<i32, bool>(|a, b| a > b, a, b).map(Lit::Bool),
35
357
        Expr::Leq(_, a, b) => bin_op::<i32, bool>(|a, b| a <= b, a, b).map(Lit::Bool),
36
272
        Expr::Geq(_, a, b) => bin_op::<i32, bool>(|a, b| a >= b, a, b).map(Lit::Bool),
37

            
38
697
        Expr::Not(_, expr) => un_op::<bool, bool>(|e| !e, expr).map(Lit::Bool),
39

            
40
23800
        Expr::And(_, exprs) => vec_op::<bool, bool>(|e| e.iter().all(|&e| e), exprs).map(Lit::Bool),
41
446556
        Expr::Or(_, exprs) => vec_op::<bool, bool>(|e| e.iter().any(|&e| e), exprs).map(Lit::Bool),
42

            
43
14926
        Expr::Sum(_, exprs) => vec_op::<i32, i32>(|e| e.iter().sum(), exprs).map(Lit::Int),
44

            
45
627861
        Expr::Ineq(_, a, b, c) => {
46
627861
            tern_op::<i32, bool>(|a, b, c| a <= (b + c), a, b, c).map(Lit::Bool)
47
        }
48

            
49
740401
        Expr::SumGeq(_, exprs, a) => {
50
740401
            flat_op::<i32, bool>(|e, a| e.iter().sum::<i32>() >= a, exprs, a).map(Lit::Bool)
51
        }
52
693498
        Expr::SumLeq(_, exprs, a) => {
53
693498
            flat_op::<i32, bool>(|e, a| e.iter().sum::<i32>() <= a, exprs, a).map(Lit::Bool)
54
        }
55
        // Expr::Div(_, a, b) => bin_op::<i32, i32>(|a, b| a / b, a, b).map(Lit::Int),
56
        // Expr::SafeDiv(_, a, b) => bin_op::<i32, i32>(|a, b| a / b, a, b).map(Lit::Int),
57
374
        Expr::Min(_, exprs) => {
58
374
            opt_vec_op::<i32, i32>(|e| e.iter().min().copied(), exprs).map(Lit::Int)
59
        }
60
204
        Expr::Max(_, exprs) => {
61
204
            opt_vec_op::<i32, i32>(|e| e.iter().max().copied(), exprs).map(Lit::Int)
62
        }
63
3469
        Expr::UnsafeDiv(_, a, b) | Expr::SafeDiv(_, a, b) => {
64
5034
            if unwrap_expr::<i32>(b)? == 0 {
65
2
                return None;
66
34
            }
67
34
            bin_op::<i32, i32>(|a, b| a / b, a, b).map(Lit::Int)
68
        }
69
612
        Expr::DivEq(_, a, b, c) => {
70
612
            let a = unwrap_factor::<i32>(a)?;
71
34
            let b = unwrap_factor::<i32>(b)?;
72
            let c = unwrap_factor::<i32>(c)?;
73

            
74
            if b == 0 {
75
                return None;
76
            }
77

            
78
            Some(Lit::Bool(a / b == c))
79
        }
80
952
        Expr::Bubble(_, a, b) => bin_op::<bool, bool>(|a, b| a && b, a, b).map(Lit::Bool),
81

            
82
        Expr::Reify(_, a, b) => bin_op::<bool, bool>(|a, b| a == b, a, b).map(Lit::Bool),
83
        _ => {
84
6681
            warn!(%expr,"Unimplemented constant eval");
85
6681
            None
86
        }
87
    }
88
9179919
}
89

            
90
697
fn un_op<T, A>(f: fn(T) -> A, a: &Expr) -> Option<A>
91
697
where
92
697
    T: TryFrom<Lit>,
93
697
{
94
697
    let a = unwrap_expr::<T>(a)?;
95
    Some(f(a))
96
697
}
97

            
98
27387
fn bin_op<T, A>(f: fn(T, T) -> A, a: &Expr, b: &Expr) -> Option<A>
99
27387
where
100
27387
    T: TryFrom<Lit>,
101
27387
{
102
27387
    let a = unwrap_expr::<T>(a)?;
103
187
    let b = unwrap_expr::<T>(b)?;
104
68
    Some(f(a, b))
105
27387
}
106

            
107
627861
fn tern_op<T, A>(f: fn(T, T, T) -> A, a: &Expr, b: &Expr, c: &Expr) -> Option<A>
108
627861
where
109
627861
    T: TryFrom<Lit>,
110
627861
{
111
627861
    let a = unwrap_expr::<T>(a)?;
112
816
    let b = unwrap_expr::<T>(b)?;
113
    let c = unwrap_expr::<T>(c)?;
114
    Some(f(a, b, c))
115
627861
}
116

            
117
485282
fn vec_op<T, A>(f: fn(Vec<T>) -> A, a: &[Expr]) -> Option<A>
118
485282
where
119
485282
    T: TryFrom<Lit>,
120
485282
{
121
485282
    let a = a.iter().map(unwrap_expr).collect::<Option<Vec<T>>>()?;
122
85
    Some(f(a))
123
485282
}
124

            
125
578
fn opt_vec_op<T, A>(f: fn(Vec<T>) -> Option<A>, a: &[Expr]) -> Option<A>
126
578
where
127
578
    T: TryFrom<Lit>,
128
578
{
129
578
    let a = a.iter().map(unwrap_expr).collect::<Option<Vec<T>>>()?;
130
51
    f(a)
131
578
}
132

            
133
1433899
fn flat_op<T, A>(f: fn(Vec<T>, T) -> A, a: &[Expr], b: &Expr) -> Option<A>
134
1433899
where
135
1433899
    T: TryFrom<Lit>,
136
1433899
{
137
1433899
    let a = a.iter().map(unwrap_expr).collect::<Option<Vec<T>>>()?;
138
    let b = unwrap_expr::<T>(b)?;
139
    Some(f(a, b))
140
1433899
}
141

            
142
2581928
fn unwrap_expr<T: TryFrom<Lit>>(expr: &Expr) -> Option<T> {
143
2581928
    let c = eval_constant(expr)?;
144
1600
    TryInto::<T>::try_into(c).ok()
145
2581928
}
146

            
147
646
fn unwrap_factor<T: TryFrom<Lit>>(factor: &Factor) -> Option<T> {
148
646
    let Factor::Literal(c) = factor else {
149
612
        return None;
150
    };
151
34
    TryInto::<T>::try_into(c.clone()).ok()
152
646
}
153

            
154
#[cfg(test)]
155
mod tests {
156
    use conjure_core::ast::{Expression, Factor, Literal};
157

            
158
    #[test]
159
1
    fn div_by_zero() {
160
1
        let expr = Expression::UnsafeDiv(
161
1
            Default::default(),
162
1
            Box::new(Expression::FactorE(
163
1
                Default::default(),
164
1
                Factor::Literal(Literal::Int(1)),
165
1
            )),
166
1
            Box::new(Expression::FactorE(
167
1
                Default::default(),
168
1
                Factor::Literal(Literal::Int(0)),
169
1
            )),
170
1
        );
171
1
        assert_eq!(super::eval_constant(&expr), None);
172
1
    }
173

            
174
    #[test]
175
1
    fn safediv_by_zero() {
176
1
        let expr = Expression::SafeDiv(
177
1
            Default::default(),
178
1
            Box::new(Expression::FactorE(
179
1
                Default::default(),
180
1
                Factor::Literal(Literal::Int(1)),
181
1
            )),
182
1
            Box::new(Expression::FactorE(
183
1
                Default::default(),
184
1
                Factor::Literal(Literal::Int(0)),
185
1
            )),
186
1
        );
187
1
        assert_eq!(super::eval_constant(&expr), None);
188
1
    }
189
}