1
use conjure_cp::ast::{Atom, Expression as Expr, Literal, Metadata, SymbolTable, matrix};
2
use conjure_cp::rule_engine::{
3
    ApplicationError::RuleNotApplicable, ApplicationResult, Reduction, register_rule,
4
};
5

            
6
/// Turn an index into a flattened matrix expression directly into the fully qualified index.
7
///
8
/// E.g. instead of transforming flatten(m)[1] ~> [m[1,1],m[1,2],..][1],
9
///                          do: flatten(m)[1] ~> m[1,1]
10
#[register_rule("Base", 8001, [SafeIndex, UnsafeIndex])]
11
1411751
fn indexed_flatten_matrix(expr: &Expr, _: &SymbolTable) -> ApplicationResult {
12
1411751
    let (subject, index) = match expr {
13
89534
        Expr::SafeIndex(_, subj, idx) | Expr::UnsafeIndex(_, subj, idx) => (subj, idx),
14
1322217
        _ => return Err(RuleNotApplicable),
15
    };
16
89534
    let Expr::Flatten(_, n, matrix) = subject.as_ref() else {
17
89438
        return Err(RuleNotApplicable);
18
    };
19

            
20
96
    if n.is_some() || index.len() != 1 {
21
        // TODO handle flatten with n dimension option
22
        return Err(RuleNotApplicable);
23
96
    }
24

            
25
    // get the actual number of the index
26
96
    let Expr::Atomic(_, Atom::Literal(Literal::Int(index))) = index[0] else {
27
        return Err(RuleNotApplicable);
28
    };
29

            
30
    // resolve index domains so that we can enumerate them later
31
96
    let index_domains =
32
96
        matrix::bound_index_domains_of_expr(matrix.as_ref()).ok_or(RuleNotApplicable)?;
33
192
    if index_domains.iter().any(|domain| domain.length().is_err()) {
34
        return Err(RuleNotApplicable);
35
96
    }
36

            
37
96
    let flat_index = matrix::flat_index_to_full_index(&index_domains, (index - 1) as u64);
38
96
    let flat_index: Vec<Expr> = flat_index.into_iter().map(Into::into).collect();
39

            
40
    // This must be unsafe since we are using a possibly unsafe flat index.
41
    // TODO: this can be made safe if matrix::flat_index_to_full_index fails out of bounds
42
96
    let new_expr = Expr::UnsafeIndex(Metadata::new(), matrix.clone(), flat_index);
43
96
    Ok(Reduction::pure(new_expr))
44
1411751
}