conjure_core/rules/normalisers/
weighted_sums.rs

1//! Normalising rules for weighted sums.
2//!
3//! Weighted sums are sums in the form c1*v1 + c2*v2 + ..., where cx are literals, and vx variable
4//! references.
5
6use std::collections::BTreeMap;
7
8use conjure_macros::register_rule;
9
10use crate::ast::{Atom, Expression as Expr, Literal as Lit, Name, SymbolTable};
11use crate::into_matrix_expr;
12use crate::metadata::Metadata;
13use crate::rule_engine::ApplicationError::RuleNotApplicable;
14use crate::rule_engine::{ApplicationResult, Reduction};
15
16/// Collects like terms in a weighted sum.
17///
18/// For some variable v, and constants cx,
19///
20/// ```plain
21/// (c1 * v)  + .. + (c2 * v) + ... ~> ((c1 + c2) * v) + ...
22/// ```
23#[register_rule(("Base", 8400))]
24fn collect_like_terms(expr: &Expr, _: &SymbolTable) -> ApplicationResult {
25    let Expr::Sum(meta, exprs) = expr.clone() else {
26        return Err(RuleNotApplicable);
27    };
28
29    let exprs = exprs.unwrap_list().ok_or(RuleNotApplicable)?;
30
31    // Store:
32    //  * map variable -> coefficient for weighted sum terms
33    //  * a list of non-weighted sum terms
34
35    let mut weighted_terms: BTreeMap<Name, i32> = BTreeMap::new();
36    let mut other_terms: Vec<Expr> = Vec::new();
37
38    // Assume valid terms are in form constant*variable, as reorder_product and partial_eval
39    // should've already ran.
40
41    for expr in exprs.clone() {
42        match expr.clone() {
43            Expr::Product(_, exprs2) => {
44                match exprs2.as_slice() {
45                    // -c*v
46                    [Expr::Atomic(_, Atom::Reference(name)), Expr::Neg(_, e3)] => {
47                        if let Expr::Atomic(_, Atom::Literal(Lit::Int(l))) = **e3 {
48                            weighted_terms
49                                .insert(name.clone(), weighted_terms.get(name).unwrap_or(&0) - l);
50                        } else {
51                            other_terms.push(expr);
52                        };
53                    }
54
55                    // c*v
56                    [Expr::Atomic(_, Atom::Reference(name)), Expr::Atomic(_, Atom::Literal(Lit::Int(l)))] =>
57                    {
58                        weighted_terms
59                            .insert(name.clone(), weighted_terms.get(name).unwrap_or(&0) + l);
60                    }
61
62                    // invalid
63                    _ => {
64                        other_terms.push(expr);
65                    }
66                }
67            }
68
69            // not a product
70            _ => {
71                other_terms.push(expr);
72            }
73        }
74    }
75
76    // this rule has done nothing.
77    if weighted_terms.is_empty() {
78        return Err(RuleNotApplicable);
79    }
80
81    let mut new_exprs = vec![];
82    for (name, coefficient) in weighted_terms {
83        new_exprs.push(Expr::Product(
84            Metadata::new(),
85            vec![
86                Expr::Atomic(Metadata::new(), name.into()),
87                Expr::Atomic(Metadata::new(), Atom::Literal(Lit::Int(coefficient))),
88            ],
89        ));
90    }
91
92    new_exprs.extend(other_terms);
93
94    // no change
95    if new_exprs.len() == exprs.len() {
96        return Err(RuleNotApplicable);
97    }
98
99    Ok(Reduction::pure(Expr::Sum(
100        meta,
101        Box::new(into_matrix_expr![new_exprs]),
102    )))
103}