conjure_core/rules/matrix/
repr_matrix.rs

1use conjure_core::ast::Expression as Expr;
2use conjure_core::ast::{matrix, SymbolTable};
3use conjure_core::rule_engine::{
4    register_rule, ApplicationError::RuleNotApplicable, ApplicationResult, Reduction,
5};
6use itertools::{chain, izip, Itertools};
7use uniplate::Uniplate;
8
9use crate::ast::Domain;
10use crate::ast::Literal;
11use crate::ast::Name;
12use crate::ast::{Atom, Range};
13use crate::into_matrix_expr;
14use crate::metadata::Metadata;
15
16/// Using the `matrix_to_atom`  representation rule, rewrite matrix indexing.
17#[register_rule(("Base", 2000))]
18fn index_matrix_to_atom(expr: &Expr, symbols: &SymbolTable) -> ApplicationResult {
19    // is this an indexing operation?
20    let Expr::SafeIndex(_, subject, indices) = expr else {
21        return Err(RuleNotApplicable);
22    };
23
24    // ensure that we are indexing a decision variable with the representation "matrix_to_atom"
25    // selected for it.
26    let Expr::Atomic(_, Atom::Reference(Name::WithRepresentation(name, reprs))) = &**subject else {
27        return Err(RuleNotApplicable);
28    };
29
30    if reprs.first().is_none_or(|x| x.as_str() != "matrix_to_atom") {
31        return Err(RuleNotApplicable);
32    }
33
34    let repr = symbols
35        .get_representation(name, &["matrix_to_atom"])
36        .unwrap()[0]
37        .clone();
38
39    // ensure that the subject has a matrix domain.
40    let decl = symbols.lookup(name).unwrap();
41
42    // resolve index domains so that we can enumerate them later
43    let Some(Domain::DomainMatrix(_, index_domains)) =
44        decl.domain().cloned().map(|x| x.resolve(symbols))
45    else {
46        return Err(RuleNotApplicable);
47    };
48
49    // checks are all ok: do the actual rewrite!
50
51    // 1. indices are constant -> find the element being indexed and only return that variable.
52    // 2. indices are not constant -> flatten matrix and return [flattened_matrix][flattened_index_expr]
53
54    // are the indices constant?
55    let mut indices_are_const = true;
56    let mut indices_as_lits: Vec<Literal> = vec![];
57
58    for index in indices {
59        let Some(index) = index.clone().to_literal() else {
60            indices_are_const = false;
61            break;
62        };
63        indices_as_lits.push(index);
64    }
65
66    if indices_are_const {
67        // indices are constant -> find the element being indexed and only return that variable.
68        //
69        let indices_as_name = Name::RepresentedName(
70            name.clone(),
71            "matrix_to_atom".into(),
72            indices_as_lits.iter().join("_"),
73        );
74
75        let subject = repr.expression_down(symbols)?[&indices_as_name].clone();
76
77        Ok(Reduction::pure(subject))
78    } else {
79        // indices are not constant -> flatten matrix and return [flattened_matrix][flattened_index_expr]
80
81        // For now, only supports matrices with index domains in the form int(n..m).
82        //
83        // Assuming this, to turn some x[a,b] and x[a,b,c] into x'[z]:
84        //
85        // z =                               + size(b) * (a-lb(a)) + 1 * (b-lb(b))  + 1 [2d matrix]
86        // z = (size(b)*size(c))*(a−lb(a))   + size(c) * (b−lb(b)) + 1 * (c−lb(c))  + 1 [3d matrix]
87        //
88        // where lb(a) is the lower bound for a.
89        //
90        //
91        // TODO: For other cases, we should generate table constraints that map the flat indices to
92        // the real ones.
93
94        // only need to do this for >1d matrices.
95        let n_dims = index_domains.len();
96        if n_dims <= 1 {
97            return Err(RuleNotApplicable);
98        };
99
100        // some intermediate values we need to do the above..
101
102        // [(lb(a),ub(a)),(lb(b),ub(b)),(lb(c),ub(c),...]
103        let bounds = index_domains
104            .iter()
105            .map(|dom| {
106                let Domain::IntDomain(ranges) = dom else {
107                    return Err(RuleNotApplicable);
108                };
109
110                let &[Range::Bounded(from, to)] = &ranges[..] else {
111                    return Err(RuleNotApplicable);
112                };
113
114                Ok((from, to))
115            })
116            .process_results(|it| it.collect_vec())?;
117
118        // [size(a),size(b),size(c),..]
119        let sizes = bounds
120            .iter()
121            .map(|(from, to)| (to - from) + 1)
122            .collect_vec();
123
124        // [lb(a),lb(b),lb(c),..]
125        let lower_bounds = bounds.iter().map(|(from, _)| from).collect_vec();
126
127        // from the examples above:
128        //
129        // index = (coefficients . terms) + 1
130        //
131        // where coefficients = [size(b)*size(c), size(c), 1      ]
132        //       terms =        [a-lb(a)        , b-lb(b), c-lb(c)]
133
134        // building coefficients.
135        //
136        // starting with sizes==[size(a),size(b),size(c)]
137        //
138        // ~~ skip(1) ~~>
139        //
140        // [size(b),size(c)]
141        //
142        // ~~ rev ~~>
143        //
144        // [size(c),size(b)]
145        //
146        // ~~ chain!(std::iter::once(&1),...) ~~>
147        //
148        // [1,size(c),size(b)]
149        //
150        // ~~ scan * ~~>
151        //
152        // [1,1*size(c),1*size(c)*size(b)]
153        //
154        // ~~ reverse ~~>
155        //
156        // [size(b)*size(c),size(c),1]
157        let mut coeffs: Vec<Expr> = chain!(std::iter::once(&1), sizes.iter().skip(1).rev())
158            .scan(1, |state, &x| {
159                *state *= x;
160                Some(*state)
161            })
162            .map(|x| Expr::Atomic(Metadata::new(), Atom::Literal(Literal::Int(x))))
163            .collect_vec();
164
165        coeffs.reverse();
166
167        // [(a-lb(a)),b-lb(b),c-lb(c)]
168        let terms: Vec<Expr> = izip!(indices, lower_bounds)
169            .map(|(i, lbi)| {
170                Expr::Minus(
171                    Metadata::new(),
172                    Box::new(i.clone()),
173                    Box::new(Expr::Atomic(
174                        Metadata::new(),
175                        Atom::Literal(Literal::Int(*lbi)),
176                    )),
177                )
178            })
179            .collect_vec();
180
181        // coeffs . terms
182        let mut sum_terms: Vec<Expr> = izip!(coeffs, terms)
183            .map(|(coeff, term)| Expr::Product(Metadata::new(), vec![coeff, term]))
184            .collect_vec();
185
186        // (coeffs . terms) + 1
187        sum_terms.push(Expr::Atomic(
188            Metadata::new(),
189            Atom::Literal(Literal::Int(1)),
190        ));
191
192        let flat_index = Expr::Sum(Metadata::new(), Box::new(into_matrix_expr![sum_terms]));
193
194        // now lets get the flat matrix.
195
196        let repr_exprs = repr.expression_down(symbols)?;
197        let flat_elems = matrix::enumerate_indices(index_domains.clone())
198            .map(|xs| {
199                Name::RepresentedName(
200                    name.clone(),
201                    "matrix_to_atom".into(),
202                    xs.into_iter().join("_"),
203                )
204            })
205            .map(|x| repr_exprs[&x].clone())
206            .collect_vec();
207
208        let flat_matrix = into_matrix_expr![flat_elems];
209
210        Ok(Reduction::pure(Expr::SafeIndex(
211            Metadata::new(),
212            Box::new(flat_matrix),
213            vec![flat_index],
214        )))
215    }
216}
217
218/// Using the `matrix_to_atom` representation rule, rewrite matrix slicing.
219#[register_rule(("Base", 2000))]
220fn slice_matrix_to_atom(expr: &Expr, symbols: &SymbolTable) -> ApplicationResult {
221    let Expr::SafeSlice(_, subject, indices) = expr else {
222        return Err(RuleNotApplicable);
223    };
224
225    let Expr::Atomic(_, Atom::Reference(Name::WithRepresentation(name, reprs))) = &**subject else {
226        return Err(RuleNotApplicable);
227    };
228
229    if reprs.first().is_none_or(|x| x.as_str() != "matrix_to_atom") {
230        return Err(RuleNotApplicable);
231    }
232
233    let decl = symbols.lookup(name).unwrap();
234    let repr = symbols
235        .get_representation(name, &["matrix_to_atom"])
236        .unwrap()[0]
237        .clone();
238
239    // resolve index domains so that we can enumerate them later
240    let Some(Domain::DomainMatrix(_, index_domains)) =
241        decl.domain().cloned().map(|x| x.resolve(symbols))
242    else {
243        return Err(RuleNotApplicable);
244    };
245
246    let mut indices_as_lits: Vec<Option<Literal>> = vec![];
247    let mut hole_dim: i32 = -1;
248    for (i, index) in indices.iter().enumerate() {
249        match index {
250            Some(e) => {
251                let lit = e.clone().to_literal().ok_or(RuleNotApplicable)?;
252                indices_as_lits.push(Some(lit.clone()));
253            }
254            None => {
255                indices_as_lits.push(None);
256                assert_eq!(hole_dim, -1);
257                hole_dim = i as _;
258            }
259        }
260    }
261
262    assert_ne!(hole_dim, -1);
263
264    let repr_values = repr.expression_down(symbols)?;
265
266    let slice = index_domains[hole_dim as usize]
267        .values()
268        .expect("index domain should be finite and enumerable")
269        .into_iter()
270        .map(|i| {
271            let mut indices_as_lits = indices_as_lits.clone();
272            indices_as_lits[hole_dim as usize] = Some(i);
273            let name = Name::RepresentedName(
274                name.clone(),
275                "matrix_to_atom".into(),
276                indices_as_lits.into_iter().map(|x| x.unwrap()).join("_"),
277            );
278            repr_values[&name].clone()
279        })
280        .collect_vec();
281
282    let new_expr = into_matrix_expr!(slice);
283
284    Ok(Reduction::pure(new_expr))
285}
286
287/// Converts a reference to a 1d-matrix not contained within an indexing or slicing expression to its atoms.
288#[register_rule(("Base", 2000))]
289fn matrix_ref_to_atom(expr: &Expr, symbols: &SymbolTable) -> ApplicationResult {
290    if let Expr::SafeSlice(_, _, _)
291    | Expr::UnsafeSlice(_, _, _)
292    | Expr::SafeIndex(_, _, _)
293    | Expr::UnsafeIndex(_, _, _) = expr
294    {
295        return Err(RuleNotApplicable);
296    };
297
298    for (child, ctx) in expr.holes() {
299        let Expr::Atomic(_, Atom::Reference(Name::WithRepresentation(name, reprs))) = child else {
300            continue;
301        };
302
303        if reprs.first().is_none_or(|x| x.as_str() != "matrix_to_atom") {
304            continue;
305        }
306
307        let decl = symbols.lookup(name.as_ref()).unwrap();
308        let repr = symbols
309            .get_representation(name.as_ref(), &["matrix_to_atom"])
310            .unwrap()[0]
311            .clone();
312
313        // resolve index domains so that we can enumerate them later
314        let Some(Domain::DomainMatrix(_, index_domains)) =
315            decl.domain().cloned().map(|x| x.resolve(symbols))
316        else {
317            continue;
318        };
319
320        if index_domains.len() > 1 {
321            continue;
322        }
323
324        let Ok(matrix_values) = repr.expression_down(symbols) else {
325            continue;
326        };
327
328        let flat_values = matrix::enumerate_indices(index_domains)
329            .map(|i| {
330                matrix_values[&Name::RepresentedName(
331                    name.clone(),
332                    "matrix_to_atom".into(),
333                    i.iter().join("_"),
334                )]
335                    .clone()
336            })
337            .collect_vec();
338        return Ok(Reduction::pure(ctx(into_matrix_expr![flat_values])));
339    }
340
341    Err(RuleNotApplicable)
342}