conjure_core/rules/normalisers/
weighted_sums.rs1use std::collections::BTreeMap;
7
8use conjure_macros::register_rule;
9
10use crate::ast::{Atom, Expression as Expr, Literal as Lit, Name, SymbolTable};
11use crate::metadata::Metadata;
12use crate::rule_engine::ApplicationError::RuleNotApplicable;
13use crate::rule_engine::{ApplicationResult, Reduction};
14
15#[register_rule(("Base", 8400))]
23fn collect_like_terms(expr: &Expr, _: &SymbolTable) -> ApplicationResult {
24 let Expr::Sum(meta, exprs) = expr.clone() else {
25 return Err(RuleNotApplicable);
26 };
27
28 let mut weighted_terms: BTreeMap<Name, i32> = BTreeMap::new();
33 let mut other_terms: Vec<Expr> = Vec::new();
34
35 for expr in exprs.clone() {
39 match expr.clone() {
40 Expr::Product(_, exprs2) => {
41 match exprs2.as_slice() {
42 [Expr::Atomic(_, Atom::Reference(name)), Expr::Neg(_, e3)] => {
44 if let Expr::Atomic(_, Atom::Literal(Lit::Int(l))) = **e3 {
45 weighted_terms
46 .insert(name.clone(), weighted_terms.get(name).unwrap_or(&0) - l);
47 } else {
48 other_terms.push(expr);
49 };
50 }
51
52 [Expr::Atomic(_, Atom::Reference(name)), Expr::Atomic(_, Atom::Literal(Lit::Int(l)))] =>
54 {
55 weighted_terms
56 .insert(name.clone(), weighted_terms.get(name).unwrap_or(&0) + l);
57 }
58
59 _ => {
61 other_terms.push(expr);
62 }
63 }
64 }
65
66 _ => {
68 other_terms.push(expr);
69 }
70 }
71 }
72
73 if weighted_terms.is_empty() {
75 return Err(RuleNotApplicable);
76 }
77
78 let mut new_exprs = vec![];
79 for (name, coefficient) in weighted_terms {
80 new_exprs.push(Expr::Product(
81 Metadata::new(),
82 vec![
83 Expr::Atomic(Metadata::new(), name.into()),
84 Expr::Atomic(Metadata::new(), Atom::Literal(Lit::Int(coefficient))),
85 ],
86 ));
87 }
88
89 new_exprs.extend(other_terms);
90
91 if new_exprs.len() == exprs.len() {
93 return Err(RuleNotApplicable);
94 }
95
96 Ok(Reduction::pure(Expr::Sum(meta, new_exprs)))
97}