1
use conjure_cp::ast::matrix::safe_index_optimised;
2
use conjure_cp::ast::{
3
    AbstractLiteral, Expression as Expr, GroundDomain, Metadata, Moo, SymbolTable,
4
};
5
use conjure_cp::essence_expr;
6
use conjure_cp::rule_engine::{
7
    ApplicationError::{DomainError, RuleNotApplicable},
8
    ApplicationResult, Reduction, register_rule, register_rule_set,
9
};
10
use conjure_cp::settings::SolverFamily;
11
use conjure_cp::solver::adaptors::smt::TheoryConfig;
12

            
13
// Only applicable when unwrap_alldiff is enabled in the SMT adaptor
14
register_rule_set!("SmtUnwrapAllDiff", ("Base"), |f: &SolverFamily| matches!(
15
982
    f,
16
    SolverFamily::Smt(TheoryConfig {
17
        unwrap_alldiff: true,
18
        ..
19
    })
20
));
21

            
22
#[register_rule(("SmtUnwrapAllDiff", 1000))]
23
9
fn unwrap_alldiff(expr: &Expr, _: &SymbolTable) -> ApplicationResult {
24
9
    let Expr::AllDiff(_, m) = expr else {
25
9
        return Err(RuleNotApplicable);
26
    };
27

            
28
    let dom = m.domain_of().ok_or(RuleNotApplicable)?;
29
    let Some(GroundDomain::Matrix(val_domain, index_domains)) =
30
        dom.resolve().map(Moo::unwrap_or_clone)
31
    else {
32
        return Err(RuleNotApplicable);
33
    };
34
    let [idx_domain] = index_domains.as_slice() else {
35
        return Err(DomainError);
36
    };
37

            
38
    let val_iter = val_domain.values().map_err(|_| DomainError)?;
39
    let clauses = val_iter
40
        .map(|lit| {
41
            let idx_iter = idx_domain.values().map_err(|_| DomainError)?;
42
            let occurences = idx_iter
43
                .map(|idx| {
44
                    let elem = safe_index_optimised(m.as_ref().clone(), idx).ok_or(DomainError)?;
45
                    Ok(essence_expr!("toInt(&elem = &lit)"))
46
                })
47
                .collect::<Result<Vec<_>, _>>()?;
48
            let occurences_list = Expr::AbstractLiteral(
49
                Metadata::new(),
50
                AbstractLiteral::matrix_implied_indices(occurences),
51
            );
52
            Ok(essence_expr!("sum(&occurences_list) <= 1"))
53
        })
54
        .collect::<Result<Vec<_>, _>>()?;
55
    let clauses_list = Expr::AbstractLiteral(
56
        Metadata::new(),
57
        AbstractLiteral::matrix_implied_indices(clauses),
58
    );
59

            
60
    Ok(Reduction::pure(essence_expr!("and(&clauses_list)")))
61
9
}