1
use conjure_core::{ast::Constant as Const, ast::Expression as Expr, rule::RuleApplicationError};
2
use conjure_rules::register_rule;
3

            
4
#[register_rule]
5
fn apply_eval_constant(expr: &Expr) -> Result<Expr, RuleApplicationError> {
6
    if expr.is_constant() {
7
        return Err(RuleApplicationError::RuleNotApplicable);
8
    }
9
    let res = eval_constant(expr)
10
        .map(Expr::Constant)
11
        .ok_or(RuleApplicationError::RuleNotApplicable);
12
    res
13
}
14

            
15
/// Simplify an expression to a constant if possible
16
/// Returns:
17
/// `None` if the expression cannot be simplified to a constant (e.g. if it contains a variable)
18
/// `Some(Const)` if the expression can be simplified to a constant
19
pub fn eval_constant(expr: &Expr) -> Option<Const> {
20
    match expr {
21
        Expr::Constant(c) => Some(c.clone()),
22
        Expr::Reference(_) => None,
23

            
24
        Expr::Eq(a, b) => bin_op::<i32, bool>(|a, b| a == b, a, b)
25
            .or_else(|| bin_op::<bool, bool>(|a, b| a == b, a, b))
26
            .map(Const::Bool),
27
        Expr::Neq(a, b) => bin_op::<i32, bool>(|a, b| a != b, a, b).map(Const::Bool),
28
        Expr::Lt(a, b) => bin_op::<i32, bool>(|a, b| a < b, a, b).map(Const::Bool),
29
        Expr::Gt(a, b) => bin_op::<i32, bool>(|a, b| a > b, a, b).map(Const::Bool),
30
        Expr::Leq(a, b) => bin_op::<i32, bool>(|a, b| a <= b, a, b).map(Const::Bool),
31
        Expr::Geq(a, b) => bin_op::<i32, bool>(|a, b| a >= b, a, b).map(Const::Bool),
32

            
33
        Expr::Not(expr) => un_op::<bool, bool>(|e| !e, expr).map(Const::Bool),
34

            
35
        Expr::And(exprs) => vec_op::<bool, bool>(|e| e.iter().all(|&e| e), exprs).map(Const::Bool),
36
        Expr::Or(exprs) => vec_op::<bool, bool>(|e| e.iter().any(|&e| e), exprs).map(Const::Bool),
37

            
38
        Expr::Sum(exprs) => vec_op::<i32, i32>(|e| e.iter().sum(), exprs).map(Const::Int),
39

            
40
        Expr::Ineq(a, b, c) => {
41
            tern_op::<i32, bool>(|a, b, c| a <= (b + c), a, b, c).map(Const::Bool)
42
        }
43

            
44
        Expr::SumGeq(exprs, a) => {
45
            flat_op::<i32, bool>(|e, a| e.iter().sum::<i32>() >= a, exprs, a).map(Const::Bool)
46
        }
47
        Expr::SumLeq(exprs, a) => {
48
            flat_op::<i32, bool>(|e, a| e.iter().sum::<i32>() <= a, exprs, a).map(Const::Bool)
49
        }
50
        _ => {
51
            println!("WARNING: Unimplemented constant eval: {:?}", expr);
52
            None
53
        }
54
    }
55
}
56

            
57
fn un_op<T, A>(f: fn(T) -> A, a: &Expr) -> Option<A>
58
where
59
    T: TryFrom<Const>,
60
{
61
    let a = unwrap_expr::<T>(a)?;
62
    Some(f(a))
63
}
64

            
65
fn bin_op<T, A>(f: fn(T, T) -> A, a: &Expr, b: &Expr) -> Option<A>
66
where
67
    T: TryFrom<Const>,
68
{
69
    let a = unwrap_expr::<T>(a)?;
70
    let b = unwrap_expr::<T>(b)?;
71
    Some(f(a, b))
72
}
73

            
74
fn tern_op<T, A>(f: fn(T, T, T) -> A, a: &Expr, b: &Expr, c: &Expr) -> Option<A>
75
where
76
    T: TryFrom<Const>,
77
{
78
    let a = unwrap_expr::<T>(a)?;
79
    let b = unwrap_expr::<T>(b)?;
80
    let c = unwrap_expr::<T>(c)?;
81
    Some(f(a, b, c))
82
}
83

            
84
fn vec_op<T, A>(f: fn(Vec<T>) -> A, a: &Vec<Expr>) -> Option<A>
85
where
86
    T: TryFrom<Const>,
87
{
88
    let a = a
89
        .iter()
90
        .map(unwrap_expr)
91
        .into_iter()
92
        .collect::<Option<Vec<T>>>()?;
93
    Some(f(a))
94
}
95

            
96
fn flat_op<T, A>(f: fn(Vec<T>, T) -> A, a: &Vec<Expr>, b: &Expr) -> Option<A>
97
where
98
    T: TryFrom<Const>,
99
{
100
    let a = a
101
        .iter()
102
        .map(unwrap_expr)
103
        .into_iter()
104
        .collect::<Option<Vec<T>>>()?;
105
    let b = unwrap_expr::<T>(b)?;
106
    Some(f(a, b))
107
}
108

            
109
fn unwrap_expr<T: TryFrom<Const>>(expr: &Expr) -> Option<T> {
110
    let c = eval_constant(expr)?;
111
    TryInto::<T>::try_into(c).ok()
112
}