1
use conjure_cp::ast::{Atom, Expression as Expr, Literal};
2
use conjure_cp::ast::{SATIntEncoding, SymbolTable};
3
use conjure_cp::rule_engine::ApplicationError;
4
use conjure_cp::rule_engine::{
5
    ApplicationError::RuleNotApplicable, ApplicationResult, Reduction, register_rule,
6
};
7

            
8
use conjure_cp::ast::Metadata;
9
use conjure_cp::ast::Moo;
10
use conjure_cp::into_matrix_expr;
11

            
12
use super::boolean::{tseytin_and, tseytin_iff, tseytin_not, tseytin_or, tseytin_xor};
13

            
14
use conjure_cp::ast::CnfClause;
15
/// Converts an integer literal to SATInt form
16
///
17
/// ```text
18
///  3
19
///  ~~>
20
///  SATInt([true;int(1..), (3, 3)])
21
///
22
/// ```
23
#[register_rule("SAT_Direct", 9500, [Atomic])]
24
1374903
fn literal_sat_direct_int(expr: &Expr, _: &SymbolTable) -> ApplicationResult {
25
5070
    let value = {
26
52245
        if let Expr::Atomic(_, Atom::Literal(Literal::Int(value))) = expr {
27
5070
            *value
28
        } else {
29
1369833
            return Err(RuleNotApplicable);
30
        }
31
    };
32

            
33
5070
    Ok(Reduction::pure(Expr::SATInt(
34
5070
        Metadata::new(),
35
5070
        SATIntEncoding::Direct,
36
5070
        Moo::new(into_matrix_expr!(vec![Expr::Atomic(
37
5070
            Metadata::new(),
38
5070
            Atom::Literal(Literal::Bool(true)),
39
5070
        )])),
40
5070
        (value, value),
41
5070
    )))
42
1374903
}
43

            
44
/// This function confirms that all of the input expressions are direct SATInts, and returns vectors for each input of their bits
45
/// This function also normalizes direct SATInt operands to a common value range by zero-padding.
46
57042
pub fn validate_direct_int_operands(
47
57042
    exprs: Vec<Expr>,
48
57042
) -> Result<(Vec<Vec<Expr>>, i32, i32), ApplicationError> {
49
    // TODO: In the future it may be possible to optimize operations between integers with different bit sizes
50
    // Collect inner bit vectors from each SATInt
51

            
52
    // Iterate over all inputs
53
    // Check they are direct and calulate a lower and upper bound
54
57042
    let mut global_min: i32 = i32::MAX;
55
57042
    let mut global_max: i32 = i32::MIN;
56

            
57
68712
    for operand in &exprs {
58
53796
        let Expr::SATInt(_, SATIntEncoding::Direct, _, (local_min, local_max)) = operand else {
59
50340
            return Err(RuleNotApplicable);
60
        };
61
18372
        global_min = global_min.min(*local_min);
62
18372
        global_max = global_max.max(*local_max);
63
    }
64

            
65
    // build out by iterating over each operand and expanding it to match the new bounds
66

            
67
6702
    let out: Vec<Vec<Expr>> = exprs
68
6702
        .into_iter()
69
13290
        .map(|expr| {
70
13290
            let Expr::SATInt(_, SATIntEncoding::Direct, inner, (local_min, local_max)) = expr
71
            else {
72
                return Err(RuleNotApplicable);
73
            };
74

            
75
13290
            let Some(v) = inner.as_ref().clone().unwrap_list() else {
76
                return Err(RuleNotApplicable);
77
            };
78

            
79
            // calulcate how many zeroes to prepend/append
80
13290
            let prefix_len = (local_min - global_min) as usize;
81
13290
            let postfix_len = (global_max - local_max) as usize;
82

            
83
13290
            let mut bits = Vec::with_capacity(v.len() + prefix_len + postfix_len);
84

            
85
            // add 0s to start
86
13290
            bits.extend(std::iter::repeat_n(
87
13290
                Expr::Atomic(Metadata::new(), Atom::Literal(Literal::Bool(false))),
88
13290
                prefix_len,
89
            ));
90

            
91
13290
            bits.extend(v);
92

            
93
            // add 0s to end
94
13290
            bits.extend(std::iter::repeat_n(
95
13290
                Expr::Atomic(Metadata::new(), Atom::Literal(Literal::Bool(false))),
96
13290
                postfix_len,
97
            ));
98

            
99
13290
            Ok(bits)
100
13290
        })
101
6702
        .collect::<Result<_, _>>()?;
102

            
103
6702
    Ok((out, global_min, global_max))
104
57042
}
105

            
106
/// Converts a = expression between two direct SATInts to a boolean expression in cnf
107
///
108
/// ```text
109
/// SATInt(a) = SATInt(b) ~> Bool
110
/// ```
111
/// NOTE: This rule reduces to AND_i (a[i] ≡ b[i]) and does not enforce one-hotness.
112
#[register_rule("SAT_Direct", 9100, [Eq])]
113
489816
fn eq_sat_direct(expr: &Expr, symbols: &SymbolTable) -> ApplicationResult {
114
    // TODO: this could be optimized by just going over the sections of both vectors where the ranges intersect
115
    // this does require enforcing structure separately
116
489816
    let Expr::Eq(_, lhs, rhs) = expr else {
117
481740
        return Err(RuleNotApplicable);
118
    };
119

            
120
990
    let (binding, _, _) =
121
8076
        validate_direct_int_operands(vec![lhs.as_ref().clone(), rhs.as_ref().clone()])?;
122
990
    let [lhs_bits, rhs_bits] = binding.as_slice() else {
123
        return Err(RuleNotApplicable);
124
    };
125

            
126
990
    let bit_count = lhs_bits.len();
127

            
128
990
    let mut output = true.into();
129
990
    let mut new_symbols = symbols.clone();
130
990
    let mut new_clauses = vec![];
131
    let mut comparison;
132

            
133
18786
    for i in 0..bit_count {
134
18786
        comparison = tseytin_iff(
135
18786
            lhs_bits[i].clone(),
136
18786
            rhs_bits[i].clone(),
137
18786
            &mut new_clauses,
138
18786
            &mut new_symbols,
139
18786
        );
140
18786
        output = tseytin_and(
141
18786
            &vec![comparison, output],
142
18786
            &mut new_clauses,
143
18786
            &mut new_symbols,
144
18786
        );
145
18786
    }
146

            
147
990
    Ok(Reduction::cnf(output, new_clauses, new_symbols))
148
489816
}
149

            
150
/// Converts a != expression between two direct SATInts to a boolean expression in cnf
151
///
152
/// ```text
153
/// SATInt(a) != SATInt(b) ~> Bool
154
///
155
/// ```
156
///
157
/// True iff at least one value position differs.
158
#[register_rule("SAT_Direct", 9100, [Neq])]
159
489816
fn neq_sat_direct(expr: &Expr, symbols: &SymbolTable) -> ApplicationResult {
160
489816
    let Expr::Neq(_, lhs, rhs) = expr else {
161
488862
        return Err(RuleNotApplicable);
162
    };
163

            
164
306
    let (binding, _, _) =
165
954
        validate_direct_int_operands(vec![lhs.as_ref().clone(), rhs.as_ref().clone()])?;
166
306
    let [lhs_bits, rhs_bits] = binding.as_slice() else {
167
        return Err(RuleNotApplicable);
168
    };
169

            
170
306
    let bit_count = lhs_bits.len();
171

            
172
306
    let mut output = false.into();
173
306
    let mut new_symbols = symbols.clone();
174
306
    let mut new_clauses = vec![];
175
    let mut comparison;
176

            
177
2142
    for i in 0..bit_count {
178
2142
        comparison = tseytin_xor(
179
2142
            lhs_bits[i].clone(),
180
2142
            rhs_bits[i].clone(),
181
2142
            &mut new_clauses,
182
2142
            &mut new_symbols,
183
2142
        );
184
2142
        output = tseytin_or(
185
2142
            &vec![comparison, output],
186
2142
            &mut new_clauses,
187
2142
            &mut new_symbols,
188
2142
        );
189
2142
    }
190

            
191
306
    Ok(Reduction::cnf(output, new_clauses, new_symbols))
192
489816
}
193

            
194
/// Converts a </>/<=/>= expression between two direct SATInts to a boolean expression in cnf
195
///
196
/// ```text
197
/// SATInt(a) </>/<=/>= SATInt(b) ~> Bool
198
///
199
/// ```
200
/// Note: < and <= are rewritten by swapping operands to reuse lt logic.
201
#[register_rule("SAT", 9100, [Lt, Gt, Leq, Geq])]
202
973779
fn ineq_sat_direct(expr: &Expr, symbols: &SymbolTable) -> ApplicationResult {
203
973779
    let (lhs, rhs, negate) = match expr {
204
        // A < B -> sat_direct_lt(A, B)
205
576
        Expr::Lt(_, x, y) => (x, y, false),
206
        // A > B -> sat_direct_lt(B, A)
207
978
        Expr::Gt(_, x, y) => (y, x, false),
208
        // A <= B -> NOT (B < A)
209
23538
        Expr::Leq(_, x, y) => (y, x, true),
210
        // A >= B -> NOT (A < B)
211
20778
        Expr::Geq(_, x, y) => (x, y, true),
212
927909
        _ => return Err(RuleNotApplicable),
213
    };
214

            
215
4704
    let (binding, _, _) =
216
45870
        validate_direct_int_operands(vec![lhs.as_ref().clone(), rhs.as_ref().clone()])?;
217
4704
    let [lhs_bits, rhs_bits] = binding.as_slice() else {
218
        return Err(RuleNotApplicable);
219
    };
220

            
221
4704
    let mut new_symbols = symbols.clone();
222
4704
    let mut new_clauses = vec![];
223

            
224
4704
    let mut output = sat_direct_lt(
225
4704
        lhs_bits.clone(),
226
4704
        rhs_bits.clone(),
227
4704
        &mut new_clauses,
228
4704
        &mut new_symbols,
229
    );
230

            
231
4704
    if negate {
232
4560
        output = tseytin_not(output, &mut new_clauses, &mut new_symbols);
233
4560
    }
234

            
235
4704
    Ok(Reduction::cnf(output, new_clauses, new_symbols))
236
973779
}
237

            
238
/// Encodes a < b for one-hot direct integers using prefix OR logic.
239
4704
fn sat_direct_lt(
240
4704
    a: Vec<Expr>,
241
4704
    b: Vec<Expr>,
242
4704
    clauses: &mut Vec<CnfClause>,
243
4704
    symbols: &mut SymbolTable,
244
4704
) -> Expr {
245
4704
    let mut b_or = Expr::Atomic(Metadata::new(), Atom::Literal(Literal::Bool(false)));
246
4704
    let mut cum_result = Expr::Atomic(Metadata::new(), Atom::Literal(Literal::Bool(false)));
247

            
248
37578
    for (a_i, b_i) in a.iter().zip(b.iter()) {
249
        // b_or is prefix_or of b up to index i: B_i = b_0 | ... | b_i
250
37578
        b_or = tseytin_or(&vec![b_or, b_i.clone()], clauses, symbols);
251
37578

            
252
        // a < b if there exists i such that a=i and b > i.
253
        // b > i is equivalent to NOT(B_i) assuming one-hotness.
254
37578
        let not_b_or = tseytin_not(b_or.clone(), clauses, symbols);
255
37578
        let a_i_and_not_b_i = tseytin_and(&vec![a_i.clone(), not_b_or], clauses, symbols);
256
37578

            
257
37578
        cum_result = tseytin_or(&vec![cum_result, a_i_and_not_b_i], clauses, symbols);
258
37578
    }
259

            
260
4704
    cum_result
261
4704
}
262

            
263
/// Converts a - expression for a SATInt to a new SATInt
264
///
265
/// ```text
266
/// -SATInt(a) ~> SATInt(b)
267
///
268
/// ```
269
#[register_rule("SAT_Direct", 9100, [Neg])]
270
489816
fn neg_sat_direct(expr: &Expr, _: &SymbolTable) -> ApplicationResult {
271
489816
    let Expr::Neg(_, value) = expr else {
272
489198
        return Err(RuleNotApplicable);
273
    };
274

            
275
618
    let (binding, old_min, old_max) = validate_direct_int_operands(vec![value.as_ref().clone()])?;
276
180
    let [val_bits] = binding.as_slice() else {
277
        return Err(RuleNotApplicable);
278
    };
279

            
280
180
    let new_min = -old_max;
281
180
    let new_max = -old_min;
282

            
283
180
    let mut out = val_bits.clone();
284
180
    out.reverse();
285

            
286
180
    Ok(Reduction::pure(Expr::SATInt(
287
180
        Metadata::new(),
288
180
        SATIntEncoding::Direct,
289
180
        Moo::new(into_matrix_expr!(out)),
290
180
        (new_min, new_max),
291
180
    )))
292
489816
}
293

            
294
17052
fn floor_div(a: i32, b: i32) -> i32 {
295
17052
    let (q, r) = (a / b, a % b);
296
17052
    if (r > 0 && b < 0) || (r < 0 && b > 0) {
297
4008
        q - 1
298
    } else {
299
13044
        q
300
    }
301
17052
}
302

            
303
/// Converts a / expression between two direct SATInts to a new direct SATInt
304
/// using the "lookup table" method.
305
///
306
/// ```text
307
/// SafeDiv(SATInt(a), SATInt(b)) ~> SATInt(c)
308
///
309
/// ```
310
#[register_rule("SAT_Direct", 9100, [SafeDiv])]
311
489816
fn safediv_sat_direct(expr: &Expr, symbols: &SymbolTable) -> ApplicationResult {
312
489816
    let Expr::SafeDiv(_, numer_expr, denom_expr) = expr else {
313
489588
        return Err(RuleNotApplicable);
314
    };
315

            
316
186
    let Expr::SATInt(_, SATIntEncoding::Direct, numer_inner, (numer_min, numer_max)) =
317
228
        numer_expr.as_ref()
318
    else {
319
42
        return Err(RuleNotApplicable);
320
    };
321
186
    let Some(numer_bits) = numer_inner.as_ref().clone().unwrap_list() else {
322
        return Err(RuleNotApplicable);
323
    };
324

            
325
186
    let Expr::SATInt(_, SATIntEncoding::Direct, denom_inner, (denom_min, denom_max)) =
326
186
        denom_expr.as_ref()
327
    else {
328
        return Err(RuleNotApplicable);
329
    };
330

            
331
186
    let Some(denom_bits) = denom_inner.as_ref().clone().unwrap_list() else {
332
        return Err(RuleNotApplicable);
333
    };
334

            
335
186
    let mut quot_min = i32::MAX;
336
186
    let mut quot_max = i32::MIN;
337

            
338
1758
    for i in *numer_min..=*numer_max {
339
18366
        for j in *denom_min..=*denom_max {
340
18366
            let k = if j == 0 { 0 } else { i / j };
341
18366
            quot_min = quot_min.min(k);
342
18366
            quot_max = quot_max.max(k);
343
        }
344
    }
345

            
346
186
    let mut new_symbols = symbols.clone();
347
186
    let mut quot_bits = Vec::new();
348

            
349
    // generate boolean variables for all possible quotients
350
1836
    for _ in quot_min..=quot_max {
351
1836
        let decl = new_symbols.gen_find(&conjure_cp::ast::Domain::bool());
352
1836
        quot_bits.push(Expr::Atomic(
353
1836
            Metadata::new(),
354
1836
            Atom::Reference(conjure_cp::ast::Reference::new(decl)),
355
1836
        ));
356
1836
    }
357

            
358
186
    let mut new_clauses = vec![];
359

            
360
    // generate the lookup table clauses: (n_i AND d_j) => q_k
361
1758
    for i in *numer_min..=*numer_max {
362
1758
        let numer_bit = &numer_bits[(i - numer_min) as usize];
363
18366
        for j in *denom_min..=*denom_max {
364
18366
            let denom_bit = &denom_bits[(j - denom_min) as usize];
365

            
366
18366
            let k = if j == 0 { 0 } else { floor_div(i, j) };
367

            
368
18366
            let quot_bit = &quot_bits[(k - quot_min) as usize];
369

            
370
18366
            new_clauses.push(CnfClause::new(vec![
371
18366
                Expr::Not(Metadata::new(), Moo::new(numer_bit.clone())),
372
18366
                Expr::Not(Metadata::new(), Moo::new(denom_bit.clone())),
373
18366
                quot_bit.clone(),
374
            ]));
375
        }
376
    }
377

            
378
    // the quotient cannot take more than one value simultaneously.
379
1836
    for a in 0..quot_bits.len() {
380
14712
        for b in (a + 1)..quot_bits.len() {
381
14712
            new_clauses.push(CnfClause::new(vec![
382
14712
                Expr::Not(Metadata::new(), Moo::new(quot_bits[a].clone())),
383
14712
                Expr::Not(Metadata::new(), Moo::new(quot_bits[b].clone())),
384
14712
            ]));
385
14712
        }
386
    }
387

            
388
186
    let quot_int = Expr::SATInt(
389
186
        Metadata::new(),
390
186
        SATIntEncoding::Direct,
391
186
        Moo::new(into_matrix_expr!(quot_bits)),
392
186
        (quot_min, quot_max),
393
186
    );
394

            
395
186
    Ok(Reduction::cnf(quot_int, new_clauses, new_symbols))
396
489816
}
397

            
398
#[register_rule("SAT_Direct", 9100, [Sum])]
399
489816
fn add_sat_direct(expr: &Expr, symbols: &SymbolTable) -> ApplicationResult {
400
489816
    let Expr::Sum(_, sum_exprs) = expr else {
401
483282
        return Err(RuleNotApplicable);
402
    };
403

            
404
6534
    let Some(exprs) = sum_exprs.as_ref().clone().unwrap_list() else {
405
5010
        return Err(RuleNotApplicable);
406
    };
407

            
408
    // There are no expressions to sum, this is a degenerate case that we can handle by returning a constant 0
409
1524
    if exprs.is_empty() {
410
        return Ok(Reduction::pure(Expr::SATInt(
411
            Metadata::new(),
412
            SATIntEncoding::Direct,
413
            Moo::new(into_matrix_expr!(vec![Expr::Atomic(
414
                Metadata::new(),
415
                Atom::Literal(Literal::Bool(true)),
416
            )])),
417
            (0, 0),
418
        )));
419
1524
    }
420

            
421
1524
    let mut new_symbols = symbols.clone();
422
1524
    let mut new_clauses: Vec<CnfClause> = vec![];
423

            
424
    // Validate all operands are direct SATInts and extract their bit vectors, also calculate a common min and max for all operands to normalize them to the same size by padding with zeroes as needed to simplify the addition logic.
425
522
    let (mut operands, common_min, common_max) =
426
1524
        validate_direct_int_operands(exprs).map_err(|_| RuleNotApplicable)?;
427

            
428
    // Addition is implemented as a series of pairwise additions. The bits of the output are defined by iterating over all possible output values, and for each output value k, ORing together ANDs for each pair of input values i,j where i+j=k. This is effectively a big disjunction of all possible ways to sum to k.
429
10338
    let false_expr = || Expr::Atomic(Metadata::new(), Atom::Literal(Literal::Bool(false)));
430

            
431
522
    let mut acc_bits = operands.remove(0);
432
522
    let mut acc_min = common_min;
433
522
    let mut acc_max = common_max;
434

            
435
588
    for right_bits in operands {
436
588
        let right_min = common_min;
437
588
        let right_max = common_max;
438

            
439
588
        let new_min = acc_min + right_min;
440
588
        let new_max = acc_max + right_max;
441
588
        let mut out_bits = Vec::with_capacity((new_max - new_min + 1) as usize);
442

            
443
10338
        for k in new_min..=new_max {
444
10338
            let mut sum_expr = false_expr();
445

            
446
252810
            for i in acc_min..=acc_max {
447
252810
                let j = k - i;
448
252810
                if j < right_min || j > right_max {
449
131640
                    continue;
450
121170
                }
451

            
452
121170
                let a = acc_bits[(i - acc_min) as usize].clone();
453
121170
                let b = right_bits[(j - right_min) as usize].clone();
454

            
455
121170
                let and_ab = tseytin_and(&vec![a, b], &mut new_clauses, &mut new_symbols);
456
121170
                sum_expr = tseytin_or(&vec![sum_expr, and_ab], &mut new_clauses, &mut new_symbols);
457
            }
458

            
459
10338
            out_bits.push(sum_expr);
460
        }
461

            
462
588
        acc_bits = out_bits;
463
588
        acc_min = new_min;
464
588
        acc_max = new_max;
465
    }
466

            
467
522
    Ok(Reduction::cnf(
468
522
        Expr::SATInt(
469
522
            Metadata::new(),
470
522
            SATIntEncoding::Direct,
471
522
            Moo::new(into_matrix_expr!(acc_bits)),
472
522
            (acc_min, acc_max),
473
522
        ),
474
522
        new_clauses,
475
522
        new_symbols,
476
522
    ))
477
489816
}