conjure_core/rules/normalisers/
product.rs

1//! Normalising rules for `Product`
2
3use std::iter;
4
5use conjure_macros::register_rule;
6
7use crate::ast::{Atom, Expression as Expr, Literal as Lit, SymbolTable};
8use crate::metadata::Metadata;
9use crate::rule_engine::ApplicationError::RuleNotApplicable;
10use crate::rule_engine::{ApplicationResult, Reduction};
11
12/// Reorders a product expression.
13///
14/// The resulting product will have the following order:
15///
16/// 1. Constant coefficients
17/// 2. Variables
18/// 3. Compound terms
19///
20/// The order of items within each category is undefined.
21///
22/// # Justification
23///
24/// Having a canonical ordering here is helpful in identifying weighted sums: 2x + 3y + 4d + ....
25#[register_rule(("Base", 8800))]
26fn reorder_product(expr: &Expr, _: &SymbolTable) -> ApplicationResult {
27    let Expr::Product(meta, exprs) = expr.clone() else {
28        return Err(RuleNotApplicable);
29    };
30
31    let mut constant_coefficients: Vec<Expr> = vec![];
32    let mut variables: Vec<Expr> = vec![];
33    let mut compound_exprs: Vec<Expr> = vec![];
34
35    for expr in exprs.clone() {
36        match expr {
37            Expr::Atomic(_, Atom::Literal(_)) => {
38                constant_coefficients.push(expr);
39            }
40            Expr::Atomic(_, Atom::Reference(_)) => {
41                variables.push(expr);
42            }
43
44            // -1 is a constant coefficient
45            Expr::Neg(_, ref expr2) if matches!(**expr2, Expr::Atomic(_, Atom::Literal(_))) => {
46                constant_coefficients.push(expr);
47            }
48
49            // -x === -1 * x
50            Expr::Neg(_, expr2) if matches!(*expr2, Expr::Atomic(_, Atom::Reference(_))) => {
51                constant_coefficients
52                    .push(Expr::Atomic(Metadata::new(), Atom::Literal(Lit::Int(-1))));
53                variables.push(*expr2);
54            }
55
56            _ => {
57                compound_exprs.push(expr);
58            }
59        }
60    }
61
62    constant_coefficients.extend(variables);
63    constant_coefficients.extend(compound_exprs);
64
65    // check if we have actually done anything
66    // TODO: check order before doing all this
67    let mut changed: bool = false;
68    for (e1, e2) in iter::zip(exprs, constant_coefficients.clone()) {
69        if e1 != e2 {
70            changed = true;
71            break;
72        }
73    }
74
75    if !changed {
76        return Err(RuleNotApplicable);
77    }
78
79    Ok(Reduction::pure(Expr::Product(meta, constant_coefficients)))
80}
81
82/// Removes products with a single argument.
83///
84/// ```text
85/// product([a]) ~> a
86/// ```
87///
88#[register_rule(("Base", 8800))]
89fn remove_unit_vector_products(expr: &Expr, _: &SymbolTable) -> ApplicationResult {
90    match expr {
91        Expr::Product(_, exprs) => {
92            if exprs.len() == 1 {
93                return Ok(Reduction::pure(exprs[0].clone()));
94            }
95            Err(RuleNotApplicable)
96        }
97        _ => Err(RuleNotApplicable),
98    }
99}