conjure_core/rules/normalisers/
weighted_sums.rs
1use std::collections::BTreeMap;
7
8use conjure_macros::register_rule;
9
10use crate::ast::{Atom, Expression as Expr, Literal as Lit, Name, SymbolTable};
11use crate::into_matrix_expr;
12use crate::metadata::Metadata;
13use crate::rule_engine::ApplicationError::RuleNotApplicable;
14use crate::rule_engine::{ApplicationResult, Reduction};
15
16#[register_rule(("Base", 8400))]
24fn collect_like_terms(expr: &Expr, _: &SymbolTable) -> ApplicationResult {
25 let Expr::Sum(meta, exprs) = expr.clone() else {
26 return Err(RuleNotApplicable);
27 };
28
29 let exprs = exprs.unwrap_list().ok_or(RuleNotApplicable)?;
30
31 let mut weighted_terms: BTreeMap<Name, i32> = BTreeMap::new();
36 let mut other_terms: Vec<Expr> = Vec::new();
37
38 for expr in exprs.clone() {
42 match expr.clone() {
43 Expr::Product(_, exprs2) => {
44 match exprs2.as_slice() {
45 [Expr::Atomic(_, Atom::Reference(name)), Expr::Neg(_, e3)] => {
47 if let Expr::Atomic(_, Atom::Literal(Lit::Int(l))) = **e3 {
48 weighted_terms
49 .insert(name.clone(), weighted_terms.get(name).unwrap_or(&0) - l);
50 } else {
51 other_terms.push(expr);
52 };
53 }
54
55 [Expr::Atomic(_, Atom::Reference(name)), Expr::Atomic(_, Atom::Literal(Lit::Int(l)))] =>
57 {
58 weighted_terms
59 .insert(name.clone(), weighted_terms.get(name).unwrap_or(&0) + l);
60 }
61
62 _ => {
64 other_terms.push(expr);
65 }
66 }
67 }
68
69 _ => {
71 other_terms.push(expr);
72 }
73 }
74 }
75
76 if weighted_terms.is_empty() {
78 return Err(RuleNotApplicable);
79 }
80
81 let mut new_exprs = vec![];
82 for (name, coefficient) in weighted_terms {
83 new_exprs.push(Expr::Product(
84 Metadata::new(),
85 vec![
86 Expr::Atomic(Metadata::new(), name.into()),
87 Expr::Atomic(Metadata::new(), Atom::Literal(Lit::Int(coefficient))),
88 ],
89 ));
90 }
91
92 new_exprs.extend(other_terms);
93
94 if new_exprs.len() == exprs.len() {
96 return Err(RuleNotApplicable);
97 }
98
99 Ok(Reduction::pure(Expr::Sum(
100 meta,
101 Box::new(into_matrix_expr![new_exprs]),
102 )))
103}