1
//! Normalising rules for `Product`
2

            
3
use conjure_cp::rule_engine::register_rule;
4

            
5
use conjure_cp::{
6
    ast::Metadata,
7
    ast::{Atom, Expression as Expr, Literal as Lit, Moo, SymbolTable, categories::CategoryOf},
8
    bug, into_matrix_expr,
9
    rule_engine::{ApplicationError::RuleNotApplicable, ApplicationResult, Reduction},
10
};
11

            
12
/// Reorders a product expression.
13
///
14
///  All literal coefficients in the product are folded together, and placed at the start of the
15
///  product.
16
///
17
/// Factors are first sorted by category. Then within each category, references are placed before
18
/// compound expressions.
19
///
20
/// # Justification
21
///
22
/// + Having a canonical ordering here is helpful in identifying weighted sums: 2x + 3y + 4d + ....
23
///
24
/// + Having constant, quantified, given references appear before decision variable references
25
///   means that we do not have to reorder the product again once those references get evaluated to
26
///   literals later on in the rewriting process.
27
///
28
#[register_rule(("Base", 8800))]
29
113726
fn reorder_product(expr: &Expr, _: &SymbolTable) -> ApplicationResult {
30
113726
    let Expr::Product(meta, factors) = expr.clone() else {
31
112847
        return Err(RuleNotApplicable);
32
    };
33

            
34
879
    let (factors, index_domain) = Moo::unwrap_or_clone(factors)
35
879
        .unwrap_matrix_unchecked()
36
879
        .ok_or(RuleNotApplicable)?;
37
876
    let factors_copy = factors.clone();
38

            
39
    // Order variables by category.
40
    //
41
    // This ensures that references that will eventually become constants are in front of decision
42
    // variables, preventing them from needing to be moved again once they do become constant.
43
876
    let mut constant_exprs: Vec<Expr> = vec![];
44
876
    let mut bottom_exprs: Vec<Expr> = vec![];
45
876
    let mut parameter_exprs: Vec<Expr> = vec![];
46
876
    let mut quantified_exprs: Vec<Expr> = vec![];
47
876
    let mut decision_exprs: Vec<Expr> = vec![];
48

            
49
1815
    for expr in factors {
50
1815
        match expr.category_of() {
51
            conjure_cp::ast::categories::Category::Bottom => bottom_exprs.push(expr),
52
642
            conjure_cp::ast::categories::Category::Constant => constant_exprs.push(expr),
53
            conjure_cp::ast::categories::Category::Parameter => parameter_exprs.push(expr),
54
51
            conjure_cp::ast::categories::Category::Quantified => quantified_exprs.push(expr),
55
1122
            conjure_cp::ast::categories::Category::Decision => decision_exprs.push(expr),
56
        }
57
    }
58

            
59
876
    let mut coefficient = 1;
60

            
61
876
    let (i, constant_exprs) = order_by_complexity(constant_exprs);
62
876
    coefficient *= i;
63

            
64
876
    let (i, parameter_exprs) = order_by_complexity(parameter_exprs);
65
876
    coefficient *= i;
66

            
67
876
    let (i, quantified_exprs) = order_by_complexity(quantified_exprs);
68
876
    coefficient *= i;
69

            
70
876
    let (i, decision_exprs) = order_by_complexity(decision_exprs);
71
876
    coefficient *= i;
72

            
73
876
    let (i, bottom_exprs) = order_by_complexity(bottom_exprs);
74
876
    coefficient *= i;
75

            
76
876
    let mut factors = if coefficient != 1 {
77
630
        vec![Expr::Atomic(
78
630
            Metadata::new(),
79
630
            Atom::Literal(Lit::Int(coefficient)),
80
630
        )]
81
    } else {
82
246
        vec![]
83
    };
84

            
85
876
    factors.extend(constant_exprs);
86
876
    factors.extend(bottom_exprs);
87
876
    factors.extend(parameter_exprs);
88
876
    factors.extend(quantified_exprs);
89
876
    factors.extend(decision_exprs);
90

            
91
    // check if we have actually done anything
92
876
    if factors_copy != factors {
93
18
        Ok(Reduction::pure(Expr::Product(
94
18
            meta,
95
18
            Moo::new(into_matrix_expr!(factors;index_domain)),
96
18
        )))
97
    } else {
98
858
        Err(RuleNotApplicable)
99
    }
100
113726
}
101

            
102
// orders factors by "complexity":
103
//
104
// This returns an integer coefficient, created by folding all literals in `factors` into one
105
// value, as well a list of expressions ordered like so:
106
//
107
// 1. references
108
// 2. other expressions
109
4380
fn order_by_complexity(factors: Vec<Expr>) -> (i32, Vec<Expr>) {
110
    // literal coefficient
111
4380
    let mut literal: i32 = 1;
112
4380
    let mut variables: Vec<Expr> = vec![];
113
4380
    let mut compound_exprs: Vec<Expr> = vec![];
114

            
115
4380
    for expr in factors {
116
1434
        match expr {
117
642
            Expr::Atomic(_, Atom::Literal(lit)) => {
118
642
                let Lit::Int(i) = lit else {
119
                    bug!("Literals in a product operation should be integer, but got {lit}")
120
                };
121
642
                literal *= i;
122
            }
123

            
124
792
            Expr::Atomic(_, Atom::Reference(_)) => {
125
792
                variables.push(expr);
126
792
            }
127

            
128
            // -1 * literal
129
            Expr::Neg(_, expr2) if matches!(*expr2, Expr::Atomic(_, Atom::Literal(_))) => {
130
                let Expr::Atomic(_, Atom::Literal(lit)) = &*expr2 else {
131
                    unreachable!()
132
                };
133

            
134
                let Lit::Int(i) = lit else {
135
                    bug!("Literals in a product operation should be integer, but got {lit}")
136
                };
137

            
138
                literal *= -i;
139
            }
140

            
141
            // -1 * x
142
            Expr::Neg(_, expr2) if matches!(&*expr2, Expr::Atomic(_, Atom::Reference(_))) => {
143
                literal *= -1;
144
                variables.push(Moo::unwrap_or_clone(expr2));
145
            }
146

            
147
            // -1 * <expression>
148
            Expr::Neg(_, expr2) => {
149
                literal *= -1;
150
                compound_exprs.push(Moo::unwrap_or_clone(expr2));
151
            }
152
381
            _ => {
153
381
                compound_exprs.push(expr);
154
381
            }
155
        }
156
    }
157
4380
    variables.extend(compound_exprs);
158

            
159
4380
    (literal, variables)
160
4380
}
161

            
162
/// Removes products with a single argument.
163
///
164
/// ```text
165
/// product([a]) ~> a
166
/// ```
167
///
168
#[register_rule(("Base", 8800))]
169
113726
fn remove_unit_vector_products(expr: &Expr, _: &SymbolTable) -> ApplicationResult {
170
113726
    match expr {
171
879
        Expr::Product(_, mat) => {
172
879
            let list = (**mat).clone().unwrap_list().ok_or(RuleNotApplicable)?;
173
363
            if list.len() == 1 {
174
6
                return Ok(Reduction::pure(list[0].clone()));
175
357
            }
176
357
            Err(RuleNotApplicable)
177
        }
178
112847
        _ => Err(RuleNotApplicable),
179
    }
180
113726
}