1
use conjure_cp::ast::{Atom, Expression as Expr, Literal, Metadata, Moo, SymbolTable};
2
use conjure_cp::essence_expr;
3
use conjure_cp::rule_engine::{ApplicationError, ApplicationResult, Reduction, register_rule};
4

            
5
use ApplicationError::{DomainError, RuleNotApplicable};
6

            
7
use itertools::Itertools as _;
8

            
9
#[register_rule(("Base", 9000))]
10
fn normalise_lex_gt_geq(expr: &Expr, _: &SymbolTable) -> ApplicationResult {
11
    match expr {
12
        Expr::LexGt(metadata, a, b) => Ok(Reduction::pure(Expr::LexLt(
13
            metadata.clone_dirty(),
14
            b.clone(),
15
            a.clone(),
16
        ))),
17
        Expr::LexGeq(metadata, a, b) => Ok(Reduction::pure(Expr::LexLeq(
18
            metadata.clone_dirty(),
19
            b.clone(),
20
            a.clone(),
21
        ))),
22
        _ => Err(RuleNotApplicable),
23
    }
24
}
25

            
26
/// Turn lexicographical less-than into flat Minion constraints.
27
///
28
/// Minion does not support different-length lists being compared, so we need to truncate the longer.
29
/// Luckily, we can use the fact that [1,1,1] < [1,1,1,x] for any x, e.g. "cat" <lex "cats".
30
///
31
/// - [a,b,c,d] <=lex [e,f,g] <-> [a,b,c] <lex [d,e,f]
32
/// - [a,b,c] <lex [d,e,f,g] <-> [a,b,c] <=lex [d,e,f]
33
/// - Everything else stays the same, with the longer matrix being chopped off
34
#[register_rule(("Minion", 2000))]
35
fn flatten_lex_lt_leq(expr: &Expr, _: &SymbolTable) -> ApplicationResult {
36
    let (a, b) = match expr {
37
        Expr::LexLt(_, a, b) | Expr::LexLeq(_, a, b) => (
38
            Moo::unwrap_or_clone(a.clone())
39
                .unwrap_list()
40
                .ok_or(RuleNotApplicable)?,
41
            Moo::unwrap_or_clone(b.clone())
42
                .unwrap_list()
43
                .ok_or(RuleNotApplicable)?,
44
        ),
45
        _ => return Err(RuleNotApplicable),
46
    };
47

            
48
    let mut atoms_a: Vec<Atom> = a
49
        .into_iter()
50
        .map(|e| e.try_into().map_err(|_| RuleNotApplicable))
51
        .collect::<Result<Vec<_>, ApplicationError>>()?;
52
    let mut atoms_b: Vec<Atom> = b
53
        .into_iter()
54
        .map(|e| e.try_into().map_err(|_| RuleNotApplicable))
55
        .collect::<Result<Vec<_>, ApplicationError>>()?;
56

            
57
    let new_expr = if atoms_a.len() == atoms_b.len() {
58
        // Same length, keep the same comparator
59
        match expr {
60
            Expr::LexLt(..) => Expr::FlatLexLt(Metadata::new(), atoms_a, atoms_b),
61
            Expr::LexLeq(..) => Expr::FlatLexLeq(Metadata::new(), atoms_a, atoms_b),
62
            _ => unreachable!(),
63
        }
64
    } else {
65
        // Different lengths; might need to use a different comparator
66
        // Doing out the 4 cases (which longer * original comparator), it can be determined from
67
        // whether the first matrix is longer
68
        let first_longer = atoms_a.len() > atoms_b.len();
69

            
70
        let min_len = atoms_a.len().min(atoms_b.len());
71
        atoms_a.truncate(min_len);
72
        atoms_b.truncate(min_len);
73

            
74
        match first_longer {
75
            true => Expr::FlatLexLt(Metadata::new(), atoms_a, atoms_b),
76
            false => Expr::FlatLexLeq(Metadata::new(), atoms_a, atoms_b),
77
        }
78
    };
79

            
80
    Ok(Reduction::pure(new_expr))
81
}
82

            
83
/// Expand lexicographical lt/leq into a "recursive or" form
84
/// a <lex b ~> a[1] < b[1] \/ (a[1] = b[1] /\ (a[2] < b[2] \/ ( ... )))
85
///
86
/// If the matrices are different lengths, they can never be equal.
87
/// E.g. if |a| > |b| then a > b if they are equal for the length of b
88
///
89
/// If they are the same length, then the strictness of the comparison comes into effect.
90
///
91
/// Must be applied before matrix_to_list since this enumerates over operand indices.
92
#[register_rule(("Smt", 2001))]
93
fn expand_lex_lt_leq(expr: &Expr, _: &SymbolTable) -> ApplicationResult {
94
    let (a, b) = match expr {
95
        Expr::LexLt(_, a, b) | Expr::LexLeq(_, a, b) => (a, b),
96
        _ => return Err(RuleNotApplicable),
97
    };
98

            
99
    let dom_a = a.domain_of().ok_or(RuleNotApplicable)?;
100
    let dom_b = b.domain_of().ok_or(RuleNotApplicable)?;
101

            
102
    let (Some((_, a_idx_domains)), Some((_, b_idx_domains))) =
103
        (dom_a.as_matrix_ground(), dom_b.as_matrix_ground())
104
    else {
105
        return Err(RuleNotApplicable);
106
    };
107

            
108
    if a_idx_domains.len() != 1 || b_idx_domains.len() != 1 {
109
        return Err(RuleNotApplicable);
110
    }
111

            
112
    let (a_idxs, b_idxs) = (
113
        a_idx_domains[0]
114
            .values()
115
            .map_err(|_| DomainError)?
116
            .collect_vec(),
117
        b_idx_domains[0]
118
            .values()
119
            .map_err(|_| DomainError)?
120
            .collect_vec(),
121
    );
122

            
123
    // If strict, then the base case where they are equal
124
    let or_eq = matches!(expr, Expr::LexLeq(..));
125
    let new_expr = lex_lt_to_recursive_or(a, b, &a_idxs, &b_idxs, or_eq);
126
    Ok(Reduction::pure(new_expr))
127
}
128

            
129
fn lex_lt_to_recursive_or(
130
    a: &Expr,
131
    b: &Expr,
132
    a_idxs: &[Literal],
133
    b_idxs: &[Literal],
134
    allow_eq: bool,
135
) -> Expr {
136
    match (a_idxs, b_idxs) {
137
        ([], []) => allow_eq.into(), // Base case: same length
138
        ([..], []) => false.into(),  // Base case: b is shorter
139
        ([], [..]) => true.into(),   // Base case: a is shorter
140

            
141
        ([a_idx, a_tail @ ..], [b_idx, b_tail @ ..]) => {
142
            let (a_at_idx, b_at_idx) = (
143
                Expr::SafeIndex(
144
                    Metadata::new(),
145
                    Moo::new(a.clone()),
146
                    vec![a_idx.clone().into()],
147
                ),
148
                Expr::SafeIndex(
149
                    Metadata::new(),
150
                    Moo::new(b.clone()),
151
                    vec![b_idx.clone().into()],
152
                ),
153
            );
154

            
155
            let tail = lex_lt_to_recursive_or(a, b, a_tail, b_tail, allow_eq);
156
            essence_expr!(r"&a_at_idx < &b_at_idx \/ (&a_at_idx = &b_at_idx /\ &tail)")
157
        }
158
    }
159
}