1
use conjure_cp::ast::{Atom, Expression as Expr, GroundDomain, Literal, Moo, Name, SymbolTable};
2
use conjure_cp::rule_engine::{
3
    ApplicationError::RuleNotApplicable, ApplicationResult, Reduction, register_rule,
4
};
5
use itertools::Itertools;
6

            
7
#[register_rule(("Base", 8001))]
8
fn indexed_flatten_matrix(expr: &Expr, symbols: &SymbolTable) -> ApplicationResult {
9
    match expr {
10
        Expr::SafeIndex(_, subject, index) | Expr::UnsafeIndex(_, subject, index) => {
11
            if let Expr::Flatten(_, n, matrix) = subject.as_ref() {
12
                if n.is_some() {
13
                    // TODO handle flatten with n dimension option
14
                    return Err(RuleNotApplicable);
15
                }
16

            
17
                if index.len() > 1 {
18
                    return Err(RuleNotApplicable);
19
                }
20

            
21
                // get the actual number of the index
22
                let Expr::Atomic(_, Atom::Literal(Literal::Int(index))) = index[0] else {
23
                    return Err(RuleNotApplicable);
24
                };
25

            
26
                let Expr::Atomic(_, Atom::Reference(decl)) = matrix.as_ref() else {
27
                    return Err(RuleNotApplicable);
28
                };
29

            
30
                let Name::WithRepresentation(name, reprs) = &decl.name() as &Name else {
31
                    return Err(RuleNotApplicable);
32
                };
33

            
34
                if reprs.first().is_none_or(|x| x.as_str() != "matrix_to_atom") {
35
                    return Err(RuleNotApplicable);
36
                }
37

            
38
                let decl = symbols.lookup(name.as_ref()).unwrap();
39
                let repr = symbols
40
                    .get_representation(name.as_ref(), &["matrix_to_atom"])
41
                    .unwrap()[0]
42
                    .clone();
43

            
44
                // resolve index domains so that we can enumerate them later
45
                let dom = decl.resolved_domain().ok_or(RuleNotApplicable)?;
46
                let GroundDomain::Matrix(_, index_domains) = dom.as_ref() else {
47
                    return Err(RuleNotApplicable);
48
                };
49

            
50
                let Ok(matrix_values) = repr.expression_down(symbols) else {
51
                    return Err(RuleNotApplicable);
52
                };
53

            
54
                let flat_index = ndim_to_flat_index(index_domains.clone(), index as usize - 1);
55
                println!("{}", flat_index.iter().join("_"));
56

            
57
                let flat_value = matrix_values[&Name::Represented(Box::new((
58
                    name.as_ref().clone(),
59
                    "matrix_to_atom".into(),
60
                    flat_index.iter().join("_").into(),
61
                )))]
62
                    .clone();
63

            
64
                return Ok(Reduction::pure(flat_value));
65
            }
66

            
67
            Err(RuleNotApplicable)
68
        }
69
        _ => Err(RuleNotApplicable),
70
    }
71
}
72

            
73
// Given index domains for a multi-dimensional matrix and the nth index in the flattened matrix, find the coordinates in the original matrix
74
fn ndim_to_flat_index(index_domains: Vec<Moo<GroundDomain>>, index: usize) -> Vec<usize> {
75
    let mut remaining = index;
76
    let mut multipliers = vec![1; index_domains.len()];
77

            
78
    for i in (0..index_domains.len() - 1).rev() {
79
        multipliers[i] = multipliers[i + 1] * index_domains[i + 1].as_ref().length().unwrap();
80
    }
81

            
82
    let mut coords = vec![0; index_domains.len()];
83
    for i in 0..index_domains.len() {
84
        coords[i] = remaining / multipliers[i] as usize;
85
        remaining %= multipliers[i] as usize;
86
    }
87

            
88
    // adjust for 1-based indexing
89
    for coord in coords.iter_mut() {
90
        *coord += 1;
91
    }
92
    coords
93
}