1
//! Normalising rules for weighted sums.
2
//!
3
//! Weighted sums are sums in the form c1*v1 + c2*v2 + ..., where cx are literals, and vx variable
4
//! references.
5

            
6
use std::collections::BTreeMap;
7

            
8
use conjure_cp::ast::Reference;
9
use conjure_cp::essence_expr;
10
use conjure_cp::rule_engine::register_rule;
11
use conjure_cp::{
12
    ast::Metadata,
13
    ast::{Atom, Expression as Expr, Literal as Lit, Moo, SymbolTable},
14
    into_matrix_expr,
15
    rule_engine::{ApplicationError::RuleNotApplicable, ApplicationResult, Reduction},
16
};
17

            
18
/// Collects like terms in a weighted sum.
19
///
20
/// For some variable v, and constants cx,
21
///
22
/// ```plain
23
/// (c1 * v)  + .. + (c2 * v) + ... ~> ((c1 + c2) * v) + ...
24
/// ```
25
#[register_rule(("Base", 8400))]
26
103283
fn collect_like_terms(expr: &Expr, _: &SymbolTable) -> ApplicationResult {
27
103283
    let Expr::Sum(meta, exprs) = expr else {
28
100340
        return Err(RuleNotApplicable);
29
    };
30
2943
    let exprs = exprs.unwrap_list().ok_or(RuleNotApplicable)?;
31

            
32
    // Store:
33
    //  * map variable -> coefficient for weighted sum terms
34
    //  * a list of non-weighted sum terms
35

            
36
    #[allow(clippy::mutable_key_type)]
37
1557
    let mut weighted_terms: BTreeMap<Reference, i32> = BTreeMap::new();
38
1557
    let mut other_terms: Vec<Expr> = Vec::new();
39

            
40
    // Assume valid terms are in form constant*variable, as reorder_product and partial_eval
41
    // should've already ran.
42

            
43
3519
    for expr in exprs.iter() {
44
3519
        match expr {
45
300
            Expr::Product(_, exprs2) => {
46
300
                match exprs2.unwrap_list().ok_or(RuleNotApplicable)?.as_slice() {
47
                    // todo (gs248) It would be nice to generate these destructures by macro, like `essence_expr!` but in reverse
48
                    // -c*v
49
                    [Expr::Atomic(_, Atom::Reference(re)), Expr::Neg(_, e3)] => {
50
                        if let Expr::Atomic(_, Atom::Literal(Lit::Int(l))) = **e3 {
51
                            let curr_weight = weighted_terms.get(re).unwrap_or(&0);
52
                            weighted_terms.insert(re.clone(), curr_weight - l);
53
                        } else {
54
                            other_terms.push(expr.clone());
55
                        };
56
                    }
57

            
58
                    // c*v
59
                    [
60
                        Expr::Atomic(_, Atom::Reference(re)),
61
                        Expr::Atomic(_, Atom::Literal(Lit::Int(l))),
62
                    ] => {
63
                        let curr_weight = weighted_terms.get(re).unwrap_or(&0);
64
                        weighted_terms.insert(re.clone(), curr_weight + l);
65
                    }
66

            
67
                    // invalid
68
243
                    _ => {
69
243
                        other_terms.push(expr.clone());
70
243
                    }
71
                }
72
            }
73

            
74
            // not a product
75
3219
            _ => {
76
3219
                other_terms.push(expr.clone());
77
3219
            }
78
        }
79
    }
80

            
81
    // this rule has done nothing.
82
1500
    if weighted_terms.is_empty() {
83
1500
        return Err(RuleNotApplicable);
84
    }
85

            
86
    let mut new_exprs = vec![];
87
    for (re, coefficient) in weighted_terms {
88
        let atom = Expr::Atomic(Metadata::new(), Atom::Reference(re));
89
        new_exprs.push(essence_expr!(&atom * &coefficient));
90
    }
91

            
92
    new_exprs.extend(other_terms);
93

            
94
    // no change
95
    if new_exprs.len() == exprs.len() {
96
        return Err(RuleNotApplicable);
97
    }
98

            
99
    Ok(Reduction::pure(Expr::Sum(
100
        meta.clone(),
101
        Moo::new(into_matrix_expr![new_exprs]),
102
    )))
103
103283
}