1
use conjure_cp::ast::categories::{Category, CategoryOf};
2
use conjure_cp::ast::{
3
    Atom, Expression as Expr, GroundDomain, Literal, Metadata, Moo, Name, Range, SymbolTable,
4
    matrix,
5
};
6
use conjure_cp::essence_expr;
7
use conjure_cp::into_matrix_expr;
8
use conjure_cp::rule_engine::{
9
    ApplicationError::RuleNotApplicable, ApplicationResult, Reduction, register_rule,
10
};
11
use itertools::{Itertools, chain, izip};
12
use uniplate::Uniplate;
13

            
14
use crate::bottom_up_adaptor::as_bottom_up;
15

            
16
/// Using the `matrix_to_atom`  representation rule, rewrite matrix indexing.
17
#[register_rule(("Base", 5000))]
18
589518
fn index_matrix_to_atom(expr: &Expr, symbols: &SymbolTable) -> ApplicationResult {
19
589518
    (as_bottom_up(index_matrix_to_atom_impl))(expr, symbols)
20
589518
}
21

            
22
588885
fn index_matrix_to_atom_impl(expr: &Expr, symbols: &SymbolTable) -> ApplicationResult {
23
    // is this an indexing operation?
24
588885
    if let Expr::SafeIndex(_, subject, indices) = expr
25

            
26
    // ensure that we are indexing a decision variable with the representation "matrix_to_atom"
27
    // selected for it.
28
    //
29
28290
    && let Expr::Atomic(_, Atom::Reference(decl)) = &**subject
30
10950
    && let Name::WithRepresentation(name, reprs) =  &decl.name() as &Name
31
3090
    && reprs.first().is_none_or(|x| x.as_str() == "matrix_to_atom")
32
    {
33
2790
        let repr = symbols
34
2790
            .get_representation(name, &["matrix_to_atom"])
35
2790
            .unwrap()[0]
36
2790
            .clone();
37

            
38
        // resolve index domains so that we can enumerate them later
39
2790
        let dom = decl.resolved_domain().ok_or(RuleNotApplicable)?;
40

            
41
        // ensure that the subject has a matrix domain.
42
2790
        let GroundDomain::Matrix(_, index_domains) = dom.as_ref() else {
43
            return Err(RuleNotApplicable);
44
        };
45

            
46
        // checks are all ok: do the actual rewrite!
47

            
48
        // 1. indices are constant -> find the element being indexed and only return that variable.
49
        // 2. indices are not constant -> flatten matrix and return [flattened_matrix][flattened_index_expr]
50

            
51
        // are the indices constant?
52
2790
        let mut indices_are_const = true;
53
2790
        let mut indices_as_lits: Vec<Literal> = vec![];
54

            
55
3300
        for index in indices {
56
3300
            let Some(index) = index.clone().into_literal() else {
57
1440
                indices_are_const = false;
58
1440
                break;
59
            };
60
1860
            indices_as_lits.push(index);
61
        }
62

            
63
2790
        if indices_are_const {
64
            // indices are constant -> find the element being indexed and only return that variable.
65
            //
66
1350
            let indices_as_name = Name::Represented(Box::new((
67
1350
                name.as_ref().clone(),
68
1350
                "matrix_to_atom".into(),
69
1350
                indices_as_lits.iter().join("_").into(),
70
1350
            )));
71

            
72
1350
            let subject = repr.expression_down(symbols)?[&indices_as_name].clone();
73
1350
            Ok(Reduction::pure(subject))
74
        } else {
75
            // indices are not constant -> flatten matrix and return [flattened_matrix][flattened_index_expr]
76

            
77
            // For now, only supports matrices with index domains in the form int(n..m).
78
            //
79
            // Assuming this, to turn some x[a,b] and x[a,b,c] into x'[z]:
80
            //
81
            // z =                               + size(b) * (a-lb(a)) + 1 * (b-lb(b))  + 1 [2d matrix]
82
            // z = (size(b)*size(c))*(a−lb(a))   + size(c) * (b−lb(b)) + 1 * (c−lb(c))  + 1 [3d matrix]
83
            //
84
            // where lb(a) is the lower bound for a.
85
            //
86
            //
87
            // TODO: For other cases, we should generate table constraints that map the flat indices to
88
            // the real ones.
89

            
90
            // only need to do this for >1d matrices.
91
1440
            let n_dims = index_domains.len();
92
1440
            assert_ne!(
93
                n_dims, 0,
94
                "a matrix indexing operation should have atleast one index"
95
            );
96
1440
            if n_dims == 1 {
97
                // only apply this rule if the index is a decision variable
98
270
                if indices[0].category_of() != Category::Decision {
99
                    return Err(RuleNotApplicable);
100
270
                }
101
270
                let represented_expressions = repr
102
270
                    .expression_down(symbols)
103
270
                    .map_err(|_| RuleNotApplicable)?;
104
                // for some m[x], return [m1,m2,m3...mn][x]
105
270
                let new_subject = into_matrix_expr!(
106
270
                    matrix::enumerate_indices(index_domains.clone())
107
                        // for each index in the matrix, create the name that that index will have as
108
                        // an atom
109
1458
                        .map(|xs| {
110
1458
                            Name::Represented(Box::new((
111
1458
                                name.as_ref().clone(),
112
1458
                                "matrix_to_atom".into(),
113
1458
                                xs.into_iter().join("_").into(),
114
1458
                            )))
115
1458
                        })
116
1458
                        .map(|x| represented_expressions[&x].clone())
117
270
                        .collect_vec()
118
                );
119

            
120
270
                let old_index_domain = &index_domains[0];
121

            
122
270
                let GroundDomain::Int(ranges) = old_index_domain.as_ref() else {
123
                    return Err(RuleNotApplicable);
124
                };
125

            
126
270
                let &[Range::Bounded(from, _)] = &ranges[..] else {
127
                    return Err(RuleNotApplicable);
128
                };
129

            
130
270
                let offset = Expr::Atomic(Metadata::new(), Literal::Int(from - 1).into());
131
270
                let old_index = &indices[0].clone();
132

            
133
270
                return Ok(Reduction::pure(Expr::SafeIndex(
134
270
                    Metadata::new(),
135
270
                    Moo::new(new_subject),
136
270
                    vec![essence_expr!(&old_index - &offset)],
137
270
                )));
138
1170
            }
139

            
140
            // some intermediate values we need to do the above..
141

            
142
            // [(lb(a),ub(a)),(lb(b),ub(b)),(lb(c),ub(c),...]
143
1170
            let bounds = index_domains
144
1170
                .iter()
145
2340
                .map(|dom| {
146
2340
                    let GroundDomain::Int(ranges) = dom.as_ref() else {
147
                        return Err(RuleNotApplicable);
148
                    };
149

            
150
2340
                    let &[Range::Bounded(from, to)] = &ranges[..] else {
151
                        return Err(RuleNotApplicable);
152
                    };
153

            
154
2340
                    Ok((from, to))
155
2340
                })
156
1170
                .process_results(|it| it.collect_vec())?;
157

            
158
            // [size(a),size(b),size(c),..]
159
1170
            let sizes = bounds
160
1170
                .iter()
161
2340
                .map(|(from, to)| (to - from) + 1)
162
1170
                .collect_vec();
163

            
164
            // [lb(a),lb(b),lb(c),..]
165
1170
            let lower_bounds = bounds.iter().map(|(from, _)| from).collect_vec();
166

            
167
            // from the examples above:
168
            //
169
            // index = (coefficients . terms) + 1
170
            //
171
            // where coefficients = [size(b)*size(c), size(c), 1      ]
172
            //       terms =        [a-lb(a)        , b-lb(b), c-lb(c)]
173

            
174
            // building coefficients.
175
            //
176
            // starting with sizes==[size(a),size(b),size(c)]
177
            //
178
            // ~~ skip(1) ~~>
179
            //
180
            // [size(b),size(c)]
181
            //
182
            // ~~ rev ~~>
183
            //
184
            // [size(c),size(b)]
185
            //
186
            // ~~ chain!(std::iter::once(&1),...) ~~>
187
            //
188
            // [1,size(c),size(b)]
189
            //
190
            // ~~ scan * ~~>
191
            //
192
            // [1,1*size(c),1*size(c)*size(b)]
193
            //
194
            // ~~ reverse ~~>
195
            //
196
            // [size(b)*size(c),size(c),1]
197
1170
            let mut coeffs: Vec<Expr> = chain!(std::iter::once(&1), sizes.iter().skip(1).rev())
198
2340
                .scan(1, |state, &x| {
199
2340
                    *state *= x;
200
2340
                    Some(*state)
201
2340
                })
202
2340
                .map(|x| essence_expr!(&x))
203
1170
                .collect_vec();
204

            
205
1170
            coeffs.reverse();
206

            
207
            // [(a-lb(a)),b-lb(b),c-lb(c)]
208
1170
            let terms: Vec<Expr> = izip!(indices, lower_bounds)
209
2340
                .map(|(i, lbi)| essence_expr!(&i - &lbi))
210
1170
                .collect_vec();
211

            
212
            // coeffs . terms
213
1170
            let mut sum_terms: Vec<Expr> = izip!(coeffs, terms)
214
2340
                .map(|(coeff, term)| essence_expr!(&coeff * &term))
215
1170
                .collect_vec();
216

            
217
            // (coeffs . terms) + 1
218
1170
            sum_terms.push(essence_expr!(1));
219

            
220
1170
            let flat_index = Expr::Sum(Metadata::new(), Moo::new(into_matrix_expr![sum_terms]));
221

            
222
            // now lets get the flat matrix.
223

            
224
1170
            let repr_exprs = repr.expression_down(symbols)?;
225
1170
            let flat_elems = matrix::enumerate_indices(index_domains.clone())
226
10512
                .map(|xs| {
227
10512
                    Name::Represented(Box::new((
228
10512
                        name.as_ref().clone(),
229
10512
                        "matrix_to_atom".into(),
230
10512
                        xs.into_iter().join("_").into(),
231
10512
                    )))
232
10512
                })
233
10512
                .map(|x| repr_exprs[&x].clone())
234
1170
                .collect_vec();
235

            
236
1170
            let flat_matrix = into_matrix_expr![flat_elems];
237

            
238
1170
            Ok(Reduction::pure(Expr::SafeIndex(
239
1170
                Metadata::new(),
240
1170
                Moo::new(flat_matrix),
241
1170
                vec![flat_index],
242
1170
            )))
243
        }
244
    } else {
245
586095
        Err(RuleNotApplicable)
246
    }
247
588885
}
248

            
249
/// Using the `matrix_to_atom` representation rule, rewrite matrix slicing.
250
#[register_rule(("Base", 2000))]
251
112167
fn slice_matrix_to_atom(expr: &Expr, symbols: &SymbolTable) -> ApplicationResult {
252
112167
    let Expr::SafeSlice(_, subject, indices) = expr else {
253
110247
        return Err(RuleNotApplicable);
254
    };
255

            
256
1920
    let Expr::Atomic(_, Atom::Reference(decl)) = &**subject else {
257
        return Err(RuleNotApplicable);
258
    };
259

            
260
1920
    let Name::WithRepresentation(name, reprs) = &decl.name() as &Name else {
261
144
        return Err(RuleNotApplicable);
262
    };
263
1776
    if reprs.first().is_none_or(|x| x.as_str() != "matrix_to_atom") {
264
        return Err(RuleNotApplicable);
265
1776
    }
266

            
267
1776
    let decl = symbols.lookup(name).unwrap();
268
1776
    let repr = symbols
269
1776
        .get_representation(name, &["matrix_to_atom"])
270
1776
        .unwrap()[0]
271
1776
        .clone();
272

            
273
    // resolve index domains so that we can enumerate them later
274
1776
    let dom = decl.resolved_domain().ok_or(RuleNotApplicable)?;
275
1776
    let GroundDomain::Matrix(_, index_domains) = dom.as_ref() else {
276
        return Err(RuleNotApplicable);
277
    };
278

            
279
1776
    let mut indices_as_lits: Vec<Option<Literal>> = vec![];
280
1776
    let mut hole_dim: i32 = -1;
281
2760
    for (i, index) in indices.iter().enumerate() {
282
2760
        match index {
283
1776
            Some(e) => {
284
1776
                let lit = e.clone().into_literal().ok_or(RuleNotApplicable)?;
285
84
                indices_as_lits.push(Some(lit.clone()));
286
            }
287
            None => {
288
984
                indices_as_lits.push(None);
289
984
                assert_eq!(hole_dim, -1);
290
984
                hole_dim = i as _;
291
            }
292
        }
293
    }
294

            
295
84
    assert_ne!(hole_dim, -1);
296

            
297
84
    let repr_values = repr.expression_down(symbols)?;
298

            
299
84
    let slice = index_domains[hole_dim as usize]
300
84
        .values()
301
84
        .expect("index domain should be finite and enumerable")
302
216
        .map(|i| {
303
216
            let mut indices_as_lits = indices_as_lits.clone();
304
216
            indices_as_lits[hole_dim as usize] = Some(i);
305
216
            let name = Name::Represented(Box::new((
306
216
                name.as_ref().clone(),
307
216
                "matrix_to_atom".into(),
308
216
                indices_as_lits
309
216
                    .into_iter()
310
432
                    .map(|x| x.unwrap())
311
216
                    .join("_")
312
216
                    .into(),
313
            )));
314
216
            repr_values[&name].clone()
315
216
        })
316
84
        .collect_vec();
317

            
318
84
    let new_expr = into_matrix_expr!(slice);
319

            
320
84
    Ok(Reduction::pure(new_expr))
321
112167
}
322

            
323
/// Converts a reference to a 1d-matrix not contained within an indexing or slicing expression to its atoms.
324
#[register_rule(("Base", 2000))]
325
112167
fn matrix_ref_to_atom(expr: &Expr, symbols: &SymbolTable) -> ApplicationResult {
326
    if let Expr::SafeSlice(_, _, _)
327
    | Expr::UnsafeSlice(_, _, _)
328
    | Expr::SafeIndex(_, _, _)
329
112167
    | Expr::UnsafeIndex(_, _, _) = expr
330
    {
331
9732
        return Err(RuleNotApplicable);
332
102435
    };
333

            
334
102435
    for (child, ctx) in expr.holes() {
335
30960
        let Expr::Atomic(_, Atom::Reference(decl)) = child else {
336
72048
            continue;
337
        };
338

            
339
22842
        let Name::WithRepresentation(name, reprs) = &decl.name() as &Name else {
340
22698
            continue;
341
        };
342

            
343
144
        if reprs.first().is_none_or(|x| x.as_str() != "matrix_to_atom") {
344
54
            continue;
345
90
        }
346

            
347
90
        let decl = symbols.lookup(name.as_ref()).unwrap();
348
90
        let repr = symbols
349
90
            .get_representation(name.as_ref(), &["matrix_to_atom"])
350
90
            .unwrap()[0]
351
90
            .clone();
352

            
353
        // resolve index domains so that we can enumerate them later
354
90
        let dom = decl.resolved_domain().ok_or(RuleNotApplicable)?;
355
90
        let GroundDomain::Matrix(_, index_domains) = dom.as_ref() else {
356
            continue;
357
        };
358

            
359
90
        if index_domains.len() > 1 {
360
            continue;
361
90
        }
362

            
363
90
        let Ok(matrix_values) = repr.expression_down(symbols) else {
364
            continue;
365
        };
366

            
367
90
        let flat_values = matrix::enumerate_indices(index_domains.clone())
368
396
            .map(|i| {
369
396
                matrix_values[&Name::Represented(Box::new((
370
396
                    name.as_ref().clone(),
371
396
                    "matrix_to_atom".into(),
372
396
                    i.iter().join("_").into(),
373
396
                )))]
374
396
                    .clone()
375
396
            })
376
90
            .collect_vec();
377
90
        return Ok(Reduction::pure(ctx(into_matrix_expr![flat_values])));
378
    }
379

            
380
102345
    Err(RuleNotApplicable)
381
112167
}