conjure_core/rules/normalisers/
associative_commutative.rs

1//! Generic normalising rules for associative-commutative operators.
2
3use std::collections::VecDeque;
4use std::mem::Discriminant;
5
6use conjure_core::ast::Expression as Expr;
7use conjure_core::rule_engine::{
8    register_rule, ApplicationError::RuleNotApplicable, ApplicationResult, Reduction,
9};
10use uniplate::Biplate;
11
12use crate::ast::SymbolTable;
13
14/// Normalises associative_commutative operations.
15///
16/// For now, this just removes nested expressions by associativity.
17///
18/// ```text
19/// v(v(a,b,...),c,d,...) ~> v(a,b,c,d)
20/// where v is an AC vector operator
21/// ```
22#[register_rule(("Base", 8900))]
23fn normalise_associative_commutative(expr: &Expr, _: &SymbolTable) -> ApplicationResult {
24    if !expr.is_associative_commutative_operator() {
25        return Err(RuleNotApplicable);
26    }
27
28    // remove nesting deeply
29    fn recurse_deeply(
30        root_discriminant: Discriminant<Expr>,
31        expr: Expr,
32        changed: &mut bool,
33    ) -> Vec<Expr> {
34        // if expr a different expression type, stop recursing
35        if std::mem::discriminant(&expr) != root_discriminant {
36            return vec![expr];
37        }
38
39        let child_vecs: VecDeque<Vec<Expr>> = expr.children_bi();
40
41        // empty expression
42        if child_vecs.is_empty() {
43            return vec![expr];
44        }
45
46        // go deeper
47        let children = child_vecs[0].clone();
48        let old_len = children.len();
49
50        let new_children = children
51            .into_iter()
52            .flat_map(|child| recurse_deeply(root_discriminant, child, changed))
53            .collect::<Vec<_>>();
54        if new_children.len() != old_len {
55            *changed = true;
56        }
57
58        new_children
59    }
60
61    let child_vecs: VecDeque<Vec<Expr>> = expr.children_bi();
62    if child_vecs.is_empty() {
63        return Err(RuleNotApplicable);
64    }
65
66    let mut changed = false;
67    let new_children = recurse_deeply(std::mem::discriminant(expr), expr.clone(), &mut changed);
68
69    if !changed {
70        return Err(RuleNotApplicable);
71    }
72
73    let new_expr = expr.with_children_bi(vec![new_children].into());
74
75    Ok(Reduction::pure(new_expr))
76}