1
use conjure_cp::ast::Expression as Expr;
2
use conjure_cp::ast::{SATIntEncoding, SymbolTable};
3
use conjure_cp::rule_engine::{
4
    ApplicationError::RuleNotApplicable, ApplicationResult, Reduction, register_rule,
5
};
6

            
7
use conjure_cp::ast::AbstractLiteral::Matrix;
8
use conjure_cp::ast::Metadata;
9
use conjure_cp::ast::Moo;
10
use conjure_cp::into_matrix_expr;
11

            
12
use itertools::Itertools;
13

            
14
use super::boolean::{
15
    tseytin_and, tseytin_iff, tseytin_imply, tseytin_mux, tseytin_not, tseytin_or, tseytin_xor,
16
};
17
use super::integer_repr::{bit_magnitude, match_bits_length, validate_log_int_operands};
18

            
19
use conjure_cp::ast::CnfClause;
20

            
21
use std::cmp;
22

            
23
/// Converts an inequality expression between two SATInts to a boolean expression in cnf.
24
///
25
/// ```text
26
/// SATInt(a) </>/<=/>= SATInt(b) ~> Bool
27
///
28
/// ```
29
#[register_rule(("SAT_Log", 4100))]
30
9
fn cnf_int_ineq(expr: &Expr, symbols: &SymbolTable) -> ApplicationResult {
31
9
    let (lhs, rhs, strict) = match expr {
32
        Expr::Lt(_, x, y) => (y, x, true),
33
        Expr::Gt(_, x, y) => (x, y, true),
34
        Expr::Leq(_, x, y) => (y, x, false),
35
        Expr::Geq(_, x, y) => (x, y, false),
36
9
        _ => return Err(RuleNotApplicable),
37
    };
38

            
39
    let binding =
40
        validate_log_int_operands(vec![lhs.as_ref().clone(), rhs.as_ref().clone()], None)?;
41
    let [lhs_bits, rhs_bits] = binding.as_slice() else {
42
        return Err(RuleNotApplicable);
43
    };
44

            
45
    let mut new_symbols = symbols.clone();
46
    let mut new_clauses = vec![];
47

            
48
    let output = inequality_boolean(
49
        lhs_bits.clone(),
50
        rhs_bits.clone(),
51
        strict,
52
        &mut new_clauses,
53
        &mut new_symbols,
54
    );
55
    Ok(Reduction::cnf(output, new_clauses, new_symbols))
56
9
}
57

            
58
/// Converts a = expression between two SATInts to a boolean expression in cnf
59
///
60
/// ```text
61
/// SATInt(a) = SATInt(b) ~> Bool
62
///
63
/// ```
64
#[register_rule(("SAT_Log", 9100))]
65
9
fn cnf_int_eq(expr: &Expr, symbols: &SymbolTable) -> ApplicationResult {
66
9
    let Expr::Eq(_, lhs, rhs) = expr else {
67
9
        return Err(RuleNotApplicable);
68
    };
69

            
70
    let binding =
71
        validate_log_int_operands(vec![lhs.as_ref().clone(), rhs.as_ref().clone()], None)?;
72
    let [lhs_bits, rhs_bits] = binding.as_slice() else {
73
        return Err(RuleNotApplicable);
74
    };
75

            
76
    let bit_count = lhs_bits.len();
77

            
78
    let mut output = true.into();
79
    let mut new_symbols = symbols.clone();
80
    let mut new_clauses = vec![];
81
    let mut comparison;
82

            
83
    for i in 0..bit_count {
84
        comparison = tseytin_iff(
85
            lhs_bits[i].clone(),
86
            rhs_bits[i].clone(),
87
            &mut new_clauses,
88
            &mut new_symbols,
89
        );
90
        output = tseytin_and(
91
            &vec![comparison, output],
92
            &mut new_clauses,
93
            &mut new_symbols,
94
        );
95
    }
96

            
97
    Ok(Reduction::cnf(output, new_clauses, new_symbols))
98
9
}
99

            
100
/// Converts a != expression between two SATInts to a boolean expression in cnf
101
///
102
/// ```text
103
/// SATInt(a) != SATInt(b) ~> Bool
104
///
105
/// ```
106
#[register_rule(("SAT_Log", 4100))]
107
9
fn cnf_int_neq(expr: &Expr, symbols: &SymbolTable) -> ApplicationResult {
108
9
    let Expr::Neq(_, lhs, rhs) = expr else {
109
9
        return Err(RuleNotApplicable);
110
    };
111

            
112
    let binding =
113
        validate_log_int_operands(vec![lhs.as_ref().clone(), rhs.as_ref().clone()], None)?;
114
    let [lhs_bits, rhs_bits] = binding.as_slice() else {
115
        return Err(RuleNotApplicable);
116
    };
117

            
118
    let bit_count = lhs_bits.len();
119

            
120
    let mut output = false.into();
121
    let mut new_symbols = symbols.clone();
122
    let mut new_clauses = vec![];
123
    let mut comparison;
124

            
125
    for i in 0..bit_count {
126
        comparison = tseytin_xor(
127
            lhs_bits[i].clone(),
128
            rhs_bits[i].clone(),
129
            &mut new_clauses,
130
            &mut new_symbols,
131
        );
132
        output = tseytin_or(
133
            &vec![comparison, output],
134
            &mut new_clauses,
135
            &mut new_symbols,
136
        );
137
    }
138

            
139
    Ok(Reduction::cnf(output, new_clauses, new_symbols))
140
9
}
141

            
142
// Creates a boolean expression for > or >=
143
// a > b or a >= b
144
// This can also be used for < and <= by reversing the order of the inputs
145
// Returns result, new symbol table, new clauses
146
fn inequality_boolean(
147
    a: Vec<Expr>,
148
    b: Vec<Expr>,
149
    strict: bool,
150
    clauses: &mut Vec<CnfClause>,
151
    symbols: &mut SymbolTable,
152
) -> Expr {
153
    let mut notb;
154
    let mut output;
155

            
156
    if strict {
157
        notb = tseytin_not(b[0].clone(), clauses, symbols);
158
        output = tseytin_and(&vec![a[0].clone(), notb], clauses, symbols);
159
    } else {
160
        output = tseytin_imply(b[0].clone(), a[0].clone(), clauses, symbols);
161
    }
162

            
163
    //TODO: There may be room for simplification, and constant optimization
164

            
165
    let bit_count = a.len();
166

            
167
    let mut lhs;
168
    let mut rhs;
169
    let mut iff;
170
    for n in 1..(bit_count - 1) {
171
        notb = tseytin_not(b[n].clone(), clauses, symbols);
172
        lhs = tseytin_and(&vec![a[n].clone(), notb.clone()], clauses, symbols);
173
        iff = tseytin_iff(a[n].clone(), b[n].clone(), clauses, symbols);
174
        rhs = tseytin_and(&vec![iff.clone(), output.clone()], clauses, symbols);
175
        output = tseytin_or(&vec![lhs.clone(), rhs.clone()], clauses, symbols);
176
    }
177

            
178
    // final bool is the sign bit and should be handled inversely
179
    let nota = tseytin_not(a[bit_count - 1].clone(), clauses, symbols);
180
    lhs = tseytin_and(&vec![nota, b[bit_count - 1].clone()], clauses, symbols);
181
    iff = tseytin_iff(
182
        a[bit_count - 1].clone(),
183
        b[bit_count - 1].clone(),
184
        clauses,
185
        symbols,
186
    );
187
    rhs = tseytin_and(&vec![iff, output.clone()], clauses, symbols);
188
    output = tseytin_or(&vec![lhs, rhs], clauses, symbols);
189

            
190
    output
191
}
192

            
193
/// Converts sum of SATInts to a single SATInt
194
///
195
/// ```text
196
/// Sum(SATInt(a), SATInt(b), ...) ~> SATInt(c)
197
///
198
/// ```
199
#[register_rule(("SAT_Log", 4100))]
200
9
fn cnf_int_sum(expr: &Expr, symbols: &SymbolTable) -> ApplicationResult {
201
9
    let Expr::Sum(_, exprs) = expr else {
202
9
        return Err(RuleNotApplicable);
203
    };
204

            
205
    let Expr::AbstractLiteral(_, Matrix(exprs_list, _)) = exprs.as_ref() else {
206
        return Err(RuleNotApplicable);
207
    };
208

            
209
    let ranges: Result<Vec<_>, _> = exprs_list
210
        .iter()
211
        .map(|e| match e {
212
            Expr::SATInt(_, _, _, x) => Ok(x),
213
            _ => Err(RuleNotApplicable),
214
        })
215
        .collect();
216

            
217
    let ranges = ranges?;
218

            
219
    let min = ranges.iter().map(|(a, _)| *a).sum();
220
    let max = ranges.iter().map(|(_, a)| *a).sum();
221

            
222
    let output_size = cmp::max(bit_magnitude(min), bit_magnitude(max));
223

            
224
    // Check operands are valid log ints
225
    let mut exprs_bits =
226
        validate_log_int_operands(exprs_list.clone(), Some(output_size.try_into().unwrap()))?;
227

            
228
    let mut new_symbols = symbols.clone();
229
    let mut values;
230
    let mut new_clauses = vec![];
231

            
232
    while exprs_bits.len() > 1 {
233
        let mut next = Vec::with_capacity(exprs_bits.len().div_ceil(2));
234
        let mut iter = exprs_bits.into_iter();
235

            
236
        while let Some(a) = iter.next() {
237
            if let Some(b) = iter.next() {
238
                values = tseytin_int_adder(&a, &b, output_size, &mut new_clauses, &mut new_symbols);
239
                next.push(values);
240
            } else {
241
                next.push(a);
242
            }
243
        }
244

            
245
        exprs_bits = next;
246
    }
247

            
248
    let result = exprs_bits.pop().unwrap();
249

            
250
    Ok(Reduction::cnf(
251
        Expr::SATInt(
252
            Metadata::new(),
253
            SATIntEncoding::Log,
254
            Moo::new(into_matrix_expr!(result)),
255
            (min, max),
256
        ),
257
        new_clauses,
258
        new_symbols,
259
    ))
260
9
}
261

            
262
/// Returns result, new symbol table, new clauses
263
/// This function expects bits to match the lengths of x and y
264
fn tseytin_int_adder(
265
    x: &[Expr],
266
    y: &[Expr],
267
    bits: usize,
268
    clauses: &mut Vec<CnfClause>,
269
    symbols: &mut SymbolTable,
270
) -> Vec<Expr> {
271
    //TODO: Optimizing for constants
272
    let (mut result, mut carry) = tseytin_half_adder(x[0].clone(), y[0].clone(), clauses, symbols);
273

            
274
    let mut output = vec![result];
275
    for i in 1..bits {
276
        (result, carry) =
277
            tseytin_full_adder(x[i].clone(), y[i].clone(), carry.clone(), clauses, symbols);
278
        output.push(result);
279
    }
280

            
281
    output
282
}
283

            
284
// Returns: result, carry, new symbol table, new clauses
285
fn tseytin_full_adder(
286
    a: Expr,
287
    b: Expr,
288
    carry: Expr,
289
    clauses: &mut Vec<CnfClause>,
290
    symbols: &mut SymbolTable,
291
) -> (Expr, Expr) {
292
    let axorb = tseytin_xor(a.clone(), b.clone(), clauses, symbols);
293
    let result = tseytin_xor(axorb.clone(), carry.clone(), clauses, symbols);
294
    let aandb = tseytin_and(&vec![a, b], clauses, symbols);
295
    let carryandaxorb = tseytin_and(&vec![carry, axorb], clauses, symbols);
296
    let carryout = tseytin_or(&vec![aandb, carryandaxorb], clauses, symbols);
297

            
298
    (result, carryout)
299
}
300

            
301
fn tseytin_half_adder(
302
    a: Expr,
303
    b: Expr,
304
    clauses: &mut Vec<CnfClause>,
305
    symbols: &mut SymbolTable,
306
) -> (Expr, Expr) {
307
    let result = tseytin_xor(a.clone(), b.clone(), clauses, symbols);
308
    let carry = tseytin_and(&vec![a, b], clauses, symbols);
309

            
310
    (result, carry)
311
}
312

            
313
/// this function is for specifically adding a power of two constant to a cnf int
314
fn tseytin_add_two_power(
315
    expr: &[Expr],
316
    exponent: usize,
317
    bits: usize,
318
    clauses: &mut Vec<CnfClause>,
319
    symbols: &mut SymbolTable,
320
) -> Vec<Expr> {
321
    let mut result = vec![];
322
    let mut product = expr[exponent].clone();
323

            
324
    for item in expr.iter().take(exponent) {
325
        result.push(item.clone());
326
    }
327

            
328
    result.push(tseytin_not(expr[exponent].clone(), clauses, symbols));
329

            
330
    for item in expr.iter().take(bits).skip(exponent + 1) {
331
        result.push(tseytin_xor(product.clone(), item.clone(), clauses, symbols));
332
        product = tseytin_and(&vec![product, item.clone()], clauses, symbols);
333
    }
334

            
335
    result
336
}
337

            
338
// Returns result, new symbol table, new clauses
339
fn cnf_shift_add_multiply(
340
    x: &[Expr],
341
    y: &[Expr],
342
    bits: usize,
343
    clauses: &mut Vec<CnfClause>,
344
    symbols: &mut SymbolTable,
345
) -> Vec<Expr> {
346
    let mut x = x.to_owned();
347
    let mut y = y.to_owned();
348

            
349
    //TODO Optimizing for constants
350
    //TODO Optimize addition for i left shifted values - skip first i bits
351

            
352
    // extend sign bits of operands to 2*`bits`
353
    x.extend(std::iter::repeat_n(x[bits - 1].clone(), bits));
354
    y.extend(std::iter::repeat_n(y[bits - 1].clone(), bits));
355

            
356
    let mut s: Vec<Expr> = vec![];
357
    let mut x_0andy_i;
358

            
359
    for bit in &y {
360
        x_0andy_i = tseytin_and(&vec![x[0].clone(), bit.clone()], clauses, symbols);
361
        s.push(x_0andy_i);
362
    }
363

            
364
    let mut sum;
365
    let mut if_true;
366
    let mut not_x_n;
367
    let mut if_false;
368

            
369
    for item in x.iter().take(bits).skip(1) {
370
        // y << 1
371
        for i in (1..bits * 2).rev() {
372
            y[i] = y[i - 1].clone();
373
        }
374
        y[0] = false.into();
375

            
376
        // TODO switch to multiplexer
377
        // TODO Add negatives support once MUX is added
378
        sum = tseytin_int_adder(&s, &y, bits * 2, clauses, symbols);
379
        not_x_n = tseytin_not(item.clone(), clauses, symbols);
380

            
381
        for i in 0..(bits * 2) {
382
            if_true = tseytin_and(&vec![item.clone(), sum[i].clone()], clauses, symbols);
383
            if_false = tseytin_and(&vec![not_x_n.clone(), s[i].clone()], clauses, symbols);
384
            s[i] = tseytin_or(&vec![if_true.clone(), if_false.clone()], clauses, symbols);
385
        }
386
    }
387

            
388
    s
389
}
390

            
391
fn product_of_ranges(ranges: Vec<&(i32, i32)>) -> (i32, i32) {
392
    if ranges.is_empty() {
393
        return (1, 1); // product of zero numbers = 1
394
    }
395

            
396
    let &(mut min_prod, mut max_prod) = ranges[0];
397

            
398
    for &(a, b) in &ranges[1..] {
399
        let candidates = [min_prod * a, min_prod * b, max_prod * a, max_prod * b];
400
        min_prod = *candidates.iter().min().unwrap();
401
        max_prod = *candidates.iter().max().unwrap();
402
    }
403

            
404
    (min_prod, max_prod)
405
}
406

            
407
/// Converts product of SATInts to a single SATInt
408
///
409
/// ```text
410
/// Product(SATInt(a), SATInt(b), ...) ~> SATInt(c)
411
///
412
/// ```
413
#[register_rule(("SAT_Log", 9000))]
414
9
fn cnf_int_product(expr: &Expr, symbols: &SymbolTable) -> ApplicationResult {
415
9
    let Expr::Product(_, exprs) = expr else {
416
9
        return Err(RuleNotApplicable);
417
    };
418

            
419
    let Expr::AbstractLiteral(_, Matrix(exprs_list, _)) = exprs.as_ref() else {
420
        return Err(RuleNotApplicable);
421
    };
422

            
423
    let ranges: Result<Vec<_>, _> = exprs_list
424
        .iter()
425
        .map(|e| match e {
426
            Expr::SATInt(_, _, _, x) => Ok(x),
427
            _ => Err(RuleNotApplicable),
428
        })
429
        .collect();
430

            
431
    let ranges = ranges?; // propagate error if any
432

            
433
    let (min, max) = product_of_ranges(ranges.clone());
434

            
435
    let exprs_bits = validate_log_int_operands(exprs_list.clone(), None)?;
436

            
437
    let mut new_symbols = symbols.clone();
438
    let mut new_clauses = vec![];
439

            
440
    let (result, _) = exprs_bits
441
        .iter()
442
        .cloned()
443
        .zip(ranges.into_iter().copied())
444
        .reduce(|lhs, rhs| {
445
            // Make both bit vectors the same length
446
            let (lhs_bits, rhs_bits) = match_bits_length(lhs.0.clone(), rhs.0.clone());
447

            
448
            // Multiply operands
449
            let mut values = cnf_shift_add_multiply(
450
                &lhs_bits,
451
                &rhs_bits,
452
                lhs_bits.len(),
453
                &mut new_clauses,
454
                &mut new_symbols,
455
            );
456

            
457
            // Determine new range of result
458
            let (mut cum_min, mut cum_max) = lhs.1;
459
            let candidates = [
460
                cum_min * rhs.1.0,
461
                cum_min * rhs.1.1,
462
                cum_max * rhs.1.0,
463
                cum_max * rhs.1.1,
464
            ];
465
            cum_min = *candidates.iter().min().unwrap();
466
            cum_max = *candidates.iter().max().unwrap();
467

            
468
            let new_bit_count = bit_magnitude(cum_min).max(bit_magnitude(cum_max));
469
            values.truncate(new_bit_count);
470

            
471
            (values, (cum_min, cum_max))
472
        })
473
        .unwrap();
474

            
475
    Ok(Reduction::cnf(
476
        Expr::SATInt(
477
            Metadata::new(),
478
            SATIntEncoding::Log,
479
            Moo::new(into_matrix_expr!(result)),
480
            (min, max),
481
        ),
482
        new_clauses,
483
        new_symbols,
484
    ))
485
9
}
486

            
487
/// Converts negation of a SATInt to a SATInt
488
///
489
/// ```text
490
/// -SATInt(a) ~> SATInt(b)
491
///
492
/// ```
493
#[register_rule(("SAT_Log", 4100))]
494
9
fn cnf_int_neg(expr: &Expr, symbols: &SymbolTable) -> ApplicationResult {
495
9
    let Expr::Neg(_, expr) = expr else {
496
9
        return Err(RuleNotApplicable);
497
    };
498

            
499
    let Expr::SATInt(_, _, _, (min, max)) = expr.as_ref() else {
500
        return Err(RuleNotApplicable);
501
    };
502

            
503
    let binding = validate_log_int_operands(vec![expr.as_ref().clone()], None)?;
504
    let [bits] = binding.as_slice() else {
505
        return Err(RuleNotApplicable);
506
    };
507

            
508
    let mut new_clauses = vec![];
509
    let mut new_symbols = symbols.clone();
510

            
511
    let result = tseytin_negate(bits, bits.len(), &mut new_clauses, &mut new_symbols);
512

            
513
    Ok(Reduction::cnf(
514
        Expr::SATInt(
515
            Metadata::new(),
516
            SATIntEncoding::Log,
517
            Moo::new(into_matrix_expr!(result)),
518
            (-max, -min),
519
        ),
520
        new_clauses,
521
        new_symbols,
522
    ))
523
9
}
524

            
525
fn tseytin_negate(
526
    expr: &Vec<Expr>,
527
    bits: usize,
528
    clauses: &mut Vec<CnfClause>,
529
    symbols: &mut SymbolTable,
530
) -> Vec<Expr> {
531
    let mut result = vec![];
532
    // invert bits
533
    for bit in expr {
534
        result.push(tseytin_not(bit.clone(), clauses, symbols));
535
    }
536

            
537
    // add one
538
    result = tseytin_add_two_power(&result, 0, bits, clauses, symbols);
539

            
540
    result
541
}
542

            
543
/// Converts min of SATInts to a single SATInt
544
///
545
/// ```text
546
/// Min(SATInt(a), SATInt(b), ...) ~> SATInt(c)
547
///
548
/// ```
549
#[register_rule(("SAT_Log", 4100))]
550
9
fn cnf_int_min(expr: &Expr, symbols: &SymbolTable) -> ApplicationResult {
551
9
    let Expr::Min(_, exprs) = expr else {
552
9
        return Err(RuleNotApplicable);
553
    };
554

            
555
    let Expr::AbstractLiteral(_, Matrix(exprs_list, _)) = exprs.as_ref() else {
556
        return Err(RuleNotApplicable);
557
    };
558

            
559
    let ranges: Result<Vec<_>, _> = exprs_list
560
        .iter()
561
        .map(|e| match e {
562
            Expr::SATInt(_, _, _, x) => Ok(x),
563
            _ => Err(RuleNotApplicable),
564
        })
565
        .collect();
566

            
567
    let ranges = ranges?; // propagate error if any
568

            
569
    // Is this optimal?
570
    let min = ranges.iter().map(|(a, _)| *a).min().unwrap();
571
    let max = ranges.iter().map(|(_, b)| *b).min().unwrap();
572

            
573
    let mut exprs_bits = validate_log_int_operands(exprs_list.clone(), None)?;
574

            
575
    let mut new_symbols = symbols.clone();
576
    let mut values;
577
    let mut new_clauses = vec![];
578

            
579
    while exprs_bits.len() > 1 {
580
        let mut next = Vec::with_capacity(exprs_bits.len().div_ceil(2));
581
        let mut iter = exprs_bits.into_iter();
582

            
583
        while let Some(a) = iter.next() {
584
            if let Some(b) = iter.next() {
585
                values = tseytin_binary_min_max(&a, &b, true, &mut new_clauses, &mut new_symbols);
586
                next.push(values);
587
            } else {
588
                next.push(a);
589
            }
590
        }
591

            
592
        exprs_bits = next;
593
    }
594

            
595
    let result = exprs_bits.pop().unwrap();
596

            
597
    Ok(Reduction::cnf(
598
        Expr::SATInt(
599
            Metadata::new(),
600
            SATIntEncoding::Log,
601
            Moo::new(into_matrix_expr!(result)),
602
            (min, max),
603
        ),
604
        new_clauses,
605
        new_symbols,
606
    ))
607
9
}
608

            
609
fn tseytin_binary_min_max(
610
    x: &[Expr],
611
    y: &[Expr],
612
    min: bool,
613
    clauses: &mut Vec<CnfClause>,
614
    symbols: &mut SymbolTable,
615
) -> Vec<Expr> {
616
    let mut out = vec![];
617

            
618
    let bit_count = x.len();
619

            
620
    for i in 0..bit_count {
621
        out.push(tseytin_xor(x[i].clone(), y[i].clone(), clauses, symbols))
622
    }
623

            
624
    // TODO: compare generated expression to using MUX
625

            
626
    let mask = if min {
627
        // mask is 1 if x > y
628
        inequality_boolean(x.to_owned(), y.to_owned(), true, clauses, symbols)
629
    } else {
630
        // flip the args if getting maximum x < y -> 1
631
        inequality_boolean(y.to_owned(), x.to_owned(), true, clauses, symbols)
632
    };
633

            
634
    for item in out.iter_mut().take(bit_count) {
635
        *item = tseytin_and(&vec![item.clone(), mask.clone()], clauses, symbols);
636
    }
637

            
638
    for i in 0..bit_count {
639
        out[i] = tseytin_xor(x[i].clone(), out[i].clone(), clauses, symbols);
640
    }
641

            
642
    out
643
}
644

            
645
/// Converts max of SATInts to a single SATInt
646
///
647
/// ```text
648
/// Max(SATInt(a), SATInt(b), ...) ~> SATInt(c)
649
///
650
/// ```
651
#[register_rule(("SAT_Log", 4100))]
652
9
fn cnf_int_max(expr: &Expr, symbols: &SymbolTable) -> ApplicationResult {
653
9
    let Expr::Max(_, exprs) = expr else {
654
9
        return Err(RuleNotApplicable);
655
    };
656

            
657
    let Expr::AbstractLiteral(_, Matrix(exprs_list, _)) = exprs.as_ref() else {
658
        return Err(RuleNotApplicable);
659
    };
660

            
661
    let ranges: Result<Vec<_>, _> = exprs_list
662
        .iter()
663
        .map(|e| match e {
664
            Expr::SATInt(_, _, _, x) => Ok(x),
665
            _ => Err(RuleNotApplicable),
666
        })
667
        .collect();
668

            
669
    let ranges = ranges?; // propagate error if any
670

            
671
    // Is this optimal?
672
    let min = ranges.iter().map(|(a, _)| *a).max().unwrap();
673
    let max = ranges.iter().map(|(_, b)| *b).max().unwrap();
674

            
675
    let mut exprs_bits = validate_log_int_operands(exprs_list.clone(), None)?;
676

            
677
    let mut new_symbols = symbols.clone();
678
    let mut values;
679
    let mut new_clauses = vec![];
680

            
681
    while exprs_bits.len() > 1 {
682
        let mut next = Vec::with_capacity(exprs_bits.len().div_ceil(2));
683
        let mut iter = exprs_bits.into_iter();
684

            
685
        while let Some(a) = iter.next() {
686
            if let Some(b) = iter.next() {
687
                values = tseytin_binary_min_max(&a, &b, false, &mut new_clauses, &mut new_symbols);
688
                next.push(values);
689
            } else {
690
                next.push(a);
691
            }
692
        }
693

            
694
        exprs_bits = next;
695
    }
696

            
697
    let result = exprs_bits.pop().unwrap();
698

            
699
    Ok(Reduction::cnf(
700
        Expr::SATInt(
701
            Metadata::new(),
702
            SATIntEncoding::Log,
703
            Moo::new(into_matrix_expr!(result)),
704
            (min, max),
705
        ),
706
        new_clauses,
707
        new_symbols,
708
    ))
709
9
}
710

            
711
/// Converts Abs of a SATInt to a SATInt
712
///
713
/// ```text
714
/// |SATInt(a)| ~> SATInt(b)
715
///
716
/// ```
717
#[register_rule(("SAT_Log", 4100))]
718
9
fn cnf_int_abs(expr: &Expr, symbols: &SymbolTable) -> ApplicationResult {
719
9
    let Expr::Abs(_, expr) = expr else {
720
9
        return Err(RuleNotApplicable);
721
    };
722

            
723
    let Expr::SATInt(_, _, _, (min, max)) = expr.as_ref() else {
724
        return Err(RuleNotApplicable);
725
    };
726

            
727
    let range = (
728
        cmp::max(0, cmp::max(*min, -*max)),
729
        cmp::max(min.abs(), max.abs()),
730
    );
731

            
732
    let binding = validate_log_int_operands(vec![expr.as_ref().clone()], None)?;
733
    let [bits] = binding.as_slice() else {
734
        return Err(RuleNotApplicable);
735
    };
736

            
737
    let mut new_clauses = vec![];
738
    let mut new_symbols = symbols.clone();
739

            
740
    let mut result = vec![];
741

            
742
    // How does this handle negatives edge cases: -(-8) = 8, an extra bit is needed
743

            
744
    // invert bits
745
    for bit in bits {
746
        result.push(tseytin_not(bit.clone(), &mut new_clauses, &mut new_symbols));
747
    }
748

            
749
    let bit_count = result.len();
750

            
751
    // add one
752
    result = tseytin_add_two_power(&result, 0, bit_count, &mut new_clauses, &mut new_symbols);
753

            
754
    for i in 0..bit_count {
755
        result[i] = tseytin_mux(
756
            bits[bit_count - 1].clone(),
757
            bits[i].clone(),
758
            result[i].clone(),
759
            &mut new_clauses,
760
            &mut new_symbols,
761
        )
762
    }
763

            
764
    Ok(Reduction::cnf(
765
        Expr::SATInt(
766
            Metadata::new(),
767
            SATIntEncoding::Log,
768
            Moo::new(into_matrix_expr!(result)),
769
            range,
770
        ),
771
        new_clauses,
772
        new_symbols,
773
    ))
774
9
}
775

            
776
/// Converts SafeDiv of SATInts to a single SATInt
777
///
778
/// ```text
779
/// SafeDiv(SATInt(a), SATInt(b)) ~> SATInt(c)
780
///
781
/// ```
782
#[register_rule(("SAT_Log", 4100))]
783
9
fn cnf_int_safediv(expr: &Expr, symbols: &SymbolTable) -> ApplicationResult {
784
    // Using "Restoring division" algorithm
785
    // https://en.wikipedia.org/wiki/Division_algorithm#Restoring_division
786
9
    let Expr::SafeDiv(_, numer, denom) = expr else {
787
9
        return Err(RuleNotApplicable);
788
    };
789

            
790
    let Expr::SATInt(_, _, _, (numer_min, numer_max)) = numer.as_ref() else {
791
        return Err(RuleNotApplicable);
792
    };
793

            
794
    let Expr::SATInt(_, _, _, (denom_min, denom_max)) = denom.as_ref() else {
795
        return Err(RuleNotApplicable);
796
    };
797

            
798
    let candidates = [
799
        numer_min / denom_min,
800
        numer_min / denom_max,
801
        numer_max / denom_min,
802
        numer_max / denom_max,
803
    ];
804

            
805
    let min = *candidates.iter().min().unwrap();
806
    let max = *candidates.iter().max().unwrap();
807

            
808
    let binding =
809
        validate_log_int_operands(vec![numer.as_ref().clone(), denom.as_ref().clone()], None)?;
810
    let [numer_bits, denom_bits] = binding.as_slice() else {
811
        return Err(RuleNotApplicable);
812
    };
813

            
814
    let bit_count = numer_bits.len();
815

            
816
    // TODO: Separate into division/mod function
817
    // TODO: Support negatives
818

            
819
    let mut new_symbols = symbols.clone();
820
    let mut new_clauses = vec![];
821
    let mut quotient = vec![false.into(); bit_count];
822

            
823
    let mut r = numer_bits.clone();
824
    r.extend(std::iter::repeat_n(r[bit_count - 1].clone(), bit_count));
825
    let mut d = std::iter::repeat_n(false.into(), bit_count).collect_vec();
826
    d.extend(denom_bits.clone());
827

            
828
    let minus_d = tseytin_negate(
829
        &d.clone(),
830
        2 * bit_count,
831
        &mut new_clauses,
832
        &mut new_symbols,
833
    );
834
    let mut rminusd;
835

            
836
    for i in (0..bit_count).rev() {
837
        // r << 1
838
        for j in (1..bit_count * 2).rev() {
839
            r[j] = r[j - 1].clone();
840
        }
841
        r[0] = false.into();
842

            
843
        rminusd = tseytin_int_adder(
844
            &r.clone(),
845
            &minus_d.clone(),
846
            2 * bit_count,
847
            &mut new_clauses,
848
            &mut new_symbols,
849
        );
850

            
851
        // TODO: For mod don't calculate on final iter
852
        quotient[i] = tseytin_not(
853
            // q[i] = inverse of sign bit - 1 if positive, 0 if negative
854
            rminusd[2 * bit_count - 1].clone(),
855
            &mut new_clauses,
856
            &mut new_symbols,
857
        );
858

            
859
        // TODO: For div don't calculate on final iter
860
        for j in 0..(2 * bit_count) {
861
            r[j] = tseytin_mux(
862
                quotient[i].clone(),
863
                r[j].clone(),       // use r if negative
864
                rminusd[j].clone(), // use r-d if positive
865
                &mut new_clauses,
866
                &mut new_symbols,
867
            );
868
        }
869
    }
870

            
871
    Ok(Reduction::cnf(
872
        Expr::SATInt(
873
            Metadata::new(),
874
            SATIntEncoding::Log,
875
            Moo::new(into_matrix_expr!(quotient)),
876
            (min, max),
877
        ),
878
        new_clauses,
879
        new_symbols,
880
    ))
881
9
}
882

            
883
/*
884
/// Converts SafeMod of SATInts to a single SATInt
885
///
886
/// ```text
887
/// SafeMod(SATInt(a), SATInt(b)) ~> SATInt(c)
888
///
889
/// ```
890
#[register_rule(("SAT_Log", 4100))]
891
fn cnf_int_safemod(expr: &Expr, _: &SymbolTable) -> ApplicationResult {}
892

            
893
/// Converts SafePow of SATInts to a single SATInt
894
///
895
/// ```text
896
/// SafePow(SATInt(a), SATInt(b)) ~> SATInt(c)
897
///
898
/// ```
899
#[register_rule(("SAT", 4100))]
900
fn cnf_int_safepow(expr: &Expr, _: &SymbolTable) -> ApplicationResult {
901
    // use 'Exponentiation by squaring'
902
}
903
*/