1
use conjure_cp::ast::categories::{Category, CategoryOf};
2
use conjure_cp::ast::{
3
    Atom, Expression as Expr, GroundDomain, Literal, Metadata, Moo, Name, Range, SymbolTable,
4
    matrix,
5
};
6
use conjure_cp::essence_expr;
7
use conjure_cp::into_matrix_expr;
8
use conjure_cp::rule_engine::{
9
    ApplicationError::RuleNotApplicable, ApplicationResult, Reduction, register_rule,
10
};
11
use itertools::{Itertools, chain, izip};
12
use uniplate::Uniplate;
13

            
14
use crate::bottom_up_adaptor::as_bottom_up;
15

            
16
/// Using the `matrix_to_atom`  representation rule, rewrite matrix indexing.
17
#[register_rule(("Base", 5000))]
18
fn index_matrix_to_atom(expr: &Expr, symbols: &SymbolTable) -> ApplicationResult {
19
    (as_bottom_up(index_matrix_to_atom_impl))(expr, symbols)
20
}
21

            
22
fn index_matrix_to_atom_impl(expr: &Expr, symbols: &SymbolTable) -> ApplicationResult {
23
    // is this an indexing operation?
24
    if let Expr::SafeIndex(_, subject, indices) = expr
25

            
26
    // ensure that we are indexing a decision variable with the representation "matrix_to_atom"
27
    // selected for it.
28
    //
29
    && let Expr::Atomic(_, Atom::Reference(decl)) = &**subject
30
    && let Name::WithRepresentation(name, reprs) =  &decl.name() as &Name
31
    && reprs.first().is_none_or(|x| x.as_str() == "matrix_to_atom")
32
    {
33
        let repr = symbols
34
            .get_representation(name, &["matrix_to_atom"])
35
            .unwrap()[0]
36
            .clone();
37

            
38
        // resolve index domains so that we can enumerate them later
39
        let dom = decl.resolved_domain().ok_or(RuleNotApplicable)?;
40

            
41
        // ensure that the subject has a matrix domain.
42
        let GroundDomain::Matrix(_, index_domains) = dom.as_ref() else {
43
            return Err(RuleNotApplicable);
44
        };
45

            
46
        // checks are all ok: do the actual rewrite!
47

            
48
        // 1. indices are constant -> find the element being indexed and only return that variable.
49
        // 2. indices are not constant -> flatten matrix and return [flattened_matrix][flattened_index_expr]
50

            
51
        // are the indices constant?
52
        let mut indices_are_const = true;
53
        let mut indices_as_lits: Vec<Literal> = vec![];
54

            
55
        for index in indices {
56
            let Some(index) = index.clone().into_literal() else {
57
                indices_are_const = false;
58
                break;
59
            };
60
            indices_as_lits.push(index);
61
        }
62

            
63
        if indices_are_const {
64
            // indices are constant -> find the element being indexed and only return that variable.
65
            //
66
            let indices_as_name = Name::Represented(Box::new((
67
                name.as_ref().clone(),
68
                "matrix_to_atom".into(),
69
                indices_as_lits.iter().join("_").into(),
70
            )));
71

            
72
            let subject = repr.expression_down(symbols)?[&indices_as_name].clone();
73
            Ok(Reduction::pure(subject))
74
        } else {
75
            // indices are not constant -> flatten matrix and return [flattened_matrix][flattened_index_expr]
76

            
77
            // For now, only supports matrices with index domains in the form int(n..m).
78
            //
79
            // Assuming this, to turn some x[a,b] and x[a,b,c] into x'[z]:
80
            //
81
            // z =                               + size(b) * (a-lb(a)) + 1 * (b-lb(b))  + 1 [2d matrix]
82
            // z = (size(b)*size(c))*(a−lb(a))   + size(c) * (b−lb(b)) + 1 * (c−lb(c))  + 1 [3d matrix]
83
            //
84
            // where lb(a) is the lower bound for a.
85
            //
86
            //
87
            // TODO: For other cases, we should generate table constraints that map the flat indices to
88
            // the real ones.
89

            
90
            // only need to do this for >1d matrices.
91
            let n_dims = index_domains.len();
92
            assert_ne!(
93
                n_dims, 0,
94
                "a matrix indexing operation should have atleast one index"
95
            );
96
            if n_dims == 1 {
97
                // only apply this rule if the index is a decision variable
98
                if indices[0].category_of() != Category::Decision {
99
                    return Err(RuleNotApplicable);
100
                }
101
                let represented_expressions = repr
102
                    .expression_down(symbols)
103
                    .map_err(|_| RuleNotApplicable)?;
104
                // for some m[x], return [m1,m2,m3...mn][x]
105
                let new_subject = into_matrix_expr!(
106
                    matrix::enumerate_indices(index_domains.clone())
107
                        // for each index in the matrix, create the name that that index will have as
108
                        // an atom
109
                        .map(|xs| {
110
                            Name::Represented(Box::new((
111
                                name.as_ref().clone(),
112
                                "matrix_to_atom".into(),
113
                                xs.into_iter().join("_").into(),
114
                            )))
115
                        })
116
                        .map(|x| represented_expressions[&x].clone())
117
                        .collect_vec()
118
                );
119

            
120
                let old_index_domain = &index_domains[0];
121

            
122
                let GroundDomain::Int(ranges) = old_index_domain.as_ref() else {
123
                    return Err(RuleNotApplicable);
124
                };
125

            
126
                let &[Range::Bounded(from, _)] = &ranges[..] else {
127
                    return Err(RuleNotApplicable);
128
                };
129

            
130
                let offset = Expr::Atomic(Metadata::new(), Literal::Int(from - 1).into());
131
                let old_index = &indices[0].clone();
132

            
133
                return Ok(Reduction::pure(Expr::SafeIndex(
134
                    Metadata::new(),
135
                    Moo::new(new_subject),
136
                    vec![essence_expr!(&old_index - &offset)],
137
                )));
138
            }
139

            
140
            // some intermediate values we need to do the above..
141

            
142
            // [(lb(a),ub(a)),(lb(b),ub(b)),(lb(c),ub(c),...]
143
            let bounds = index_domains
144
                .iter()
145
                .map(|dom| {
146
                    let GroundDomain::Int(ranges) = dom.as_ref() else {
147
                        return Err(RuleNotApplicable);
148
                    };
149

            
150
                    let &[Range::Bounded(from, to)] = &ranges[..] else {
151
                        return Err(RuleNotApplicable);
152
                    };
153

            
154
                    Ok((from, to))
155
                })
156
                .process_results(|it| it.collect_vec())?;
157

            
158
            // [size(a),size(b),size(c),..]
159
            let sizes = bounds
160
                .iter()
161
                .map(|(from, to)| (to - from) + 1)
162
                .collect_vec();
163

            
164
            // [lb(a),lb(b),lb(c),..]
165
            let lower_bounds = bounds.iter().map(|(from, _)| from).collect_vec();
166

            
167
            // from the examples above:
168
            //
169
            // index = (coefficients . terms) + 1
170
            //
171
            // where coefficients = [size(b)*size(c), size(c), 1      ]
172
            //       terms =        [a-lb(a)        , b-lb(b), c-lb(c)]
173

            
174
            // building coefficients.
175
            //
176
            // starting with sizes==[size(a),size(b),size(c)]
177
            //
178
            // ~~ skip(1) ~~>
179
            //
180
            // [size(b),size(c)]
181
            //
182
            // ~~ rev ~~>
183
            //
184
            // [size(c),size(b)]
185
            //
186
            // ~~ chain!(std::iter::once(&1),...) ~~>
187
            //
188
            // [1,size(c),size(b)]
189
            //
190
            // ~~ scan * ~~>
191
            //
192
            // [1,1*size(c),1*size(c)*size(b)]
193
            //
194
            // ~~ reverse ~~>
195
            //
196
            // [size(b)*size(c),size(c),1]
197
            let mut coeffs: Vec<Expr> = chain!(std::iter::once(&1), sizes.iter().skip(1).rev())
198
                .scan(1, |state, &x| {
199
                    *state *= x;
200
                    Some(*state)
201
                })
202
                .map(|x| essence_expr!(&x))
203
                .collect_vec();
204

            
205
            coeffs.reverse();
206

            
207
            // [(a-lb(a)),b-lb(b),c-lb(c)]
208
            let terms: Vec<Expr> = izip!(indices, lower_bounds)
209
                .map(|(i, lbi)| essence_expr!(&i - &lbi))
210
                .collect_vec();
211

            
212
            // coeffs . terms
213
            let mut sum_terms: Vec<Expr> = izip!(coeffs, terms)
214
                .map(|(coeff, term)| essence_expr!(&coeff * &term))
215
                .collect_vec();
216

            
217
            // (coeffs . terms) + 1
218
            sum_terms.push(essence_expr!(1));
219

            
220
            let flat_index = Expr::Sum(Metadata::new(), Moo::new(into_matrix_expr![sum_terms]));
221

            
222
            // now lets get the flat matrix.
223

            
224
            let repr_exprs = repr.expression_down(symbols)?;
225
            let flat_elems = matrix::enumerate_indices(index_domains.clone())
226
                .map(|xs| {
227
                    Name::Represented(Box::new((
228
                        name.as_ref().clone(),
229
                        "matrix_to_atom".into(),
230
                        xs.into_iter().join("_").into(),
231
                    )))
232
                })
233
                .map(|x| repr_exprs[&x].clone())
234
                .collect_vec();
235

            
236
            let flat_matrix = into_matrix_expr![flat_elems];
237

            
238
            Ok(Reduction::pure(Expr::SafeIndex(
239
                Metadata::new(),
240
                Moo::new(flat_matrix),
241
                vec![flat_index],
242
            )))
243
        }
244
    } else {
245
        Err(RuleNotApplicable)
246
    }
247
}
248

            
249
/// Using the `matrix_to_atom` representation rule, rewrite matrix slicing.
250
#[register_rule(("Base", 2000))]
251
fn slice_matrix_to_atom(expr: &Expr, symbols: &SymbolTable) -> ApplicationResult {
252
    let Expr::SafeSlice(_, subject, indices) = expr else {
253
        return Err(RuleNotApplicable);
254
    };
255

            
256
    let Expr::Atomic(_, Atom::Reference(decl)) = &**subject else {
257
        return Err(RuleNotApplicable);
258
    };
259

            
260
    let Name::WithRepresentation(name, reprs) = &decl.name() as &Name else {
261
        return Err(RuleNotApplicable);
262
    };
263
    if reprs.first().is_none_or(|x| x.as_str() != "matrix_to_atom") {
264
        return Err(RuleNotApplicable);
265
    }
266

            
267
    let decl = symbols.lookup(name).unwrap();
268
    let repr = symbols
269
        .get_representation(name, &["matrix_to_atom"])
270
        .unwrap()[0]
271
        .clone();
272

            
273
    // resolve index domains so that we can enumerate them later
274
    let dom = decl.resolved_domain().ok_or(RuleNotApplicable)?;
275
    let GroundDomain::Matrix(_, index_domains) = dom.as_ref() else {
276
        return Err(RuleNotApplicable);
277
    };
278

            
279
    let mut indices_as_lits: Vec<Option<Literal>> = vec![];
280
    let mut hole_dim: i32 = -1;
281
    for (i, index) in indices.iter().enumerate() {
282
        match index {
283
            Some(e) => {
284
                let lit = e.clone().into_literal().ok_or(RuleNotApplicable)?;
285
                indices_as_lits.push(Some(lit.clone()));
286
            }
287
            None => {
288
                indices_as_lits.push(None);
289
                assert_eq!(hole_dim, -1);
290
                hole_dim = i as _;
291
            }
292
        }
293
    }
294

            
295
    assert_ne!(hole_dim, -1);
296

            
297
    let repr_values = repr.expression_down(symbols)?;
298

            
299
    let slice = index_domains[hole_dim as usize]
300
        .values()
301
        .expect("index domain should be finite and enumerable")
302
        .map(|i| {
303
            let mut indices_as_lits = indices_as_lits.clone();
304
            indices_as_lits[hole_dim as usize] = Some(i);
305
            let name = Name::Represented(Box::new((
306
                name.as_ref().clone(),
307
                "matrix_to_atom".into(),
308
                indices_as_lits
309
                    .into_iter()
310
                    .map(|x| x.unwrap())
311
                    .join("_")
312
                    .into(),
313
            )));
314
            repr_values[&name].clone()
315
        })
316
        .collect_vec();
317

            
318
    let new_expr = into_matrix_expr!(slice);
319

            
320
    Ok(Reduction::pure(new_expr))
321
}
322

            
323
/// Converts a reference to a 1d-matrix not contained within an indexing or slicing expression to its atoms.
324
#[register_rule(("Base", 2000))]
325
fn matrix_ref_to_atom(expr: &Expr, symbols: &SymbolTable) -> ApplicationResult {
326
    if let Expr::SafeSlice(_, _, _)
327
    | Expr::UnsafeSlice(_, _, _)
328
    | Expr::SafeIndex(_, _, _)
329
    | Expr::UnsafeIndex(_, _, _) = expr
330
    {
331
        return Err(RuleNotApplicable);
332
    };
333

            
334
    for (child, ctx) in expr.holes() {
335
        let Expr::Atomic(_, Atom::Reference(decl)) = child else {
336
            continue;
337
        };
338

            
339
        let Name::WithRepresentation(name, reprs) = &decl.name() as &Name else {
340
            continue;
341
        };
342

            
343
        if reprs.first().is_none_or(|x| x.as_str() != "matrix_to_atom") {
344
            continue;
345
        }
346

            
347
        let decl = symbols.lookup(name.as_ref()).unwrap();
348
        let repr = symbols
349
            .get_representation(name.as_ref(), &["matrix_to_atom"])
350
            .unwrap()[0]
351
            .clone();
352

            
353
        // resolve index domains so that we can enumerate them later
354
        let dom = decl.resolved_domain().ok_or(RuleNotApplicable)?;
355
        let GroundDomain::Matrix(_, index_domains) = dom.as_ref() else {
356
            continue;
357
        };
358

            
359
        if index_domains.len() > 1 {
360
            continue;
361
        }
362

            
363
        let Ok(matrix_values) = repr.expression_down(symbols) else {
364
            continue;
365
        };
366

            
367
        let flat_values = matrix::enumerate_indices(index_domains.clone())
368
            .map(|i| {
369
                matrix_values[&Name::Represented(Box::new((
370
                    name.as_ref().clone(),
371
                    "matrix_to_atom".into(),
372
                    i.iter().join("_").into(),
373
                )))]
374
                    .clone()
375
            })
376
            .collect_vec();
377
        return Ok(Reduction::pure(ctx(into_matrix_expr![flat_values])));
378
    }
379

            
380
    Err(RuleNotApplicable)
381
}