1
use conjure_cp::ast::{Atom, Expression as Expr, Literal};
2
use conjure_cp::ast::{SATIntEncoding, SymbolTable};
3
use conjure_cp::rule_engine::ApplicationError;
4
use conjure_cp::rule_engine::{
5
    ApplicationError::RuleNotApplicable, ApplicationResult, Reduction, register_rule,
6
};
7

            
8
use conjure_cp::ast::Metadata;
9
use conjure_cp::ast::Moo;
10
use conjure_cp::into_matrix_expr;
11

            
12
use super::boolean::{tseytin_and, tseytin_iff, tseytin_not, tseytin_or, tseytin_xor};
13

            
14
use conjure_cp::ast::CnfClause;
15
/// Converts an integer literal to SATInt form
16
///
17
/// ```text
18
///  3
19
///  ~~>
20
///  SATInt([true;int(1..), (3, 3)])
21
///
22
/// ```
23
#[register_rule("SAT_Direct", 9500, [Atomic])]
24
1398801
fn literal_sat_direct_int(expr: &Expr, _: &SymbolTable) -> ApplicationResult {
25
5202
    let value = {
26
53463
        if let Expr::Atomic(_, Atom::Literal(Literal::Int(value))) = expr {
27
5202
            *value
28
        } else {
29
1393599
            return Err(RuleNotApplicable);
30
        }
31
    };
32

            
33
5202
    Ok(Reduction::pure(Expr::SATInt(
34
5202
        Metadata::new(),
35
5202
        SATIntEncoding::Direct,
36
5202
        Moo::new(into_matrix_expr!(vec![Expr::Atomic(
37
5202
            Metadata::new(),
38
5202
            Atom::Literal(Literal::Bool(true)),
39
5202
        )])),
40
5202
        (value, value),
41
5202
    )))
42
1398801
}
43

            
44
/// This function confirms that all of the input expressions are direct SATInts, and returns vectors for each input of their bits
45
/// This function also normalizes direct SATInt operands to a common value range by zero-padding.
46
58482
pub fn validate_direct_int_operands(
47
58482
    exprs: Vec<Expr>,
48
58482
) -> Result<(Vec<Vec<Expr>>, i32, i32), ApplicationError> {
49
    // TODO: In the future it may be possible to optimize operations between integers with different bit sizes
50
    // Collect inner bit vectors from each SATInt
51

            
52
    // Iterate over all inputs
53
    // Check they are direct and calulate a lower and upper bound
54
58482
    let mut global_min: i32 = i32::MAX;
55
58482
    let mut global_max: i32 = i32::MIN;
56

            
57
70338
    for operand in &exprs {
58
55392
        let Expr::SATInt(_, SATIntEncoding::Direct, _, (local_min, local_max)) = operand else {
59
51594
            return Err(RuleNotApplicable);
60
        };
61
18744
        global_min = global_min.min(*local_min);
62
18744
        global_max = global_max.max(*local_max);
63
    }
64

            
65
    // build out by iterating over each operand and expanding it to match the new bounds
66

            
67
6888
    let out: Vec<Vec<Expr>> = exprs
68
6888
        .into_iter()
69
13632
        .map(|expr| {
70
13632
            let Expr::SATInt(_, SATIntEncoding::Direct, inner, (local_min, local_max)) = expr
71
            else {
72
                return Err(RuleNotApplicable);
73
            };
74

            
75
13632
            let Some(v) = inner.as_ref().clone().unwrap_list() else {
76
                return Err(RuleNotApplicable);
77
            };
78

            
79
            // calulcate how many zeroes to prepend/append
80
13632
            let prefix_len = (local_min - global_min) as usize;
81
13632
            let postfix_len = (global_max - local_max) as usize;
82

            
83
13632
            let mut bits = Vec::with_capacity(v.len() + prefix_len + postfix_len);
84

            
85
            // add 0s to start
86
13632
            bits.extend(std::iter::repeat_n(
87
13632
                Expr::Atomic(Metadata::new(), Atom::Literal(Literal::Bool(false))),
88
13632
                prefix_len,
89
            ));
90

            
91
13632
            bits.extend(v);
92

            
93
            // add 0s to end
94
13632
            bits.extend(std::iter::repeat_n(
95
13632
                Expr::Atomic(Metadata::new(), Atom::Literal(Literal::Bool(false))),
96
13632
                postfix_len,
97
            ));
98

            
99
13632
            Ok(bits)
100
13632
        })
101
6888
        .collect::<Result<_, _>>()?;
102

            
103
6888
    Ok((out, global_min, global_max))
104
58482
}
105

            
106
/// Converts a = expression between two direct SATInts to a boolean expression in cnf
107
///
108
/// ```text
109
/// SATInt(a) = SATInt(b) ~> Bool
110
/// ```
111
/// NOTE: This rule reduces to AND_i (a[i] ≡ b[i]) and does not enforce one-hotness.
112
#[register_rule("SAT_Direct", 9100, [Eq])]
113
493488
fn eq_sat_direct(expr: &Expr, symbols: &SymbolTable) -> ApplicationResult {
114
    // TODO: this could be optimized by just going over the sections of both vectors where the ranges intersect
115
    // this does require enforcing structure separately
116
493488
    let Expr::Eq(_, lhs, rhs) = expr else {
117
485322
        return Err(RuleNotApplicable);
118
    };
119

            
120
1050
    let (binding, _, _) =
121
8166
        validate_direct_int_operands(vec![lhs.as_ref().clone(), rhs.as_ref().clone()])?;
122
1050
    let [lhs_bits, rhs_bits] = binding.as_slice() else {
123
        return Err(RuleNotApplicable);
124
    };
125

            
126
1050
    let bit_count = lhs_bits.len();
127

            
128
1050
    let mut output = true.into();
129
1050
    let mut new_symbols = symbols.clone();
130
1050
    let mut new_clauses = vec![];
131
    let mut comparison;
132

            
133
19182
    for i in 0..bit_count {
134
19182
        comparison = tseytin_iff(
135
19182
            lhs_bits[i].clone(),
136
19182
            rhs_bits[i].clone(),
137
19182
            &mut new_clauses,
138
19182
            &mut new_symbols,
139
19182
        );
140
19182
        output = tseytin_and(
141
19182
            &vec![comparison, output],
142
19182
            &mut new_clauses,
143
19182
            &mut new_symbols,
144
19182
        );
145
19182
    }
146

            
147
1050
    Ok(Reduction::cnf(output, new_clauses, new_symbols))
148
493488
}
149

            
150
/// Converts a != expression between two direct SATInts to a boolean expression in cnf
151
///
152
/// ```text
153
/// SATInt(a) != SATInt(b) ~> Bool
154
///
155
/// ```
156
///
157
/// True iff at least one value position differs.
158
#[register_rule("SAT_Direct", 9100, [Neq])]
159
493488
fn neq_sat_direct(expr: &Expr, symbols: &SymbolTable) -> ApplicationResult {
160
493488
    let Expr::Neq(_, lhs, rhs) = expr else {
161
492534
        return Err(RuleNotApplicable);
162
    };
163

            
164
306
    let (binding, _, _) =
165
954
        validate_direct_int_operands(vec![lhs.as_ref().clone(), rhs.as_ref().clone()])?;
166
306
    let [lhs_bits, rhs_bits] = binding.as_slice() else {
167
        return Err(RuleNotApplicable);
168
    };
169

            
170
306
    let bit_count = lhs_bits.len();
171

            
172
306
    let mut output = false.into();
173
306
    let mut new_symbols = symbols.clone();
174
306
    let mut new_clauses = vec![];
175
    let mut comparison;
176

            
177
2142
    for i in 0..bit_count {
178
2142
        comparison = tseytin_xor(
179
2142
            lhs_bits[i].clone(),
180
2142
            rhs_bits[i].clone(),
181
2142
            &mut new_clauses,
182
2142
            &mut new_symbols,
183
2142
        );
184
2142
        output = tseytin_or(
185
2142
            &vec![comparison, output],
186
2142
            &mut new_clauses,
187
2142
            &mut new_symbols,
188
2142
        );
189
2142
    }
190

            
191
306
    Ok(Reduction::cnf(output, new_clauses, new_symbols))
192
493488
}
193

            
194
/// Converts a </>/<=/>= expression between two direct SATInts to a boolean expression in cnf
195
///
196
/// ```text
197
/// SATInt(a) </>/<=/>= SATInt(b) ~> Bool
198
///
199
/// ```
200
/// Note: < and <= are rewritten by swapping operands to reuse lt logic.
201
#[register_rule("SAT", 9100, [Lt, Gt, Leq, Geq])]
202
993354
fn ineq_sat_direct(expr: &Expr, symbols: &SymbolTable) -> ApplicationResult {
203
993354
    let (lhs, rhs, negate) = match expr {
204
        // A < B -> sat_direct_lt(A, B)
205
576
        Expr::Lt(_, x, y) => (x, y, false),
206
        // A > B -> sat_direct_lt(B, A)
207
978
        Expr::Gt(_, x, y) => (y, x, false),
208
        // A <= B -> NOT (B < A)
209
24246
        Expr::Leq(_, x, y) => (y, x, true),
210
        // A >= B -> NOT (A < B)
211
21390
        Expr::Geq(_, x, y) => (x, y, true),
212
946164
        _ => return Err(RuleNotApplicable),
213
    };
214

            
215
4800
    let (binding, _, _) =
216
47190
        validate_direct_int_operands(vec![lhs.as_ref().clone(), rhs.as_ref().clone()])?;
217
4800
    let [lhs_bits, rhs_bits] = binding.as_slice() else {
218
        return Err(RuleNotApplicable);
219
    };
220

            
221
4800
    let mut new_symbols = symbols.clone();
222
4800
    let mut new_clauses = vec![];
223

            
224
4800
    let mut output = sat_direct_lt(
225
4800
        lhs_bits.clone(),
226
4800
        rhs_bits.clone(),
227
4800
        &mut new_clauses,
228
4800
        &mut new_symbols,
229
    );
230

            
231
4800
    if negate {
232
4656
        output = tseytin_not(output, &mut new_clauses, &mut new_symbols);
233
4656
    }
234

            
235
4800
    Ok(Reduction::cnf(output, new_clauses, new_symbols))
236
993354
}
237

            
238
/// Encodes a < b for one-hot direct integers using prefix OR logic.
239
4800
fn sat_direct_lt(
240
4800
    a: Vec<Expr>,
241
4800
    b: Vec<Expr>,
242
4800
    clauses: &mut Vec<CnfClause>,
243
4800
    symbols: &mut SymbolTable,
244
4800
) -> Expr {
245
4800
    let mut b_or = Expr::Atomic(Metadata::new(), Atom::Literal(Literal::Bool(false)));
246
4800
    let mut cum_result = Expr::Atomic(Metadata::new(), Atom::Literal(Literal::Bool(false)));
247

            
248
38142
    for (a_i, b_i) in a.iter().zip(b.iter()) {
249
        // b_or is prefix_or of b up to index i: B_i = b_0 | ... | b_i
250
38142
        b_or = tseytin_or(&vec![b_or, b_i.clone()], clauses, symbols);
251
38142

            
252
        // a < b if there exists i such that a=i and b > i.
253
        // b > i is equivalent to NOT(B_i) assuming one-hotness.
254
38142
        let not_b_or = tseytin_not(b_or.clone(), clauses, symbols);
255
38142
        let a_i_and_not_b_i = tseytin_and(&vec![a_i.clone(), not_b_or], clauses, symbols);
256
38142

            
257
38142
        cum_result = tseytin_or(&vec![cum_result, a_i_and_not_b_i], clauses, symbols);
258
38142
    }
259

            
260
4800
    cum_result
261
4800
}
262

            
263
/// Converts a - expression for a SATInt to a new SATInt
264
///
265
/// ```text
266
/// -SATInt(a) ~> SATInt(b)
267
///
268
/// ```
269
#[register_rule("SAT_Direct", 9100, [Neg])]
270
493488
fn neg_sat_direct(expr: &Expr, _: &SymbolTable) -> ApplicationResult {
271
493488
    let Expr::Neg(_, value) = expr else {
272
492864
        return Err(RuleNotApplicable);
273
    };
274

            
275
624
    let (binding, old_min, old_max) = validate_direct_int_operands(vec![value.as_ref().clone()])?;
276
186
    let [val_bits] = binding.as_slice() else {
277
        return Err(RuleNotApplicable);
278
    };
279

            
280
186
    let new_min = -old_max;
281
186
    let new_max = -old_min;
282

            
283
186
    let mut out = val_bits.clone();
284
186
    out.reverse();
285

            
286
186
    Ok(Reduction::pure(Expr::SATInt(
287
186
        Metadata::new(),
288
186
        SATIntEncoding::Direct,
289
186
        Moo::new(into_matrix_expr!(out)),
290
186
        (new_min, new_max),
291
186
    )))
292
493488
}
293

            
294
17052
fn floor_div(a: i32, b: i32) -> i32 {
295
17052
    let (q, r) = (a / b, a % b);
296
17052
    if (r > 0 && b < 0) || (r < 0 && b > 0) {
297
4008
        q - 1
298
    } else {
299
13044
        q
300
    }
301
17052
}
302

            
303
/// Converts a / expression between two direct SATInts to a new direct SATInt
304
/// using the "lookup table" method.
305
///
306
/// ```text
307
/// SafeDiv(SATInt(a), SATInt(b)) ~> SATInt(c)
308
///
309
/// ```
310
#[register_rule("SAT_Direct", 9100, [SafeDiv])]
311
493488
fn safediv_sat_direct(expr: &Expr, symbols: &SymbolTable) -> ApplicationResult {
312
493488
    let Expr::SafeDiv(_, numer_expr, denom_expr) = expr else {
313
493260
        return Err(RuleNotApplicable);
314
    };
315

            
316
186
    let Expr::SATInt(_, SATIntEncoding::Direct, numer_inner, (numer_min, numer_max)) =
317
228
        numer_expr.as_ref()
318
    else {
319
42
        return Err(RuleNotApplicable);
320
    };
321
186
    let Some(numer_bits) = numer_inner.as_ref().clone().unwrap_list() else {
322
        return Err(RuleNotApplicable);
323
    };
324

            
325
186
    let Expr::SATInt(_, SATIntEncoding::Direct, denom_inner, (denom_min, denom_max)) =
326
186
        denom_expr.as_ref()
327
    else {
328
        return Err(RuleNotApplicable);
329
    };
330

            
331
186
    let Some(denom_bits) = denom_inner.as_ref().clone().unwrap_list() else {
332
        return Err(RuleNotApplicable);
333
    };
334

            
335
186
    let mut quot_min = i32::MAX;
336
186
    let mut quot_max = i32::MIN;
337

            
338
1758
    for i in *numer_min..=*numer_max {
339
18366
        for j in *denom_min..=*denom_max {
340
18366
            let k = if j == 0 { 0 } else { i / j };
341
18366
            quot_min = quot_min.min(k);
342
18366
            quot_max = quot_max.max(k);
343
        }
344
    }
345

            
346
186
    let mut new_symbols = symbols.clone();
347
186
    let mut quot_bits = Vec::new();
348

            
349
    // generate boolean variables for all possible quotients
350
1836
    for _ in quot_min..=quot_max {
351
1836
        let decl = new_symbols.gen_find(&conjure_cp::ast::Domain::bool());
352
1836
        quot_bits.push(Expr::Atomic(
353
1836
            Metadata::new(),
354
1836
            Atom::Reference(conjure_cp::ast::Reference::new(decl)),
355
1836
        ));
356
1836
    }
357

            
358
186
    let mut new_clauses = vec![];
359

            
360
    // generate the lookup table clauses: (n_i AND d_j) => q_k
361
1758
    for i in *numer_min..=*numer_max {
362
1758
        let numer_bit = &numer_bits[(i - numer_min) as usize];
363
18366
        for j in *denom_min..=*denom_max {
364
18366
            let denom_bit = &denom_bits[(j - denom_min) as usize];
365

            
366
18366
            let k = if j == 0 { 0 } else { floor_div(i, j) };
367

            
368
18366
            let quot_bit = &quot_bits[(k - quot_min) as usize];
369

            
370
18366
            new_clauses.push(CnfClause::new(vec![
371
18366
                Expr::Not(Metadata::new(), Moo::new(numer_bit.clone())),
372
18366
                Expr::Not(Metadata::new(), Moo::new(denom_bit.clone())),
373
18366
                quot_bit.clone(),
374
            ]));
375
        }
376
    }
377

            
378
    // the quotient cannot take more than one value simultaneously.
379
1836
    for a in 0..quot_bits.len() {
380
14712
        for b in (a + 1)..quot_bits.len() {
381
14712
            new_clauses.push(CnfClause::new(vec![
382
14712
                Expr::Not(Metadata::new(), Moo::new(quot_bits[a].clone())),
383
14712
                Expr::Not(Metadata::new(), Moo::new(quot_bits[b].clone())),
384
14712
            ]));
385
14712
        }
386
    }
387

            
388
186
    let quot_int = Expr::SATInt(
389
186
        Metadata::new(),
390
186
        SATIntEncoding::Direct,
391
186
        Moo::new(into_matrix_expr!(quot_bits)),
392
186
        (quot_min, quot_max),
393
186
    );
394

            
395
186
    Ok(Reduction::cnf(quot_int, new_clauses, new_symbols))
396
493488
}
397

            
398
#[register_rule("SAT_Direct", 9100, [Sum])]
399
493488
fn add_sat_direct(expr: &Expr, symbols: &SymbolTable) -> ApplicationResult {
400
493488
    let Expr::Sum(_, sum_exprs) = expr else {
401
486954
        return Err(RuleNotApplicable);
402
    };
403

            
404
6534
    let Some(exprs) = sum_exprs.as_ref().clone().unwrap_list() else {
405
5010
        return Err(RuleNotApplicable);
406
    };
407

            
408
    // There are no expressions to sum, this is a degenerate case that we can handle by returning a constant 0
409
1524
    if exprs.is_empty() {
410
        return Ok(Reduction::pure(Expr::SATInt(
411
            Metadata::new(),
412
            SATIntEncoding::Direct,
413
            Moo::new(into_matrix_expr!(vec![Expr::Atomic(
414
                Metadata::new(),
415
                Atom::Literal(Literal::Bool(true)),
416
            )])),
417
            (0, 0),
418
        )));
419
1524
    }
420

            
421
1524
    let mut new_symbols = symbols.clone();
422
1524
    let mut new_clauses: Vec<CnfClause> = vec![];
423

            
424
    // Validate all operands are direct SATInts and extract their bit vectors, also calculate a common min and max for all operands to normalize them to the same size by padding with zeroes as needed to simplify the addition logic.
425
522
    let (mut operands, common_min, common_max) =
426
1524
        validate_direct_int_operands(exprs).map_err(|_| RuleNotApplicable)?;
427

            
428
    // Addition is implemented as a series of pairwise additions. The bits of the output are defined by iterating over all possible output values, and for each output value k, ORing together ANDs for each pair of input values i,j where i+j=k. This is effectively a big disjunction of all possible ways to sum to k.
429
10338
    let false_expr = || Expr::Atomic(Metadata::new(), Atom::Literal(Literal::Bool(false)));
430

            
431
522
    let mut acc_bits = operands.remove(0);
432
522
    let mut acc_min = common_min;
433
522
    let mut acc_max = common_max;
434

            
435
588
    for right_bits in operands {
436
588
        let right_min = common_min;
437
588
        let right_max = common_max;
438

            
439
588
        let new_min = acc_min + right_min;
440
588
        let new_max = acc_max + right_max;
441
588
        let mut out_bits = Vec::with_capacity((new_max - new_min + 1) as usize);
442

            
443
10338
        for k in new_min..=new_max {
444
10338
            let mut sum_expr = false_expr();
445

            
446
252810
            for i in acc_min..=acc_max {
447
252810
                let j = k - i;
448
252810
                if j < right_min || j > right_max {
449
131640
                    continue;
450
121170
                }
451

            
452
121170
                let a = acc_bits[(i - acc_min) as usize].clone();
453
121170
                let b = right_bits[(j - right_min) as usize].clone();
454

            
455
121170
                let and_ab = tseytin_and(&vec![a, b], &mut new_clauses, &mut new_symbols);
456
121170
                sum_expr = tseytin_or(&vec![sum_expr, and_ab], &mut new_clauses, &mut new_symbols);
457
            }
458

            
459
10338
            out_bits.push(sum_expr);
460
        }
461

            
462
588
        acc_bits = out_bits;
463
588
        acc_min = new_min;
464
588
        acc_max = new_max;
465
    }
466

            
467
522
    Ok(Reduction::cnf(
468
522
        Expr::SATInt(
469
522
            Metadata::new(),
470
522
            SATIntEncoding::Direct,
471
522
            Moo::new(into_matrix_expr!(acc_bits)),
472
522
            (acc_min, acc_max),
473
522
        ),
474
522
        new_clauses,
475
522
        new_symbols,
476
522
    ))
477
493488
}
478

            
479
/// Matches a `|SATInt|` with an absolute value operation and rewrites it to a direct-encoded absolute-value `SATInt` by grouping input indicator bits by `|value|` and OR-ing each group (named here as buckets) into the corresponding output bit.
480
#[register_rule("SAT_Direct", 9100, [Abs])]
481
493488
fn abs_value_sat_direct(expr: &Expr, symbols: &SymbolTable) -> ApplicationResult {
482
493488
    let Expr::Abs(_, value_expr) = expr else {
483
493464
        return Err(RuleNotApplicable);
484
    };
485

            
486
24
    let (binding, old_min, old_max) =
487
24
        validate_direct_int_operands(vec![value_expr.as_ref().clone()])?;
488

            
489
24
    let [val_bits] = binding.as_slice() else {
490
        return Err(RuleNotApplicable);
491
    };
492

            
493
    // The new range is from the minimum absolute value to the maximum absolute value. The minimum absolute value is either 0 if the old range includes 0, or the smaller of the absolute values of the old min and max if the old range does not include 0. The maximum absolute value is the larger of the absolute values of the old min and max.
494
24
    let new_min = if old_min <= 0 && old_max >= 0 {
495
12
        0
496
    } else {
497
12
        old_min.abs().min(old_max.abs())
498
    };
499
24
    let new_max = old_min.abs().max(old_max.abs());
500

            
501
24
    let mut new_symbols = symbols.clone();
502
24
    let mut new_clauses = vec![];
503

            
504
24
    let bucket_count = (new_max - new_min + 1) as usize;
505
24
    let mut buckets: Vec<Vec<Expr>> = vec![Vec::new(); bucket_count];
506

            
507
    // Iterates over all possible input values, calculate their absolute value, and place them in the corresponding bucket. Each bucket corresponds to a possible output value, and contains the disjunction of all input bits that could produce that output.
508
156
    for value in old_min..=old_max {
509
156
        let input_bit = val_bits[(value - old_min) as usize].clone();
510
156
        let abs_value = value.abs();
511
156
        let bucket_idx = (abs_value - new_min) as usize;
512
156
        buckets[bucket_idx].push(input_bit);
513
156
    }
514

            
515
    // For each bucket, if it's empty then the output bit is false, if it contains one element then the output bit is that element, and if it contains multiple elements then the output bit is the OR of all elements in the bucket.
516
24
    let mut abs_bits = Vec::with_capacity(bucket_count);
517
126
    for bucket in buckets {
518
126
        let out_bit = match bucket.len() {
519
            0 => Expr::Atomic(Metadata::new(), Atom::Literal(Literal::Bool(false))),
520
96
            1 => bucket[0].clone(),
521
30
            _ => tseytin_or(&bucket, &mut new_clauses, &mut new_symbols),
522
        };
523

            
524
126
        abs_bits.push(out_bit);
525
    }
526

            
527
24
    let abs_int = Expr::SATInt(
528
24
        Metadata::new(),
529
24
        SATIntEncoding::Direct,
530
24
        Moo::new(into_matrix_expr!(abs_bits)),
531
24
        (new_min, new_max),
532
24
    );
533

            
534
24
    Ok(Reduction::cnf(abs_int, new_clauses, new_symbols))
535
493488
}