1
use conjure_cp::ast::Expression as Expr;
2
use conjure_cp::ast::{SATIntEncoding, SymbolTable};
3
use conjure_cp::rule_engine::{
4
    ApplicationError::RuleNotApplicable, ApplicationResult, Reduction, register_rule,
5
};
6

            
7
use conjure_cp::ast::AbstractLiteral::Matrix;
8
use conjure_cp::ast::Metadata;
9
use conjure_cp::ast::Moo;
10
use conjure_cp::into_matrix_expr;
11

            
12
use itertools::Itertools;
13

            
14
use super::boolean::{
15
    tseytin_and, tseytin_iff, tseytin_imply, tseytin_mux, tseytin_not, tseytin_or, tseytin_xor,
16
};
17
use super::integer_repr::{bit_magnitude, match_bits_length, validate_log_int_operands};
18

            
19
use conjure_cp::ast::CnfClause;
20

            
21
use std::cmp;
22

            
23
/// Converts an inequality expression between two SATInts to a boolean expression in cnf.
24
///
25
/// ```text
26
/// SATInt(a) </>/<=/>= SATInt(b) ~> Bool
27
///
28
/// ```
29
#[register_rule(("SAT_Log", 4100))]
30
84
fn cnf_int_ineq(expr: &Expr, symbols: &SymbolTable) -> ApplicationResult {
31
84
    let (lhs, rhs, strict) = match expr {
32
3
        Expr::Lt(_, x, y) => (y, x, true),
33
3
        Expr::Gt(_, x, y) => (x, y, true),
34
12
        Expr::Leq(_, x, y) => (y, x, false),
35
12
        Expr::Geq(_, x, y) => (x, y, false),
36
54
        _ => return Err(RuleNotApplicable),
37
    };
38

            
39
30
    let binding =
40
30
        validate_log_int_operands(vec![lhs.as_ref().clone(), rhs.as_ref().clone()], None)?;
41
30
    let [lhs_bits, rhs_bits] = binding.as_slice() else {
42
        return Err(RuleNotApplicable);
43
    };
44

            
45
30
    let mut new_symbols = symbols.clone();
46
30
    let mut new_clauses = vec![];
47

            
48
30
    let output = inequality_boolean(
49
30
        lhs_bits.clone(),
50
30
        rhs_bits.clone(),
51
30
        strict,
52
30
        &mut new_clauses,
53
30
        &mut new_symbols,
54
    );
55
30
    Ok(Reduction::cnf(output, new_clauses, new_symbols))
56
84
}
57

            
58
/// Converts a = expression between two SATInts to a boolean expression in cnf
59
///
60
/// ```text
61
/// SATInt(a) = SATInt(b) ~> Bool
62
///
63
/// ```
64
#[register_rule(("SAT_Log", 9100))]
65
5244
fn cnf_int_eq(expr: &Expr, symbols: &SymbolTable) -> ApplicationResult {
66
5244
    let Expr::Eq(_, lhs, rhs) = expr else {
67
5244
        return Err(RuleNotApplicable);
68
    };
69

            
70
    let binding =
71
        validate_log_int_operands(vec![lhs.as_ref().clone(), rhs.as_ref().clone()], None)?;
72
    let [lhs_bits, rhs_bits] = binding.as_slice() else {
73
        return Err(RuleNotApplicable);
74
    };
75

            
76
    let bit_count = lhs_bits.len();
77

            
78
    let mut output = true.into();
79
    let mut new_symbols = symbols.clone();
80
    let mut new_clauses = vec![];
81
    let mut comparison;
82

            
83
    for i in 0..bit_count {
84
        comparison = tseytin_iff(
85
            lhs_bits[i].clone(),
86
            rhs_bits[i].clone(),
87
            &mut new_clauses,
88
            &mut new_symbols,
89
        );
90
        output = tseytin_and(
91
            &vec![comparison, output],
92
            &mut new_clauses,
93
            &mut new_symbols,
94
        );
95
    }
96

            
97
    Ok(Reduction::cnf(output, new_clauses, new_symbols))
98
5244
}
99

            
100
/// Converts a != expression between two SATInts to a boolean expression in cnf
101
///
102
/// ```text
103
/// SATInt(a) != SATInt(b) ~> Bool
104
///
105
/// ```
106
#[register_rule(("SAT_Log", 4100))]
107
84
fn cnf_int_neq(expr: &Expr, symbols: &SymbolTable) -> ApplicationResult {
108
84
    let Expr::Neq(_, lhs, rhs) = expr else {
109
84
        return Err(RuleNotApplicable);
110
    };
111

            
112
    let binding =
113
        validate_log_int_operands(vec![lhs.as_ref().clone(), rhs.as_ref().clone()], None)?;
114
    let [lhs_bits, rhs_bits] = binding.as_slice() else {
115
        return Err(RuleNotApplicable);
116
    };
117

            
118
    let bit_count = lhs_bits.len();
119

            
120
    let mut output = false.into();
121
    let mut new_symbols = symbols.clone();
122
    let mut new_clauses = vec![];
123
    let mut comparison;
124

            
125
    for i in 0..bit_count {
126
        comparison = tseytin_xor(
127
            lhs_bits[i].clone(),
128
            rhs_bits[i].clone(),
129
            &mut new_clauses,
130
            &mut new_symbols,
131
        );
132
        output = tseytin_or(
133
            &vec![comparison, output],
134
            &mut new_clauses,
135
            &mut new_symbols,
136
        );
137
    }
138

            
139
    Ok(Reduction::cnf(output, new_clauses, new_symbols))
140
84
}
141

            
142
// Creates a boolean expression for > or >=
143
// a > b or a >= b
144
// This can also be used for < and <= by reversing the order of the inputs
145
// Returns result, new symbol table, new clauses
146
30
fn inequality_boolean(
147
30
    a: Vec<Expr>,
148
30
    b: Vec<Expr>,
149
30
    strict: bool,
150
30
    clauses: &mut Vec<CnfClause>,
151
30
    symbols: &mut SymbolTable,
152
30
) -> Expr {
153
    let mut notb;
154
    let mut output;
155

            
156
30
    if strict {
157
6
        notb = tseytin_not(b[0].clone(), clauses, symbols);
158
6
        output = tseytin_and(&vec![a[0].clone(), notb], clauses, symbols);
159
24
    } else {
160
24
        output = tseytin_imply(b[0].clone(), a[0].clone(), clauses, symbols);
161
24
    }
162

            
163
    //TODO: There may be room for simplification, and constant optimization
164

            
165
30
    let bit_count = a.len();
166

            
167
    let mut lhs;
168
    let mut rhs;
169
    let mut iff;
170
60
    for n in 1..(bit_count - 1) {
171
60
        notb = tseytin_not(b[n].clone(), clauses, symbols);
172
60
        lhs = tseytin_and(&vec![a[n].clone(), notb.clone()], clauses, symbols);
173
60
        iff = tseytin_iff(a[n].clone(), b[n].clone(), clauses, symbols);
174
60
        rhs = tseytin_and(&vec![iff.clone(), output.clone()], clauses, symbols);
175
60
        output = tseytin_or(&vec![lhs.clone(), rhs.clone()], clauses, symbols);
176
60
    }
177

            
178
    // final bool is the sign bit and should be handled inversely
179
30
    let nota = tseytin_not(a[bit_count - 1].clone(), clauses, symbols);
180
30
    lhs = tseytin_and(&vec![nota, b[bit_count - 1].clone()], clauses, symbols);
181
30
    iff = tseytin_iff(
182
30
        a[bit_count - 1].clone(),
183
30
        b[bit_count - 1].clone(),
184
30
        clauses,
185
30
        symbols,
186
    );
187
30
    rhs = tseytin_and(&vec![iff, output.clone()], clauses, symbols);
188
30
    output = tseytin_or(&vec![lhs, rhs], clauses, symbols);
189

            
190
30
    output
191
30
}
192

            
193
/// Converts sum of SATInts to a single SATInt
194
///
195
/// ```text
196
/// Sum(SATInt(a), SATInt(b), ...) ~> SATInt(c)
197
///
198
/// ```
199
#[register_rule(("SAT_Log", 4100))]
200
84
fn cnf_int_sum(expr: &Expr, symbols: &SymbolTable) -> ApplicationResult {
201
84
    let Expr::Sum(_, exprs) = expr else {
202
84
        return Err(RuleNotApplicable);
203
    };
204

            
205
    let Expr::AbstractLiteral(_, Matrix(exprs_list, _)) = exprs.as_ref() else {
206
        return Err(RuleNotApplicable);
207
    };
208

            
209
    let ranges: Result<Vec<_>, _> = exprs_list
210
        .iter()
211
        .map(|e| match e {
212
            Expr::SATInt(_, _, _, x) => Ok(x),
213
            _ => Err(RuleNotApplicable),
214
        })
215
        .collect();
216

            
217
    let ranges = ranges?;
218

            
219
    let min = ranges.iter().map(|(a, _)| *a).sum();
220
    let max = ranges.iter().map(|(_, a)| *a).sum();
221

            
222
    let output_size = cmp::max(bit_magnitude(min), bit_magnitude(max));
223

            
224
    // Check operands are valid log ints
225
    let mut exprs_bits =
226
        validate_log_int_operands(exprs_list.clone(), Some(output_size.try_into().unwrap()))?;
227

            
228
    let mut new_symbols = symbols.clone();
229
    let mut values;
230
    let mut new_clauses = vec![];
231

            
232
    while exprs_bits.len() > 1 {
233
        let mut next = Vec::with_capacity(exprs_bits.len().div_ceil(2));
234
        let mut iter = exprs_bits.into_iter();
235

            
236
        while let Some(a) = iter.next() {
237
            if let Some(b) = iter.next() {
238
                values = tseytin_int_adder(&a, &b, output_size, &mut new_clauses, &mut new_symbols);
239
                next.push(values);
240
            } else {
241
                next.push(a);
242
            }
243
        }
244

            
245
        exprs_bits = next;
246
    }
247

            
248
    let result = exprs_bits.pop().unwrap();
249

            
250
    Ok(Reduction::cnf(
251
        Expr::SATInt(
252
            Metadata::new(),
253
            SATIntEncoding::Log,
254
            Moo::new(into_matrix_expr!(result)),
255
            (min, max),
256
        ),
257
        new_clauses,
258
        new_symbols,
259
    ))
260
84
}
261

            
262
/// Returns result, new symbol table, new clauses
263
/// This function expects bits to match the lengths of x and y
264
fn tseytin_int_adder(
265
    x: &[Expr],
266
    y: &[Expr],
267
    bits: usize,
268
    clauses: &mut Vec<CnfClause>,
269
    symbols: &mut SymbolTable,
270
) -> Vec<Expr> {
271
    //TODO: Optimizing for constants
272
    let (mut result, mut carry) = tseytin_half_adder(x[0].clone(), y[0].clone(), clauses, symbols);
273

            
274
    let mut output = vec![result];
275
    for i in 1..bits {
276
        (result, carry) =
277
            tseytin_full_adder(x[i].clone(), y[i].clone(), carry.clone(), clauses, symbols);
278
        output.push(result);
279
    }
280

            
281
    output
282
}
283

            
284
/// This function adds two booleans and a carry boolean using the full-adder logic circuit, it is intended for use in a binary adder.
285
fn tseytin_full_adder(
286
    a: Expr,
287
    b: Expr,
288
    carry: Expr,
289
    clauses: &mut Vec<CnfClause>,
290
    symbols: &mut SymbolTable,
291
) -> (Expr, Expr) {
292
    let axorb = tseytin_xor(a.clone(), b.clone(), clauses, symbols);
293
    let result = tseytin_xor(axorb.clone(), carry.clone(), clauses, symbols);
294
    let aandb = tseytin_and(&vec![a, b], clauses, symbols);
295
    let carryandaxorb = tseytin_and(&vec![carry, axorb], clauses, symbols);
296
    let carryout = tseytin_or(&vec![aandb, carryandaxorb], clauses, symbols);
297

            
298
    (result, carryout)
299
}
300

            
301
/// This function adds two booleans using the half-adder logic circuit, it is intended for use in a binary adder.
302
fn tseytin_half_adder(
303
    a: Expr,
304
    b: Expr,
305
    clauses: &mut Vec<CnfClause>,
306
    symbols: &mut SymbolTable,
307
) -> (Expr, Expr) {
308
    let result = tseytin_xor(a.clone(), b.clone(), clauses, symbols);
309
    let carry = tseytin_and(&vec![a, b], clauses, symbols);
310

            
311
    (result, carry)
312
}
313

            
314
/// this function is for specifically adding a power of two constant to a cnf int.
315
fn tseytin_add_two_power(
316
    expr: &[Expr],
317
    exponent: usize,
318
    bits: usize,
319
    clauses: &mut Vec<CnfClause>,
320
    symbols: &mut SymbolTable,
321
) -> Vec<Expr> {
322
    let mut result = vec![];
323
    let mut product = expr[exponent].clone();
324

            
325
    for item in expr.iter().take(exponent) {
326
        result.push(item.clone());
327
    }
328

            
329
    result.push(tseytin_not(expr[exponent].clone(), clauses, symbols));
330

            
331
    for item in expr.iter().take(bits).skip(exponent + 1) {
332
        result.push(tseytin_xor(product.clone(), item.clone(), clauses, symbols));
333
        product = tseytin_and(&vec![product, item.clone()], clauses, symbols);
334
    }
335

            
336
    result
337
}
338

            
339
/// This function multiplies two binary values using the shift-add multiplication algorithm.
340
fn cnf_shift_add_multiply(
341
    x: &[Expr],
342
    y: &[Expr],
343
    bits: usize,
344
    clauses: &mut Vec<CnfClause>,
345
    symbols: &mut SymbolTable,
346
) -> Vec<Expr> {
347
    let mut x = x.to_owned();
348
    let mut y = y.to_owned();
349

            
350
    //TODO Optimizing for constants
351
    //TODO Optimize addition for i left shifted values - skip first i bits
352

            
353
    // extend sign bits of operands to 2*`bits`
354
    x.extend(std::iter::repeat_n(x[bits - 1].clone(), bits));
355
    y.extend(std::iter::repeat_n(y[bits - 1].clone(), bits));
356

            
357
    let mut s: Vec<Expr> = vec![];
358
    let mut x_0andy_i;
359

            
360
    for bit in &y {
361
        x_0andy_i = tseytin_and(&vec![x[0].clone(), bit.clone()], clauses, symbols);
362
        s.push(x_0andy_i);
363
    }
364

            
365
    let mut sum;
366
    let mut if_true;
367
    let mut not_x_n;
368
    let mut if_false;
369

            
370
    for item in x.iter().take(bits).skip(1) {
371
        // y << 1
372
        for i in (1..bits * 2).rev() {
373
            y[i] = y[i - 1].clone();
374
        }
375
        y[0] = false.into();
376

            
377
        // TODO switch to multiplexer
378
        // TODO Add negatives support once MUX is added
379
        sum = tseytin_int_adder(&s, &y, bits * 2, clauses, symbols);
380
        not_x_n = tseytin_not(item.clone(), clauses, symbols);
381

            
382
        for i in 0..(bits * 2) {
383
            if_true = tseytin_and(&vec![item.clone(), sum[i].clone()], clauses, symbols);
384
            if_false = tseytin_and(&vec![not_x_n.clone(), s[i].clone()], clauses, symbols);
385
            s[i] = tseytin_or(&vec![if_true.clone(), if_false.clone()], clauses, symbols);
386
        }
387
    }
388

            
389
    s
390
}
391

            
392
/// This function calculates the range of the product of multiple integers.
393
/// E.g.
394
/// a : [2, 5], b : [-1, 2], c : [-10, -6], d : [0, 3]
395
/// a * b * c *d : [-300, 150]
396
fn product_of_ranges(ranges: Vec<&(i32, i32)>) -> (i32, i32) {
397
    if ranges.is_empty() {
398
        return (1, 1); // product of zero numbers = 1
399
    }
400

            
401
    let &(mut min_prod, mut max_prod) = ranges[0];
402

            
403
    for &(a, b) in &ranges[1..] {
404
        let candidates = [min_prod * a, min_prod * b, max_prod * a, max_prod * b];
405
        min_prod = *candidates.iter().min().unwrap();
406
        max_prod = *candidates.iter().max().unwrap();
407
    }
408

            
409
    (min_prod, max_prod)
410
}
411

            
412
/// Converts product of SATInts to a single SATInt
413
///
414
/// ```text
415
/// Product(SATInt(a), SATInt(b), ...) ~> SATInt(c)
416
///
417
/// ```
418
#[register_rule(("SAT_Log", 9000))]
419
3873
fn cnf_int_product(expr: &Expr, symbols: &SymbolTable) -> ApplicationResult {
420
3873
    let Expr::Product(_, exprs) = expr else {
421
3873
        return Err(RuleNotApplicable);
422
    };
423

            
424
    let Expr::AbstractLiteral(_, Matrix(exprs_list, _)) = exprs.as_ref() else {
425
        return Err(RuleNotApplicable);
426
    };
427

            
428
    let ranges: Result<Vec<_>, _> = exprs_list
429
        .iter()
430
        .map(|e| match e {
431
            Expr::SATInt(_, _, _, x) => Ok(x),
432
            _ => Err(RuleNotApplicable),
433
        })
434
        .collect();
435

            
436
    let ranges = ranges?; // propagate error if any
437

            
438
    let (min, max) = product_of_ranges(ranges.clone());
439

            
440
    let exprs_bits = validate_log_int_operands(exprs_list.clone(), None)?;
441

            
442
    let mut new_symbols = symbols.clone();
443
    let mut new_clauses = vec![];
444

            
445
    let (result, _) = exprs_bits
446
        .iter()
447
        .cloned()
448
        .zip(ranges.into_iter().copied())
449
        .reduce(|lhs, rhs| {
450
            // Make both bit vectors the same length
451
            let (lhs_bits, rhs_bits) = match_bits_length(lhs.0.clone(), rhs.0.clone());
452

            
453
            // Multiply operands
454
            let mut values = cnf_shift_add_multiply(
455
                &lhs_bits,
456
                &rhs_bits,
457
                lhs_bits.len(),
458
                &mut new_clauses,
459
                &mut new_symbols,
460
            );
461

            
462
            // Determine new range of result
463
            let (mut cum_min, mut cum_max) = lhs.1;
464
            let candidates = [
465
                cum_min * rhs.1.0,
466
                cum_min * rhs.1.1,
467
                cum_max * rhs.1.0,
468
                cum_max * rhs.1.1,
469
            ];
470
            cum_min = *candidates.iter().min().unwrap();
471
            cum_max = *candidates.iter().max().unwrap();
472

            
473
            let new_bit_count = bit_magnitude(cum_min).max(bit_magnitude(cum_max));
474
            values.truncate(new_bit_count);
475

            
476
            (values, (cum_min, cum_max))
477
        })
478
        .unwrap();
479

            
480
    Ok(Reduction::cnf(
481
        Expr::SATInt(
482
            Metadata::new(),
483
            SATIntEncoding::Log,
484
            Moo::new(into_matrix_expr!(result)),
485
            (min, max),
486
        ),
487
        new_clauses,
488
        new_symbols,
489
    ))
490
3873
}
491

            
492
/// Converts negation of a SATInt to a SATInt
493
///
494
/// ```text
495
/// -SATInt(a) ~> SATInt(b)
496
///
497
/// ```
498
#[register_rule(("SAT_Log", 4100))]
499
84
fn cnf_int_neg(expr: &Expr, symbols: &SymbolTable) -> ApplicationResult {
500
84
    let Expr::Neg(_, expr) = expr else {
501
84
        return Err(RuleNotApplicable);
502
    };
503

            
504
    let Expr::SATInt(_, _, _, (min, max)) = expr.as_ref() else {
505
        return Err(RuleNotApplicable);
506
    };
507

            
508
    let binding = validate_log_int_operands(vec![expr.as_ref().clone()], None)?;
509
    let [bits] = binding.as_slice() else {
510
        return Err(RuleNotApplicable);
511
    };
512

            
513
    let mut new_clauses = vec![];
514
    let mut new_symbols = symbols.clone();
515

            
516
    let result = tseytin_negate(bits, bits.len(), &mut new_clauses, &mut new_symbols);
517

            
518
    Ok(Reduction::cnf(
519
        Expr::SATInt(
520
            Metadata::new(),
521
            SATIntEncoding::Log,
522
            Moo::new(into_matrix_expr!(result)),
523
            (-max, -min),
524
        ),
525
        new_clauses,
526
        new_symbols,
527
    ))
528
84
}
529

            
530
fn tseytin_negate(
531
    expr: &Vec<Expr>,
532
    bits: usize,
533
    clauses: &mut Vec<CnfClause>,
534
    symbols: &mut SymbolTable,
535
) -> Vec<Expr> {
536
    let mut result = vec![];
537
    // invert bits
538
    for bit in expr {
539
        result.push(tseytin_not(bit.clone(), clauses, symbols));
540
    }
541

            
542
    // add one
543
    result = tseytin_add_two_power(&result, 0, bits, clauses, symbols);
544

            
545
    result
546
}
547

            
548
/// Converts min of SATInts to a single SATInt
549
///
550
/// ```text
551
/// Min(SATInt(a), SATInt(b), ...) ~> SATInt(c)
552
///
553
/// ```
554
#[register_rule(("SAT_Log", 4100))]
555
84
fn cnf_int_min(expr: &Expr, symbols: &SymbolTable) -> ApplicationResult {
556
84
    let Expr::Min(_, exprs) = expr else {
557
84
        return Err(RuleNotApplicable);
558
    };
559

            
560
    let Expr::AbstractLiteral(_, Matrix(exprs_list, _)) = exprs.as_ref() else {
561
        return Err(RuleNotApplicable);
562
    };
563

            
564
    let ranges: Result<Vec<_>, _> = exprs_list
565
        .iter()
566
        .map(|e| match e {
567
            Expr::SATInt(_, _, _, x) => Ok(x),
568
            _ => Err(RuleNotApplicable),
569
        })
570
        .collect();
571

            
572
    let ranges = ranges?; // propagate error if any
573

            
574
    // Is this optimal?
575
    let min = ranges.iter().map(|(a, _)| *a).min().unwrap();
576
    let max = ranges.iter().map(|(_, b)| *b).min().unwrap();
577

            
578
    let mut exprs_bits = validate_log_int_operands(exprs_list.clone(), None)?;
579

            
580
    let mut new_symbols = symbols.clone();
581
    let mut values;
582
    let mut new_clauses = vec![];
583

            
584
    while exprs_bits.len() > 1 {
585
        let mut next = Vec::with_capacity(exprs_bits.len().div_ceil(2));
586
        let mut iter = exprs_bits.into_iter();
587

            
588
        while let Some(a) = iter.next() {
589
            if let Some(b) = iter.next() {
590
                values = tseytin_binary_min_max(&a, &b, true, &mut new_clauses, &mut new_symbols);
591
                next.push(values);
592
            } else {
593
                next.push(a);
594
            }
595
        }
596

            
597
        exprs_bits = next;
598
    }
599

            
600
    let result = exprs_bits.pop().unwrap();
601

            
602
    Ok(Reduction::cnf(
603
        Expr::SATInt(
604
            Metadata::new(),
605
            SATIntEncoding::Log,
606
            Moo::new(into_matrix_expr!(result)),
607
            (min, max),
608
        ),
609
        new_clauses,
610
        new_symbols,
611
    ))
612
84
}
613

            
614
/// General function for getting the min or max of two log integers.
615
fn tseytin_binary_min_max(
616
    x: &[Expr],
617
    y: &[Expr],
618
    min: bool,
619
    clauses: &mut Vec<CnfClause>,
620
    symbols: &mut SymbolTable,
621
) -> Vec<Expr> {
622
    let mask = if min {
623
        // mask is 1 if x > y
624
        inequality_boolean(x.to_owned(), y.to_owned(), true, clauses, symbols)
625
    } else {
626
        // flip the args if getting maximum x < y -> 1
627
        inequality_boolean(y.to_owned(), x.to_owned(), true, clauses, symbols)
628
    };
629

            
630
    tseytin_select_array(mask, x, y, clauses, symbols)
631
}
632

            
633
// Selects between two boolean vectors depending on a condition (both vectors must be the same length)
634
/// cond ? b : a
635
///
636
/// cond = 1 => b
637
/// cond = 0 => a
638
fn tseytin_select_array(
639
    cond: Expr,
640
    a: &[Expr],
641
    b: &[Expr],
642
    clauses: &mut Vec<CnfClause>,
643
    symbols: &mut SymbolTable,
644
) -> Vec<Expr> {
645
    assert_eq!(
646
        a.len(),
647
        b.len(),
648
        "Input vectors 'a' and 'b' must have the same length"
649
    );
650

            
651
    let mut out = vec![];
652

            
653
    let bit_count = a.len();
654

            
655
    for i in 0..bit_count {
656
        out.push(tseytin_mux(
657
            cond.clone(),
658
            a[i].clone(),
659
            b[i].clone(),
660
            clauses,
661
            symbols,
662
        ));
663
    }
664

            
665
    out
666
}
667

            
668
/// Converts max of SATInts to a single SATInt
669
///
670
/// ```text
671
/// Max(SATInt(a), SATInt(b), ...) ~> SATInt(c)
672
///
673
/// ```
674
#[register_rule(("SAT_Log", 4100))]
675
84
fn cnf_int_max(expr: &Expr, symbols: &SymbolTable) -> ApplicationResult {
676
84
    let Expr::Max(_, exprs) = expr else {
677
84
        return Err(RuleNotApplicable);
678
    };
679

            
680
    let Expr::AbstractLiteral(_, Matrix(exprs_list, _)) = exprs.as_ref() else {
681
        return Err(RuleNotApplicable);
682
    };
683

            
684
    let ranges: Result<Vec<_>, _> = exprs_list
685
        .iter()
686
        .map(|e| match e {
687
            Expr::SATInt(_, _, _, x) => Ok(x),
688
            _ => Err(RuleNotApplicable),
689
        })
690
        .collect();
691

            
692
    let ranges = ranges?; // propagate error if any
693

            
694
    // Is this optimal?
695
    let min = ranges.iter().map(|(a, _)| *a).max().unwrap();
696
    let max = ranges.iter().map(|(_, b)| *b).max().unwrap();
697

            
698
    let mut exprs_bits = validate_log_int_operands(exprs_list.clone(), None)?;
699

            
700
    let mut new_symbols = symbols.clone();
701
    let mut values;
702
    let mut new_clauses = vec![];
703

            
704
    while exprs_bits.len() > 1 {
705
        let mut next = Vec::with_capacity(exprs_bits.len().div_ceil(2));
706
        let mut iter = exprs_bits.into_iter();
707

            
708
        while let Some(a) = iter.next() {
709
            if let Some(b) = iter.next() {
710
                values = tseytin_binary_min_max(&a, &b, false, &mut new_clauses, &mut new_symbols);
711
                next.push(values);
712
            } else {
713
                next.push(a);
714
            }
715
        }
716

            
717
        exprs_bits = next;
718
    }
719

            
720
    let result = exprs_bits.pop().unwrap();
721

            
722
    Ok(Reduction::cnf(
723
        Expr::SATInt(
724
            Metadata::new(),
725
            SATIntEncoding::Log,
726
            Moo::new(into_matrix_expr!(result)),
727
            (min, max),
728
        ),
729
        new_clauses,
730
        new_symbols,
731
    ))
732
84
}
733

            
734
/// Converts Abs of a SATInt to a SATInt
735
///
736
/// ```text
737
/// |SATInt(a)| ~> SATInt(b)
738
///
739
/// ```
740
#[register_rule(("SAT_Log", 4100))]
741
84
fn cnf_int_abs(expr: &Expr, symbols: &SymbolTable) -> ApplicationResult {
742
84
    let Expr::Abs(_, expr) = expr else {
743
84
        return Err(RuleNotApplicable);
744
    };
745

            
746
    let Expr::SATInt(_, _, _, (min, max)) = expr.as_ref() else {
747
        return Err(RuleNotApplicable);
748
    };
749

            
750
    let range = (
751
        cmp::max(0, cmp::max(*min, -*max)),
752
        cmp::max(min.abs(), max.abs()),
753
    );
754

            
755
    let binding = validate_log_int_operands(vec![expr.as_ref().clone()], None)?;
756
    let [bits] = binding.as_slice() else {
757
        return Err(RuleNotApplicable);
758
    };
759

            
760
    let mut new_clauses = vec![];
761
    let mut new_symbols = symbols.clone();
762

            
763
    let mut result = vec![];
764

            
765
    // How does this handle negatives edge cases: -(-8) = 8, an extra bit is needed
766

            
767
    // invert bits
768
    for bit in bits {
769
        result.push(tseytin_not(bit.clone(), &mut new_clauses, &mut new_symbols));
770
    }
771

            
772
    let bit_count = result.len();
773

            
774
    // add one
775
    result = tseytin_add_two_power(&result, 0, bit_count, &mut new_clauses, &mut new_symbols);
776

            
777
    for i in 0..bit_count {
778
        result[i] = tseytin_mux(
779
            bits[bit_count - 1].clone(),
780
            bits[i].clone(),
781
            result[i].clone(),
782
            &mut new_clauses,
783
            &mut new_symbols,
784
        )
785
    }
786

            
787
    Ok(Reduction::cnf(
788
        Expr::SATInt(
789
            Metadata::new(),
790
            SATIntEncoding::Log,
791
            Moo::new(into_matrix_expr!(result)),
792
            range,
793
        ),
794
        new_clauses,
795
        new_symbols,
796
    ))
797
84
}
798

            
799
/// Converts SafeDiv of SATInts to a single SATInt
800
///
801
/// ```text
802
/// SafeDiv(SATInt(a), SATInt(b)) ~> SATInt(c)
803
///
804
/// ```
805
#[register_rule(("SAT_Log", 4100))]
806
84
fn cnf_int_safediv(expr: &Expr, symbols: &SymbolTable) -> ApplicationResult {
807
    // Using "Restoring division" algorithm
808
    // https://en.wikipedia.org/wiki/Division_algorithm#Restoring_division
809
84
    let Expr::SafeDiv(_, numer, denom) = expr else {
810
84
        return Err(RuleNotApplicable);
811
    };
812

            
813
    let Expr::SATInt(_, _, _, (numer_min, numer_max)) = numer.as_ref() else {
814
        return Err(RuleNotApplicable);
815
    };
816

            
817
    let Expr::SATInt(_, _, _, (denom_min, denom_max)) = denom.as_ref() else {
818
        return Err(RuleNotApplicable);
819
    };
820

            
821
    let candidates = [
822
        numer_min / denom_min,
823
        numer_min / denom_max,
824
        numer_max / denom_min,
825
        numer_max / denom_max,
826
    ];
827

            
828
    let min = *candidates.iter().min().unwrap();
829
    let max = *candidates.iter().max().unwrap();
830

            
831
    let binding =
832
        validate_log_int_operands(vec![numer.as_ref().clone(), denom.as_ref().clone()], None)?;
833
    let [numer_bits, denom_bits] = binding.as_slice() else {
834
        return Err(RuleNotApplicable);
835
    };
836

            
837
    let bit_count = numer_bits.len();
838

            
839
    // TODO: Separate into division/mod function
840

            
841
    let mut new_symbols = symbols.clone();
842
    let mut new_clauses = vec![];
843
    let mut quotient = vec![false.into(); bit_count];
844

            
845
    let minus_numer = tseytin_negate(
846
        &numer_bits.clone(),
847
        bit_count,
848
        &mut new_clauses,
849
        &mut new_symbols,
850
    );
851
    let minus_denom = tseytin_negate(
852
        &denom_bits.clone(),
853
        bit_count,
854
        &mut new_clauses,
855
        &mut new_symbols,
856
    );
857

            
858
    let sign_bit = tseytin_xor(
859
        numer_bits[bit_count - 1].clone(),
860
        denom_bits[bit_count - 1].clone(),
861
        &mut new_clauses,
862
        &mut new_symbols,
863
    );
864

            
865
    let numer_bits = tseytin_select_array(
866
        numer_bits[bit_count - 1].clone(),
867
        &numer_bits.clone(),
868
        &minus_numer,
869
        &mut new_clauses,
870
        &mut new_symbols,
871
    );
872
    let denom_bits = tseytin_select_array(
873
        denom_bits[bit_count - 1].clone(),
874
        &denom_bits.clone(),
875
        &minus_denom,
876
        &mut new_clauses,
877
        &mut new_symbols,
878
    );
879

            
880
    let mut r = numer_bits;
881
    r.extend(std::iter::repeat_n(r[bit_count - 1].clone(), bit_count));
882
    let mut d = std::iter::repeat_n(false.into(), bit_count).collect_vec();
883
    d.extend(denom_bits);
884

            
885
    let minus_d = tseytin_negate(
886
        &d.clone(),
887
        2 * bit_count,
888
        &mut new_clauses,
889
        &mut new_symbols,
890
    );
891
    let mut rminusd;
892

            
893
    for i in (0..bit_count).rev() {
894
        // r << 1
895
        for j in (1..bit_count * 2).rev() {
896
            r[j] = r[j - 1].clone();
897
        }
898
        r[0] = false.into();
899

            
900
        rminusd = tseytin_int_adder(
901
            &r.clone(),
902
            &minus_d.clone(),
903
            2 * bit_count,
904
            &mut new_clauses,
905
            &mut new_symbols,
906
        );
907

            
908
        // TODO: For mod don't calculate on final iter
909
        quotient[i] = tseytin_not(
910
            // q[i] = inverse of sign bit - 1 if positive, 0 if negative
911
            rminusd[2 * bit_count - 1].clone(),
912
            &mut new_clauses,
913
            &mut new_symbols,
914
        );
915

            
916
        // TODO: For div don't calculate on final iter
917
        for j in 0..(2 * bit_count) {
918
            r[j] = tseytin_mux(
919
                quotient[i].clone(),
920
                r[j].clone(),       // use r if negative
921
                rminusd[j].clone(), // use r-d if positive
922
                &mut new_clauses,
923
                &mut new_symbols,
924
            );
925
        }
926
    }
927

            
928
    let minus_quotient = tseytin_negate(
929
        &quotient.clone(),
930
        bit_count,
931
        &mut new_clauses,
932
        &mut new_symbols,
933
    );
934

            
935
    let out = tseytin_select_array(
936
        sign_bit,
937
        &quotient,
938
        &minus_quotient,
939
        &mut new_clauses,
940
        &mut new_symbols,
941
    );
942

            
943
    Ok(Reduction::cnf(
944
        Expr::SATInt(
945
            Metadata::new(),
946
            SATIntEncoding::Log,
947
            Moo::new(into_matrix_expr!(out)),
948
            (min, max),
949
        ),
950
        new_clauses,
951
        new_symbols,
952
    ))
953
84
}
954

            
955
/*
956
/// Converts SafeMod of SATInts to a single SATInt
957
///
958
/// ```text
959
/// SafeMod(SATInt(a), SATInt(b)) ~> SATInt(c)
960
///
961
/// ```
962
#[register_rule(("SAT_Log", 4100))]
963
fn cnf_int_safemod(expr: &Expr, _: &SymbolTable) -> ApplicationResult {}
964

            
965
/// Converts SafePow of SATInts to a single SATInt
966
///
967
/// ```text
968
/// SafePow(SATInt(a), SATInt(b)) ~> SATInt(c)
969
///
970
/// ```
971
#[register_rule(("SAT", 4100))]
972
fn cnf_int_safepow(expr: &Expr, _: &SymbolTable) -> ApplicationResult {
973
    // use 'Exponentiation by squaring'
974
}
975
*/