1
//! Normalising rules for `Product`
2

            
3
use conjure_cp::rule_engine::register_rule;
4

            
5
use conjure_cp::{
6
    ast::Metadata,
7
    ast::{Atom, Expression as Expr, Literal as Lit, Moo, SymbolTable, categories::CategoryOf},
8
    bug, into_matrix_expr,
9
    rule_engine::{ApplicationError::RuleNotApplicable, ApplicationResult, Reduction},
10
};
11

            
12
/// Reorders a product expression.
13
///
14
///  All literal coefficients in the product are folded together, and placed at the start of the
15
///  product.
16
///
17
/// Factors are first sorted by category. Then within each category, references are placed before
18
/// compound expressions.
19
///
20
/// # Justification
21
///
22
/// + Having a canonical ordering here is helpful in identifying weighted sums: 2x + 3y + 4d + ....
23
///
24
/// + Having constant, quantified, given references appear before decision variable references
25
///   means that we do not have to reorder the product again once those references get evaluated to
26
///   literals later on in the rewriting process.
27
///
28
#[register_rule(("Base", 8800))]
29
fn reorder_product(expr: &Expr, _: &SymbolTable) -> ApplicationResult {
30
    let Expr::Product(meta, factors) = expr.clone() else {
31
        return Err(RuleNotApplicable);
32
    };
33

            
34
    let (factors, index_domain) = Moo::unwrap_or_clone(factors)
35
        .unwrap_matrix_unchecked()
36
        .ok_or(RuleNotApplicable)?;
37
    let factors_copy = factors.clone();
38

            
39
    // Order variables by category.
40
    //
41
    // This ensures that references that will eventually become constants are in front of decision
42
    // variables, preventing them from needing to be moved again once they do become constant.
43
    let mut constant_exprs: Vec<Expr> = vec![];
44
    let mut bottom_exprs: Vec<Expr> = vec![];
45
    let mut parameter_exprs: Vec<Expr> = vec![];
46
    let mut quantified_exprs: Vec<Expr> = vec![];
47
    let mut decision_exprs: Vec<Expr> = vec![];
48

            
49
    for expr in factors {
50
        match expr.category_of() {
51
            conjure_cp::ast::categories::Category::Bottom => bottom_exprs.push(expr),
52
            conjure_cp::ast::categories::Category::Constant => constant_exprs.push(expr),
53
            conjure_cp::ast::categories::Category::Parameter => parameter_exprs.push(expr),
54
            conjure_cp::ast::categories::Category::Quantified => quantified_exprs.push(expr),
55
            conjure_cp::ast::categories::Category::Decision => decision_exprs.push(expr),
56
        }
57
    }
58

            
59
    let mut coefficient = 1;
60

            
61
    let (i, constant_exprs) = order_by_complexity(constant_exprs);
62
    coefficient *= i;
63

            
64
    let (i, parameter_exprs) = order_by_complexity(parameter_exprs);
65
    coefficient *= i;
66

            
67
    let (i, quantified_exprs) = order_by_complexity(quantified_exprs);
68
    coefficient *= i;
69

            
70
    let (i, decision_exprs) = order_by_complexity(decision_exprs);
71
    coefficient *= i;
72

            
73
    let (i, bottom_exprs) = order_by_complexity(bottom_exprs);
74
    coefficient *= i;
75

            
76
    let mut factors = if coefficient != 1 {
77
        vec![Expr::Atomic(
78
            Metadata::new(),
79
            Atom::Literal(Lit::Int(coefficient)),
80
        )]
81
    } else {
82
        vec![]
83
    };
84

            
85
    factors.extend(constant_exprs);
86
    factors.extend(bottom_exprs);
87
    factors.extend(parameter_exprs);
88
    factors.extend(quantified_exprs);
89
    factors.extend(decision_exprs);
90

            
91
    // check if we have actually done anything
92
    if factors_copy != factors {
93
        Ok(Reduction::pure(Expr::Product(
94
            meta,
95
            Moo::new(into_matrix_expr!(factors;index_domain)),
96
        )))
97
    } else {
98
        Err(RuleNotApplicable)
99
    }
100
}
101

            
102
// orders factors by "complexity":
103
//
104
// This returns an integer coefficient, created by folding all literals in `factors` into one
105
// value, as well a list of expressions ordered like so:
106
//
107
// 1. references
108
// 2. other expressions
109
fn order_by_complexity(factors: Vec<Expr>) -> (i32, Vec<Expr>) {
110
    // literal coefficient
111
    let mut literal: i32 = 1;
112
    let mut variables: Vec<Expr> = vec![];
113
    let mut compound_exprs: Vec<Expr> = vec![];
114

            
115
    for expr in factors {
116
        match expr {
117
            Expr::Atomic(_, Atom::Literal(lit)) => {
118
                let Lit::Int(i) = lit else {
119
                    bug!("Literals in a product operation should be integer, but got {lit}")
120
                };
121
                literal *= i;
122
            }
123

            
124
            Expr::Atomic(_, Atom::Reference(_)) => {
125
                variables.push(expr);
126
            }
127

            
128
            // -1 * literal
129
            Expr::Neg(_, expr2) if matches!(*expr2, Expr::Atomic(_, Atom::Literal(_))) => {
130
                let Expr::Atomic(_, Atom::Literal(lit)) = &*expr2 else {
131
                    unreachable!()
132
                };
133

            
134
                let Lit::Int(i) = lit else {
135
                    bug!("Literals in a product operation should be integer, but got {lit}")
136
                };
137

            
138
                literal *= -i;
139
            }
140

            
141
            // -1 * x
142
            Expr::Neg(_, expr2) if matches!(&*expr2, Expr::Atomic(_, Atom::Reference(_))) => {
143
                literal *= -1;
144
                variables.push(Moo::unwrap_or_clone(expr2));
145
            }
146

            
147
            // -1 * <expression>
148
            Expr::Neg(_, expr2) => {
149
                literal *= -1;
150
                compound_exprs.push(Moo::unwrap_or_clone(expr2));
151
            }
152
            _ => {
153
                compound_exprs.push(expr);
154
            }
155
        }
156
    }
157
    variables.extend(compound_exprs);
158

            
159
    (literal, variables)
160
}
161

            
162
/// Removes products with a single argument.
163
///
164
/// ```text
165
/// product([a]) ~> a
166
/// ```
167
///
168
#[register_rule(("Base", 8800))]
169
fn remove_unit_vector_products(expr: &Expr, _: &SymbolTable) -> ApplicationResult {
170
    match expr {
171
        Expr::Product(_, mat) => {
172
            let list = (**mat).clone().unwrap_list().ok_or(RuleNotApplicable)?;
173
            if list.len() == 1 {
174
                return Ok(Reduction::pure(list[0].clone()));
175
            }
176
            Err(RuleNotApplicable)
177
        }
178
        _ => Err(RuleNotApplicable),
179
    }
180
}