1
use conjure_cp::ast::{Atom, Expression as Expr, Literal};
2
use conjure_cp::ast::{SATIntEncoding, SymbolTable};
3
use conjure_cp::rule_engine::ApplicationError;
4
use conjure_cp::rule_engine::{
5
    ApplicationError::RuleNotApplicable, ApplicationResult, Reduction, register_rule,
6
};
7

            
8
use crate::sat::boolean::{tseytin_and, tseytin_iff, tseytin_not, tseytin_or};
9
use conjure_cp::ast::Metadata;
10
use conjure_cp::ast::Moo;
11
use conjure_cp::into_matrix_expr;
12

            
13
/// This function confirms that all of the input expressions are order SATInts, and returns vectors for each input of their bits
14
/// This function also normalizes order SATInt operands to a common value range.
15
1848
pub fn validate_order_int_operands(
16
1848
    exprs: Vec<Expr>,
17
1848
) -> Result<(Vec<Vec<Expr>>, i32, i32), ApplicationError> {
18
    // Iterate over all inputs
19
    // Check they are order and calulate a lower and upper bound
20
1848
    let mut global_min: i32 = i32::MAX;
21
1848
    let mut global_max: i32 = i32::MIN;
22

            
23
3360
    for operand in &exprs {
24
2412
        let Expr::SATInt(_, SATIntEncoding::Order, _, (local_min, local_max)) = operand else {
25
948
            return Err(RuleNotApplicable);
26
        };
27
2412
        global_min = global_min.min(*local_min);
28
2412
        global_max = global_max.max(*local_max);
29
    }
30

            
31
    // build out by iterating over each operand and expanding it to match the new bounds
32
900
    let out: Vec<Vec<Expr>> = exprs
33
900
        .into_iter()
34
1770
        .map(|expr| {
35
1770
            let Expr::SATInt(_, SATIntEncoding::Order, inner, (local_min, local_max)) = expr else {
36
                return Err(RuleNotApplicable);
37
            };
38

            
39
1770
            let Some(v) = inner.as_ref().clone().unwrap_list() else {
40
                return Err(RuleNotApplicable);
41
            };
42

            
43
            // calulcate how many trues/falses to prepend/append
44
1770
            let prefix_len = (local_min - global_min) as usize;
45
1770
            let postfix_len = (global_max - local_max) as usize;
46

            
47
1770
            let mut bits = Vec::with_capacity(v.len() + prefix_len + postfix_len);
48

            
49
            // add `true`s to start
50
1770
            bits.extend(std::iter::repeat_n(
51
1770
                Expr::Atomic(Metadata::new(), Atom::Literal(Literal::Bool(true))),
52
1770
                prefix_len,
53
            ));
54

            
55
1770
            bits.extend(v);
56

            
57
            // add `false`s to end
58
1770
            bits.extend(std::iter::repeat_n(
59
1770
                Expr::Atomic(Metadata::new(), Atom::Literal(Literal::Bool(false))),
60
1770
                postfix_len,
61
            ));
62

            
63
1770
            Ok(bits)
64
1770
        })
65
900
        .collect::<Result<_, _>>()?;
66

            
67
900
    Ok((out, global_min, global_max))
68
1848
}
69

            
70
/// Encodes a < b for order integers.
71
///
72
/// `x < y` iff `exists i . (NOT x_i AND y_i)`
73
690
fn sat_order_lt(
74
690
    a_bits: Vec<Expr>,
75
690
    b_bits: Vec<Expr>,
76
690
    clauses: &mut Vec<conjure_cp::ast::CnfClause>,
77
690
    symbols: &mut SymbolTable,
78
690
) -> Expr {
79
690
    let mut result = Expr::Atomic(Metadata::new(), Atom::Literal(Literal::Bool(false)));
80

            
81
4686
    for (a_i, b_i) in a_bits.iter().zip(b_bits.iter()) {
82
        // (NOT a_i AND b_i)
83
4686
        let not_a_i = tseytin_not(a_i.clone(), clauses, symbols);
84
4686
        let current_term = tseytin_and(&vec![not_a_i, b_i.clone()], clauses, symbols);
85
4686

            
86
        // accumulate (NOT a_i AND b_i) into OR term
87
4686
        result = tseytin_or(&vec![result, current_term], clauses, symbols);
88
4686
    }
89
690
    result
90
690
}
91

            
92
/// Converts an integer literal to SATInt form
93
///
94
/// ```text
95
///  3
96
///  ~~>
97
///  SATInt([true;int(1..), (3, 3)])
98
///
99
/// ```
100
#[register_rule(("SAT_Order", 9500))]
101
127956
fn literal_sat_order_int(expr: &Expr, _: &SymbolTable) -> ApplicationResult {
102
612
    let value = {
103
5862
        if let Expr::Atomic(_, Atom::Literal(Literal::Int(value))) = expr {
104
612
            *value
105
        } else {
106
127344
            return Err(RuleNotApplicable);
107
        }
108
    };
109

            
110
612
    Ok(Reduction::pure(Expr::SATInt(
111
612
        Metadata::new(),
112
612
        SATIntEncoding::Order,
113
612
        Moo::new(into_matrix_expr!(vec![Expr::Atomic(
114
612
            Metadata::new(),
115
612
            Atom::Literal(Literal::Bool(true)),
116
612
        )])),
117
612
        (value, value),
118
612
    )))
119
127956
}
120

            
121
/// Builds CNF for equality between two order SATInt bit-vectors.
122
/// This function is used by both eq and neq rules, with the output negated for neq.
123
/// Returns (expr, clauses, symbols).
124
180
fn sat_order_eq_expr(
125
180
    lhs_bits: &[Expr],
126
180
    rhs_bits: &[Expr],
127
180
    symbols: &SymbolTable,
128
180
) -> (Expr, Vec<conjure_cp::ast::CnfClause>, SymbolTable) {
129
180
    let bit_count = lhs_bits.len();
130

            
131
180
    let mut output = true.into();
132
180
    let mut new_symbols = symbols.clone();
133
180
    let mut new_clauses = vec![];
134

            
135
1026
    for i in 0..bit_count {
136
1026
        let comparison = tseytin_iff(
137
1026
            lhs_bits[i].clone(),
138
1026
            rhs_bits[i].clone(),
139
1026
            &mut new_clauses,
140
1026
            &mut new_symbols,
141
1026
        );
142
1026
        output = tseytin_and(
143
1026
            &vec![comparison, output],
144
1026
            &mut new_clauses,
145
1026
            &mut new_symbols,
146
1026
        );
147
1026
    }
148

            
149
180
    (output, new_clauses, new_symbols)
150
180
}
151

            
152
/// Converts a = expression between two order SATInts to a boolean expression in cnf
153
///
154
/// ```text
155
/// SATInt(a) = SATInt(b) ~> Bool
156
/// ```
157
#[register_rule(("SAT_Order", 9100))]
158
34818
fn eq_sat_order(expr: &Expr, symbols: &SymbolTable) -> ApplicationResult {
159
34818
    let Expr::Eq(_, lhs, rhs) = expr else {
160
34626
        return Err(RuleNotApplicable);
161
    };
162

            
163
156
    let (binding, _, _) =
164
192
        validate_order_int_operands(vec![lhs.as_ref().clone(), rhs.as_ref().clone()])?;
165
156
    let [lhs_bits, rhs_bits] = binding.as_slice() else {
166
        return Err(RuleNotApplicable);
167
    };
168

            
169
156
    let (output, new_clauses, new_symbols) = sat_order_eq_expr(lhs_bits, rhs_bits, symbols);
170

            
171
156
    Ok(Reduction::cnf(output, new_clauses, new_symbols))
172
34818
}
173

            
174
/// Converts a != expression between two order SATInts to a boolean expression in cnf
175
#[register_rule(("SAT_Order", 9100))]
176
34818
fn neq_sat_order(expr: &Expr, symbols: &SymbolTable) -> ApplicationResult {
177
34818
    let Expr::Neq(_, lhs, rhs) = expr else {
178
34788
        return Err(RuleNotApplicable);
179
    }; // considered covered
180

            
181
24
    let (binding, _, _) =
182
30
        validate_order_int_operands(vec![lhs.as_ref().clone(), rhs.as_ref().clone()])?;
183
24
    let [lhs_bits, rhs_bits] = binding.as_slice() else {
184
        return Err(RuleNotApplicable); // consider covered
185
    };
186

            
187
24
    let (mut output, mut new_clauses, mut new_symbols) =
188
24
        sat_order_eq_expr(lhs_bits, rhs_bits, symbols);
189

            
190
24
    output = tseytin_not(output, &mut new_clauses, &mut new_symbols);
191

            
192
24
    Ok(Reduction::cnf(output, new_clauses, new_symbols))
193
34818
}
194

            
195
/// Converts a </>/<=/>= expression between two order SATInts to a boolean expression in cnf
196
///
197
/// ```text
198
/// SATInt(a) </>/<=/>= SATInt(b) ~> Bool
199
///
200
/// ```
201
/// Note: < and <= are rewritten by swapping operands to reuse lt logic.
202
#[register_rule(("SAT_Order", 9100))]
203
34818
fn ineq_sat_order(expr: &Expr, symbols: &SymbolTable) -> ApplicationResult {
204
34818
    let (lhs, rhs, negate) = match expr {
205
        // A < B -> sat_order_lt(A, B)
206
18
        Expr::Lt(_, x, y) => (x, y, false),
207
        // A > B -> sat_order_lt(B, A)
208
        Expr::Gt(_, x, y) => (y, x, false),
209
        // A <= B -> NOT (B < A)
210
978
        Expr::Leq(_, x, y) => (y, x, true),
211
        // A >= B -> NOT (A < B)
212
594
        Expr::Geq(_, x, y) => (x, y, true),
213
33228
        _ => return Err(RuleNotApplicable),
214
    };
215

            
216
690
    let (binding, _, _) =
217
1590
        validate_order_int_operands(vec![lhs.as_ref().clone(), rhs.as_ref().clone()])?;
218
690
    let [lhs_bits, rhs_bits] = binding.as_slice() else {
219
        return Err(RuleNotApplicable);
220
    };
221

            
222
690
    let mut new_symbols = symbols.clone();
223
690
    let mut new_clauses = vec![];
224

            
225
690
    let mut output = sat_order_lt(
226
690
        lhs_bits.clone(),
227
690
        rhs_bits.clone(),
228
690
        &mut new_clauses,
229
690
        &mut new_symbols,
230
    );
231

            
232
690
    if negate {
233
672
        output = tseytin_not(output, &mut new_clauses, &mut new_symbols);
234
672
    }
235

            
236
690
    Ok(Reduction::cnf(output, new_clauses, new_symbols))
237
34818
}
238

            
239
/// Converts a - expression for a SATInt to a new SATInt
240
///
241
/// ```text
242
/// -SATInt(a) ~> SATInt(b)
243
///
244
/// ```
245
#[register_rule(("SAT_Order", 9100))]
246
34818
fn neg_sat_order(expr: &Expr, symbols: &SymbolTable) -> ApplicationResult {
247
34818
    let Expr::Neg(_, value) = expr else {
248
34782
        return Err(RuleNotApplicable);
249
    };
250

            
251
36
    let (binding, old_min, old_max) = validate_order_int_operands(vec![value.as_ref().clone()])?;
252
30
    let [val_bits] = binding.as_slice() else {
253
        return Err(RuleNotApplicable); // consider covered
254
    };
255

            
256
30
    let new_min = -old_max;
257
30
    let new_max = -old_min;
258

            
259
30
    let n = val_bits.len();
260
30
    let mut out: Vec<Expr> = Vec::with_capacity(n);
261

            
262
30
    let mut new_symbols = symbols.clone();
263
30
    let mut new_clauses = vec![];
264

            
265
30
    let ff = Expr::Atomic(Metadata::new(), Atom::Literal(Literal::Bool(false)));
266
192
    for i in 0..n {
267
192
        let idx = n - i;
268
192
        let src = if idx == n {
269
30
            ff.clone()
270
        } else {
271
162
            val_bits[idx].clone()
272
        };
273

            
274
192
        let neg_bit = tseytin_not(src, &mut new_clauses, &mut new_symbols);
275
192
        out.push(neg_bit);
276
    }
277

            
278
30
    Ok(Reduction::cnf(
279
30
        Expr::SATInt(
280
30
            Metadata::new(),
281
30
            SATIntEncoding::Order,
282
30
            Moo::new(into_matrix_expr!(out)),
283
30
            (new_min, new_max),
284
30
        ),
285
30
        new_clauses,
286
30
        new_symbols,
287
30
    ))
288
34818
}