1
use conjure_cp::ast::Expression as Expr;
2
use conjure_cp::ast::{SATIntEncoding, SymbolTable};
3
use conjure_cp::rule_engine::{
4
    ApplicationError::RuleNotApplicable, ApplicationResult, Reduction, register_rule,
5
};
6

            
7
use conjure_cp::ast::AbstractLiteral::Matrix;
8
use conjure_cp::ast::Metadata;
9
use conjure_cp::ast::Moo;
10
use conjure_cp::into_matrix_expr;
11

            
12
use itertools::Itertools;
13

            
14
use super::boolean::{
15
    tseytin_and, tseytin_iff, tseytin_imply, tseytin_mux, tseytin_not, tseytin_or, tseytin_xor,
16
};
17
use super::integer_repr::{bit_magnitude, match_bits_length, validate_log_int_operands};
18

            
19
use conjure_cp::ast::CnfClause;
20

            
21
use std::cmp;
22

            
23
/// Converts an inequality expression between two SATInts to a boolean expression in cnf.
24
///
25
/// ```text
26
/// SATInt(a) </>/<=/>= SATInt(b) ~> Bool
27
///
28
/// ```
29
#[register_rule(("SAT_Log", 4100))]
30
4392
fn cnf_int_ineq(expr: &Expr, symbols: &SymbolTable) -> ApplicationResult {
31
4392
    let (lhs, rhs, strict) = match expr {
32
63
        Expr::Lt(_, x, y) => (y, x, true),
33
54
        Expr::Gt(_, x, y) => (x, y, true),
34
648
        Expr::Leq(_, x, y) => (y, x, false),
35
558
        Expr::Geq(_, x, y) => (x, y, false),
36
3069
        _ => return Err(RuleNotApplicable),
37
    };
38

            
39
1179
    let binding =
40
1323
        validate_log_int_operands(vec![lhs.as_ref().clone(), rhs.as_ref().clone()], None)?;
41
1179
    let [lhs_bits, rhs_bits] = binding.as_slice() else {
42
        return Err(RuleNotApplicable);
43
    };
44

            
45
1179
    let mut new_symbols = symbols.clone();
46
1179
    let mut new_clauses = vec![];
47

            
48
1179
    let output = inequality_boolean(
49
1179
        lhs_bits.clone(),
50
1179
        rhs_bits.clone(),
51
1179
        strict,
52
1179
        &mut new_clauses,
53
1179
        &mut new_symbols,
54
    );
55
1179
    Ok(Reduction::cnf(output, new_clauses, new_symbols))
56
4392
}
57

            
58
/// Converts a = expression between two SATInts to a boolean expression in cnf
59
///
60
/// ```text
61
/// SATInt(a) = SATInt(b) ~> Bool
62
///
63
/// ```
64
#[register_rule(("SAT_Log", 9100))]
65
247203
fn cnf_int_eq(expr: &Expr, symbols: &SymbolTable) -> ApplicationResult {
66
247203
    let Expr::Eq(_, lhs, rhs) = expr else {
67
246420
        return Err(RuleNotApplicable);
68
    };
69

            
70
207
    let binding =
71
783
        validate_log_int_operands(vec![lhs.as_ref().clone(), rhs.as_ref().clone()], None)?;
72
207
    let [lhs_bits, rhs_bits] = binding.as_slice() else {
73
        return Err(RuleNotApplicable);
74
    };
75

            
76
207
    let bit_count = lhs_bits.len();
77

            
78
207
    let mut output = true.into();
79
207
    let mut new_symbols = symbols.clone();
80
207
    let mut new_clauses = vec![];
81
    let mut comparison;
82

            
83
1260
    for i in 0..bit_count {
84
1260
        comparison = tseytin_iff(
85
1260
            lhs_bits[i].clone(),
86
1260
            rhs_bits[i].clone(),
87
1260
            &mut new_clauses,
88
1260
            &mut new_symbols,
89
1260
        );
90
1260
        output = tseytin_and(
91
1260
            &vec![comparison, output],
92
1260
            &mut new_clauses,
93
1260
            &mut new_symbols,
94
1260
        );
95
1260
    }
96

            
97
207
    Ok(Reduction::cnf(output, new_clauses, new_symbols))
98
247203
}
99

            
100
/// Converts a != expression between two SATInts to a boolean expression in cnf
101
///
102
/// ```text
103
/// SATInt(a) != SATInt(b) ~> Bool
104
///
105
/// ```
106
#[register_rule(("SAT_Log", 4100))]
107
4392
fn cnf_int_neq(expr: &Expr, symbols: &SymbolTable) -> ApplicationResult {
108
4392
    let Expr::Neq(_, lhs, rhs) = expr else {
109
4257
        return Err(RuleNotApplicable);
110
    };
111

            
112
135
    let binding =
113
135
        validate_log_int_operands(vec![lhs.as_ref().clone(), rhs.as_ref().clone()], None)?;
114
135
    let [lhs_bits, rhs_bits] = binding.as_slice() else {
115
        return Err(RuleNotApplicable);
116
    };
117

            
118
135
    let bit_count = lhs_bits.len();
119

            
120
135
    let mut output = false.into();
121
135
    let mut new_symbols = symbols.clone();
122
135
    let mut new_clauses = vec![];
123
    let mut comparison;
124

            
125
432
    for i in 0..bit_count {
126
432
        comparison = tseytin_xor(
127
432
            lhs_bits[i].clone(),
128
432
            rhs_bits[i].clone(),
129
432
            &mut new_clauses,
130
432
            &mut new_symbols,
131
432
        );
132
432
        output = tseytin_or(
133
432
            &vec![comparison, output],
134
432
            &mut new_clauses,
135
432
            &mut new_symbols,
136
432
        );
137
432
    }
138

            
139
135
    Ok(Reduction::cnf(output, new_clauses, new_symbols))
140
4392
}
141

            
142
// Creates a boolean expression for > or >=
143
// a > b or a >= b
144
// This can also be used for < and <= by reversing the order of the inputs
145
// Returns result, new symbol table, new clauses
146
1251
fn inequality_boolean(
147
1251
    a: Vec<Expr>,
148
1251
    b: Vec<Expr>,
149
1251
    strict: bool,
150
1251
    clauses: &mut Vec<CnfClause>,
151
1251
    symbols: &mut SymbolTable,
152
1251
) -> Expr {
153
    let mut notb;
154
    let mut output;
155

            
156
1251
    if strict {
157
162
        notb = tseytin_not(b[0].clone(), clauses, symbols);
158
162
        output = tseytin_and(&vec![a[0].clone(), notb], clauses, symbols);
159
1089
    } else {
160
1089
        output = tseytin_imply(b[0].clone(), a[0].clone(), clauses, symbols);
161
1089
    }
162

            
163
    //TODO: There may be room for simplification, and constant optimization
164

            
165
1251
    let bit_count = a.len();
166

            
167
    let mut lhs;
168
    let mut rhs;
169
    let mut iff;
170
2646
    for n in 1..(bit_count - 1) {
171
2646
        notb = tseytin_not(b[n].clone(), clauses, symbols);
172
2646
        lhs = tseytin_and(&vec![a[n].clone(), notb.clone()], clauses, symbols);
173
2646
        iff = tseytin_iff(a[n].clone(), b[n].clone(), clauses, symbols);
174
2646
        rhs = tseytin_and(&vec![iff.clone(), output.clone()], clauses, symbols);
175
2646
        output = tseytin_or(&vec![lhs.clone(), rhs.clone()], clauses, symbols);
176
2646
    }
177

            
178
    // final bool is the sign bit and should be handled inversely
179
1251
    let nota = tseytin_not(a[bit_count - 1].clone(), clauses, symbols);
180
1251
    lhs = tseytin_and(&vec![nota, b[bit_count - 1].clone()], clauses, symbols);
181
1251
    iff = tseytin_iff(
182
1251
        a[bit_count - 1].clone(),
183
1251
        b[bit_count - 1].clone(),
184
1251
        clauses,
185
1251
        symbols,
186
    );
187
1251
    rhs = tseytin_and(&vec![iff, output.clone()], clauses, symbols);
188
1251
    output = tseytin_or(&vec![lhs, rhs], clauses, symbols);
189

            
190
1251
    output
191
1251
}
192

            
193
/// Converts sum of SATInts to a single SATInt
194
///
195
/// ```text
196
/// Sum(SATInt(a), SATInt(b), ...) ~> SATInt(c)
197
///
198
/// ```
199
#[register_rule(("SAT_Log", 4100))]
200
4392
fn cnf_int_sum(expr: &Expr, symbols: &SymbolTable) -> ApplicationResult {
201
4392
    let Expr::Sum(_, exprs) = expr else {
202
4302
        return Err(RuleNotApplicable);
203
    };
204

            
205
90
    let Expr::AbstractLiteral(_, Matrix(exprs_list, _)) = exprs.as_ref() else {
206
        return Err(RuleNotApplicable);
207
    };
208

            
209
90
    let ranges: Result<Vec<_>, _> = exprs_list
210
90
        .iter()
211
180
        .map(|e| match e {
212
135
            Expr::SATInt(_, _, _, x) => Ok(x),
213
45
            _ => Err(RuleNotApplicable),
214
180
        })
215
90
        .collect();
216

            
217
90
    let ranges = ranges?;
218

            
219
45
    let min = ranges.iter().map(|(a, _)| *a).sum();
220
45
    let max = ranges.iter().map(|(_, a)| *a).sum();
221

            
222
45
    let output_size = cmp::max(bit_magnitude(min), bit_magnitude(max));
223

            
224
    // Check operands are valid log ints
225
45
    let mut exprs_bits =
226
45
        validate_log_int_operands(exprs_list.clone(), Some(output_size.try_into().unwrap()))?;
227

            
228
45
    let mut new_symbols = symbols.clone();
229
    let mut values;
230
45
    let mut new_clauses = vec![];
231

            
232
99
    while exprs_bits.len() > 1 {
233
54
        let mut next = Vec::with_capacity(exprs_bits.len().div_ceil(2));
234
54
        let mut iter = exprs_bits.into_iter();
235

            
236
117
        while let Some(a) = iter.next() {
237
63
            if let Some(b) = iter.next() {
238
54
                values = tseytin_int_adder(&a, &b, output_size, &mut new_clauses, &mut new_symbols);
239
54
                next.push(values);
240
54
            } else {
241
9
                next.push(a);
242
9
            }
243
        }
244

            
245
54
        exprs_bits = next;
246
    }
247

            
248
45
    let result = exprs_bits.pop().unwrap();
249

            
250
45
    Ok(Reduction::cnf(
251
45
        Expr::SATInt(
252
45
            Metadata::new(),
253
45
            SATIntEncoding::Log,
254
45
            Moo::new(into_matrix_expr!(result)),
255
45
            (min, max),
256
45
        ),
257
45
        new_clauses,
258
45
        new_symbols,
259
45
    ))
260
4392
}
261

            
262
/// Returns result, new symbol table, new clauses
263
/// This function expects bits to match the lengths of x and y
264
99
fn tseytin_int_adder(
265
99
    x: &[Expr],
266
99
    y: &[Expr],
267
99
    bits: usize,
268
99
    clauses: &mut Vec<CnfClause>,
269
99
    symbols: &mut SymbolTable,
270
99
) -> Vec<Expr> {
271
    //TODO: Optimizing for constants
272
99
    let (mut result, mut carry) = tseytin_half_adder(x[0].clone(), y[0].clone(), clauses, symbols);
273

            
274
99
    let mut output = vec![result];
275
765
    for i in 1..bits {
276
765
        (result, carry) =
277
765
            tseytin_full_adder(x[i].clone(), y[i].clone(), carry.clone(), clauses, symbols);
278
765
        output.push(result);
279
765
    }
280

            
281
99
    output
282
99
}
283

            
284
/// This function adds two booleans and a carry boolean using the full-adder logic circuit, it is intended for use in a binary adder.
285
765
fn tseytin_full_adder(
286
765
    a: Expr,
287
765
    b: Expr,
288
765
    carry: Expr,
289
765
    clauses: &mut Vec<CnfClause>,
290
765
    symbols: &mut SymbolTable,
291
765
) -> (Expr, Expr) {
292
765
    let axorb = tseytin_xor(a.clone(), b.clone(), clauses, symbols);
293
765
    let result = tseytin_xor(axorb.clone(), carry.clone(), clauses, symbols);
294
765
    let aandb = tseytin_and(&vec![a, b], clauses, symbols);
295
765
    let carryandaxorb = tseytin_and(&vec![carry, axorb], clauses, symbols);
296
765
    let carryout = tseytin_or(&vec![aandb, carryandaxorb], clauses, symbols);
297

            
298
765
    (result, carryout)
299
765
}
300

            
301
/// This function adds two booleans using the half-adder logic circuit, it is intended for use in a binary adder.
302
99
fn tseytin_half_adder(
303
99
    a: Expr,
304
99
    b: Expr,
305
99
    clauses: &mut Vec<CnfClause>,
306
99
    symbols: &mut SymbolTable,
307
99
) -> (Expr, Expr) {
308
99
    let result = tseytin_xor(a.clone(), b.clone(), clauses, symbols);
309
99
    let carry = tseytin_and(&vec![a, b], clauses, symbols);
310

            
311
99
    (result, carry)
312
99
}
313

            
314
/// this function is for specifically adding a power of two constant to a cnf int.
315
72
fn tseytin_add_two_power(
316
72
    expr: &[Expr],
317
72
    exponent: usize,
318
72
    bits: usize,
319
72
    clauses: &mut Vec<CnfClause>,
320
72
    symbols: &mut SymbolTable,
321
72
) -> Vec<Expr> {
322
72
    let mut result = vec![];
323
72
    let mut product = expr[exponent].clone();
324

            
325
72
    for item in expr.iter().take(exponent) {
326
        result.push(item.clone());
327
    }
328

            
329
72
    result.push(tseytin_not(expr[exponent].clone(), clauses, symbols));
330

            
331
171
    for item in expr.iter().take(bits).skip(exponent + 1) {
332
171
        result.push(tseytin_xor(product.clone(), item.clone(), clauses, symbols));
333
171
        product = tseytin_and(&vec![product, item.clone()], clauses, symbols);
334
171
    }
335

            
336
72
    result
337
72
}
338

            
339
/// This function multiplies two binary values using the shift-add multiplication algorithm.
340
9
fn cnf_shift_add_multiply(
341
9
    x: &[Expr],
342
9
    y: &[Expr],
343
9
    bits: usize,
344
9
    clauses: &mut Vec<CnfClause>,
345
9
    symbols: &mut SymbolTable,
346
9
) -> Vec<Expr> {
347
9
    let mut x = x.to_owned();
348
9
    let mut y = y.to_owned();
349

            
350
    //TODO Optimizing for constants
351
    //TODO Optimize addition for i left shifted values - skip first i bits
352

            
353
    // extend sign bits of operands to 2*`bits`
354
9
    x.extend(std::iter::repeat_n(x[bits - 1].clone(), bits));
355
9
    y.extend(std::iter::repeat_n(y[bits - 1].clone(), bits));
356

            
357
9
    let mut s: Vec<Expr> = vec![];
358
    let mut x_0andy_i;
359

            
360
108
    for bit in &y {
361
108
        x_0andy_i = tseytin_and(&vec![x[0].clone(), bit.clone()], clauses, symbols);
362
108
        s.push(x_0andy_i);
363
108
    }
364

            
365
    let mut sum;
366
    let mut if_true;
367
    let mut not_x_n;
368
    let mut if_false;
369

            
370
45
    for item in x.iter().take(bits).skip(1) {
371
        // y << 1
372
495
        for i in (1..bits * 2).rev() {
373
495
            y[i] = y[i - 1].clone();
374
495
        }
375
45
        y[0] = false.into();
376

            
377
        // TODO switch to multiplexer
378
        // TODO Add negatives support once MUX is added
379
45
        sum = tseytin_int_adder(&s, &y, bits * 2, clauses, symbols);
380
45
        not_x_n = tseytin_not(item.clone(), clauses, symbols);
381

            
382
540
        for i in 0..(bits * 2) {
383
540
            if_true = tseytin_and(&vec![item.clone(), sum[i].clone()], clauses, symbols);
384
540
            if_false = tseytin_and(&vec![not_x_n.clone(), s[i].clone()], clauses, symbols);
385
540
            s[i] = tseytin_or(&vec![if_true.clone(), if_false.clone()], clauses, symbols);
386
540
        }
387
    }
388

            
389
9
    s
390
9
}
391

            
392
/// This function calculates the range of the product of multiple integers.
393
/// E.g.
394
/// a : [2, 5], b : [-1, 2], c : [-10, -6], d : [0, 3]
395
/// a * b * c *d : [-300, 150]
396
9
fn product_of_ranges(ranges: Vec<&(i32, i32)>) -> (i32, i32) {
397
9
    if ranges.is_empty() {
398
        return (1, 1); // product of zero numbers = 1
399
9
    }
400

            
401
9
    let &(mut min_prod, mut max_prod) = ranges[0];
402

            
403
9
    for &(a, b) in &ranges[1..] {
404
9
        let candidates = [min_prod * a, min_prod * b, max_prod * a, max_prod * b];
405
9
        min_prod = *candidates.iter().min().unwrap();
406
9
        max_prod = *candidates.iter().max().unwrap();
407
9
    }
408

            
409
9
    (min_prod, max_prod)
410
9
}
411

            
412
/// Converts product of SATInts to a single SATInt
413
///
414
/// ```text
415
/// Product(SATInt(a), SATInt(b), ...) ~> SATInt(c)
416
///
417
/// ```
418
#[register_rule(("SAT_Log", 9000))]
419
176679
fn cnf_int_product(expr: &Expr, symbols: &SymbolTable) -> ApplicationResult {
420
176679
    let Expr::Product(_, exprs) = expr else {
421
176670
        return Err(RuleNotApplicable);
422
    };
423

            
424
9
    let Expr::AbstractLiteral(_, Matrix(exprs_list, _)) = exprs.as_ref() else {
425
        return Err(RuleNotApplicable);
426
    };
427

            
428
9
    let ranges: Result<Vec<_>, _> = exprs_list
429
9
        .iter()
430
18
        .map(|e| match e {
431
18
            Expr::SATInt(_, _, _, x) => Ok(x),
432
            _ => Err(RuleNotApplicable),
433
18
        })
434
9
        .collect();
435

            
436
9
    let ranges = ranges?; // propagate error if any
437

            
438
9
    let (min, max) = product_of_ranges(ranges.clone());
439

            
440
9
    let exprs_bits = validate_log_int_operands(exprs_list.clone(), None)?;
441

            
442
9
    let mut new_symbols = symbols.clone();
443
9
    let mut new_clauses = vec![];
444

            
445
9
    let (result, _) = exprs_bits
446
9
        .iter()
447
9
        .cloned()
448
9
        .zip(ranges.into_iter().copied())
449
9
        .reduce(|lhs, rhs| {
450
            // Make both bit vectors the same length
451
9
            let (lhs_bits, rhs_bits) = match_bits_length(lhs.0.clone(), rhs.0.clone());
452

            
453
            // Multiply operands
454
9
            let mut values = cnf_shift_add_multiply(
455
9
                &lhs_bits,
456
9
                &rhs_bits,
457
9
                lhs_bits.len(),
458
9
                &mut new_clauses,
459
9
                &mut new_symbols,
460
            );
461

            
462
            // Determine new range of result
463
9
            let (mut cum_min, mut cum_max) = lhs.1;
464
9
            let candidates = [
465
9
                cum_min * rhs.1.0,
466
9
                cum_min * rhs.1.1,
467
9
                cum_max * rhs.1.0,
468
9
                cum_max * rhs.1.1,
469
9
            ];
470
9
            cum_min = *candidates.iter().min().unwrap();
471
9
            cum_max = *candidates.iter().max().unwrap();
472

            
473
9
            let new_bit_count = bit_magnitude(cum_min).max(bit_magnitude(cum_max));
474
9
            values.truncate(new_bit_count);
475

            
476
9
            (values, (cum_min, cum_max))
477
9
        })
478
9
        .unwrap();
479

            
480
9
    Ok(Reduction::cnf(
481
9
        Expr::SATInt(
482
9
            Metadata::new(),
483
9
            SATIntEncoding::Log,
484
9
            Moo::new(into_matrix_expr!(result)),
485
9
            (min, max),
486
9
        ),
487
9
        new_clauses,
488
9
        new_symbols,
489
9
    ))
490
176679
}
491

            
492
/// Converts negation of a SATInt to a SATInt
493
///
494
/// ```text
495
/// -SATInt(a) ~> SATInt(b)
496
///
497
/// ```
498
#[register_rule(("SAT_Log", 4100))]
499
4392
fn cnf_int_neg(expr: &Expr, symbols: &SymbolTable) -> ApplicationResult {
500
4392
    let Expr::Neg(_, expr) = expr else {
501
4338
        return Err(RuleNotApplicable);
502
    };
503

            
504
54
    let Expr::SATInt(_, _, _, (min, max)) = expr.as_ref() else {
505
        return Err(RuleNotApplicable);
506
    };
507

            
508
54
    let binding = validate_log_int_operands(vec![expr.as_ref().clone()], None)?;
509
54
    let [bits] = binding.as_slice() else {
510
        return Err(RuleNotApplicable);
511
    };
512

            
513
54
    let mut new_clauses = vec![];
514
54
    let mut new_symbols = symbols.clone();
515

            
516
54
    let result = tseytin_negate(bits, bits.len(), &mut new_clauses, &mut new_symbols);
517

            
518
54
    Ok(Reduction::cnf(
519
54
        Expr::SATInt(
520
54
            Metadata::new(),
521
54
            SATIntEncoding::Log,
522
54
            Moo::new(into_matrix_expr!(result)),
523
54
            (-max, -min),
524
54
        ),
525
54
        new_clauses,
526
54
        new_symbols,
527
54
    ))
528
4392
}
529

            
530
54
fn tseytin_negate(
531
54
    expr: &Vec<Expr>,
532
54
    bits: usize,
533
54
    clauses: &mut Vec<CnfClause>,
534
54
    symbols: &mut SymbolTable,
535
54
) -> Vec<Expr> {
536
54
    let mut result = vec![];
537
    // invert bits
538
153
    for bit in expr {
539
153
        result.push(tseytin_not(bit.clone(), clauses, symbols));
540
153
    }
541

            
542
    // add one
543
54
    result = tseytin_add_two_power(&result, 0, bits, clauses, symbols);
544

            
545
54
    result
546
54
}
547

            
548
/// Converts min of SATInts to a single SATInt
549
///
550
/// ```text
551
/// Min(SATInt(a), SATInt(b), ...) ~> SATInt(c)
552
///
553
/// ```
554
#[register_rule(("SAT_Log", 4100))]
555
4392
fn cnf_int_min(expr: &Expr, symbols: &SymbolTable) -> ApplicationResult {
556
4392
    let Expr::Min(_, exprs) = expr else {
557
4356
        return Err(RuleNotApplicable);
558
    };
559

            
560
36
    let Expr::AbstractLiteral(_, Matrix(exprs_list, _)) = exprs.as_ref() else {
561
        return Err(RuleNotApplicable);
562
    };
563

            
564
36
    let ranges: Result<Vec<_>, _> = exprs_list
565
36
        .iter()
566
72
        .map(|e| match e {
567
72
            Expr::SATInt(_, _, _, x) => Ok(x),
568
            _ => Err(RuleNotApplicable),
569
72
        })
570
36
        .collect();
571

            
572
36
    let ranges = ranges?; // propagate error if any
573

            
574
    // Is this optimal?
575
36
    let min = ranges.iter().map(|(a, _)| *a).min().unwrap();
576
36
    let max = ranges.iter().map(|(_, b)| *b).min().unwrap();
577

            
578
36
    let mut exprs_bits = validate_log_int_operands(exprs_list.clone(), None)?;
579

            
580
36
    let mut new_symbols = symbols.clone();
581
    let mut values;
582
36
    let mut new_clauses = vec![];
583

            
584
72
    while exprs_bits.len() > 1 {
585
36
        let mut next = Vec::with_capacity(exprs_bits.len().div_ceil(2));
586
36
        let mut iter = exprs_bits.into_iter();
587

            
588
72
        while let Some(a) = iter.next() {
589
36
            if let Some(b) = iter.next() {
590
36
                values = tseytin_binary_min_max(&a, &b, true, &mut new_clauses, &mut new_symbols);
591
36
                next.push(values);
592
36
            } else {
593
                next.push(a);
594
            }
595
        }
596

            
597
36
        exprs_bits = next;
598
    }
599

            
600
36
    let result = exprs_bits.pop().unwrap();
601

            
602
36
    Ok(Reduction::cnf(
603
36
        Expr::SATInt(
604
36
            Metadata::new(),
605
36
            SATIntEncoding::Log,
606
36
            Moo::new(into_matrix_expr!(result)),
607
36
            (min, max),
608
36
        ),
609
36
        new_clauses,
610
36
        new_symbols,
611
36
    ))
612
4392
}
613

            
614
/// General function for getting the min or max of two log integers.
615
72
fn tseytin_binary_min_max(
616
72
    x: &[Expr],
617
72
    y: &[Expr],
618
72
    min: bool,
619
72
    clauses: &mut Vec<CnfClause>,
620
72
    symbols: &mut SymbolTable,
621
72
) -> Vec<Expr> {
622
72
    let mask = if min {
623
        // mask is 1 if x > y
624
36
        inequality_boolean(x.to_owned(), y.to_owned(), true, clauses, symbols)
625
    } else {
626
        // flip the args if getting maximum x < y -> 1
627
36
        inequality_boolean(y.to_owned(), x.to_owned(), true, clauses, symbols)
628
    };
629

            
630
72
    tseytin_select_array(mask, x, y, clauses, symbols)
631
72
}
632

            
633
// Selects between two boolean vectors depending on a condition (both vectors must be the same length)
634
/// cond ? b : a
635
///
636
/// cond = 1 => b
637
/// cond = 0 => a
638
72
fn tseytin_select_array(
639
72
    cond: Expr,
640
72
    a: &[Expr],
641
72
    b: &[Expr],
642
72
    clauses: &mut Vec<CnfClause>,
643
72
    symbols: &mut SymbolTable,
644
72
) -> Vec<Expr> {
645
72
    assert_eq!(
646
72
        a.len(),
647
72
        b.len(),
648
        "Input vectors 'a' and 'b' must have the same length"
649
    );
650

            
651
72
    let mut out = vec![];
652

            
653
72
    let bit_count = a.len();
654

            
655
270
    for i in 0..bit_count {
656
270
        out.push(tseytin_mux(
657
270
            cond.clone(),
658
270
            a[i].clone(),
659
270
            b[i].clone(),
660
270
            clauses,
661
270
            symbols,
662
270
        ));
663
270
    }
664

            
665
72
    out
666
72
}
667

            
668
/// Converts max of SATInts to a single SATInt
669
///
670
/// ```text
671
/// Max(SATInt(a), SATInt(b), ...) ~> SATInt(c)
672
///
673
/// ```
674
#[register_rule(("SAT_Log", 4100))]
675
4392
fn cnf_int_max(expr: &Expr, symbols: &SymbolTable) -> ApplicationResult {
676
4392
    let Expr::Max(_, exprs) = expr else {
677
4356
        return Err(RuleNotApplicable);
678
    };
679

            
680
36
    let Expr::AbstractLiteral(_, Matrix(exprs_list, _)) = exprs.as_ref() else {
681
        return Err(RuleNotApplicable);
682
    };
683

            
684
36
    let ranges: Result<Vec<_>, _> = exprs_list
685
36
        .iter()
686
72
        .map(|e| match e {
687
72
            Expr::SATInt(_, _, _, x) => Ok(x),
688
            _ => Err(RuleNotApplicable),
689
72
        })
690
36
        .collect();
691

            
692
36
    let ranges = ranges?; // propagate error if any
693

            
694
    // Is this optimal?
695
36
    let min = ranges.iter().map(|(a, _)| *a).max().unwrap();
696
36
    let max = ranges.iter().map(|(_, b)| *b).max().unwrap();
697

            
698
36
    let mut exprs_bits = validate_log_int_operands(exprs_list.clone(), None)?;
699

            
700
36
    let mut new_symbols = symbols.clone();
701
    let mut values;
702
36
    let mut new_clauses = vec![];
703

            
704
72
    while exprs_bits.len() > 1 {
705
36
        let mut next = Vec::with_capacity(exprs_bits.len().div_ceil(2));
706
36
        let mut iter = exprs_bits.into_iter();
707

            
708
72
        while let Some(a) = iter.next() {
709
36
            if let Some(b) = iter.next() {
710
36
                values = tseytin_binary_min_max(&a, &b, false, &mut new_clauses, &mut new_symbols);
711
36
                next.push(values);
712
36
            } else {
713
                next.push(a);
714
            }
715
        }
716

            
717
36
        exprs_bits = next;
718
    }
719

            
720
36
    let result = exprs_bits.pop().unwrap();
721

            
722
36
    Ok(Reduction::cnf(
723
36
        Expr::SATInt(
724
36
            Metadata::new(),
725
36
            SATIntEncoding::Log,
726
36
            Moo::new(into_matrix_expr!(result)),
727
36
            (min, max),
728
36
        ),
729
36
        new_clauses,
730
36
        new_symbols,
731
36
    ))
732
4392
}
733

            
734
/// Converts Abs of a SATInt to a SATInt
735
///
736
/// ```text
737
/// |SATInt(a)| ~> SATInt(b)
738
///
739
/// ```
740
#[register_rule(("SAT_Log", 4100))]
741
4392
fn cnf_int_abs(expr: &Expr, symbols: &SymbolTable) -> ApplicationResult {
742
4392
    let Expr::Abs(_, expr) = expr else {
743
4356
        return Err(RuleNotApplicable);
744
    };
745

            
746
36
    let Expr::SATInt(_, _, _, (min, max)) = expr.as_ref() else {
747
18
        return Err(RuleNotApplicable);
748
    };
749

            
750
18
    let range = (
751
18
        cmp::max(0, cmp::max(*min, -*max)),
752
18
        cmp::max(min.abs(), max.abs()),
753
18
    );
754

            
755
18
    let binding = validate_log_int_operands(vec![expr.as_ref().clone()], None)?;
756
18
    let [bits] = binding.as_slice() else {
757
        return Err(RuleNotApplicable);
758
    };
759

            
760
18
    let mut new_clauses = vec![];
761
18
    let mut new_symbols = symbols.clone();
762

            
763
18
    let mut result = vec![];
764

            
765
    // How does this handle negatives edge cases: -(-8) = 8, an extra bit is needed
766

            
767
    // invert bits
768
90
    for bit in bits {
769
90
        result.push(tseytin_not(bit.clone(), &mut new_clauses, &mut new_symbols));
770
90
    }
771

            
772
18
    let bit_count = result.len();
773

            
774
    // add one
775
18
    result = tseytin_add_two_power(&result, 0, bit_count, &mut new_clauses, &mut new_symbols);
776

            
777
90
    for i in 0..bit_count {
778
90
        result[i] = tseytin_mux(
779
90
            bits[bit_count - 1].clone(),
780
90
            bits[i].clone(),
781
90
            result[i].clone(),
782
90
            &mut new_clauses,
783
90
            &mut new_symbols,
784
90
        )
785
    }
786

            
787
18
    Ok(Reduction::cnf(
788
18
        Expr::SATInt(
789
18
            Metadata::new(),
790
18
            SATIntEncoding::Log,
791
18
            Moo::new(into_matrix_expr!(result)),
792
18
            range,
793
18
        ),
794
18
        new_clauses,
795
18
        new_symbols,
796
18
    ))
797
4392
}
798

            
799
/// Converts SafeDiv of SATInts to a single SATInt
800
///
801
/// ```text
802
/// SafeDiv(SATInt(a), SATInt(b)) ~> SATInt(c)
803
///
804
/// ```
805
#[register_rule(("SAT_Log", 4100))]
806
4392
fn cnf_int_safediv(expr: &Expr, symbols: &SymbolTable) -> ApplicationResult {
807
    // Using "Restoring division" algorithm
808
    // https://en.wikipedia.org/wiki/Division_algorithm#Restoring_division
809
4392
    let Expr::SafeDiv(_, numer, denom) = expr else {
810
4392
        return Err(RuleNotApplicable);
811
    };
812

            
813
    let Expr::SATInt(_, _, _, (numer_min, numer_max)) = numer.as_ref() else {
814
        return Err(RuleNotApplicable);
815
    };
816

            
817
    let Expr::SATInt(_, _, _, (denom_min, denom_max)) = denom.as_ref() else {
818
        return Err(RuleNotApplicable);
819
    };
820

            
821
    let candidates = [
822
        numer_min / denom_min,
823
        numer_min / denom_max,
824
        numer_max / denom_min,
825
        numer_max / denom_max,
826
    ];
827

            
828
    let min = *candidates.iter().min().unwrap();
829
    let max = *candidates.iter().max().unwrap();
830

            
831
    let binding =
832
        validate_log_int_operands(vec![numer.as_ref().clone(), denom.as_ref().clone()], None)?;
833
    let [numer_bits, denom_bits] = binding.as_slice() else {
834
        return Err(RuleNotApplicable);
835
    };
836

            
837
    let bit_count = numer_bits.len();
838

            
839
    // TODO: Separate into division/mod function
840

            
841
    let mut new_symbols = symbols.clone();
842
    let mut new_clauses = vec![];
843
    let mut quotient = vec![false.into(); bit_count];
844

            
845
    let minus_numer = tseytin_negate(
846
        &numer_bits.clone(),
847
        bit_count,
848
        &mut new_clauses,
849
        &mut new_symbols,
850
    );
851
    let minus_denom = tseytin_negate(
852
        &denom_bits.clone(),
853
        bit_count,
854
        &mut new_clauses,
855
        &mut new_symbols,
856
    );
857

            
858
    let sign_bit = tseytin_xor(
859
        numer_bits[bit_count - 1].clone(),
860
        denom_bits[bit_count - 1].clone(),
861
        &mut new_clauses,
862
        &mut new_symbols,
863
    );
864

            
865
    let numer_bits = tseytin_select_array(
866
        numer_bits[bit_count - 1].clone(),
867
        &numer_bits.clone(),
868
        &minus_numer,
869
        &mut new_clauses,
870
        &mut new_symbols,
871
    );
872
    let denom_bits = tseytin_select_array(
873
        denom_bits[bit_count - 1].clone(),
874
        &denom_bits.clone(),
875
        &minus_denom,
876
        &mut new_clauses,
877
        &mut new_symbols,
878
    );
879

            
880
    let mut r = numer_bits;
881
    r.extend(std::iter::repeat_n(r[bit_count - 1].clone(), bit_count));
882
    let mut d = std::iter::repeat_n(false.into(), bit_count).collect_vec();
883
    d.extend(denom_bits);
884

            
885
    let minus_d = tseytin_negate(
886
        &d.clone(),
887
        2 * bit_count,
888
        &mut new_clauses,
889
        &mut new_symbols,
890
    );
891
    let mut rminusd;
892

            
893
    for i in (0..bit_count).rev() {
894
        // r << 1
895
        for j in (1..bit_count * 2).rev() {
896
            r[j] = r[j - 1].clone();
897
        }
898
        r[0] = false.into();
899

            
900
        rminusd = tseytin_int_adder(
901
            &r.clone(),
902
            &minus_d.clone(),
903
            2 * bit_count,
904
            &mut new_clauses,
905
            &mut new_symbols,
906
        );
907

            
908
        // TODO: For mod don't calculate on final iter
909
        quotient[i] = tseytin_not(
910
            // q[i] = inverse of sign bit - 1 if positive, 0 if negative
911
            rminusd[2 * bit_count - 1].clone(),
912
            &mut new_clauses,
913
            &mut new_symbols,
914
        );
915

            
916
        // TODO: For div don't calculate on final iter
917
        for j in 0..(2 * bit_count) {
918
            r[j] = tseytin_mux(
919
                quotient[i].clone(),
920
                r[j].clone(),       // use r if negative
921
                rminusd[j].clone(), // use r-d if positive
922
                &mut new_clauses,
923
                &mut new_symbols,
924
            );
925
        }
926
    }
927

            
928
    let minus_quotient = tseytin_negate(
929
        &quotient.clone(),
930
        bit_count,
931
        &mut new_clauses,
932
        &mut new_symbols,
933
    );
934

            
935
    let out = tseytin_select_array(
936
        sign_bit,
937
        &quotient,
938
        &minus_quotient,
939
        &mut new_clauses,
940
        &mut new_symbols,
941
    );
942

            
943
    Ok(Reduction::cnf(
944
        Expr::SATInt(
945
            Metadata::new(),
946
            SATIntEncoding::Log,
947
            Moo::new(into_matrix_expr!(out)),
948
            (min, max),
949
        ),
950
        new_clauses,
951
        new_symbols,
952
    ))
953
4392
}
954

            
955
/*
956
/// Converts SafeMod of SATInts to a single SATInt
957
///
958
/// ```text
959
/// SafeMod(SATInt(a), SATInt(b)) ~> SATInt(c)
960
///
961
/// ```
962
#[register_rule(("SAT_Log", 4100))]
963
fn cnf_int_safemod(expr: &Expr, _: &SymbolTable) -> ApplicationResult {}
964

            
965
/// Converts SafePow of SATInts to a single SATInt
966
///
967
/// ```text
968
/// SafePow(SATInt(a), SATInt(b)) ~> SATInt(c)
969
///
970
/// ```
971
#[register_rule(("SAT", 4100))]
972
fn cnf_int_safepow(expr: &Expr, _: &SymbolTable) -> ApplicationResult {
973
    // use 'Exponentiation by squaring'
974
}
975
*/