1
use conjure_cp::ast::Expression as Expr;
2
use conjure_cp::ast::{SATIntEncoding, SymbolTable};
3
use conjure_cp::rule_engine::{
4
    ApplicationError::RuleNotApplicable, ApplicationResult, Reduction, register_rule,
5
};
6

            
7
use conjure_cp::ast::AbstractLiteral::Matrix;
8
use conjure_cp::ast::Metadata;
9
use conjure_cp::ast::Moo;
10
use conjure_cp::into_matrix_expr;
11

            
12
use itertools::Itertools;
13

            
14
use super::boolean::{
15
    tseytin_and, tseytin_iff, tseytin_imply, tseytin_mux, tseytin_not, tseytin_or, tseytin_xor,
16
};
17
use super::integer_repr::{bit_magnitude, match_bits_length, validate_log_int_operands};
18

            
19
use conjure_cp::ast::CnfClause;
20

            
21
use std::cmp;
22

            
23
/// Converts an inequality expression between two SATInts to a boolean expression in cnf.
24
///
25
/// ```text
26
/// SATInt(a) </>/<=/>= SATInt(b) ~> Bool
27
///
28
/// ```
29
#[register_rule(("SAT_Log", 4100))]
30
1464
fn cnf_int_ineq(expr: &Expr, symbols: &SymbolTable) -> ApplicationResult {
31
1464
    let (lhs, rhs, strict) = match expr {
32
21
        Expr::Lt(_, x, y) => (y, x, true),
33
18
        Expr::Gt(_, x, y) => (x, y, true),
34
216
        Expr::Leq(_, x, y) => (y, x, false),
35
186
        Expr::Geq(_, x, y) => (x, y, false),
36
1023
        _ => return Err(RuleNotApplicable),
37
    };
38

            
39
393
    let binding =
40
441
        validate_log_int_operands(vec![lhs.as_ref().clone(), rhs.as_ref().clone()], None)?;
41
393
    let [lhs_bits, rhs_bits] = binding.as_slice() else {
42
        return Err(RuleNotApplicable);
43
    };
44

            
45
393
    let mut new_symbols = symbols.clone();
46
393
    let mut new_clauses = vec![];
47

            
48
393
    let output = inequality_boolean(
49
393
        lhs_bits.clone(),
50
393
        rhs_bits.clone(),
51
393
        strict,
52
393
        &mut new_clauses,
53
393
        &mut new_symbols,
54
    );
55
393
    Ok(Reduction::cnf(output, new_clauses, new_symbols))
56
1464
}
57

            
58
/// Converts a = expression between two SATInts to a boolean expression in cnf
59
///
60
/// ```text
61
/// SATInt(a) = SATInt(b) ~> Bool
62
///
63
/// ```
64
#[register_rule(("SAT_Log", 9100))]
65
82401
fn cnf_int_eq(expr: &Expr, symbols: &SymbolTable) -> ApplicationResult {
66
82401
    let Expr::Eq(_, lhs, rhs) = expr else {
67
82140
        return Err(RuleNotApplicable);
68
    };
69

            
70
69
    let binding =
71
261
        validate_log_int_operands(vec![lhs.as_ref().clone(), rhs.as_ref().clone()], None)?;
72
69
    let [lhs_bits, rhs_bits] = binding.as_slice() else {
73
        return Err(RuleNotApplicable);
74
    };
75

            
76
69
    let bit_count = lhs_bits.len();
77

            
78
69
    let mut output = true.into();
79
69
    let mut new_symbols = symbols.clone();
80
69
    let mut new_clauses = vec![];
81
    let mut comparison;
82

            
83
420
    for i in 0..bit_count {
84
420
        comparison = tseytin_iff(
85
420
            lhs_bits[i].clone(),
86
420
            rhs_bits[i].clone(),
87
420
            &mut new_clauses,
88
420
            &mut new_symbols,
89
420
        );
90
420
        output = tseytin_and(
91
420
            &vec![comparison, output],
92
420
            &mut new_clauses,
93
420
            &mut new_symbols,
94
420
        );
95
420
    }
96

            
97
69
    Ok(Reduction::cnf(output, new_clauses, new_symbols))
98
82401
}
99

            
100
/// Converts a != expression between two SATInts to a boolean expression in cnf
101
///
102
/// ```text
103
/// SATInt(a) != SATInt(b) ~> Bool
104
///
105
/// ```
106
#[register_rule(("SAT_Log", 4100))]
107
1464
fn cnf_int_neq(expr: &Expr, symbols: &SymbolTable) -> ApplicationResult {
108
1464
    let Expr::Neq(_, lhs, rhs) = expr else {
109
1419
        return Err(RuleNotApplicable);
110
    };
111

            
112
45
    let binding =
113
45
        validate_log_int_operands(vec![lhs.as_ref().clone(), rhs.as_ref().clone()], None)?;
114
45
    let [lhs_bits, rhs_bits] = binding.as_slice() else {
115
        return Err(RuleNotApplicable);
116
    };
117

            
118
45
    let bit_count = lhs_bits.len();
119

            
120
45
    let mut output = false.into();
121
45
    let mut new_symbols = symbols.clone();
122
45
    let mut new_clauses = vec![];
123
    let mut comparison;
124

            
125
144
    for i in 0..bit_count {
126
144
        comparison = tseytin_xor(
127
144
            lhs_bits[i].clone(),
128
144
            rhs_bits[i].clone(),
129
144
            &mut new_clauses,
130
144
            &mut new_symbols,
131
144
        );
132
144
        output = tseytin_or(
133
144
            &vec![comparison, output],
134
144
            &mut new_clauses,
135
144
            &mut new_symbols,
136
144
        );
137
144
    }
138

            
139
45
    Ok(Reduction::cnf(output, new_clauses, new_symbols))
140
1464
}
141

            
142
// Creates a boolean expression for > or >=
143
// a > b or a >= b
144
// This can also be used for < and <= by reversing the order of the inputs
145
// Returns result, new symbol table, new clauses
146
417
fn inequality_boolean(
147
417
    a: Vec<Expr>,
148
417
    b: Vec<Expr>,
149
417
    strict: bool,
150
417
    clauses: &mut Vec<CnfClause>,
151
417
    symbols: &mut SymbolTable,
152
417
) -> Expr {
153
    let mut notb;
154
    let mut output;
155

            
156
417
    if strict {
157
54
        notb = tseytin_not(b[0].clone(), clauses, symbols);
158
54
        output = tseytin_and(&vec![a[0].clone(), notb], clauses, symbols);
159
363
    } else {
160
363
        output = tseytin_imply(b[0].clone(), a[0].clone(), clauses, symbols);
161
363
    }
162

            
163
    //TODO: There may be room for simplification, and constant optimization
164

            
165
417
    let bit_count = a.len();
166

            
167
    let mut lhs;
168
    let mut rhs;
169
    let mut iff;
170
882
    for n in 1..(bit_count - 1) {
171
882
        notb = tseytin_not(b[n].clone(), clauses, symbols);
172
882
        lhs = tseytin_and(&vec![a[n].clone(), notb.clone()], clauses, symbols);
173
882
        iff = tseytin_iff(a[n].clone(), b[n].clone(), clauses, symbols);
174
882
        rhs = tseytin_and(&vec![iff.clone(), output.clone()], clauses, symbols);
175
882
        output = tseytin_or(&vec![lhs.clone(), rhs.clone()], clauses, symbols);
176
882
    }
177

            
178
    // final bool is the sign bit and should be handled inversely
179
417
    let nota = tseytin_not(a[bit_count - 1].clone(), clauses, symbols);
180
417
    lhs = tseytin_and(&vec![nota, b[bit_count - 1].clone()], clauses, symbols);
181
417
    iff = tseytin_iff(
182
417
        a[bit_count - 1].clone(),
183
417
        b[bit_count - 1].clone(),
184
417
        clauses,
185
417
        symbols,
186
    );
187
417
    rhs = tseytin_and(&vec![iff, output.clone()], clauses, symbols);
188
417
    output = tseytin_or(&vec![lhs, rhs], clauses, symbols);
189

            
190
417
    output
191
417
}
192

            
193
/// Converts sum of SATInts to a single SATInt
194
///
195
/// ```text
196
/// Sum(SATInt(a), SATInt(b), ...) ~> SATInt(c)
197
///
198
/// ```
199
#[register_rule(("SAT_Log", 4100))]
200
1464
fn cnf_int_sum(expr: &Expr, symbols: &SymbolTable) -> ApplicationResult {
201
1464
    let Expr::Sum(_, exprs) = expr else {
202
1434
        return Err(RuleNotApplicable);
203
    };
204

            
205
30
    let Expr::AbstractLiteral(_, Matrix(exprs_list, _)) = exprs.as_ref() else {
206
        return Err(RuleNotApplicable);
207
    };
208

            
209
30
    let ranges: Result<Vec<_>, _> = exprs_list
210
30
        .iter()
211
60
        .map(|e| match e {
212
45
            Expr::SATInt(_, _, _, x) => Ok(x),
213
15
            _ => Err(RuleNotApplicable),
214
60
        })
215
30
        .collect();
216

            
217
30
    let ranges = ranges?;
218

            
219
15
    let min = ranges.iter().map(|(a, _)| *a).sum();
220
15
    let max = ranges.iter().map(|(_, a)| *a).sum();
221

            
222
15
    let output_size = cmp::max(bit_magnitude(min), bit_magnitude(max));
223

            
224
    // Check operands are valid log ints
225
15
    let mut exprs_bits =
226
15
        validate_log_int_operands(exprs_list.clone(), Some(output_size.try_into().unwrap()))?;
227

            
228
15
    let mut new_symbols = symbols.clone();
229
    let mut values;
230
15
    let mut new_clauses = vec![];
231

            
232
33
    while exprs_bits.len() > 1 {
233
18
        let mut next = Vec::with_capacity(exprs_bits.len().div_ceil(2));
234
18
        let mut iter = exprs_bits.into_iter();
235

            
236
39
        while let Some(a) = iter.next() {
237
21
            if let Some(b) = iter.next() {
238
18
                values = tseytin_int_adder(&a, &b, output_size, &mut new_clauses, &mut new_symbols);
239
18
                next.push(values);
240
18
            } else {
241
3
                next.push(a);
242
3
            }
243
        }
244

            
245
18
        exprs_bits = next;
246
    }
247

            
248
15
    let result = exprs_bits.pop().unwrap();
249

            
250
15
    Ok(Reduction::cnf(
251
15
        Expr::SATInt(
252
15
            Metadata::new(),
253
15
            SATIntEncoding::Log,
254
15
            Moo::new(into_matrix_expr!(result)),
255
15
            (min, max),
256
15
        ),
257
15
        new_clauses,
258
15
        new_symbols,
259
15
    ))
260
1464
}
261

            
262
/// Returns result, new symbol table, new clauses
263
/// This function expects bits to match the lengths of x and y
264
33
fn tseytin_int_adder(
265
33
    x: &[Expr],
266
33
    y: &[Expr],
267
33
    bits: usize,
268
33
    clauses: &mut Vec<CnfClause>,
269
33
    symbols: &mut SymbolTable,
270
33
) -> Vec<Expr> {
271
    //TODO: Optimizing for constants
272
33
    let (mut result, mut carry) = tseytin_half_adder(x[0].clone(), y[0].clone(), clauses, symbols);
273

            
274
33
    let mut output = vec![result];
275
255
    for i in 1..bits {
276
255
        (result, carry) =
277
255
            tseytin_full_adder(x[i].clone(), y[i].clone(), carry.clone(), clauses, symbols);
278
255
        output.push(result);
279
255
    }
280

            
281
33
    output
282
33
}
283

            
284
/// This function adds two booleans and a carry boolean using the full-adder logic circuit, it is intended for use in a binary adder.
285
255
fn tseytin_full_adder(
286
255
    a: Expr,
287
255
    b: Expr,
288
255
    carry: Expr,
289
255
    clauses: &mut Vec<CnfClause>,
290
255
    symbols: &mut SymbolTable,
291
255
) -> (Expr, Expr) {
292
255
    let axorb = tseytin_xor(a.clone(), b.clone(), clauses, symbols);
293
255
    let result = tseytin_xor(axorb.clone(), carry.clone(), clauses, symbols);
294
255
    let aandb = tseytin_and(&vec![a, b], clauses, symbols);
295
255
    let carryandaxorb = tseytin_and(&vec![carry, axorb], clauses, symbols);
296
255
    let carryout = tseytin_or(&vec![aandb, carryandaxorb], clauses, symbols);
297

            
298
255
    (result, carryout)
299
255
}
300

            
301
/// This function adds two booleans using the half-adder logic circuit, it is intended for use in a binary adder.
302
33
fn tseytin_half_adder(
303
33
    a: Expr,
304
33
    b: Expr,
305
33
    clauses: &mut Vec<CnfClause>,
306
33
    symbols: &mut SymbolTable,
307
33
) -> (Expr, Expr) {
308
33
    let result = tseytin_xor(a.clone(), b.clone(), clauses, symbols);
309
33
    let carry = tseytin_and(&vec![a, b], clauses, symbols);
310

            
311
33
    (result, carry)
312
33
}
313

            
314
/// this function is for specifically adding a power of two constant to a cnf int.
315
24
fn tseytin_add_two_power(
316
24
    expr: &[Expr],
317
24
    exponent: usize,
318
24
    bits: usize,
319
24
    clauses: &mut Vec<CnfClause>,
320
24
    symbols: &mut SymbolTable,
321
24
) -> Vec<Expr> {
322
24
    let mut result = vec![];
323
24
    let mut product = expr[exponent].clone();
324

            
325
24
    for item in expr.iter().take(exponent) {
326
        result.push(item.clone());
327
    }
328

            
329
24
    result.push(tseytin_not(expr[exponent].clone(), clauses, symbols));
330

            
331
57
    for item in expr.iter().take(bits).skip(exponent + 1) {
332
57
        result.push(tseytin_xor(product.clone(), item.clone(), clauses, symbols));
333
57
        product = tseytin_and(&vec![product, item.clone()], clauses, symbols);
334
57
    }
335

            
336
24
    result
337
24
}
338

            
339
/// This function multiplies two binary values using the shift-add multiplication algorithm.
340
3
fn cnf_shift_add_multiply(
341
3
    x: &[Expr],
342
3
    y: &[Expr],
343
3
    bits: usize,
344
3
    clauses: &mut Vec<CnfClause>,
345
3
    symbols: &mut SymbolTable,
346
3
) -> Vec<Expr> {
347
3
    let mut x = x.to_owned();
348
3
    let mut y = y.to_owned();
349

            
350
    //TODO Optimizing for constants
351
    //TODO Optimize addition for i left shifted values - skip first i bits
352

            
353
    // extend sign bits of operands to 2*`bits`
354
3
    x.extend(std::iter::repeat_n(x[bits - 1].clone(), bits));
355
3
    y.extend(std::iter::repeat_n(y[bits - 1].clone(), bits));
356

            
357
3
    let mut s: Vec<Expr> = vec![];
358
    let mut x_0andy_i;
359

            
360
36
    for bit in &y {
361
36
        x_0andy_i = tseytin_and(&vec![x[0].clone(), bit.clone()], clauses, symbols);
362
36
        s.push(x_0andy_i);
363
36
    }
364

            
365
    let mut sum;
366
    let mut if_true;
367
    let mut not_x_n;
368
    let mut if_false;
369

            
370
15
    for item in x.iter().take(bits).skip(1) {
371
        // y << 1
372
165
        for i in (1..bits * 2).rev() {
373
165
            y[i] = y[i - 1].clone();
374
165
        }
375
15
        y[0] = false.into();
376

            
377
        // TODO switch to multiplexer
378
        // TODO Add negatives support once MUX is added
379
15
        sum = tseytin_int_adder(&s, &y, bits * 2, clauses, symbols);
380
15
        not_x_n = tseytin_not(item.clone(), clauses, symbols);
381

            
382
180
        for i in 0..(bits * 2) {
383
180
            if_true = tseytin_and(&vec![item.clone(), sum[i].clone()], clauses, symbols);
384
180
            if_false = tseytin_and(&vec![not_x_n.clone(), s[i].clone()], clauses, symbols);
385
180
            s[i] = tseytin_or(&vec![if_true.clone(), if_false.clone()], clauses, symbols);
386
180
        }
387
    }
388

            
389
3
    s
390
3
}
391

            
392
/// This function calculates the range of the product of multiple integers.
393
/// E.g.
394
/// a : [2, 5], b : [-1, 2], c : [-10, -6], d : [0, 3]
395
/// a * b * c *d : [-300, 150]
396
3
fn product_of_ranges(ranges: Vec<&(i32, i32)>) -> (i32, i32) {
397
3
    if ranges.is_empty() {
398
        return (1, 1); // product of zero numbers = 1
399
3
    }
400

            
401
3
    let &(mut min_prod, mut max_prod) = ranges[0];
402

            
403
3
    for &(a, b) in &ranges[1..] {
404
3
        let candidates = [min_prod * a, min_prod * b, max_prod * a, max_prod * b];
405
3
        min_prod = *candidates.iter().min().unwrap();
406
3
        max_prod = *candidates.iter().max().unwrap();
407
3
    }
408

            
409
3
    (min_prod, max_prod)
410
3
}
411

            
412
/// Converts product of SATInts to a single SATInt
413
///
414
/// ```text
415
/// Product(SATInt(a), SATInt(b), ...) ~> SATInt(c)
416
///
417
/// ```
418
#[register_rule(("SAT_Log", 9000))]
419
58893
fn cnf_int_product(expr: &Expr, symbols: &SymbolTable) -> ApplicationResult {
420
58893
    let Expr::Product(_, exprs) = expr else {
421
58890
        return Err(RuleNotApplicable);
422
    };
423

            
424
3
    let Expr::AbstractLiteral(_, Matrix(exprs_list, _)) = exprs.as_ref() else {
425
        return Err(RuleNotApplicable);
426
    };
427

            
428
3
    let ranges: Result<Vec<_>, _> = exprs_list
429
3
        .iter()
430
6
        .map(|e| match e {
431
6
            Expr::SATInt(_, _, _, x) => Ok(x),
432
            _ => Err(RuleNotApplicable),
433
6
        })
434
3
        .collect();
435

            
436
3
    let ranges = ranges?; // propagate error if any
437

            
438
3
    let (min, max) = product_of_ranges(ranges.clone());
439

            
440
3
    let exprs_bits = validate_log_int_operands(exprs_list.clone(), None)?;
441

            
442
3
    let mut new_symbols = symbols.clone();
443
3
    let mut new_clauses = vec![];
444

            
445
3
    let (result, _) = exprs_bits
446
3
        .iter()
447
3
        .cloned()
448
3
        .zip(ranges.into_iter().copied())
449
3
        .reduce(|lhs, rhs| {
450
            // Make both bit vectors the same length
451
3
            let (lhs_bits, rhs_bits) = match_bits_length(lhs.0.clone(), rhs.0.clone());
452

            
453
            // Multiply operands
454
3
            let mut values = cnf_shift_add_multiply(
455
3
                &lhs_bits,
456
3
                &rhs_bits,
457
3
                lhs_bits.len(),
458
3
                &mut new_clauses,
459
3
                &mut new_symbols,
460
            );
461

            
462
            // Determine new range of result
463
3
            let (mut cum_min, mut cum_max) = lhs.1;
464
3
            let candidates = [
465
3
                cum_min * rhs.1.0,
466
3
                cum_min * rhs.1.1,
467
3
                cum_max * rhs.1.0,
468
3
                cum_max * rhs.1.1,
469
3
            ];
470
3
            cum_min = *candidates.iter().min().unwrap();
471
3
            cum_max = *candidates.iter().max().unwrap();
472

            
473
3
            let new_bit_count = bit_magnitude(cum_min).max(bit_magnitude(cum_max));
474
3
            values.truncate(new_bit_count);
475

            
476
3
            (values, (cum_min, cum_max))
477
3
        })
478
3
        .unwrap();
479

            
480
3
    Ok(Reduction::cnf(
481
3
        Expr::SATInt(
482
3
            Metadata::new(),
483
3
            SATIntEncoding::Log,
484
3
            Moo::new(into_matrix_expr!(result)),
485
3
            (min, max),
486
3
        ),
487
3
        new_clauses,
488
3
        new_symbols,
489
3
    ))
490
58893
}
491

            
492
/// Converts negation of a SATInt to a SATInt
493
///
494
/// ```text
495
/// -SATInt(a) ~> SATInt(b)
496
///
497
/// ```
498
#[register_rule(("SAT_Log", 4100))]
499
1464
fn cnf_int_neg(expr: &Expr, symbols: &SymbolTable) -> ApplicationResult {
500
1464
    let Expr::Neg(_, expr) = expr else {
501
1446
        return Err(RuleNotApplicable);
502
    };
503

            
504
18
    let Expr::SATInt(_, _, _, (min, max)) = expr.as_ref() else {
505
        return Err(RuleNotApplicable);
506
    };
507

            
508
18
    let binding = validate_log_int_operands(vec![expr.as_ref().clone()], None)?;
509
18
    let [bits] = binding.as_slice() else {
510
        return Err(RuleNotApplicable);
511
    };
512

            
513
18
    let mut new_clauses = vec![];
514
18
    let mut new_symbols = symbols.clone();
515

            
516
18
    let result = tseytin_negate(bits, bits.len(), &mut new_clauses, &mut new_symbols);
517

            
518
18
    Ok(Reduction::cnf(
519
18
        Expr::SATInt(
520
18
            Metadata::new(),
521
18
            SATIntEncoding::Log,
522
18
            Moo::new(into_matrix_expr!(result)),
523
18
            (-max, -min),
524
18
        ),
525
18
        new_clauses,
526
18
        new_symbols,
527
18
    ))
528
1464
}
529

            
530
18
fn tseytin_negate(
531
18
    expr: &Vec<Expr>,
532
18
    bits: usize,
533
18
    clauses: &mut Vec<CnfClause>,
534
18
    symbols: &mut SymbolTable,
535
18
) -> Vec<Expr> {
536
18
    let mut result = vec![];
537
    // invert bits
538
51
    for bit in expr {
539
51
        result.push(tseytin_not(bit.clone(), clauses, symbols));
540
51
    }
541

            
542
    // add one
543
18
    result = tseytin_add_two_power(&result, 0, bits, clauses, symbols);
544

            
545
18
    result
546
18
}
547

            
548
/// Converts min of SATInts to a single SATInt
549
///
550
/// ```text
551
/// Min(SATInt(a), SATInt(b), ...) ~> SATInt(c)
552
///
553
/// ```
554
#[register_rule(("SAT_Log", 4100))]
555
1464
fn cnf_int_min(expr: &Expr, symbols: &SymbolTable) -> ApplicationResult {
556
1464
    let Expr::Min(_, exprs) = expr else {
557
1452
        return Err(RuleNotApplicable);
558
    };
559

            
560
12
    let Expr::AbstractLiteral(_, Matrix(exprs_list, _)) = exprs.as_ref() else {
561
        return Err(RuleNotApplicable);
562
    };
563

            
564
12
    let ranges: Result<Vec<_>, _> = exprs_list
565
12
        .iter()
566
24
        .map(|e| match e {
567
24
            Expr::SATInt(_, _, _, x) => Ok(x),
568
            _ => Err(RuleNotApplicable),
569
24
        })
570
12
        .collect();
571

            
572
12
    let ranges = ranges?; // propagate error if any
573

            
574
    // Is this optimal?
575
12
    let min = ranges.iter().map(|(a, _)| *a).min().unwrap();
576
12
    let max = ranges.iter().map(|(_, b)| *b).min().unwrap();
577

            
578
12
    let mut exprs_bits = validate_log_int_operands(exprs_list.clone(), None)?;
579

            
580
12
    let mut new_symbols = symbols.clone();
581
    let mut values;
582
12
    let mut new_clauses = vec![];
583

            
584
24
    while exprs_bits.len() > 1 {
585
12
        let mut next = Vec::with_capacity(exprs_bits.len().div_ceil(2));
586
12
        let mut iter = exprs_bits.into_iter();
587

            
588
24
        while let Some(a) = iter.next() {
589
12
            if let Some(b) = iter.next() {
590
12
                values = tseytin_binary_min_max(&a, &b, true, &mut new_clauses, &mut new_symbols);
591
12
                next.push(values);
592
12
            } else {
593
                next.push(a);
594
            }
595
        }
596

            
597
12
        exprs_bits = next;
598
    }
599

            
600
12
    let result = exprs_bits.pop().unwrap();
601

            
602
12
    Ok(Reduction::cnf(
603
12
        Expr::SATInt(
604
12
            Metadata::new(),
605
12
            SATIntEncoding::Log,
606
12
            Moo::new(into_matrix_expr!(result)),
607
12
            (min, max),
608
12
        ),
609
12
        new_clauses,
610
12
        new_symbols,
611
12
    ))
612
1464
}
613

            
614
/// General function for getting the min or max of two log integers.
615
24
fn tseytin_binary_min_max(
616
24
    x: &[Expr],
617
24
    y: &[Expr],
618
24
    min: bool,
619
24
    clauses: &mut Vec<CnfClause>,
620
24
    symbols: &mut SymbolTable,
621
24
) -> Vec<Expr> {
622
24
    let mask = if min {
623
        // mask is 1 if x > y
624
12
        inequality_boolean(x.to_owned(), y.to_owned(), true, clauses, symbols)
625
    } else {
626
        // flip the args if getting maximum x < y -> 1
627
12
        inequality_boolean(y.to_owned(), x.to_owned(), true, clauses, symbols)
628
    };
629

            
630
24
    tseytin_select_array(mask, x, y, clauses, symbols)
631
24
}
632

            
633
// Selects between two boolean vectors depending on a condition (both vectors must be the same length)
634
/// cond ? b : a
635
///
636
/// cond = 1 => b
637
/// cond = 0 => a
638
24
fn tseytin_select_array(
639
24
    cond: Expr,
640
24
    a: &[Expr],
641
24
    b: &[Expr],
642
24
    clauses: &mut Vec<CnfClause>,
643
24
    symbols: &mut SymbolTable,
644
24
) -> Vec<Expr> {
645
24
    assert_eq!(
646
24
        a.len(),
647
24
        b.len(),
648
        "Input vectors 'a' and 'b' must have the same length"
649
    );
650

            
651
24
    let mut out = vec![];
652

            
653
24
    let bit_count = a.len();
654

            
655
90
    for i in 0..bit_count {
656
90
        out.push(tseytin_mux(
657
90
            cond.clone(),
658
90
            a[i].clone(),
659
90
            b[i].clone(),
660
90
            clauses,
661
90
            symbols,
662
90
        ));
663
90
    }
664

            
665
24
    out
666
24
}
667

            
668
/// Converts max of SATInts to a single SATInt
669
///
670
/// ```text
671
/// Max(SATInt(a), SATInt(b), ...) ~> SATInt(c)
672
///
673
/// ```
674
#[register_rule(("SAT_Log", 4100))]
675
1464
fn cnf_int_max(expr: &Expr, symbols: &SymbolTable) -> ApplicationResult {
676
1464
    let Expr::Max(_, exprs) = expr else {
677
1452
        return Err(RuleNotApplicable);
678
    };
679

            
680
12
    let Expr::AbstractLiteral(_, Matrix(exprs_list, _)) = exprs.as_ref() else {
681
        return Err(RuleNotApplicable);
682
    };
683

            
684
12
    let ranges: Result<Vec<_>, _> = exprs_list
685
12
        .iter()
686
24
        .map(|e| match e {
687
24
            Expr::SATInt(_, _, _, x) => Ok(x),
688
            _ => Err(RuleNotApplicable),
689
24
        })
690
12
        .collect();
691

            
692
12
    let ranges = ranges?; // propagate error if any
693

            
694
    // Is this optimal?
695
12
    let min = ranges.iter().map(|(a, _)| *a).max().unwrap();
696
12
    let max = ranges.iter().map(|(_, b)| *b).max().unwrap();
697

            
698
12
    let mut exprs_bits = validate_log_int_operands(exprs_list.clone(), None)?;
699

            
700
12
    let mut new_symbols = symbols.clone();
701
    let mut values;
702
12
    let mut new_clauses = vec![];
703

            
704
24
    while exprs_bits.len() > 1 {
705
12
        let mut next = Vec::with_capacity(exprs_bits.len().div_ceil(2));
706
12
        let mut iter = exprs_bits.into_iter();
707

            
708
24
        while let Some(a) = iter.next() {
709
12
            if let Some(b) = iter.next() {
710
12
                values = tseytin_binary_min_max(&a, &b, false, &mut new_clauses, &mut new_symbols);
711
12
                next.push(values);
712
12
            } else {
713
                next.push(a);
714
            }
715
        }
716

            
717
12
        exprs_bits = next;
718
    }
719

            
720
12
    let result = exprs_bits.pop().unwrap();
721

            
722
12
    Ok(Reduction::cnf(
723
12
        Expr::SATInt(
724
12
            Metadata::new(),
725
12
            SATIntEncoding::Log,
726
12
            Moo::new(into_matrix_expr!(result)),
727
12
            (min, max),
728
12
        ),
729
12
        new_clauses,
730
12
        new_symbols,
731
12
    ))
732
1464
}
733

            
734
/// Converts Abs of a SATInt to a SATInt
735
///
736
/// ```text
737
/// |SATInt(a)| ~> SATInt(b)
738
///
739
/// ```
740
#[register_rule(("SAT_Log", 4100))]
741
1464
fn cnf_int_abs(expr: &Expr, symbols: &SymbolTable) -> ApplicationResult {
742
1464
    let Expr::Abs(_, expr) = expr else {
743
1452
        return Err(RuleNotApplicable);
744
    };
745

            
746
12
    let Expr::SATInt(_, _, _, (min, max)) = expr.as_ref() else {
747
6
        return Err(RuleNotApplicable);
748
    };
749

            
750
6
    let range = (
751
6
        cmp::max(0, cmp::max(*min, -*max)),
752
6
        cmp::max(min.abs(), max.abs()),
753
6
    );
754

            
755
6
    let binding = validate_log_int_operands(vec![expr.as_ref().clone()], None)?;
756
6
    let [bits] = binding.as_slice() else {
757
        return Err(RuleNotApplicable);
758
    };
759

            
760
6
    let mut new_clauses = vec![];
761
6
    let mut new_symbols = symbols.clone();
762

            
763
6
    let mut result = vec![];
764

            
765
    // How does this handle negatives edge cases: -(-8) = 8, an extra bit is needed
766

            
767
    // invert bits
768
30
    for bit in bits {
769
30
        result.push(tseytin_not(bit.clone(), &mut new_clauses, &mut new_symbols));
770
30
    }
771

            
772
6
    let bit_count = result.len();
773

            
774
    // add one
775
6
    result = tseytin_add_two_power(&result, 0, bit_count, &mut new_clauses, &mut new_symbols);
776

            
777
30
    for i in 0..bit_count {
778
30
        result[i] = tseytin_mux(
779
30
            bits[bit_count - 1].clone(),
780
30
            bits[i].clone(),
781
30
            result[i].clone(),
782
30
            &mut new_clauses,
783
30
            &mut new_symbols,
784
30
        )
785
    }
786

            
787
6
    Ok(Reduction::cnf(
788
6
        Expr::SATInt(
789
6
            Metadata::new(),
790
6
            SATIntEncoding::Log,
791
6
            Moo::new(into_matrix_expr!(result)),
792
6
            range,
793
6
        ),
794
6
        new_clauses,
795
6
        new_symbols,
796
6
    ))
797
1464
}
798

            
799
/// Converts SafeDiv of SATInts to a single SATInt
800
///
801
/// ```text
802
/// SafeDiv(SATInt(a), SATInt(b)) ~> SATInt(c)
803
///
804
/// ```
805
#[register_rule(("SAT_Log", 4100))]
806
1464
fn cnf_int_safediv(expr: &Expr, symbols: &SymbolTable) -> ApplicationResult {
807
    // Using "Restoring division" algorithm
808
    // https://en.wikipedia.org/wiki/Division_algorithm#Restoring_division
809
1464
    let Expr::SafeDiv(_, numer, denom) = expr else {
810
1464
        return Err(RuleNotApplicable);
811
    };
812

            
813
    let Expr::SATInt(_, _, _, (numer_min, numer_max)) = numer.as_ref() else {
814
        return Err(RuleNotApplicable);
815
    };
816

            
817
    let Expr::SATInt(_, _, _, (denom_min, denom_max)) = denom.as_ref() else {
818
        return Err(RuleNotApplicable);
819
    };
820

            
821
    let candidates = [
822
        numer_min / denom_min,
823
        numer_min / denom_max,
824
        numer_max / denom_min,
825
        numer_max / denom_max,
826
    ];
827

            
828
    let min = *candidates.iter().min().unwrap();
829
    let max = *candidates.iter().max().unwrap();
830

            
831
    let binding =
832
        validate_log_int_operands(vec![numer.as_ref().clone(), denom.as_ref().clone()], None)?;
833
    let [numer_bits, denom_bits] = binding.as_slice() else {
834
        return Err(RuleNotApplicable);
835
    };
836

            
837
    let bit_count = numer_bits.len();
838

            
839
    // TODO: Separate into division/mod function
840

            
841
    let mut new_symbols = symbols.clone();
842
    let mut new_clauses = vec![];
843
    let mut quotient = vec![false.into(); bit_count];
844

            
845
    let minus_numer = tseytin_negate(
846
        &numer_bits.clone(),
847
        bit_count,
848
        &mut new_clauses,
849
        &mut new_symbols,
850
    );
851
    let minus_denom = tseytin_negate(
852
        &denom_bits.clone(),
853
        bit_count,
854
        &mut new_clauses,
855
        &mut new_symbols,
856
    );
857

            
858
    let sign_bit = tseytin_xor(
859
        numer_bits[bit_count - 1].clone(),
860
        denom_bits[bit_count - 1].clone(),
861
        &mut new_clauses,
862
        &mut new_symbols,
863
    );
864

            
865
    let numer_bits = tseytin_select_array(
866
        numer_bits[bit_count - 1].clone(),
867
        &numer_bits.clone(),
868
        &minus_numer,
869
        &mut new_clauses,
870
        &mut new_symbols,
871
    );
872
    let denom_bits = tseytin_select_array(
873
        denom_bits[bit_count - 1].clone(),
874
        &denom_bits.clone(),
875
        &minus_denom,
876
        &mut new_clauses,
877
        &mut new_symbols,
878
    );
879

            
880
    let mut r = numer_bits;
881
    r.extend(std::iter::repeat_n(r[bit_count - 1].clone(), bit_count));
882
    let mut d = std::iter::repeat_n(false.into(), bit_count).collect_vec();
883
    d.extend(denom_bits);
884

            
885
    let minus_d = tseytin_negate(
886
        &d.clone(),
887
        2 * bit_count,
888
        &mut new_clauses,
889
        &mut new_symbols,
890
    );
891
    let mut rminusd;
892

            
893
    for i in (0..bit_count).rev() {
894
        // r << 1
895
        for j in (1..bit_count * 2).rev() {
896
            r[j] = r[j - 1].clone();
897
        }
898
        r[0] = false.into();
899

            
900
        rminusd = tseytin_int_adder(
901
            &r.clone(),
902
            &minus_d.clone(),
903
            2 * bit_count,
904
            &mut new_clauses,
905
            &mut new_symbols,
906
        );
907

            
908
        // TODO: For mod don't calculate on final iter
909
        quotient[i] = tseytin_not(
910
            // q[i] = inverse of sign bit - 1 if positive, 0 if negative
911
            rminusd[2 * bit_count - 1].clone(),
912
            &mut new_clauses,
913
            &mut new_symbols,
914
        );
915

            
916
        // TODO: For div don't calculate on final iter
917
        for j in 0..(2 * bit_count) {
918
            r[j] = tseytin_mux(
919
                quotient[i].clone(),
920
                r[j].clone(),       // use r if negative
921
                rminusd[j].clone(), // use r-d if positive
922
                &mut new_clauses,
923
                &mut new_symbols,
924
            );
925
        }
926
    }
927

            
928
    let minus_quotient = tseytin_negate(
929
        &quotient.clone(),
930
        bit_count,
931
        &mut new_clauses,
932
        &mut new_symbols,
933
    );
934

            
935
    let out = tseytin_select_array(
936
        sign_bit,
937
        &quotient,
938
        &minus_quotient,
939
        &mut new_clauses,
940
        &mut new_symbols,
941
    );
942

            
943
    Ok(Reduction::cnf(
944
        Expr::SATInt(
945
            Metadata::new(),
946
            SATIntEncoding::Log,
947
            Moo::new(into_matrix_expr!(out)),
948
            (min, max),
949
        ),
950
        new_clauses,
951
        new_symbols,
952
    ))
953
1464
}
954

            
955
/*
956
/// Converts SafeMod of SATInts to a single SATInt
957
///
958
/// ```text
959
/// SafeMod(SATInt(a), SATInt(b)) ~> SATInt(c)
960
///
961
/// ```
962
#[register_rule(("SAT_Log", 4100))]
963
fn cnf_int_safemod(expr: &Expr, _: &SymbolTable) -> ApplicationResult {}
964

            
965
/// Converts SafePow of SATInts to a single SATInt
966
///
967
/// ```text
968
/// SafePow(SATInt(a), SATInt(b)) ~> SATInt(c)
969
///
970
/// ```
971
#[register_rule(("SAT", 4100))]
972
fn cnf_int_safepow(expr: &Expr, _: &SymbolTable) -> ApplicationResult {
973
    // use 'Exponentiation by squaring'
974
}
975
*/