1
use conjure_cp::ast::matrix::safe_index_optimised;
2
use conjure_cp::ast::{Atom, Expression as Expr, Literal, Metadata, Moo, SymbolTable};
3
use conjure_cp::essence_expr;
4
use conjure_cp::rule_engine::{ApplicationError, ApplicationResult, Reduction, register_rule};
5

            
6
use ApplicationError::{DomainError, RuleNotApplicable};
7

            
8
use itertools::Itertools as _;
9

            
10
#[register_rule(("Base", 9000))]
11
1389978
fn normalise_lex_gt_geq(expr: &Expr, _: &SymbolTable) -> ApplicationResult {
12
1389978
    match expr {
13
        Expr::LexGt(metadata, a, b) => Ok(Reduction::pure(Expr::LexLt(
14
            metadata.clone_dirty(),
15
            b.clone(),
16
            a.clone(),
17
        ))),
18
        Expr::LexGeq(metadata, a, b) => Ok(Reduction::pure(Expr::LexLeq(
19
            metadata.clone_dirty(),
20
            b.clone(),
21
            a.clone(),
22
        ))),
23
1389978
        _ => Err(RuleNotApplicable),
24
    }
25
1389978
}
26

            
27
/// Turn lexicographical less-than into flat Minion constraints.
28
///
29
/// Minion does not support different-length lists being compared, so we need to truncate the longer.
30
/// Luckily, we can use the fact that [1,1,1] < [1,1,1,x] for any x, e.g. "cat" <lex "cats".
31
///
32
/// - [a,b,c,d] <=lex [e,f,g] <-> [a,b,c] <lex [d,e,f]
33
/// - [a,b,c] <lex [d,e,f,g] <-> [a,b,c] <=lex [d,e,f]
34
/// - Everything else stays the same, with the longer matrix being chopped off
35
#[register_rule(("Minion", 2000))]
36
65436
fn flatten_lex_lt_leq(expr: &Expr, _: &SymbolTable) -> ApplicationResult {
37
65436
    let (a, b) = match expr {
38
729
        Expr::LexLt(_, a, b) | Expr::LexLeq(_, a, b) => (
39
783
            Moo::unwrap_or_clone(a.clone())
40
783
                .unwrap_list()
41
783
                .ok_or(RuleNotApplicable)?,
42
108
            Moo::unwrap_or_clone(b.clone())
43
108
                .unwrap_list()
44
108
                .ok_or(RuleNotApplicable)?,
45
        ),
46
64653
        _ => return Err(RuleNotApplicable),
47
    };
48

            
49
54
    let mut atoms_a: Vec<Atom> = a
50
54
        .into_iter()
51
144
        .map(|e| e.try_into().map_err(|_| RuleNotApplicable))
52
54
        .collect::<Result<Vec<_>, ApplicationError>>()?;
53
54
    let mut atoms_b: Vec<Atom> = b
54
54
        .into_iter()
55
144
        .map(|e| e.try_into().map_err(|_| RuleNotApplicable))
56
54
        .collect::<Result<Vec<_>, ApplicationError>>()?;
57

            
58
54
    let new_expr = if atoms_a.len() == atoms_b.len() {
59
        // Same length, keep the same comparator
60
18
        match expr {
61
            Expr::LexLt(..) => Expr::FlatLexLt(Metadata::new(), atoms_a, atoms_b),
62
18
            Expr::LexLeq(..) => Expr::FlatLexLeq(Metadata::new(), atoms_a, atoms_b),
63
            _ => unreachable!(),
64
        }
65
    } else {
66
        // Different lengths; might need to use a different comparator
67
        // Doing out the 4 cases (which longer * original comparator), it can be determined from
68
        // whether the first matrix is longer
69
36
        let first_longer = atoms_a.len() > atoms_b.len();
70

            
71
36
        let min_len = atoms_a.len().min(atoms_b.len());
72
36
        atoms_a.truncate(min_len);
73
36
        atoms_b.truncate(min_len);
74

            
75
36
        match first_longer {
76
18
            true => Expr::FlatLexLt(Metadata::new(), atoms_a, atoms_b),
77
18
            false => Expr::FlatLexLeq(Metadata::new(), atoms_a, atoms_b),
78
        }
79
    };
80

            
81
54
    Ok(Reduction::pure(new_expr))
82
65436
}
83

            
84
/// Expand lexicographical lt/leq into a "recursive or" form
85
/// a <lex b ~> a[1] < b[1] \/ (a[1] = b[1] /\ (a[2] < b[2] \/ ( ... )))
86
///
87
/// If the matrices are different lengths, they can never be equal.
88
/// E.g. if |a| > |b| then a > b if they are equal for the length of b
89
///
90
/// If they are the same length, then the strictness of the comparison comes into effect.
91
///
92
/// Must be applied before matrix_to_list since this enumerates over operand indices.
93
#[register_rule(("Smt", 2001))]
94
59256
fn expand_lex_lt_leq(expr: &Expr, _: &SymbolTable) -> ApplicationResult {
95
59256
    let (a, b) = match expr {
96
90
        Expr::LexLt(_, a, b) | Expr::LexLeq(_, a, b) => (a, b),
97
59166
        _ => return Err(RuleNotApplicable),
98
    };
99

            
100
90
    let dom_a = a.domain_of().ok_or(RuleNotApplicable)?;
101
90
    let dom_b = b.domain_of().ok_or(RuleNotApplicable)?;
102

            
103
90
    let (Some((_, a_idx_domains)), Some((_, b_idx_domains))) =
104
90
        (dom_a.as_matrix_ground(), dom_b.as_matrix_ground())
105
    else {
106
        return Err(RuleNotApplicable);
107
    };
108

            
109
90
    if a_idx_domains.len() != 1 || b_idx_domains.len() != 1 {
110
        return Err(RuleNotApplicable);
111
90
    }
112

            
113
90
    let (a_idxs, b_idxs) = (
114
90
        a_idx_domains[0]
115
90
            .values()
116
90
            .map_err(|_| DomainError)?
117
90
            .collect_vec(),
118
90
        b_idx_domains[0]
119
90
            .values()
120
90
            .map_err(|_| DomainError)?
121
90
            .collect_vec(),
122
    );
123

            
124
    // If strict, then the base case where they are equal
125
90
    let or_eq = matches!(expr, Expr::LexLeq(..));
126
90
    let new_expr = lex_lt_to_recursive_or(a, b, &a_idxs, &b_idxs, or_eq);
127
90
    Ok(Reduction::pure(new_expr))
128
59256
}
129

            
130
306
fn lex_lt_to_recursive_or(
131
306
    a: &Expr,
132
306
    b: &Expr,
133
306
    a_idxs: &[Literal],
134
306
    b_idxs: &[Literal],
135
306
    allow_eq: bool,
136
306
) -> Expr {
137
306
    match (a_idxs, b_idxs) {
138
306
        ([], []) => allow_eq.into(), // Base case: same length
139
216
        ([..], []) => false.into(),  // Base case: b is shorter
140
216
        ([], [..]) => true.into(),   // Base case: a is shorter
141

            
142
216
        ([a_idx, a_tail @ ..], [b_idx, b_tail @ ..]) => {
143
216
            let a_at_idx = safe_index_optimised(a.clone(), a_idx.clone()).unwrap();
144
216
            let b_at_idx = safe_index_optimised(b.clone(), b_idx.clone()).unwrap();
145
216
            let tail = lex_lt_to_recursive_or(a, b, a_tail, b_tail, allow_eq);
146

            
147
216
            essence_expr!(r"&a_at_idx < &b_at_idx \/ (&a_at_idx = &b_at_idx /\ &tail)")
148
        }
149
    }
150
306
}