1
use conjure_cp::ast::matrix::safe_index_optimised;
2
use conjure_cp::ast::{Atom, Expression as Expr, Literal, Metadata, Moo, SymbolTable};
3
use conjure_cp::essence_expr;
4
use conjure_cp::rule_engine::{ApplicationError, ApplicationResult, Reduction, register_rule};
5

            
6
use ApplicationError::{DomainError, RuleNotApplicable};
7

            
8
use itertools::Itertools as _;
9

            
10
#[register_rule(("Base", 9000))]
11
131234
fn normalise_lex_gt_geq(expr: &Expr, _: &SymbolTable) -> ApplicationResult {
12
131234
    match expr {
13
        Expr::LexGt(metadata, a, b) => Ok(Reduction::pure(Expr::LexLt(
14
            metadata.clone_dirty(),
15
            b.clone(),
16
            a.clone(),
17
        ))),
18
        Expr::LexGeq(metadata, a, b) => Ok(Reduction::pure(Expr::LexLeq(
19
            metadata.clone_dirty(),
20
            b.clone(),
21
            a.clone(),
22
        ))),
23
131234
        _ => Err(RuleNotApplicable),
24
    }
25
131234
}
26

            
27
/// Turn lexicographical less-than into flat Minion constraints.
28
///
29
/// Minion does not support different-length lists being compared, so we need to truncate the longer.
30
/// Luckily, we can use the fact that [1,1,1] < [1,1,1,x] for any x, e.g. "cat" <lex "cats".
31
///
32
/// - [a,b,c,d] <=lex [e,f,g] <-> [a,b,c] <lex [d,e,f]
33
/// - [a,b,c] <lex [d,e,f,g] <-> [a,b,c] <=lex [d,e,f]
34
/// - Everything else stays the same, with the longer matrix being chopped off
35
#[register_rule(("Minion", 2000))]
36
13226
fn flatten_lex_lt_leq(expr: &Expr, _: &SymbolTable) -> ApplicationResult {
37
13226
    let (a, b) = match expr {
38
36
        Expr::LexLt(_, a, b) | Expr::LexLeq(_, a, b) => (
39
54
            Moo::unwrap_or_clone(a.clone())
40
54
                .unwrap_list()
41
54
                .ok_or(RuleNotApplicable)?,
42
24
            Moo::unwrap_or_clone(b.clone())
43
24
                .unwrap_list()
44
24
                .ok_or(RuleNotApplicable)?,
45
        ),
46
13172
        _ => return Err(RuleNotApplicable),
47
    };
48

            
49
12
    let mut atoms_a: Vec<Atom> = a
50
12
        .into_iter()
51
30
        .map(|e| e.try_into().map_err(|_| RuleNotApplicable))
52
12
        .collect::<Result<Vec<_>, ApplicationError>>()?;
53
12
    let mut atoms_b: Vec<Atom> = b
54
12
        .into_iter()
55
30
        .map(|e| e.try_into().map_err(|_| RuleNotApplicable))
56
12
        .collect::<Result<Vec<_>, ApplicationError>>()?;
57

            
58
12
    let new_expr = if atoms_a.len() == atoms_b.len() {
59
        // Same length, keep the same comparator
60
        match expr {
61
            Expr::LexLt(..) => Expr::FlatLexLt(Metadata::new(), atoms_a, atoms_b),
62
            Expr::LexLeq(..) => Expr::FlatLexLeq(Metadata::new(), atoms_a, atoms_b),
63
            _ => unreachable!(),
64
        }
65
    } else {
66
        // Different lengths; might need to use a different comparator
67
        // Doing out the 4 cases (which longer * original comparator), it can be determined from
68
        // whether the first matrix is longer
69
12
        let first_longer = atoms_a.len() > atoms_b.len();
70

            
71
12
        let min_len = atoms_a.len().min(atoms_b.len());
72
12
        atoms_a.truncate(min_len);
73
12
        atoms_b.truncate(min_len);
74

            
75
12
        match first_longer {
76
6
            true => Expr::FlatLexLt(Metadata::new(), atoms_a, atoms_b),
77
6
            false => Expr::FlatLexLeq(Metadata::new(), atoms_a, atoms_b),
78
        }
79
    };
80

            
81
12
    Ok(Reduction::pure(new_expr))
82
13226
}
83

            
84
/// Expand lexicographical lt/leq into a "recursive or" form
85
/// a <lex b ~> a[1] < b[1] \/ (a[1] = b[1] /\ (a[2] < b[2] \/ ( ... )))
86
///
87
/// If the matrices are different lengths, they can never be equal.
88
/// E.g. if |a| > |b| then a > b if they are equal for the length of b
89
///
90
/// If they are the same length, then the strictness of the comparison comes into effect.
91
///
92
/// Must be applied before matrix_to_list since this enumerates over operand indices.
93
#[register_rule(("Smt", 2001))]
94
14484
fn expand_lex_lt_leq(expr: &Expr, _: &SymbolTable) -> ApplicationResult {
95
14484
    let (a, b) = match expr {
96
30
        Expr::LexLt(_, a, b) | Expr::LexLeq(_, a, b) => (a, b),
97
14454
        _ => return Err(RuleNotApplicable),
98
    };
99

            
100
30
    let dom_a = a.domain_of().ok_or(RuleNotApplicable)?;
101
30
    let dom_b = b.domain_of().ok_or(RuleNotApplicable)?;
102

            
103
30
    let (Some((_, a_idx_domains)), Some((_, b_idx_domains))) =
104
30
        (dom_a.as_matrix_ground(), dom_b.as_matrix_ground())
105
    else {
106
        return Err(RuleNotApplicable);
107
    };
108

            
109
30
    if a_idx_domains.len() != 1 || b_idx_domains.len() != 1 {
110
        return Err(RuleNotApplicable);
111
30
    }
112

            
113
30
    let (a_idxs, b_idxs) = (
114
30
        a_idx_domains[0]
115
30
            .values()
116
30
            .map_err(|_| DomainError)?
117
30
            .collect_vec(),
118
30
        b_idx_domains[0]
119
30
            .values()
120
30
            .map_err(|_| DomainError)?
121
30
            .collect_vec(),
122
    );
123

            
124
    // If strict, then the base case where they are equal
125
30
    let or_eq = matches!(expr, Expr::LexLeq(..));
126
30
    let new_expr = lex_lt_to_recursive_or(a, b, &a_idxs, &b_idxs, or_eq);
127
30
    Ok(Reduction::pure(new_expr))
128
14484
}
129

            
130
102
fn lex_lt_to_recursive_or(
131
102
    a: &Expr,
132
102
    b: &Expr,
133
102
    a_idxs: &[Literal],
134
102
    b_idxs: &[Literal],
135
102
    allow_eq: bool,
136
102
) -> Expr {
137
102
    match (a_idxs, b_idxs) {
138
102
        ([], []) => allow_eq.into(), // Base case: same length
139
72
        ([..], []) => false.into(),  // Base case: b is shorter
140
72
        ([], [..]) => true.into(),   // Base case: a is shorter
141

            
142
72
        ([a_idx, a_tail @ ..], [b_idx, b_tail @ ..]) => {
143
72
            let a_at_idx = safe_index_optimised(a.clone(), a_idx.clone()).unwrap();
144
72
            let b_at_idx = safe_index_optimised(b.clone(), b_idx.clone()).unwrap();
145
72
            let tail = lex_lt_to_recursive_or(a, b, a_tail, b_tail, allow_eq);
146

            
147
72
            essence_expr!(r"&a_at_idx < &b_at_idx \/ (&a_at_idx = &b_at_idx /\ &tail)")
148
        }
149
    }
150
102
}