1
use conjure_cp::ast::Expression as Expr;
2
use conjure_cp::ast::{SATIntEncoding, SymbolTable};
3
use conjure_cp::rule_engine::{
4
    ApplicationError::RuleNotApplicable, ApplicationResult, Reduction, register_rule,
5
};
6

            
7
use conjure_cp::ast::AbstractLiteral::Matrix;
8
use conjure_cp::ast::Metadata;
9
use conjure_cp::ast::Moo;
10
use conjure_cp::into_matrix_expr;
11

            
12
use itertools::Itertools;
13

            
14
use super::boolean::{
15
    tseytin_and, tseytin_iff, tseytin_imply, tseytin_mux, tseytin_not, tseytin_or, tseytin_xor,
16
};
17
use super::integer_repr::{bit_magnitude, match_bits_length, validate_log_int_operands};
18

            
19
use conjure_cp::ast::CnfClause;
20

            
21
use std::cmp;
22

            
23
/// Converts an inequality expression between two SATInts to a boolean expression in cnf.
24
///
25
/// ```text
26
/// SATInt(a) </>/<=/>= SATInt(b) ~> Bool
27
///
28
/// ```
29
#[register_rule("SAT_Log", 4100, [Lt, Gt, Leq, Geq])]
30
7776
fn cnf_int_ineq(expr: &Expr, symbols: &SymbolTable) -> ApplicationResult {
31
7776
    let (lhs, rhs, strict) = match expr {
32
66
        Expr::Lt(_, x, y) => (y, x, true),
33
72
        Expr::Gt(_, x, y) => (x, y, true),
34
1104
        Expr::Leq(_, x, y) => (y, x, false),
35
1200
        Expr::Geq(_, x, y) => (x, y, false),
36
5334
        _ => return Err(RuleNotApplicable),
37
    };
38

            
39
2160
    let binding =
40
2442
        validate_log_int_operands(vec![lhs.as_ref().clone(), rhs.as_ref().clone()], None)?;
41
2160
    let [lhs_bits, rhs_bits] = binding.as_slice() else {
42
        return Err(RuleNotApplicable);
43
    };
44

            
45
2160
    let mut new_symbols = symbols.clone();
46
2160
    let mut new_clauses = vec![];
47

            
48
2160
    let output = inequality_boolean(
49
2160
        lhs_bits.clone(),
50
2160
        rhs_bits.clone(),
51
2160
        strict,
52
2160
        &mut new_clauses,
53
2160
        &mut new_symbols,
54
    );
55
2160
    Ok(Reduction::cnf(output, new_clauses, new_symbols))
56
7776
}
57

            
58
/// Converts a = expression between two SATInts to a boolean expression in cnf
59
///
60
/// ```text
61
/// SATInt(a) = SATInt(b) ~> Bool
62
///
63
/// ```
64
#[register_rule("SAT_Log", 9100, [Eq])]
65
410100
fn cnf_int_eq(expr: &Expr, symbols: &SymbolTable) -> ApplicationResult {
66
410100
    let Expr::Eq(_, lhs, rhs) = expr else {
67
409053
        return Err(RuleNotApplicable);
68
    };
69

            
70
288
    let binding =
71
1047
        validate_log_int_operands(vec![lhs.as_ref().clone(), rhs.as_ref().clone()], None)?;
72
288
    let [lhs_bits, rhs_bits] = binding.as_slice() else {
73
        return Err(RuleNotApplicable);
74
    };
75

            
76
288
    let bit_count = lhs_bits.len();
77

            
78
288
    let mut output = true.into();
79
288
    let mut new_symbols = symbols.clone();
80
288
    let mut new_clauses = vec![];
81
    let mut comparison;
82

            
83
1731
    for i in 0..bit_count {
84
1731
        comparison = tseytin_iff(
85
1731
            lhs_bits[i].clone(),
86
1731
            rhs_bits[i].clone(),
87
1731
            &mut new_clauses,
88
1731
            &mut new_symbols,
89
1731
        );
90
1731
        output = tseytin_and(
91
1731
            &vec![comparison, output],
92
1731
            &mut new_clauses,
93
1731
            &mut new_symbols,
94
1731
        );
95
1731
    }
96

            
97
288
    Ok(Reduction::cnf(output, new_clauses, new_symbols))
98
410100
}
99

            
100
/// Converts a != expression between two SATInts to a boolean expression in cnf
101
///
102
/// ```text
103
/// SATInt(a) != SATInt(b) ~> Bool
104
///
105
/// ```
106
#[register_rule("SAT_Log", 4100, [Neq])]
107
7776
fn cnf_int_neq(expr: &Expr, symbols: &SymbolTable) -> ApplicationResult {
108
7776
    let Expr::Neq(_, lhs, rhs) = expr else {
109
7596
        return Err(RuleNotApplicable);
110
    };
111

            
112
180
    let binding =
113
180
        validate_log_int_operands(vec![lhs.as_ref().clone(), rhs.as_ref().clone()], None)?;
114
180
    let [lhs_bits, rhs_bits] = binding.as_slice() else {
115
        return Err(RuleNotApplicable);
116
    };
117

            
118
180
    let bit_count = lhs_bits.len();
119

            
120
180
    let mut output = false.into();
121
180
    let mut new_symbols = symbols.clone();
122
180
    let mut new_clauses = vec![];
123
    let mut comparison;
124

            
125
576
    for i in 0..bit_count {
126
576
        comparison = tseytin_xor(
127
576
            lhs_bits[i].clone(),
128
576
            rhs_bits[i].clone(),
129
576
            &mut new_clauses,
130
576
            &mut new_symbols,
131
576
        );
132
576
        output = tseytin_or(
133
576
            &vec![comparison, output],
134
576
            &mut new_clauses,
135
576
            &mut new_symbols,
136
576
        );
137
576
    }
138

            
139
180
    Ok(Reduction::cnf(output, new_clauses, new_symbols))
140
7776
}
141

            
142
// Creates a boolean expression for > or >=
143
// a > b or a >= b
144
// This can also be used for < and <= by reversing the order of the inputs
145
// Returns result, new symbol table, new clauses
146
2256
fn inequality_boolean(
147
2256
    a: Vec<Expr>,
148
2256
    b: Vec<Expr>,
149
2256
    strict: bool,
150
2256
    clauses: &mut Vec<CnfClause>,
151
2256
    symbols: &mut SymbolTable,
152
2256
) -> Expr {
153
    let mut notb;
154
    let mut output;
155

            
156
2256
    if strict {
157
216
        notb = tseytin_not(b[0].clone(), clauses, symbols);
158
216
        output = tseytin_and(&vec![a[0].clone(), notb], clauses, symbols);
159
2040
    } else {
160
2040
        output = tseytin_imply(b[0].clone(), a[0].clone(), clauses, symbols);
161
2040
    }
162

            
163
    //TODO: There may be room for simplification, and constant optimization
164

            
165
2256
    let bit_count = a.len();
166

            
167
    let mut lhs;
168
    let mut rhs;
169
    let mut iff;
170
4260
    for n in 1..(bit_count - 1) {
171
4260
        notb = tseytin_not(b[n].clone(), clauses, symbols);
172
4260
        lhs = tseytin_and(&vec![a[n].clone(), notb.clone()], clauses, symbols);
173
4260
        iff = tseytin_iff(a[n].clone(), b[n].clone(), clauses, symbols);
174
4260
        rhs = tseytin_and(&vec![iff.clone(), output.clone()], clauses, symbols);
175
4260
        output = tseytin_or(&vec![lhs.clone(), rhs.clone()], clauses, symbols);
176
4260
    }
177

            
178
    // final bool is the sign bit and should be handled inversely
179
2256
    let nota = tseytin_not(a[bit_count - 1].clone(), clauses, symbols);
180
2256
    lhs = tseytin_and(&vec![nota, b[bit_count - 1].clone()], clauses, symbols);
181
2256
    iff = tseytin_iff(
182
2256
        a[bit_count - 1].clone(),
183
2256
        b[bit_count - 1].clone(),
184
2256
        clauses,
185
2256
        symbols,
186
    );
187
2256
    rhs = tseytin_and(&vec![iff, output.clone()], clauses, symbols);
188
2256
    output = tseytin_or(&vec![lhs, rhs], clauses, symbols);
189

            
190
2256
    output
191
2256
}
192

            
193
/// Converts sum of SATInts to a single SATInt
194
///
195
/// ```text
196
/// Sum(SATInt(a), SATInt(b), ...) ~> SATInt(c)
197
///
198
/// ```
199
#[register_rule("SAT_Log", 4100, [Sum])]
200
7776
fn cnf_int_sum(expr: &Expr, symbols: &SymbolTable) -> ApplicationResult {
201
7776
    let Expr::Sum(_, exprs) = expr else {
202
7536
        return Err(RuleNotApplicable);
203
    };
204

            
205
240
    let Expr::AbstractLiteral(_, Matrix(exprs_list, _)) = exprs.as_ref() else {
206
        return Err(RuleNotApplicable);
207
    };
208

            
209
240
    let ranges: Result<Vec<_>, _> = exprs_list
210
240
        .iter()
211
480
        .map(|e| match e {
212
420
            Expr::SATInt(_, _, _, x) => Ok(x),
213
60
            _ => Err(RuleNotApplicable),
214
480
        })
215
240
        .collect();
216

            
217
240
    let ranges = ranges?;
218

            
219
180
    let min = ranges.iter().map(|(a, _)| *a).sum();
220
180
    let max = ranges.iter().map(|(_, a)| *a).sum();
221

            
222
180
    let output_size = cmp::max(bit_magnitude(min), bit_magnitude(max));
223

            
224
    // Check operands are valid log ints
225
180
    let mut exprs_bits =
226
180
        validate_log_int_operands(exprs_list.clone(), Some(output_size.try_into().unwrap()))?;
227

            
228
180
    let mut new_symbols = symbols.clone();
229
    let mut values;
230
180
    let mut new_clauses = vec![];
231

            
232
372
    while exprs_bits.len() > 1 {
233
192
        let mut next = Vec::with_capacity(exprs_bits.len().div_ceil(2));
234
192
        let mut iter = exprs_bits.into_iter();
235

            
236
396
        while let Some(a) = iter.next() {
237
204
            if let Some(b) = iter.next() {
238
192
                values = tseytin_int_adder(&a, &b, output_size, &mut new_clauses, &mut new_symbols);
239
192
                next.push(values);
240
192
            } else {
241
12
                next.push(a);
242
12
            }
243
        }
244

            
245
192
        exprs_bits = next;
246
    }
247

            
248
180
    let result = exprs_bits.pop().unwrap();
249

            
250
180
    Ok(Reduction::cnf(
251
180
        Expr::SATInt(
252
180
            Metadata::new(),
253
180
            SATIntEncoding::Log,
254
180
            Moo::new(into_matrix_expr!(result)),
255
180
            (min, max),
256
180
        ),
257
180
        new_clauses,
258
180
        new_symbols,
259
180
    ))
260
7776
}
261

            
262
/// Returns result, new symbol table, new clauses
263
/// This function expects bits to match the lengths of x and y
264
252
fn tseytin_int_adder(
265
252
    x: &[Expr],
266
252
    y: &[Expr],
267
252
    bits: usize,
268
252
    clauses: &mut Vec<CnfClause>,
269
252
    symbols: &mut SymbolTable,
270
252
) -> Vec<Expr> {
271
    //TODO: Optimizing for constants
272
252
    let (mut result, mut carry) = tseytin_half_adder(x[0].clone(), y[0].clone(), clauses, symbols);
273

            
274
252
    let mut output = vec![result];
275
1380
    for i in 1..bits {
276
1380
        (result, carry) =
277
1380
            tseytin_full_adder(x[i].clone(), y[i].clone(), carry.clone(), clauses, symbols);
278
1380
        output.push(result);
279
1380
    }
280

            
281
252
    output
282
252
}
283

            
284
/// This function adds two booleans and a carry boolean using the full-adder logic circuit, it is intended for use in a binary adder.
285
1380
fn tseytin_full_adder(
286
1380
    a: Expr,
287
1380
    b: Expr,
288
1380
    carry: Expr,
289
1380
    clauses: &mut Vec<CnfClause>,
290
1380
    symbols: &mut SymbolTable,
291
1380
) -> (Expr, Expr) {
292
1380
    let axorb = tseytin_xor(a.clone(), b.clone(), clauses, symbols);
293
1380
    let result = tseytin_xor(axorb.clone(), carry.clone(), clauses, symbols);
294
1380
    let aandb = tseytin_and(&vec![a, b], clauses, symbols);
295
1380
    let carryandaxorb = tseytin_and(&vec![carry, axorb], clauses, symbols);
296
1380
    let carryout = tseytin_or(&vec![aandb, carryandaxorb], clauses, symbols);
297

            
298
1380
    (result, carryout)
299
1380
}
300

            
301
/// This function adds two booleans using the half-adder logic circuit, it is intended for use in a binary adder.
302
252
fn tseytin_half_adder(
303
252
    a: Expr,
304
252
    b: Expr,
305
252
    clauses: &mut Vec<CnfClause>,
306
252
    symbols: &mut SymbolTable,
307
252
) -> (Expr, Expr) {
308
252
    let result = tseytin_xor(a.clone(), b.clone(), clauses, symbols);
309
252
    let carry = tseytin_and(&vec![a, b], clauses, symbols);
310

            
311
252
    (result, carry)
312
252
}
313

            
314
/// this function is for specifically adding a power of two constant to a cnf int.
315
66
fn tseytin_add_two_power(
316
66
    expr: &[Expr],
317
66
    exponent: usize,
318
66
    bits: usize,
319
66
    clauses: &mut Vec<CnfClause>,
320
66
    symbols: &mut SymbolTable,
321
66
) -> Vec<Expr> {
322
66
    let mut result = vec![];
323
66
    let mut product = expr[exponent].clone();
324

            
325
66
    for item in expr.iter().take(exponent) {
326
        result.push(item.clone());
327
    }
328

            
329
66
    result.push(tseytin_not(expr[exponent].clone(), clauses, symbols));
330

            
331
174
    for item in expr.iter().take(bits).skip(exponent + 1) {
332
174
        result.push(tseytin_xor(product.clone(), item.clone(), clauses, symbols));
333
174
        product = tseytin_and(&vec![product, item.clone()], clauses, symbols);
334
174
    }
335

            
336
66
    result
337
66
}
338

            
339
/// This function multiplies two binary values using the shift-add multiplication algorithm.
340
12
fn cnf_shift_add_multiply(
341
12
    x: &[Expr],
342
12
    y: &[Expr],
343
12
    bits: usize,
344
12
    clauses: &mut Vec<CnfClause>,
345
12
    symbols: &mut SymbolTable,
346
12
) -> Vec<Expr> {
347
12
    let mut x = x.to_owned();
348
12
    let mut y = y.to_owned();
349

            
350
    //TODO Optimizing for constants
351
    //TODO Optimize addition for i left shifted values - skip first i bits
352

            
353
    // extend sign bits of operands to 2*`bits`
354
12
    x.extend(std::iter::repeat_n(x[bits - 1].clone(), bits));
355
12
    y.extend(std::iter::repeat_n(y[bits - 1].clone(), bits));
356

            
357
12
    let mut s: Vec<Expr> = vec![];
358
    let mut x_0andy_i;
359

            
360
144
    for bit in &y {
361
144
        x_0andy_i = tseytin_and(&vec![x[0].clone(), bit.clone()], clauses, symbols);
362
144
        s.push(x_0andy_i);
363
144
    }
364

            
365
    let mut sum;
366
    let mut if_true;
367
    let mut not_x_n;
368
    let mut if_false;
369

            
370
60
    for item in x.iter().take(bits).skip(1) {
371
        // y << 1
372
660
        for i in (1..bits * 2).rev() {
373
660
            y[i] = y[i - 1].clone();
374
660
        }
375
60
        y[0] = false.into();
376

            
377
        // TODO switch to multiplexer
378
        // TODO Add negatives support once MUX is added
379
60
        sum = tseytin_int_adder(&s, &y, bits * 2, clauses, symbols);
380
60
        not_x_n = tseytin_not(item.clone(), clauses, symbols);
381

            
382
720
        for i in 0..(bits * 2) {
383
720
            if_true = tseytin_and(&vec![item.clone(), sum[i].clone()], clauses, symbols);
384
720
            if_false = tseytin_and(&vec![not_x_n.clone(), s[i].clone()], clauses, symbols);
385
720
            s[i] = tseytin_or(&vec![if_true.clone(), if_false.clone()], clauses, symbols);
386
720
        }
387
    }
388

            
389
12
    s
390
12
}
391

            
392
/// This function calculates the range of the product of multiple integers.
393
/// E.g.
394
/// a : [2, 5], b : [-1, 2], c : [-10, -6], d : [0, 3]
395
/// a * b * c *d : [-300, 150]
396
12
fn product_of_ranges(ranges: Vec<&(i32, i32)>) -> (i32, i32) {
397
12
    if ranges.is_empty() {
398
        return (1, 1); // product of zero numbers = 1
399
12
    }
400

            
401
12
    let &(mut min_prod, mut max_prod) = ranges[0];
402

            
403
12
    for &(a, b) in &ranges[1..] {
404
12
        let candidates = [min_prod * a, min_prod * b, max_prod * a, max_prod * b];
405
12
        min_prod = *candidates.iter().min().unwrap();
406
12
        max_prod = *candidates.iter().max().unwrap();
407
12
    }
408

            
409
12
    (min_prod, max_prod)
410
12
}
411

            
412
/// Converts product of SATInts to a single SATInt
413
///
414
/// ```text
415
/// Product(SATInt(a), SATInt(b), ...) ~> SATInt(c)
416
///
417
/// ```
418
#[register_rule("SAT_Log", 9000, [Product])]
419
292899
fn cnf_int_product(expr: &Expr, symbols: &SymbolTable) -> ApplicationResult {
420
292899
    let Expr::Product(_, exprs) = expr else {
421
292887
        return Err(RuleNotApplicable);
422
    };
423

            
424
12
    let Expr::AbstractLiteral(_, Matrix(exprs_list, _)) = exprs.as_ref() else {
425
        return Err(RuleNotApplicable);
426
    };
427

            
428
12
    let ranges: Result<Vec<_>, _> = exprs_list
429
12
        .iter()
430
24
        .map(|e| match e {
431
24
            Expr::SATInt(_, _, _, x) => Ok(x),
432
            _ => Err(RuleNotApplicable),
433
24
        })
434
12
        .collect();
435

            
436
12
    let ranges = ranges?; // propagate error if any
437

            
438
12
    let (min, max) = product_of_ranges(ranges.clone());
439

            
440
12
    let exprs_bits = validate_log_int_operands(exprs_list.clone(), None)?;
441

            
442
12
    let mut new_symbols = symbols.clone();
443
12
    let mut new_clauses = vec![];
444

            
445
12
    let (result, _) = exprs_bits
446
12
        .iter()
447
12
        .cloned()
448
12
        .zip(ranges.into_iter().copied())
449
12
        .reduce(|lhs, rhs| {
450
            // Make both bit vectors the same length
451
12
            let (lhs_bits, rhs_bits) = match_bits_length(lhs.0.clone(), rhs.0.clone());
452

            
453
            // Multiply operands
454
12
            let mut values = cnf_shift_add_multiply(
455
12
                &lhs_bits,
456
12
                &rhs_bits,
457
12
                lhs_bits.len(),
458
12
                &mut new_clauses,
459
12
                &mut new_symbols,
460
            );
461

            
462
            // Determine new range of result
463
12
            let (mut cum_min, mut cum_max) = lhs.1;
464
12
            let candidates = [
465
12
                cum_min * rhs.1.0,
466
12
                cum_min * rhs.1.1,
467
12
                cum_max * rhs.1.0,
468
12
                cum_max * rhs.1.1,
469
12
            ];
470
12
            cum_min = *candidates.iter().min().unwrap();
471
12
            cum_max = *candidates.iter().max().unwrap();
472

            
473
12
            let new_bit_count = bit_magnitude(cum_min).max(bit_magnitude(cum_max));
474
12
            values.truncate(new_bit_count);
475

            
476
12
            (values, (cum_min, cum_max))
477
12
        })
478
12
        .unwrap();
479

            
480
12
    Ok(Reduction::cnf(
481
12
        Expr::SATInt(
482
12
            Metadata::new(),
483
12
            SATIntEncoding::Log,
484
12
            Moo::new(into_matrix_expr!(result)),
485
12
            (min, max),
486
12
        ),
487
12
        new_clauses,
488
12
        new_symbols,
489
12
    ))
490
292899
}
491

            
492
/// Converts negation of a SATInt to a SATInt
493
///
494
/// ```text
495
/// -SATInt(a) ~> SATInt(b)
496
///
497
/// ```
498
#[register_rule("SAT_Log", 4100, [Neg])]
499
7776
fn cnf_int_neg(expr: &Expr, symbols: &SymbolTable) -> ApplicationResult {
500
7776
    let Expr::Neg(_, expr) = expr else {
501
7734
        return Err(RuleNotApplicable);
502
    };
503

            
504
42
    let Expr::SATInt(_, _, _, (min, max)) = expr.as_ref() else {
505
        return Err(RuleNotApplicable);
506
    };
507

            
508
42
    let binding = validate_log_int_operands(vec![expr.as_ref().clone()], None)?;
509
42
    let [bits] = binding.as_slice() else {
510
        return Err(RuleNotApplicable);
511
    };
512

            
513
42
    let mut new_clauses = vec![];
514
42
    let mut new_symbols = symbols.clone();
515

            
516
42
    let result = tseytin_negate(bits, bits.len(), &mut new_clauses, &mut new_symbols);
517

            
518
42
    Ok(Reduction::cnf(
519
42
        Expr::SATInt(
520
42
            Metadata::new(),
521
42
            SATIntEncoding::Log,
522
42
            Moo::new(into_matrix_expr!(result)),
523
42
            (-max, -min),
524
42
        ),
525
42
        new_clauses,
526
42
        new_symbols,
527
42
    ))
528
7776
}
529

            
530
42
fn tseytin_negate(
531
42
    expr: &Vec<Expr>,
532
42
    bits: usize,
533
42
    clauses: &mut Vec<CnfClause>,
534
42
    symbols: &mut SymbolTable,
535
42
) -> Vec<Expr> {
536
42
    let mut result = vec![];
537
    // invert bits
538
120
    for bit in expr {
539
120
        result.push(tseytin_not(bit.clone(), clauses, symbols));
540
120
    }
541

            
542
    // add one
543
42
    result = tseytin_add_two_power(&result, 0, bits, clauses, symbols);
544

            
545
42
    result
546
42
}
547

            
548
/// Converts min of SATInts to a single SATInt
549
///
550
/// ```text
551
/// Min(SATInt(a), SATInt(b), ...) ~> SATInt(c)
552
///
553
/// ```
554
#[register_rule("SAT_Log", 4100, [Min])]
555
7776
fn cnf_int_min(expr: &Expr, symbols: &SymbolTable) -> ApplicationResult {
556
7776
    let Expr::Min(_, exprs) = expr else {
557
7728
        return Err(RuleNotApplicable);
558
    };
559

            
560
48
    let Expr::AbstractLiteral(_, Matrix(exprs_list, _)) = exprs.as_ref() else {
561
        return Err(RuleNotApplicable);
562
    };
563

            
564
48
    let ranges: Result<Vec<_>, _> = exprs_list
565
48
        .iter()
566
96
        .map(|e| match e {
567
96
            Expr::SATInt(_, _, _, x) => Ok(x),
568
            _ => Err(RuleNotApplicable),
569
96
        })
570
48
        .collect();
571

            
572
48
    let ranges = ranges?; // propagate error if any
573

            
574
    // Is this optimal?
575
48
    let min = ranges.iter().map(|(a, _)| *a).min().unwrap();
576
48
    let max = ranges.iter().map(|(_, b)| *b).min().unwrap();
577

            
578
48
    let mut exprs_bits = validate_log_int_operands(exprs_list.clone(), None)?;
579

            
580
48
    let mut new_symbols = symbols.clone();
581
    let mut values;
582
48
    let mut new_clauses = vec![];
583

            
584
96
    while exprs_bits.len() > 1 {
585
48
        let mut next = Vec::with_capacity(exprs_bits.len().div_ceil(2));
586
48
        let mut iter = exprs_bits.into_iter();
587

            
588
96
        while let Some(a) = iter.next() {
589
48
            if let Some(b) = iter.next() {
590
48
                values = tseytin_binary_min_max(&a, &b, true, &mut new_clauses, &mut new_symbols);
591
48
                next.push(values);
592
48
            } else {
593
                next.push(a);
594
            }
595
        }
596

            
597
48
        exprs_bits = next;
598
    }
599

            
600
48
    let result = exprs_bits.pop().unwrap();
601

            
602
48
    Ok(Reduction::cnf(
603
48
        Expr::SATInt(
604
48
            Metadata::new(),
605
48
            SATIntEncoding::Log,
606
48
            Moo::new(into_matrix_expr!(result)),
607
48
            (min, max),
608
48
        ),
609
48
        new_clauses,
610
48
        new_symbols,
611
48
    ))
612
7776
}
613

            
614
/// General function for getting the min or max of two log integers.
615
96
fn tseytin_binary_min_max(
616
96
    x: &[Expr],
617
96
    y: &[Expr],
618
96
    min: bool,
619
96
    clauses: &mut Vec<CnfClause>,
620
96
    symbols: &mut SymbolTable,
621
96
) -> Vec<Expr> {
622
96
    let mask = if min {
623
        // mask is 1 if x > y
624
48
        inequality_boolean(x.to_owned(), y.to_owned(), true, clauses, symbols)
625
    } else {
626
        // flip the args if getting maximum x < y -> 1
627
48
        inequality_boolean(y.to_owned(), x.to_owned(), true, clauses, symbols)
628
    };
629

            
630
96
    tseytin_select_array(mask, x, y, clauses, symbols)
631
96
}
632

            
633
// Selects between two boolean vectors depending on a condition (both vectors must be the same length)
634
/// cond ? b : a
635
///
636
/// cond = 1 => b
637
/// cond = 0 => a
638
96
fn tseytin_select_array(
639
96
    cond: Expr,
640
96
    a: &[Expr],
641
96
    b: &[Expr],
642
96
    clauses: &mut Vec<CnfClause>,
643
96
    symbols: &mut SymbolTable,
644
96
) -> Vec<Expr> {
645
96
    assert_eq!(
646
96
        a.len(),
647
96
        b.len(),
648
        "Input vectors 'a' and 'b' must have the same length"
649
    );
650

            
651
96
    let mut out = vec![];
652

            
653
96
    let bit_count = a.len();
654

            
655
360
    for i in 0..bit_count {
656
360
        out.push(tseytin_mux(
657
360
            cond.clone(),
658
360
            a[i].clone(),
659
360
            b[i].clone(),
660
360
            clauses,
661
360
            symbols,
662
360
        ));
663
360
    }
664

            
665
96
    out
666
96
}
667

            
668
/// Converts max of SATInts to a single SATInt
669
///
670
/// ```text
671
/// Max(SATInt(a), SATInt(b), ...) ~> SATInt(c)
672
///
673
/// ```
674
#[register_rule("SAT_Log", 4100, [Max])]
675
7776
fn cnf_int_max(expr: &Expr, symbols: &SymbolTable) -> ApplicationResult {
676
7776
    let Expr::Max(_, exprs) = expr else {
677
7728
        return Err(RuleNotApplicable);
678
    };
679

            
680
48
    let Expr::AbstractLiteral(_, Matrix(exprs_list, _)) = exprs.as_ref() else {
681
        return Err(RuleNotApplicable);
682
    };
683

            
684
48
    let ranges: Result<Vec<_>, _> = exprs_list
685
48
        .iter()
686
96
        .map(|e| match e {
687
96
            Expr::SATInt(_, _, _, x) => Ok(x),
688
            _ => Err(RuleNotApplicable),
689
96
        })
690
48
        .collect();
691

            
692
48
    let ranges = ranges?; // propagate error if any
693

            
694
    // Is this optimal?
695
48
    let min = ranges.iter().map(|(a, _)| *a).max().unwrap();
696
48
    let max = ranges.iter().map(|(_, b)| *b).max().unwrap();
697

            
698
48
    let mut exprs_bits = validate_log_int_operands(exprs_list.clone(), None)?;
699

            
700
48
    let mut new_symbols = symbols.clone();
701
    let mut values;
702
48
    let mut new_clauses = vec![];
703

            
704
96
    while exprs_bits.len() > 1 {
705
48
        let mut next = Vec::with_capacity(exprs_bits.len().div_ceil(2));
706
48
        let mut iter = exprs_bits.into_iter();
707

            
708
96
        while let Some(a) = iter.next() {
709
48
            if let Some(b) = iter.next() {
710
48
                values = tseytin_binary_min_max(&a, &b, false, &mut new_clauses, &mut new_symbols);
711
48
                next.push(values);
712
48
            } else {
713
                next.push(a);
714
            }
715
        }
716

            
717
48
        exprs_bits = next;
718
    }
719

            
720
48
    let result = exprs_bits.pop().unwrap();
721

            
722
48
    Ok(Reduction::cnf(
723
48
        Expr::SATInt(
724
48
            Metadata::new(),
725
48
            SATIntEncoding::Log,
726
48
            Moo::new(into_matrix_expr!(result)),
727
48
            (min, max),
728
48
        ),
729
48
        new_clauses,
730
48
        new_symbols,
731
48
    ))
732
7776
}
733

            
734
/// Converts Abs of a SATInt to a SATInt
735
///
736
/// ```text
737
/// |SATInt(a)| ~> SATInt(b)
738
///
739
/// ```
740
#[register_rule("SAT_Log", 4100, [Abs])]
741
7776
fn cnf_int_abs(expr: &Expr, symbols: &SymbolTable) -> ApplicationResult {
742
7776
    let Expr::Abs(_, expr) = expr else {
743
7728
        return Err(RuleNotApplicable);
744
    };
745

            
746
48
    let Expr::SATInt(_, _, _, (min, max)) = expr.as_ref() else {
747
24
        return Err(RuleNotApplicable);
748
    };
749

            
750
24
    let range = (
751
24
        cmp::max(0, cmp::max(*min, -*max)),
752
24
        cmp::max(min.abs(), max.abs()),
753
24
    );
754

            
755
24
    let binding = validate_log_int_operands(vec![expr.as_ref().clone()], None)?;
756
24
    let [bits] = binding.as_slice() else {
757
        return Err(RuleNotApplicable);
758
    };
759

            
760
24
    let mut new_clauses = vec![];
761
24
    let mut new_symbols = symbols.clone();
762

            
763
24
    let mut result = vec![];
764

            
765
    // How does this handle negatives edge cases: -(-8) = 8, an extra bit is needed
766

            
767
    // invert bits
768
120
    for bit in bits {
769
120
        result.push(tseytin_not(bit.clone(), &mut new_clauses, &mut new_symbols));
770
120
    }
771

            
772
24
    let bit_count = result.len();
773

            
774
    // add one
775
24
    result = tseytin_add_two_power(&result, 0, bit_count, &mut new_clauses, &mut new_symbols);
776

            
777
120
    for i in 0..bit_count {
778
120
        result[i] = tseytin_mux(
779
120
            bits[bit_count - 1].clone(),
780
120
            bits[i].clone(),
781
120
            result[i].clone(),
782
120
            &mut new_clauses,
783
120
            &mut new_symbols,
784
120
        )
785
    }
786

            
787
24
    Ok(Reduction::cnf(
788
24
        Expr::SATInt(
789
24
            Metadata::new(),
790
24
            SATIntEncoding::Log,
791
24
            Moo::new(into_matrix_expr!(result)),
792
24
            range,
793
24
        ),
794
24
        new_clauses,
795
24
        new_symbols,
796
24
    ))
797
7776
}
798

            
799
/// Converts SafeDiv of SATInts to a single SATInt
800
///
801
/// ```text
802
/// SafeDiv(SATInt(a), SATInt(b)) ~> SATInt(c)
803
///
804
/// ```
805
#[register_rule("SAT_Log", 4100, [SafeDiv])]
806
7776
fn cnf_int_safediv(expr: &Expr, symbols: &SymbolTable) -> ApplicationResult {
807
    // Using "Restoring division" algorithm
808
    // https://en.wikipedia.org/wiki/Division_algorithm#Restoring_division
809
7776
    let Expr::SafeDiv(_, numer, denom) = expr else {
810
7776
        return Err(RuleNotApplicable);
811
    };
812

            
813
    let Expr::SATInt(_, _, _, (numer_min, numer_max)) = numer.as_ref() else {
814
        return Err(RuleNotApplicable);
815
    };
816

            
817
    let Expr::SATInt(_, _, _, (denom_min, denom_max)) = denom.as_ref() else {
818
        return Err(RuleNotApplicable);
819
    };
820

            
821
    let candidates = [
822
        numer_min / denom_min,
823
        numer_min / denom_max,
824
        numer_max / denom_min,
825
        numer_max / denom_max,
826
    ];
827

            
828
    let min = *candidates.iter().min().unwrap();
829
    let max = *candidates.iter().max().unwrap();
830

            
831
    let binding =
832
        validate_log_int_operands(vec![numer.as_ref().clone(), denom.as_ref().clone()], None)?;
833
    let [numer_bits, denom_bits] = binding.as_slice() else {
834
        return Err(RuleNotApplicable);
835
    };
836

            
837
    let bit_count = numer_bits.len();
838

            
839
    // TODO: Separate into division/mod function
840

            
841
    let mut new_symbols = symbols.clone();
842
    let mut new_clauses = vec![];
843
    let mut quotient = vec![false.into(); bit_count];
844

            
845
    let minus_numer = tseytin_negate(
846
        &numer_bits.clone(),
847
        bit_count,
848
        &mut new_clauses,
849
        &mut new_symbols,
850
    );
851
    let minus_denom = tseytin_negate(
852
        &denom_bits.clone(),
853
        bit_count,
854
        &mut new_clauses,
855
        &mut new_symbols,
856
    );
857

            
858
    let sign_bit = tseytin_xor(
859
        numer_bits[bit_count - 1].clone(),
860
        denom_bits[bit_count - 1].clone(),
861
        &mut new_clauses,
862
        &mut new_symbols,
863
    );
864

            
865
    let numer_bits = tseytin_select_array(
866
        numer_bits[bit_count - 1].clone(),
867
        &numer_bits.clone(),
868
        &minus_numer,
869
        &mut new_clauses,
870
        &mut new_symbols,
871
    );
872
    let denom_bits = tseytin_select_array(
873
        denom_bits[bit_count - 1].clone(),
874
        &denom_bits.clone(),
875
        &minus_denom,
876
        &mut new_clauses,
877
        &mut new_symbols,
878
    );
879

            
880
    let mut r = numer_bits;
881
    r.extend(std::iter::repeat_n(r[bit_count - 1].clone(), bit_count));
882
    let mut d = std::iter::repeat_n(false.into(), bit_count).collect_vec();
883
    d.extend(denom_bits);
884

            
885
    let minus_d = tseytin_negate(
886
        &d.clone(),
887
        2 * bit_count,
888
        &mut new_clauses,
889
        &mut new_symbols,
890
    );
891
    let mut rminusd;
892

            
893
    for i in (0..bit_count).rev() {
894
        // r << 1
895
        for j in (1..bit_count * 2).rev() {
896
            r[j] = r[j - 1].clone();
897
        }
898
        r[0] = false.into();
899

            
900
        rminusd = tseytin_int_adder(
901
            &r.clone(),
902
            &minus_d.clone(),
903
            2 * bit_count,
904
            &mut new_clauses,
905
            &mut new_symbols,
906
        );
907

            
908
        // TODO: For mod don't calculate on final iter
909
        quotient[i] = tseytin_not(
910
            // q[i] = inverse of sign bit - 1 if positive, 0 if negative
911
            rminusd[2 * bit_count - 1].clone(),
912
            &mut new_clauses,
913
            &mut new_symbols,
914
        );
915

            
916
        // TODO: For div don't calculate on final iter
917
        for j in 0..(2 * bit_count) {
918
            r[j] = tseytin_mux(
919
                quotient[i].clone(),
920
                r[j].clone(),       // use r if negative
921
                rminusd[j].clone(), // use r-d if positive
922
                &mut new_clauses,
923
                &mut new_symbols,
924
            );
925
        }
926
    }
927

            
928
    let minus_quotient = tseytin_negate(
929
        &quotient.clone(),
930
        bit_count,
931
        &mut new_clauses,
932
        &mut new_symbols,
933
    );
934

            
935
    let out = tseytin_select_array(
936
        sign_bit,
937
        &quotient,
938
        &minus_quotient,
939
        &mut new_clauses,
940
        &mut new_symbols,
941
    );
942

            
943
    Ok(Reduction::cnf(
944
        Expr::SATInt(
945
            Metadata::new(),
946
            SATIntEncoding::Log,
947
            Moo::new(into_matrix_expr!(out)),
948
            (min, max),
949
        ),
950
        new_clauses,
951
        new_symbols,
952
    ))
953
7776
}
954

            
955
/*
956
/// Converts SafeMod of SATInts to a single SATInt
957
///
958
/// ```text
959
/// SafeMod(SATInt(a), SATInt(b)) ~> SATInt(c)
960
///
961
/// ```
962
#[register_rule("SAT_Log", 4100, [SafeMod])]
963
fn cnf_int_safemod(expr: &Expr, _: &SymbolTable) -> ApplicationResult {}
964

            
965
/// Converts SafePow of SATInts to a single SATInt
966
///
967
/// ```text
968
/// SafePow(SATInt(a), SATInt(b)) ~> SATInt(c)
969
///
970
/// ```
971
#[register_rule("SAT", 4100, [SafePow])]
972
fn cnf_int_safepow(expr: &Expr, _: &SymbolTable) -> ApplicationResult {
973
    // use 'Exponentiation by squaring'
974
}
975
*/