1
use conjure_cp::ast::{Atom, Expression as Expr, GroundDomain, Name, SymbolTable, matrix};
2
use conjure_cp::into_matrix_expr;
3
use conjure_cp::rule_engine::{
4
    ApplicationError::RuleNotApplicable, ApplicationResult, Reduction, register_rule,
5
};
6
use itertools::Itertools;
7

            
8
#[register_rule(("Base", 8000))]
9
fn flatten_matrix(expr: &Expr, symbols: &SymbolTable) -> ApplicationResult {
10
    if let Expr::Flatten(_, n, matrix) = expr {
11
        if n.is_some() {
12
            // TODO handle flatten with n dimension option
13
            return Err(RuleNotApplicable);
14
        }
15

            
16
        let Expr::Atomic(_, Atom::Reference(decl)) = matrix.as_ref() else {
17
            return Err(RuleNotApplicable);
18
        };
19

            
20
        let Name::WithRepresentation(name, reprs) = &decl.name() as &Name else {
21
            return Err(RuleNotApplicable);
22
        };
23

            
24
        if reprs.first().is_none_or(|x| x.as_str() != "matrix_to_atom") {
25
            return Err(RuleNotApplicable);
26
        }
27

            
28
        let decl = symbols.lookup(name.as_ref()).unwrap();
29
        let repr = symbols
30
            .get_representation(name.as_ref(), &["matrix_to_atom"])
31
            .unwrap()[0]
32
            .clone();
33

            
34
        // resolve index domains so that we can enumerate them later
35
        let dom = decl.resolved_domain().ok_or(RuleNotApplicable)?;
36
        let GroundDomain::Matrix(_, index_domains) = dom.as_ref() else {
37
            return Err(RuleNotApplicable);
38
        };
39

            
40
        let Ok(matrix_values) = repr.expression_down(symbols) else {
41
            return Err(RuleNotApplicable);
42
        };
43

            
44
        let flat_values = matrix::enumerate_indices(index_domains.clone())
45
            .map(|i| {
46
                matrix_values[&Name::Represented(Box::new((
47
                    name.as_ref().clone(),
48
                    "matrix_to_atom".into(),
49
                    i.iter().join("_").into(),
50
                )))]
51
                    .clone()
52
            })
53
            .collect_vec();
54
        return Ok(Reduction::pure(into_matrix_expr![flat_values]));
55
    }
56

            
57
    Err(RuleNotApplicable)
58
}