1
use conjure_cp::ast::categories::{Category, CategoryOf};
2
use conjure_cp::ast::{
3
    Atom, Expression as Expr, GroundDomain, Literal, Metadata, Moo, Name, Range, SymbolTable,
4
    matrix,
5
};
6
use conjure_cp::essence_expr;
7
use conjure_cp::into_matrix_expr;
8
use conjure_cp::rule_engine::{
9
    ApplicationError::RuleNotApplicable, ApplicationResult, Reduction, register_rule,
10
};
11
use itertools::{Itertools, chain, izip};
12
use uniplate::Uniplate;
13

            
14
use crate::bottom_up_adaptor::as_bottom_up;
15

            
16
/// Using the `matrix_to_atom`  representation rule, rewrite matrix indexing.
17
#[register_rule(("Base", 5000))]
18
75536
fn index_matrix_to_atom(expr: &Expr, symbols: &SymbolTable) -> ApplicationResult {
19
75536
    (as_bottom_up(index_matrix_to_atom_impl))(expr, symbols)
20
75536
}
21

            
22
74122
fn index_matrix_to_atom_impl(expr: &Expr, symbols: &SymbolTable) -> ApplicationResult {
23
    // is this an indexing operation?
24
74122
    if let Expr::SafeIndex(_, subject, indices) = expr
25

            
26
    // ensure that we are indexing a decision variable with the representation "matrix_to_atom"
27
    // selected for it.
28
    //
29
4935
    && let Expr::Atomic(_, Atom::Reference(decl)) = &**subject
30
4539
    && let Name::WithRepresentation(name, reprs) =  &decl.name() as &Name
31
906
    && reprs.first().is_none_or(|x| x.as_str() == "matrix_to_atom")
32
    {
33
756
        let repr = symbols
34
756
            .get_representation(name, &["matrix_to_atom"])
35
756
            .unwrap()[0]
36
756
            .clone();
37

            
38
        // resolve index domains so that we can enumerate them later
39
756
        let dom = decl.resolved_domain().ok_or(RuleNotApplicable)?;
40

            
41
        // ensure that the subject has a matrix domain.
42
756
        let GroundDomain::Matrix(_, index_domains) = dom.as_ref() else {
43
            return Err(RuleNotApplicable);
44
        };
45

            
46
        // checks are all ok: do the actual rewrite!
47

            
48
        // 1. indices are constant -> find the element being indexed and only return that variable.
49
        // 2. indices are not constant -> flatten matrix and return [flattened_matrix][flattened_index_expr]
50

            
51
        // are the indices constant?
52
756
        let mut indices_are_const = true;
53
756
        let mut indices_as_lits: Vec<Literal> = vec![];
54

            
55
813
        for index in indices {
56
813
            let Some(index) = index.clone().into_literal() else {
57
339
                indices_are_const = false;
58
339
                break;
59
            };
60
474
            indices_as_lits.push(index);
61
        }
62

            
63
756
        if indices_are_const {
64
            // indices are constant -> find the element being indexed and only return that variable.
65
            //
66
417
            let indices_as_name = Name::Represented(Box::new((
67
417
                name.as_ref().clone(),
68
417
                "matrix_to_atom".into(),
69
417
                indices_as_lits.iter().join("_").into(),
70
417
            )));
71

            
72
417
            let subject = repr.expression_down(symbols)?[&indices_as_name].clone();
73
417
            Ok(Reduction::pure(subject))
74
        } else {
75
            // indices are not constant -> flatten matrix and return [flattened_matrix][flattened_index_expr]
76

            
77
            // For now, only supports matrices with index domains in the form int(n..m).
78
            //
79
            // Assuming this, to turn some x[a,b] and x[a,b,c] into x'[z]:
80
            //
81
            // z =                               + size(b) * (a-lb(a)) + 1 * (b-lb(b))  + 1 [2d matrix]
82
            // z = (size(b)*size(c))*(a−lb(a))   + size(c) * (b−lb(b)) + 1 * (c−lb(c))  + 1 [3d matrix]
83
            //
84
            // where lb(a) is the lower bound for a.
85
            //
86
            //
87
            // TODO: For other cases, we should generate table constraints that map the flat indices to
88
            // the real ones.
89

            
90
            // only need to do this for >1d matrices.
91
339
            let n_dims = index_domains.len();
92
339
            assert_ne!(
93
                n_dims, 0,
94
                "a matrix indexing operation should have atleast one index"
95
            );
96
339
            if n_dims == 1 {
97
                // only apply this rule if the index is a decision variable
98
333
                if indices[0].category_of() != Category::Decision {
99
306
                    return Err(RuleNotApplicable);
100
27
                }
101
27
                let represented_expressions = repr
102
27
                    .expression_down(symbols)
103
27
                    .map_err(|_| RuleNotApplicable)?;
104
                // for some m[x], return [m1,m2,m3...mn][x]
105
27
                let new_subject = into_matrix_expr!(
106
27
                    matrix::enumerate_indices(index_domains.clone())
107
                        // for each index in the matrix, create the name that that index will have as
108
                        // an atom
109
75
                        .map(|xs| {
110
75
                            Name::Represented(Box::new((
111
75
                                name.as_ref().clone(),
112
75
                                "matrix_to_atom".into(),
113
75
                                xs.into_iter().join("_").into(),
114
75
                            )))
115
75
                        })
116
75
                        .map(|x| represented_expressions[&x].clone())
117
27
                        .collect_vec()
118
                );
119

            
120
27
                let old_index_domain = &index_domains[0];
121

            
122
27
                let GroundDomain::Int(ranges) = old_index_domain.as_ref() else {
123
                    return Err(RuleNotApplicable);
124
                };
125

            
126
27
                let &[Range::Bounded(from, _)] = &ranges[..] else {
127
                    return Err(RuleNotApplicable);
128
                };
129

            
130
27
                let offset = Expr::Atomic(Metadata::new(), Literal::Int(from - 1).into());
131
27
                let old_index = &indices[0].clone();
132

            
133
27
                return Ok(Reduction::pure(Expr::SafeIndex(
134
27
                    Metadata::new(),
135
27
                    Moo::new(new_subject),
136
27
                    vec![essence_expr!(&old_index - &offset)],
137
27
                )));
138
6
            }
139

            
140
            // some intermediate values we need to do the above..
141

            
142
            // [(lb(a),ub(a)),(lb(b),ub(b)),(lb(c),ub(c),...]
143
6
            let bounds = index_domains
144
6
                .iter()
145
12
                .map(|dom| {
146
12
                    let GroundDomain::Int(ranges) = dom.as_ref() else {
147
                        return Err(RuleNotApplicable);
148
                    };
149

            
150
12
                    let &[Range::Bounded(from, to)] = &ranges[..] else {
151
                        return Err(RuleNotApplicable);
152
                    };
153

            
154
12
                    Ok((from, to))
155
12
                })
156
6
                .process_results(|it| it.collect_vec())?;
157

            
158
            // [size(a),size(b),size(c),..]
159
6
            let sizes = bounds
160
6
                .iter()
161
12
                .map(|(from, to)| (to - from) + 1)
162
6
                .collect_vec();
163

            
164
            // [lb(a),lb(b),lb(c),..]
165
6
            let lower_bounds = bounds.iter().map(|(from, _)| from).collect_vec();
166

            
167
            // from the examples above:
168
            //
169
            // index = (coefficients . terms) + 1
170
            //
171
            // where coefficients = [size(b)*size(c), size(c), 1      ]
172
            //       terms =        [a-lb(a)        , b-lb(b), c-lb(c)]
173

            
174
            // building coefficients.
175
            //
176
            // starting with sizes==[size(a),size(b),size(c)]
177
            //
178
            // ~~ skip(1) ~~>
179
            //
180
            // [size(b),size(c)]
181
            //
182
            // ~~ rev ~~>
183
            //
184
            // [size(c),size(b)]
185
            //
186
            // ~~ chain!(std::iter::once(&1),...) ~~>
187
            //
188
            // [1,size(c),size(b)]
189
            //
190
            // ~~ scan * ~~>
191
            //
192
            // [1,1*size(c),1*size(c)*size(b)]
193
            //
194
            // ~~ reverse ~~>
195
            //
196
            // [size(b)*size(c),size(c),1]
197
6
            let mut coeffs: Vec<Expr> = chain!(std::iter::once(&1), sizes.iter().skip(1).rev())
198
12
                .scan(1, |state, &x| {
199
12
                    *state *= x;
200
12
                    Some(*state)
201
12
                })
202
12
                .map(|x| essence_expr!(&x))
203
6
                .collect_vec();
204

            
205
6
            coeffs.reverse();
206

            
207
            // [(a-lb(a)),b-lb(b),c-lb(c)]
208
6
            let terms: Vec<Expr> = izip!(indices, lower_bounds)
209
12
                .map(|(i, lbi)| essence_expr!(&i - &lbi))
210
6
                .collect_vec();
211

            
212
            // coeffs . terms
213
6
            let mut sum_terms: Vec<Expr> = izip!(coeffs, terms)
214
12
                .map(|(coeff, term)| essence_expr!(&coeff * &term))
215
6
                .collect_vec();
216

            
217
            // (coeffs . terms) + 1
218
6
            sum_terms.push(essence_expr!(1));
219

            
220
6
            let flat_index = Expr::Sum(Metadata::new(), Moo::new(into_matrix_expr![sum_terms]));
221

            
222
            // now lets get the flat matrix.
223

            
224
6
            let repr_exprs = repr.expression_down(symbols)?;
225
6
            let flat_elems = matrix::enumerate_indices(index_domains.clone())
226
45
                .map(|xs| {
227
45
                    Name::Represented(Box::new((
228
45
                        name.as_ref().clone(),
229
45
                        "matrix_to_atom".into(),
230
45
                        xs.into_iter().join("_").into(),
231
45
                    )))
232
45
                })
233
45
                .map(|x| repr_exprs[&x].clone())
234
6
                .collect_vec();
235

            
236
6
            let flat_matrix = into_matrix_expr![flat_elems];
237

            
238
6
            Ok(Reduction::pure(Expr::SafeIndex(
239
6
                Metadata::new(),
240
6
                Moo::new(flat_matrix),
241
6
                vec![flat_index],
242
6
            )))
243
        }
244
    } else {
245
73366
        Err(RuleNotApplicable)
246
    }
247
74122
}
248

            
249
/// Using the `matrix_to_atom` representation rule, rewrite matrix slicing.
250
#[register_rule(("Base", 2000))]
251
24944
fn slice_matrix_to_atom(expr: &Expr, symbols: &SymbolTable) -> ApplicationResult {
252
24944
    let Expr::SafeSlice(_, subject, indices) = expr else {
253
24818
        return Err(RuleNotApplicable);
254
    };
255

            
256
126
    let Expr::Atomic(_, Atom::Reference(decl)) = &**subject else {
257
        return Err(RuleNotApplicable);
258
    };
259

            
260
126
    let Name::WithRepresentation(name, reprs) = &decl.name() as &Name else {
261
96
        return Err(RuleNotApplicable);
262
    };
263
30
    if reprs.first().is_none_or(|x| x.as_str() != "matrix_to_atom") {
264
        return Err(RuleNotApplicable);
265
30
    }
266

            
267
30
    let decl = symbols.lookup(name).unwrap();
268
30
    let repr = symbols
269
30
        .get_representation(name, &["matrix_to_atom"])
270
30
        .unwrap()[0]
271
30
        .clone();
272

            
273
    // resolve index domains so that we can enumerate them later
274
30
    let dom = decl.resolved_domain().ok_or(RuleNotApplicable)?;
275
30
    let GroundDomain::Matrix(_, index_domains) = dom.as_ref() else {
276
        return Err(RuleNotApplicable);
277
    };
278

            
279
30
    let mut indices_as_lits: Vec<Option<Literal>> = vec![];
280
30
    let mut hole_dim: i32 = -1;
281
60
    for (i, index) in indices.iter().enumerate() {
282
60
        match index {
283
30
            Some(e) => {
284
30
                let lit = e.clone().into_literal().ok_or(RuleNotApplicable)?;
285
30
                indices_as_lits.push(Some(lit.clone()));
286
            }
287
            None => {
288
30
                indices_as_lits.push(None);
289
30
                assert_eq!(hole_dim, -1);
290
30
                hole_dim = i as _;
291
            }
292
        }
293
    }
294

            
295
30
    assert_ne!(hole_dim, -1);
296

            
297
30
    let repr_values = repr.expression_down(symbols)?;
298

            
299
30
    let slice = index_domains[hole_dim as usize]
300
30
        .values()
301
30
        .expect("index domain should be finite and enumerable")
302
72
        .map(|i| {
303
72
            let mut indices_as_lits = indices_as_lits.clone();
304
72
            indices_as_lits[hole_dim as usize] = Some(i);
305
72
            let name = Name::Represented(Box::new((
306
72
                name.as_ref().clone(),
307
72
                "matrix_to_atom".into(),
308
72
                indices_as_lits
309
72
                    .into_iter()
310
144
                    .map(|x| x.unwrap())
311
72
                    .join("_")
312
72
                    .into(),
313
            )));
314
72
            repr_values[&name].clone()
315
72
        })
316
30
        .collect_vec();
317

            
318
30
    let new_expr = into_matrix_expr!(slice);
319

            
320
30
    Ok(Reduction::pure(new_expr))
321
24944
}
322

            
323
/// Converts a reference to a 1d-matrix not contained within an indexing or slicing expression to its atoms.
324
#[register_rule(("Base", 2000))]
325
24944
fn matrix_ref_to_atom(expr: &Expr, symbols: &SymbolTable) -> ApplicationResult {
326
    if let Expr::SafeSlice(_, _, _)
327
    | Expr::UnsafeSlice(_, _, _)
328
    | Expr::SafeIndex(_, _, _)
329
24944
    | Expr::UnsafeIndex(_, _, _) = expr
330
    {
331
2709
        return Err(RuleNotApplicable);
332
22235
    };
333

            
334
22235
    for (child, ctx) in expr.holes() {
335
4009
        let Expr::Atomic(_, Atom::Reference(decl)) = child else {
336
15739
            continue;
337
        };
338

            
339
2259
        let Name::WithRepresentation(name, reprs) = &decl.name() as &Name else {
340
2187
            continue;
341
        };
342

            
343
72
        if reprs.first().is_none_or(|x| x.as_str() != "matrix_to_atom") {
344
27
            continue;
345
45
        }
346

            
347
45
        let decl = symbols.lookup(name.as_ref()).unwrap();
348
45
        let repr = symbols
349
45
            .get_representation(name.as_ref(), &["matrix_to_atom"])
350
45
            .unwrap()[0]
351
45
            .clone();
352

            
353
        // resolve index domains so that we can enumerate them later
354
45
        let dom = decl.resolved_domain().ok_or(RuleNotApplicable)?;
355
45
        let GroundDomain::Matrix(_, index_domains) = dom.as_ref() else {
356
            continue;
357
        };
358

            
359
45
        if index_domains.len() > 1 {
360
            continue;
361
45
        }
362

            
363
45
        let Ok(matrix_values) = repr.expression_down(symbols) else {
364
            continue;
365
        };
366

            
367
45
        let flat_values = matrix::enumerate_indices(index_domains.clone())
368
198
            .map(|i| {
369
198
                matrix_values[&Name::Represented(Box::new((
370
198
                    name.as_ref().clone(),
371
198
                    "matrix_to_atom".into(),
372
198
                    i.iter().join("_").into(),
373
198
                )))]
374
198
                    .clone()
375
198
            })
376
45
            .collect_vec();
377
45
        return Ok(Reduction::pure(ctx(into_matrix_expr![flat_values])));
378
    }
379

            
380
22190
    Err(RuleNotApplicable)
381
24944
}