1
use conjure_cp::ast::{Atom, Expression as Expr, Literal};
2
use conjure_cp::ast::{SATIntEncoding, SymbolTable};
3
use conjure_cp::rule_engine::ApplicationError;
4
use conjure_cp::rule_engine::{
5
    ApplicationError::RuleNotApplicable, ApplicationResult, Reduction, register_rule,
6
};
7

            
8
use conjure_cp::ast::Metadata;
9
use conjure_cp::ast::Moo;
10
use conjure_cp::into_matrix_expr;
11

            
12
use super::boolean::{tseytin_and, tseytin_iff, tseytin_not, tseytin_or, tseytin_xor};
13

            
14
use conjure_cp::ast::CnfClause;
15
/// Converts an integer literal to SATInt form
16
///
17
/// ```text
18
///  3
19
///  ~~>
20
///  SATInt([true;int(1..), (3, 3)])
21
///
22
/// ```
23
#[register_rule(("SAT_Direct", 9500))]
24
608004
fn literal_sat_direct_int(expr: &Expr, _: &SymbolTable) -> ApplicationResult {
25
2061
    let value = {
26
18873
        if let Expr::Atomic(_, Atom::Literal(Literal::Int(value))) = expr {
27
2061
            *value
28
        } else {
29
605943
            return Err(RuleNotApplicable);
30
        }
31
    };
32

            
33
2061
    Ok(Reduction::pure(Expr::SATInt(
34
2061
        Metadata::new(),
35
2061
        SATIntEncoding::Direct,
36
2061
        Moo::new(into_matrix_expr!(vec![Expr::Atomic(
37
2061
            Metadata::new(),
38
2061
            Atom::Literal(Literal::Bool(true)),
39
2061
        )])),
40
2061
        (value, value),
41
2061
    )))
42
608004
}
43

            
44
/// This function confirms that all of the input expressions are direct SATInts, and returns vectors for each input of their bits
45
/// This function also normalizes direct SATInt operands to a common value range by zero-padding.
46
29442
pub fn validate_direct_int_operands(
47
29442
    exprs: Vec<Expr>,
48
29442
) -> Result<(Vec<Vec<Expr>>, i32, i32), ApplicationError> {
49
    // TODO: In the future it may be possible to optimize operations between integers with different bit sizes
50
    // Collect inner bit vectors from each SATInt
51

            
52
    // Iterate over all inputs
53
    // Check they are direct and calulate a lower and upper bound
54
29442
    let mut global_min: i32 = i32::MAX;
55
29442
    let mut global_max: i32 = i32::MIN;
56

            
57
33375
    for operand in &exprs {
58
28866
        let Expr::SATInt(_, SATIntEncoding::Direct, _, (local_min, local_max)) = operand else {
59
26814
            return Err(RuleNotApplicable);
60
        };
61
6561
        global_min = global_min.min(*local_min);
62
6561
        global_max = global_max.max(*local_max);
63
    }
64

            
65
    // build out by iterating over each operand and expanding it to match the new bounds
66

            
67
2628
    let out: Vec<Vec<Expr>> = exprs
68
2628
        .into_iter()
69
5202
        .map(|expr| {
70
5202
            let Expr::SATInt(_, SATIntEncoding::Direct, inner, (local_min, local_max)) = expr
71
            else {
72
                return Err(RuleNotApplicable);
73
            };
74

            
75
5202
            let Some(v) = inner.as_ref().clone().unwrap_list() else {
76
                return Err(RuleNotApplicable);
77
            };
78

            
79
            // calulcate how many zeroes to prepend/append
80
5202
            let prefix_len = (local_min - global_min) as usize;
81
5202
            let postfix_len = (global_max - local_max) as usize;
82

            
83
5202
            let mut bits = Vec::with_capacity(v.len() + prefix_len + postfix_len);
84

            
85
            // add 0s to start
86
5202
            bits.extend(std::iter::repeat_n(
87
5202
                Expr::Atomic(Metadata::new(), Atom::Literal(Literal::Bool(false))),
88
5202
                prefix_len,
89
            ));
90

            
91
5202
            bits.extend(v);
92

            
93
            // add 0s to end
94
5202
            bits.extend(std::iter::repeat_n(
95
5202
                Expr::Atomic(Metadata::new(), Atom::Literal(Literal::Bool(false))),
96
5202
                postfix_len,
97
            ));
98

            
99
5202
            Ok(bits)
100
5202
        })
101
2628
        .collect::<Result<_, _>>()?;
102

            
103
2628
    Ok((out, global_min, global_max))
104
29442
}
105

            
106
/// Converts a = expression between two direct SATInts to a boolean expression in cnf
107
///
108
/// ```text
109
/// SATInt(a) = SATInt(b) ~> Bool
110
/// ```
111
/// NOTE: This rule reduces to AND_i (a[i] ≡ b[i]) and does not enforce one-hotness.
112
#[register_rule(("SAT_Direct", 9100))]
113
199890
fn eq_sat_direct(expr: &Expr, symbols: &SymbolTable) -> ApplicationResult {
114
    // TODO: this could be optimized by just going over the sections of both vectors where the ranges intersect
115
    // this does require enforcing structure separately
116
199890
    let Expr::Eq(_, lhs, rhs) = expr else {
117
197442
        return Err(RuleNotApplicable);
118
    };
119

            
120
450
    let (binding, _, _) =
121
2448
        validate_direct_int_operands(vec![lhs.as_ref().clone(), rhs.as_ref().clone()])?;
122
450
    let [lhs_bits, rhs_bits] = binding.as_slice() else {
123
        return Err(RuleNotApplicable);
124
    };
125

            
126
450
    let bit_count = lhs_bits.len();
127

            
128
450
    let mut output = true.into();
129
450
    let mut new_symbols = symbols.clone();
130
450
    let mut new_clauses = vec![];
131
    let mut comparison;
132

            
133
9639
    for i in 0..bit_count {
134
9639
        comparison = tseytin_iff(
135
9639
            lhs_bits[i].clone(),
136
9639
            rhs_bits[i].clone(),
137
9639
            &mut new_clauses,
138
9639
            &mut new_symbols,
139
9639
        );
140
9639
        output = tseytin_and(
141
9639
            &vec![comparison, output],
142
9639
            &mut new_clauses,
143
9639
            &mut new_symbols,
144
9639
        );
145
9639
    }
146

            
147
450
    Ok(Reduction::cnf(output, new_clauses, new_symbols))
148
199890
}
149

            
150
/// Converts a != expression between two direct SATInts to a boolean expression in cnf
151
///
152
/// ```text
153
/// SATInt(a) != SATInt(b) ~> Bool
154
///
155
/// ```
156
///
157
/// True iff at least one value position differs.
158
#[register_rule(("SAT_Direct", 9100))]
159
199890
fn neq_sat_direct(expr: &Expr, symbols: &SymbolTable) -> ApplicationResult {
160
199890
    let Expr::Neq(_, lhs, rhs) = expr else {
161
199179
        return Err(RuleNotApplicable);
162
    };
163

            
164
225
    let (binding, _, _) =
165
711
        validate_direct_int_operands(vec![lhs.as_ref().clone(), rhs.as_ref().clone()])?;
166
225
    let [lhs_bits, rhs_bits] = binding.as_slice() else {
167
        return Err(RuleNotApplicable);
168
    };
169

            
170
225
    let bit_count = lhs_bits.len();
171

            
172
225
    let mut output = false.into();
173
225
    let mut new_symbols = symbols.clone();
174
225
    let mut new_clauses = vec![];
175
    let mut comparison;
176

            
177
1737
    for i in 0..bit_count {
178
1737
        comparison = tseytin_xor(
179
1737
            lhs_bits[i].clone(),
180
1737
            rhs_bits[i].clone(),
181
1737
            &mut new_clauses,
182
1737
            &mut new_symbols,
183
1737
        );
184
1737
        output = tseytin_or(
185
1737
            &vec![comparison, output],
186
1737
            &mut new_clauses,
187
1737
            &mut new_symbols,
188
1737
        );
189
1737
    }
190

            
191
225
    Ok(Reduction::cnf(output, new_clauses, new_symbols))
192
199890
}
193

            
194
/// Converts a </>/<=/>= expression between two direct SATInts to a boolean expression in cnf
195
///
196
/// ```text
197
/// SATInt(a) </>/<=/>= SATInt(b) ~> Bool
198
///
199
/// ```
200
/// Note: < and <= are rewritten by swapping operands to reuse lt logic.
201
#[register_rule(("SAT", 9100))]
202
499503
fn ineq_sat_direct(expr: &Expr, symbols: &SymbolTable) -> ApplicationResult {
203
499503
    let (lhs, rhs, negate) = match expr {
204
        // A < B -> sat_direct_lt(A, B)
205
423
        Expr::Lt(_, x, y) => (x, y, false),
206
        // A > B -> sat_direct_lt(B, A)
207
558
        Expr::Gt(_, x, y) => (y, x, false),
208
        // A <= B -> NOT (B < A)
209
13776
        Expr::Leq(_, x, y) => (y, x, true),
210
        // A >= B -> NOT (A < B)
211
11463
        Expr::Geq(_, x, y) => (x, y, true),
212
473283
        _ => return Err(RuleNotApplicable),
213
    };
214

            
215
1899
    let (binding, _, _) =
216
26220
        validate_direct_int_operands(vec![lhs.as_ref().clone(), rhs.as_ref().clone()])?;
217
1899
    let [lhs_bits, rhs_bits] = binding.as_slice() else {
218
        return Err(RuleNotApplicable);
219
    };
220

            
221
1899
    let mut new_symbols = symbols.clone();
222
1899
    let mut new_clauses = vec![];
223

            
224
1899
    let mut output = sat_direct_lt(
225
1899
        lhs_bits.clone(),
226
1899
        rhs_bits.clone(),
227
1899
        &mut new_clauses,
228
1899
        &mut new_symbols,
229
    );
230

            
231
1899
    if negate {
232
1818
        output = tseytin_not(output, &mut new_clauses, &mut new_symbols);
233
1818
    }
234

            
235
1899
    Ok(Reduction::cnf(output, new_clauses, new_symbols))
236
499503
}
237

            
238
/// Encodes a < b for one-hot direct integers using prefix OR logic.
239
1899
fn sat_direct_lt(
240
1899
    a: Vec<Expr>,
241
1899
    b: Vec<Expr>,
242
1899
    clauses: &mut Vec<CnfClause>,
243
1899
    symbols: &mut SymbolTable,
244
1899
) -> Expr {
245
1899
    let mut b_or = Expr::Atomic(Metadata::new(), Atom::Literal(Literal::Bool(false)));
246
1899
    let mut cum_result = Expr::Atomic(Metadata::new(), Atom::Literal(Literal::Bool(false)));
247

            
248
19647
    for (a_i, b_i) in a.iter().zip(b.iter()) {
249
        // b_or is prefix_or of b up to index i: B_i = b_0 | ... | b_i
250
19647
        b_or = tseytin_or(&vec![b_or, b_i.clone()], clauses, symbols);
251
19647

            
252
        // a < b if there exists i such that a=i and b > i.
253
        // b > i is equivalent to NOT(B_i) assuming one-hotness.
254
19647
        let not_b_or = tseytin_not(b_or.clone(), clauses, symbols);
255
19647
        let a_i_and_not_b_i = tseytin_and(&vec![a_i.clone(), not_b_or], clauses, symbols);
256
19647

            
257
19647
        cum_result = tseytin_or(&vec![cum_result, a_i_and_not_b_i], clauses, symbols);
258
19647
    }
259

            
260
1899
    cum_result
261
1899
}
262

            
263
/// Converts a - expression for a SATInt to a new SATInt
264
///
265
/// ```text
266
/// -SATInt(a) ~> SATInt(b)
267
///
268
/// ```
269
#[register_rule(("SAT_Direct", 9100))]
270
199890
fn neg_sat_direct(expr: &Expr, _: &SymbolTable) -> ApplicationResult {
271
199890
    let Expr::Neg(_, value) = expr else {
272
199827
        return Err(RuleNotApplicable);
273
    };
274

            
275
63
    let (binding, old_min, old_max) = validate_direct_int_operands(vec![value.as_ref().clone()])?;
276
54
    let [val_bits] = binding.as_slice() else {
277
        return Err(RuleNotApplicable);
278
    };
279

            
280
54
    let new_min = -old_max;
281
54
    let new_max = -old_min;
282

            
283
54
    let mut out = val_bits.clone();
284
54
    out.reverse();
285

            
286
54
    Ok(Reduction::pure(Expr::SATInt(
287
54
        Metadata::new(),
288
54
        SATIntEncoding::Direct,
289
54
        Moo::new(into_matrix_expr!(out)),
290
54
        (new_min, new_max),
291
54
    )))
292
199890
}
293

            
294
19764
fn floor_div(a: i32, b: i32) -> i32 {
295
19764
    let (q, r) = (a / b, a % b);
296
19764
    if (r > 0 && b < 0) || (r < 0 && b > 0) {
297
6012
        q - 1
298
    } else {
299
13752
        q
300
    }
301
19764
}
302

            
303
/// Converts a / expression between two direct SATInts to a new direct SATInt
304
/// using the "lookup table" method.
305
///
306
/// ```text
307
/// SafeDiv(SATInt(a), SATInt(b)) ~> SATInt(c)
308
///
309
/// ```
310
#[register_rule(("SAT_Direct", 9100))]
311
199890
fn safediv_sat_direct(expr: &Expr, symbols: &SymbolTable) -> ApplicationResult {
312
199890
    let Expr::SafeDiv(_, numer_expr, denom_expr) = expr else {
313
199755
        return Err(RuleNotApplicable);
314
    };
315

            
316
135
    let Expr::SATInt(_, SATIntEncoding::Direct, numer_inner, (numer_min, numer_max)) =
317
135
        numer_expr.as_ref()
318
    else {
319
        return Err(RuleNotApplicable);
320
    };
321
135
    let Some(numer_bits) = numer_inner.as_ref().clone().unwrap_list() else {
322
        return Err(RuleNotApplicable);
323
    };
324

            
325
135
    let Expr::SATInt(_, SATIntEncoding::Direct, denom_inner, (denom_min, denom_max)) =
326
135
        denom_expr.as_ref()
327
    else {
328
        return Err(RuleNotApplicable);
329
    };
330

            
331
135
    let Some(denom_bits) = denom_inner.as_ref().clone().unwrap_list() else {
332
        return Err(RuleNotApplicable);
333
    };
334

            
335
135
    let mut quot_min = i32::MAX;
336
135
    let mut quot_max = i32::MIN;
337

            
338
1368
    for i in *numer_min..=*numer_max {
339
20934
        for j in *denom_min..=*denom_max {
340
20934
            let k = if j == 0 { 0 } else { i / j };
341
20934
            quot_min = quot_min.min(k);
342
20934
            quot_max = quot_max.max(k);
343
        }
344
    }
345

            
346
135
    let mut new_symbols = symbols.clone();
347
135
    let mut quot_bits = Vec::new();
348

            
349
    // generate boolean variables for all possible quotients
350
1503
    for _ in quot_min..=quot_max {
351
1503
        let decl = new_symbols.gensym(&conjure_cp::ast::Domain::bool());
352
1503
        quot_bits.push(Expr::Atomic(
353
1503
            Metadata::new(),
354
1503
            Atom::Reference(conjure_cp::ast::Reference::new(decl)),
355
1503
        ));
356
1503
    }
357

            
358
135
    let mut new_clauses = vec![];
359

            
360
    // generate the lookup table clauses: (n_i AND d_j) => q_k
361
1368
    for i in *numer_min..=*numer_max {
362
1368
        let numer_bit = &numer_bits[(i - numer_min) as usize];
363
20934
        for j in *denom_min..=*denom_max {
364
20934
            let denom_bit = &denom_bits[(j - denom_min) as usize];
365

            
366
20934
            let k = if j == 0 { 0 } else { floor_div(i, j) };
367

            
368
20934
            let quot_bit = &quot_bits[(k - quot_min) as usize];
369

            
370
20934
            new_clauses.push(CnfClause::new(vec![
371
20934
                Expr::Not(Metadata::new(), Moo::new(numer_bit.clone())),
372
20934
                Expr::Not(Metadata::new(), Moo::new(denom_bit.clone())),
373
20934
                quot_bit.clone(),
374
            ]));
375
        }
376
    }
377

            
378
    // the quotient cannot take more than one value simultaneously.
379
1503
    for a in 0..quot_bits.len() {
380
14562
        for b in (a + 1)..quot_bits.len() {
381
14562
            new_clauses.push(CnfClause::new(vec![
382
14562
                Expr::Not(Metadata::new(), Moo::new(quot_bits[a].clone())),
383
14562
                Expr::Not(Metadata::new(), Moo::new(quot_bits[b].clone())),
384
14562
            ]));
385
14562
        }
386
    }
387

            
388
135
    let quot_int = Expr::SATInt(
389
135
        Metadata::new(),
390
135
        SATIntEncoding::Direct,
391
135
        Moo::new(into_matrix_expr!(quot_bits)),
392
135
        (quot_min, quot_max),
393
135
    );
394

            
395
135
    Ok(Reduction::cnf(quot_int, new_clauses, new_symbols))
396
199890
}