1
use conjure_cp::ast::categories::{Category, CategoryOf};
2
use conjure_cp::ast::{
3
    Atom, Expression as Expr, GroundDomain, Literal, Metadata, Moo, Name, Range, SymbolTable,
4
    matrix,
5
};
6
use conjure_cp::essence_expr;
7
use conjure_cp::into_matrix_expr;
8
use conjure_cp::rule_engine::{
9
    ApplicationError::RuleNotApplicable, ApplicationResult, Reduction, register_rule,
10
};
11
use itertools::{Itertools, chain, izip};
12
use uniplate::Uniplate;
13

            
14
use crate::bottom_up_adaptor::as_bottom_up;
15

            
16
/// Using the `matrix_to_atom`  representation rule, rewrite matrix indexing.
17
#[register_rule(("Base", 5000))]
18
606726
fn index_matrix_to_atom(expr: &Expr, symbols: &SymbolTable) -> ApplicationResult {
19
606726
    (as_bottom_up(index_matrix_to_atom_impl))(expr, symbols)
20
606726
}
21

            
22
599421
fn index_matrix_to_atom_impl(expr: &Expr, symbols: &SymbolTable) -> ApplicationResult {
23
    // is this an indexing operation?
24
599421
    if let Expr::SafeIndex(_, subject, indices) = expr
25

            
26
    // ensure that we are indexing a decision variable with the representation "matrix_to_atom"
27
    // selected for it.
28
    //
29
25551
    && let Expr::Atomic(_, Atom::Reference(decl)) = &**subject
30
15444
    && let Name::WithRepresentation(name, reprs) =  &decl.name() as &Name
31
3168
    && reprs.first().is_none_or(|x| x.as_str() == "matrix_to_atom")
32
    {
33
2718
        let repr = symbols
34
2718
            .get_representation(name, &["matrix_to_atom"])
35
2718
            .unwrap()[0]
36
2718
            .clone();
37

            
38
        // resolve index domains so that we can enumerate them later
39
2718
        let dom = decl.resolved_domain().ok_or(RuleNotApplicable)?;
40

            
41
        // ensure that the subject has a matrix domain.
42
2718
        let GroundDomain::Matrix(_, index_domains) = dom.as_ref() else {
43
            return Err(RuleNotApplicable);
44
        };
45

            
46
        // checks are all ok: do the actual rewrite!
47

            
48
        // 1. indices are constant -> find the element being indexed and only return that variable.
49
        // 2. indices are not constant -> flatten matrix and return [flattened_matrix][flattened_index_expr]
50

            
51
        // are the indices constant?
52
2718
        let mut indices_are_const = true;
53
2718
        let mut indices_as_lits: Vec<Literal> = vec![];
54

            
55
3105
        for index in indices {
56
3105
            let Some(index) = index.clone().into_literal() else {
57
909
                indices_are_const = false;
58
909
                break;
59
            };
60
2196
            indices_as_lits.push(index);
61
        }
62

            
63
2718
        if indices_are_const {
64
            // indices are constant -> find the element being indexed and only return that variable.
65
            //
66
1809
            let indices_as_name = Name::Represented(Box::new((
67
1809
                name.as_ref().clone(),
68
1809
                "matrix_to_atom".into(),
69
1809
                indices_as_lits.iter().join("_").into(),
70
1809
            )));
71

            
72
1809
            let subject = repr.expression_down(symbols)?[&indices_as_name].clone();
73
1809
            Ok(Reduction::pure(subject))
74
        } else {
75
            // indices are not constant -> flatten matrix and return [flattened_matrix][flattened_index_expr]
76

            
77
            // For now, only supports matrices with index domains in the form int(n..m).
78
            //
79
            // Assuming this, to turn some x[a,b] and x[a,b,c] into x'[z]:
80
            //
81
            // z =                               + size(b) * (a-lb(a)) + 1 * (b-lb(b))  + 1 [2d matrix]
82
            // z = (size(b)*size(c))*(a−lb(a))   + size(c) * (b−lb(b)) + 1 * (c−lb(c))  + 1 [3d matrix]
83
            //
84
            // where lb(a) is the lower bound for a.
85
            //
86
            //
87
            // TODO: For other cases, we should generate table constraints that map the flat indices to
88
            // the real ones.
89

            
90
            // only need to do this for >1d matrices.
91
909
            let n_dims = index_domains.len();
92
909
            assert_ne!(
93
                n_dims, 0,
94
                "a matrix indexing operation should have atleast one index"
95
            );
96
909
            if n_dims == 1 {
97
                // only apply this rule if the index is a decision variable
98
387
                if indices[0].category_of() != Category::Decision {
99
                    return Err(RuleNotApplicable);
100
387
                }
101
387
                let represented_expressions = repr
102
387
                    .expression_down(symbols)
103
387
                    .map_err(|_| RuleNotApplicable)?;
104
                // for some m[x], return [m1,m2,m3...mn][x]
105
387
                let new_subject = into_matrix_expr!(
106
387
                    matrix::enumerate_indices(index_domains.clone())
107
                        // for each index in the matrix, create the name that that index will have as
108
                        // an atom
109
2097
                        .map(|xs| {
110
2097
                            Name::Represented(Box::new((
111
2097
                                name.as_ref().clone(),
112
2097
                                "matrix_to_atom".into(),
113
2097
                                xs.into_iter().join("_").into(),
114
2097
                            )))
115
2097
                        })
116
2097
                        .map(|x| represented_expressions[&x].clone())
117
387
                        .collect_vec()
118
                );
119

            
120
387
                let old_index_domain = &index_domains[0];
121

            
122
387
                let GroundDomain::Int(ranges) = old_index_domain.as_ref() else {
123
                    return Err(RuleNotApplicable);
124
                };
125

            
126
387
                let &[Range::Bounded(from, _)] = &ranges[..] else {
127
                    return Err(RuleNotApplicable);
128
                };
129

            
130
387
                let offset = Expr::Atomic(Metadata::new(), Literal::Int(from - 1).into());
131
387
                let old_index = &indices[0].clone();
132

            
133
387
                return Ok(Reduction::pure(Expr::SafeIndex(
134
387
                    Metadata::new(),
135
387
                    Moo::new(new_subject),
136
387
                    vec![essence_expr!(&old_index - &offset)],
137
387
                )));
138
522
            }
139

            
140
            // some intermediate values we need to do the above..
141

            
142
            // [(lb(a),ub(a)),(lb(b),ub(b)),(lb(c),ub(c),...]
143
522
            let bounds = index_domains
144
522
                .iter()
145
1044
                .map(|dom| {
146
1044
                    let GroundDomain::Int(ranges) = dom.as_ref() else {
147
                        return Err(RuleNotApplicable);
148
                    };
149

            
150
1044
                    let &[Range::Bounded(from, to)] = &ranges[..] else {
151
                        return Err(RuleNotApplicable);
152
                    };
153

            
154
1044
                    Ok((from, to))
155
1044
                })
156
522
                .process_results(|it| it.collect_vec())?;
157

            
158
            // [size(a),size(b),size(c),..]
159
522
            let sizes = bounds
160
522
                .iter()
161
1044
                .map(|(from, to)| (to - from) + 1)
162
522
                .collect_vec();
163

            
164
            // [lb(a),lb(b),lb(c),..]
165
522
            let lower_bounds = bounds.iter().map(|(from, _)| from).collect_vec();
166

            
167
            // from the examples above:
168
            //
169
            // index = (coefficients . terms) + 1
170
            //
171
            // where coefficients = [size(b)*size(c), size(c), 1      ]
172
            //       terms =        [a-lb(a)        , b-lb(b), c-lb(c)]
173

            
174
            // building coefficients.
175
            //
176
            // starting with sizes==[size(a),size(b),size(c)]
177
            //
178
            // ~~ skip(1) ~~>
179
            //
180
            // [size(b),size(c)]
181
            //
182
            // ~~ rev ~~>
183
            //
184
            // [size(c),size(b)]
185
            //
186
            // ~~ chain!(std::iter::once(&1),...) ~~>
187
            //
188
            // [1,size(c),size(b)]
189
            //
190
            // ~~ scan * ~~>
191
            //
192
            // [1,1*size(c),1*size(c)*size(b)]
193
            //
194
            // ~~ reverse ~~>
195
            //
196
            // [size(b)*size(c),size(c),1]
197
522
            let mut coeffs: Vec<Expr> = chain!(std::iter::once(&1), sizes.iter().skip(1).rev())
198
1044
                .scan(1, |state, &x| {
199
1044
                    *state *= x;
200
1044
                    Some(*state)
201
1044
                })
202
1044
                .map(|x| essence_expr!(&x))
203
522
                .collect_vec();
204

            
205
522
            coeffs.reverse();
206

            
207
            // [(a-lb(a)),b-lb(b),c-lb(c)]
208
522
            let terms: Vec<Expr> = izip!(indices, lower_bounds)
209
1044
                .map(|(i, lbi)| essence_expr!(&i - &lbi))
210
522
                .collect_vec();
211

            
212
            // coeffs . terms
213
522
            let mut sum_terms: Vec<Expr> = izip!(coeffs, terms)
214
1044
                .map(|(coeff, term)| essence_expr!(&coeff * &term))
215
522
                .collect_vec();
216

            
217
            // (coeffs . terms) + 1
218
522
            sum_terms.push(essence_expr!(1));
219

            
220
522
            let flat_index = Expr::Sum(Metadata::new(), Moo::new(into_matrix_expr![sum_terms]));
221

            
222
            // now lets get the flat matrix.
223

            
224
522
            let repr_exprs = repr.expression_down(symbols)?;
225
522
            let flat_elems = matrix::enumerate_indices(index_domains.clone())
226
4347
                .map(|xs| {
227
4347
                    Name::Represented(Box::new((
228
4347
                        name.as_ref().clone(),
229
4347
                        "matrix_to_atom".into(),
230
4347
                        xs.into_iter().join("_").into(),
231
4347
                    )))
232
4347
                })
233
4347
                .map(|x| repr_exprs[&x].clone())
234
522
                .collect_vec();
235

            
236
522
            let flat_matrix = into_matrix_expr![flat_elems];
237

            
238
522
            Ok(Reduction::pure(Expr::SafeIndex(
239
522
                Metadata::new(),
240
522
                Moo::new(flat_matrix),
241
522
                vec![flat_index],
242
522
            )))
243
        }
244
    } else {
245
596703
        Err(RuleNotApplicable)
246
    }
247
599421
}
248

            
249
/// Using the `matrix_to_atom` representation rule, rewrite matrix slicing.
250
#[register_rule(("Base", 2000))]
251
105669
fn slice_matrix_to_atom(expr: &Expr, symbols: &SymbolTable) -> ApplicationResult {
252
105669
    let Expr::SafeSlice(_, subject, indices) = expr else {
253
104085
        return Err(RuleNotApplicable);
254
    };
255

            
256
1584
    let Expr::Atomic(_, Atom::Reference(decl)) = &**subject else {
257
        return Err(RuleNotApplicable);
258
    };
259

            
260
1584
    let Name::WithRepresentation(name, reprs) = &decl.name() as &Name else {
261
216
        return Err(RuleNotApplicable);
262
    };
263
1368
    if reprs.first().is_none_or(|x| x.as_str() != "matrix_to_atom") {
264
        return Err(RuleNotApplicable);
265
1368
    }
266

            
267
1368
    let decl = symbols.lookup(name).unwrap();
268
1368
    let repr = symbols
269
1368
        .get_representation(name, &["matrix_to_atom"])
270
1368
        .unwrap()[0]
271
1368
        .clone();
272

            
273
    // resolve index domains so that we can enumerate them later
274
1368
    let dom = decl.resolved_domain().ok_or(RuleNotApplicable)?;
275
1368
    let GroundDomain::Matrix(_, index_domains) = dom.as_ref() else {
276
        return Err(RuleNotApplicable);
277
    };
278

            
279
1368
    let mut indices_as_lits: Vec<Option<Literal>> = vec![];
280
1368
    let mut hole_dim: i32 = -1;
281
2196
    for (i, index) in indices.iter().enumerate() {
282
2196
        match index {
283
1368
            Some(e) => {
284
1368
                let lit = e.clone().into_literal().ok_or(RuleNotApplicable)?;
285
126
                indices_as_lits.push(Some(lit.clone()));
286
            }
287
            None => {
288
828
                indices_as_lits.push(None);
289
828
                assert_eq!(hole_dim, -1);
290
828
                hole_dim = i as _;
291
            }
292
        }
293
    }
294

            
295
126
    assert_ne!(hole_dim, -1);
296

            
297
126
    let repr_values = repr.expression_down(symbols)?;
298

            
299
126
    let slice = index_domains[hole_dim as usize]
300
126
        .values()
301
126
        .expect("index domain should be finite and enumerable")
302
324
        .map(|i| {
303
324
            let mut indices_as_lits = indices_as_lits.clone();
304
324
            indices_as_lits[hole_dim as usize] = Some(i);
305
324
            let name = Name::Represented(Box::new((
306
324
                name.as_ref().clone(),
307
324
                "matrix_to_atom".into(),
308
324
                indices_as_lits
309
324
                    .into_iter()
310
648
                    .map(|x| x.unwrap())
311
324
                    .join("_")
312
324
                    .into(),
313
            )));
314
324
            repr_values[&name].clone()
315
324
        })
316
126
        .collect_vec();
317

            
318
126
    let new_expr = into_matrix_expr!(slice);
319

            
320
126
    Ok(Reduction::pure(new_expr))
321
105669
}
322

            
323
/// Converts a reference to a 1d-matrix not contained within an indexing or slicing expression to its atoms.
324
#[register_rule(("Base", 2000))]
325
105669
fn matrix_ref_to_atom(expr: &Expr, symbols: &SymbolTable) -> ApplicationResult {
326
    if let Expr::SafeSlice(_, _, _)
327
    | Expr::UnsafeSlice(_, _, _)
328
    | Expr::SafeIndex(_, _, _)
329
105669
    | Expr::UnsafeIndex(_, _, _) = expr
330
    {
331
8856
        return Err(RuleNotApplicable);
332
96813
    };
333

            
334
96813
    for (child, ctx) in expr.holes() {
335
20076
        let Expr::Atomic(_, Atom::Reference(decl)) = child else {
336
75750
            continue;
337
        };
338

            
339
10980
        let Name::WithRepresentation(name, reprs) = &decl.name() as &Name else {
340
10764
            continue;
341
        };
342

            
343
216
        if reprs.first().is_none_or(|x| x.as_str() != "matrix_to_atom") {
344
81
            continue;
345
135
        }
346

            
347
135
        let decl = symbols.lookup(name.as_ref()).unwrap();
348
135
        let repr = symbols
349
135
            .get_representation(name.as_ref(), &["matrix_to_atom"])
350
135
            .unwrap()[0]
351
135
            .clone();
352

            
353
        // resolve index domains so that we can enumerate them later
354
135
        let dom = decl.resolved_domain().ok_or(RuleNotApplicable)?;
355
135
        let GroundDomain::Matrix(_, index_domains) = dom.as_ref() else {
356
            continue;
357
        };
358

            
359
135
        if index_domains.len() > 1 {
360
            continue;
361
135
        }
362

            
363
135
        let Ok(matrix_values) = repr.expression_down(symbols) else {
364
            continue;
365
        };
366

            
367
135
        let flat_values = matrix::enumerate_indices(index_domains.clone())
368
594
            .map(|i| {
369
594
                matrix_values[&Name::Represented(Box::new((
370
594
                    name.as_ref().clone(),
371
594
                    "matrix_to_atom".into(),
372
594
                    i.iter().join("_").into(),
373
594
                )))]
374
594
                    .clone()
375
594
            })
376
135
            .collect_vec();
377
135
        return Ok(Reduction::pure(ctx(into_matrix_expr![flat_values])));
378
    }
379

            
380
96678
    Err(RuleNotApplicable)
381
105669
}