1
//! Normalising rules for weighted sums.
2
//!
3
//! Weighted sums are sums in the form c1*v1 + c2*v2 + ..., where cx are literals, and vx variable
4
//! references.
5

            
6
use std::collections::BTreeMap;
7

            
8
use conjure_cp::essence_expr;
9
use conjure_cp::rule_engine::register_rule;
10
use conjure_cp::{
11
    ast::Metadata,
12
    ast::{Atom, Expression as Expr, Literal as Lit, Moo, Name, SymbolTable},
13
    into_matrix_expr,
14
    rule_engine::{ApplicationError::RuleNotApplicable, ApplicationResult, Reduction},
15
};
16

            
17
/// Collects like terms in a weighted sum.
18
///
19
/// For some variable v, and constants cx,
20
///
21
/// ```plain
22
/// (c1 * v)  + .. + (c2 * v) + ... ~> ((c1 + c2) * v) + ...
23
/// ```
24
#[register_rule(("Base", 8400))]
25
fn collect_like_terms(expr: &Expr, st: &SymbolTable) -> ApplicationResult {
26
    let Expr::Sum(meta, exprs) = expr.clone() else {
27
        return Err(RuleNotApplicable);
28
    };
29

            
30
    let exprs = Moo::unwrap_or_clone(exprs)
31
        .unwrap_list()
32
        .ok_or(RuleNotApplicable)?;
33

            
34
    // Store:
35
    //  * map variable -> coefficient for weighted sum terms
36
    //  * a list of non-weighted sum terms
37

            
38
    let mut weighted_terms: BTreeMap<Name, i32> = BTreeMap::new();
39
    let mut other_terms: Vec<Expr> = Vec::new();
40

            
41
    // Assume valid terms are in form constant*variable, as reorder_product and partial_eval
42
    // should've already ran.
43

            
44
    for expr in exprs.clone() {
45
        match expr.clone() {
46
            Expr::Product(_, exprs2) => {
47
                match Moo::unwrap_or_clone(exprs2)
48
                    .unwrap_list()
49
                    .ok_or(RuleNotApplicable)?
50
                    .as_slice()
51
                {
52
                    // todo (gs248) It would be nice to generate these destructures by macro, like `essence_expr!` but in reverse
53
                    // -c*v
54
                    [Expr::Atomic(_, Atom::Reference(decl)), Expr::Neg(_, e3)] => {
55
                        let name: &Name = &decl.name();
56
                        if let Expr::Atomic(_, Atom::Literal(Lit::Int(l))) = **e3 {
57
                            weighted_terms
58
                                .insert(name.clone(), weighted_terms.get(name).unwrap_or(&0) - l);
59
                        } else {
60
                            other_terms.push(expr);
61
                        };
62
                    }
63

            
64
                    // c*v
65
                    [
66
                        Expr::Atomic(_, Atom::Reference(decl)),
67
                        Expr::Atomic(_, Atom::Literal(Lit::Int(l))),
68
                    ] => {
69
                        let name: &Name = &decl.name();
70
                        weighted_terms
71
                            .insert(name.clone(), weighted_terms.get(name).unwrap_or(&0) + l);
72
                    }
73

            
74
                    // invalid
75
                    _ => {
76
                        other_terms.push(expr);
77
                    }
78
                }
79
            }
80

            
81
            // not a product
82
            _ => {
83
                other_terms.push(expr);
84
            }
85
        }
86
    }
87

            
88
    // this rule has done nothing.
89
    if weighted_terms.is_empty() {
90
        return Err(RuleNotApplicable);
91
    }
92

            
93
    let mut new_exprs = vec![];
94
    for (name, coefficient) in weighted_terms {
95
        let decl = st.lookup(&name).ok_or(RuleNotApplicable)?;
96
        let atom = Expr::Atomic(
97
            Metadata::new(),
98
            Atom::Reference(conjure_cp::ast::Reference::new(decl)),
99
        );
100
        new_exprs.push(essence_expr!(&atom * &coefficient));
101
    }
102

            
103
    new_exprs.extend(other_terms);
104

            
105
    // no change
106
    if new_exprs.len() == exprs.len() {
107
        return Err(RuleNotApplicable);
108
    }
109

            
110
    Ok(Reduction::pure(Expr::Sum(
111
        meta,
112
        Moo::new(into_matrix_expr![new_exprs]),
113
    )))
114
}