1
use conjure_cp::ast::{
2
    Atom, Expression as Expr, GroundDomain, Literal, Metadata, SymbolTable, matrix,
3
};
4
use conjure_cp::rule_engine::{
5
    ApplicationError::RuleNotApplicable, ApplicationResult, Reduction, register_rule,
6
};
7

            
8
/// Turn an index into a flattened matrix expression directly into the fully qualified index.
9
///
10
/// E.g. instead of transforming flatten(m)[1] ~> [m[1,1],m[1,2],..][1],
11
///                          do: flatten(m)[1] ~> m[1,1]
12
#[register_rule(("Base", 8001))]
13
95513
fn indexed_flatten_matrix(expr: &Expr, _: &SymbolTable) -> ApplicationResult {
14
95513
    let (subject, index) = match expr {
15
7230
        Expr::SafeIndex(_, subj, idx) | Expr::UnsafeIndex(_, subj, idx) => (subj, idx),
16
88283
        _ => return Err(RuleNotApplicable),
17
    };
18
7230
    let Expr::Flatten(_, n, matrix) = subject.as_ref() else {
19
7206
        return Err(RuleNotApplicable);
20
    };
21

            
22
24
    if n.is_some() || index.len() != 1 {
23
        // TODO handle flatten with n dimension option
24
        return Err(RuleNotApplicable);
25
24
    }
26

            
27
    // get the actual number of the index
28
24
    let Expr::Atomic(_, Atom::Literal(Literal::Int(index))) = index[0] else {
29
        return Err(RuleNotApplicable);
30
    };
31

            
32
    // resolve index domains so that we can enumerate them later
33
24
    let dom = matrix
34
24
        .domain_of()
35
24
        .and_then(|dom| dom.resolve())
36
24
        .ok_or(RuleNotApplicable)?;
37

            
38
24
    let GroundDomain::Matrix(_, index_domains) = dom.as_ref() else {
39
        return Err(RuleNotApplicable);
40
    };
41

            
42
24
    let flat_index = matrix::flat_index_to_full_index(index_domains, (index - 1) as u64);
43
24
    let flat_index: Vec<Expr> = flat_index.into_iter().map(Into::into).collect();
44

            
45
    // This must be unsafe since we are using a possibly unsafe flat index.
46
    // TODO: this can be made safe if matrix::flat_index_to_full_index fails out of bounds
47
24
    let new_expr = Expr::UnsafeIndex(Metadata::new(), matrix.clone(), flat_index);
48
24
    Ok(Reduction::pure(new_expr))
49
95513
}