conjure_core/rules/normalisers/
weighted_sums.rs

1//! Normalising rules for weighted sums.
2//!
3//! Weighted sums are sums in the form c1*v1 + c2*v2 + ..., where cx are literals, and vx variable
4//! references.
5
6use std::collections::BTreeMap;
7
8use conjure_macros::register_rule;
9
10use crate::ast::{Atom, Expression as Expr, Literal as Lit, Name, SymbolTable};
11use crate::metadata::Metadata;
12use crate::rule_engine::ApplicationError::RuleNotApplicable;
13use crate::rule_engine::{ApplicationResult, Reduction};
14
15/// Collects like terms in a weighted sum.
16///
17/// For some variable v, and constants cx,
18///
19/// ```plain
20/// (c1 * v)  + .. + (c2 * v) + ... ~> ((c1 + c2) * v) + ...
21/// ```
22#[register_rule(("Base", 8400))]
23fn collect_like_terms(expr: &Expr, _: &SymbolTable) -> ApplicationResult {
24    let Expr::Sum(meta, exprs) = expr.clone() else {
25        return Err(RuleNotApplicable);
26    };
27
28    // Store:
29    //  * map variable -> coefficient for weighted sum terms
30    //  * a list of non-weighted sum terms
31
32    let mut weighted_terms: BTreeMap<Name, i32> = BTreeMap::new();
33    let mut other_terms: Vec<Expr> = Vec::new();
34
35    // Assume valid terms are in form constant*variable, as reorder_product and partial_eval
36    // should've already ran.
37
38    for expr in exprs.clone() {
39        match expr.clone() {
40            Expr::Product(_, exprs2) => {
41                match exprs2.as_slice() {
42                    // -c*v
43                    [Expr::Atomic(_, Atom::Reference(name)), Expr::Neg(_, e3)] => {
44                        if let Expr::Atomic(_, Atom::Literal(Lit::Int(l))) = **e3 {
45                            weighted_terms
46                                .insert(name.clone(), weighted_terms.get(name).unwrap_or(&0) - l);
47                        } else {
48                            other_terms.push(expr);
49                        };
50                    }
51
52                    // c*v
53                    [Expr::Atomic(_, Atom::Reference(name)), Expr::Atomic(_, Atom::Literal(Lit::Int(l)))] =>
54                    {
55                        weighted_terms
56                            .insert(name.clone(), weighted_terms.get(name).unwrap_or(&0) + l);
57                    }
58
59                    // invalid
60                    _ => {
61                        other_terms.push(expr);
62                    }
63                }
64            }
65
66            // not a product
67            _ => {
68                other_terms.push(expr);
69            }
70        }
71    }
72
73    // this rule has done nothing.
74    if weighted_terms.is_empty() {
75        return Err(RuleNotApplicable);
76    }
77
78    let mut new_exprs = vec![];
79    for (name, coefficient) in weighted_terms {
80        new_exprs.push(Expr::Product(
81            Metadata::new(),
82            vec![
83                Expr::Atomic(Metadata::new(), name.into()),
84                Expr::Atomic(Metadata::new(), Atom::Literal(Lit::Int(coefficient))),
85            ],
86        ));
87    }
88
89    new_exprs.extend(other_terms);
90
91    // no change
92    if new_exprs.len() == exprs.len() {
93        return Err(RuleNotApplicable);
94    }
95
96    Ok(Reduction::pure(Expr::Sum(meta, new_exprs)))
97}