1
use conjure_cp::ast::Expression as Expr;
2
use conjure_cp::ast::{SATIntEncoding, SymbolTable};
3
use conjure_cp::rule_engine::{
4
    ApplicationError::RuleNotApplicable, ApplicationResult, Reduction, register_rule,
5
};
6

            
7
use conjure_cp::ast::AbstractLiteral::Matrix;
8
use conjure_cp::ast::Metadata;
9
use conjure_cp::ast::Moo;
10
use conjure_cp::into_matrix_expr;
11

            
12
use itertools::Itertools;
13

            
14
use super::boolean::{
15
    tseytin_and, tseytin_iff, tseytin_imply, tseytin_mux, tseytin_not, tseytin_or, tseytin_xor,
16
};
17
use super::integer_repr::{bit_magnitude, match_bits_length, validate_log_int_operands};
18

            
19
use conjure_cp::ast::CnfClause;
20

            
21
use std::cmp;
22

            
23
/// Converts an inequality expression between two SATInts to a boolean expression in cnf.
24
///
25
/// ```text
26
/// SATInt(a) </>/<=/>= SATInt(b) ~> Bool
27
///
28
/// ```
29
#[register_rule(("SAT_Log", 4100))]
30
3024
fn cnf_int_ineq(expr: &Expr, symbols: &SymbolTable) -> ApplicationResult {
31
3024
    let (lhs, rhs, strict) = match expr {
32
33
        Expr::Lt(_, x, y) => (y, x, true),
33
33
        Expr::Gt(_, x, y) => (x, y, true),
34
444
        Expr::Leq(_, x, y) => (y, x, false),
35
384
        Expr::Geq(_, x, y) => (x, y, false),
36
2130
        _ => return Err(RuleNotApplicable),
37
    };
38

            
39
804
    let binding =
40
894
        validate_log_int_operands(vec![lhs.as_ref().clone(), rhs.as_ref().clone()], None)?;
41
804
    let [lhs_bits, rhs_bits] = binding.as_slice() else {
42
        return Err(RuleNotApplicable);
43
    };
44

            
45
804
    let mut new_symbols = symbols.clone();
46
804
    let mut new_clauses = vec![];
47

            
48
804
    let output = inequality_boolean(
49
804
        lhs_bits.clone(),
50
804
        rhs_bits.clone(),
51
804
        strict,
52
804
        &mut new_clauses,
53
804
        &mut new_symbols,
54
    );
55
804
    Ok(Reduction::cnf(output, new_clauses, new_symbols))
56
3024
}
57

            
58
/// Converts a = expression between two SATInts to a boolean expression in cnf
59
///
60
/// ```text
61
/// SATInt(a) = SATInt(b) ~> Bool
62
///
63
/// ```
64
#[register_rule(("SAT_Log", 9100))]
65
172560
fn cnf_int_eq(expr: &Expr, symbols: &SymbolTable) -> ApplicationResult {
66
172560
    let Expr::Eq(_, lhs, rhs) = expr else {
67
171792
        return Err(RuleNotApplicable);
68
    };
69

            
70
156
    let binding =
71
768
        validate_log_int_operands(vec![lhs.as_ref().clone(), rhs.as_ref().clone()], None)?;
72
156
    let [lhs_bits, rhs_bits] = binding.as_slice() else {
73
        return Err(RuleNotApplicable);
74
    };
75

            
76
156
    let bit_count = lhs_bits.len();
77

            
78
156
    let mut output = true.into();
79
156
    let mut new_symbols = symbols.clone();
80
156
    let mut new_clauses = vec![];
81
    let mut comparison;
82

            
83
912
    for i in 0..bit_count {
84
912
        comparison = tseytin_iff(
85
912
            lhs_bits[i].clone(),
86
912
            rhs_bits[i].clone(),
87
912
            &mut new_clauses,
88
912
            &mut new_symbols,
89
912
        );
90
912
        output = tseytin_and(
91
912
            &vec![comparison, output],
92
912
            &mut new_clauses,
93
912
            &mut new_symbols,
94
912
        );
95
912
    }
96

            
97
156
    Ok(Reduction::cnf(output, new_clauses, new_symbols))
98
172560
}
99

            
100
/// Converts a != expression between two SATInts to a boolean expression in cnf
101
///
102
/// ```text
103
/// SATInt(a) != SATInt(b) ~> Bool
104
///
105
/// ```
106
#[register_rule(("SAT_Log", 4100))]
107
3024
fn cnf_int_neq(expr: &Expr, symbols: &SymbolTable) -> ApplicationResult {
108
3024
    let Expr::Neq(_, lhs, rhs) = expr else {
109
2934
        return Err(RuleNotApplicable);
110
    };
111

            
112
90
    let binding =
113
90
        validate_log_int_operands(vec![lhs.as_ref().clone(), rhs.as_ref().clone()], None)?;
114
90
    let [lhs_bits, rhs_bits] = binding.as_slice() else {
115
        return Err(RuleNotApplicable);
116
    };
117

            
118
90
    let bit_count = lhs_bits.len();
119

            
120
90
    let mut output = false.into();
121
90
    let mut new_symbols = symbols.clone();
122
90
    let mut new_clauses = vec![];
123
    let mut comparison;
124

            
125
288
    for i in 0..bit_count {
126
288
        comparison = tseytin_xor(
127
288
            lhs_bits[i].clone(),
128
288
            rhs_bits[i].clone(),
129
288
            &mut new_clauses,
130
288
            &mut new_symbols,
131
288
        );
132
288
        output = tseytin_or(
133
288
            &vec![comparison, output],
134
288
            &mut new_clauses,
135
288
            &mut new_symbols,
136
288
        );
137
288
    }
138

            
139
90
    Ok(Reduction::cnf(output, new_clauses, new_symbols))
140
3024
}
141

            
142
// Creates a boolean expression for > or >=
143
// a > b or a >= b
144
// This can also be used for < and <= by reversing the order of the inputs
145
// Returns result, new symbol table, new clauses
146
852
fn inequality_boolean(
147
852
    a: Vec<Expr>,
148
852
    b: Vec<Expr>,
149
852
    strict: bool,
150
852
    clauses: &mut Vec<CnfClause>,
151
852
    symbols: &mut SymbolTable,
152
852
) -> Expr {
153
    let mut notb;
154
    let mut output;
155

            
156
852
    if strict {
157
102
        notb = tseytin_not(b[0].clone(), clauses, symbols);
158
102
        output = tseytin_and(&vec![a[0].clone(), notb], clauses, symbols);
159
750
    } else {
160
750
        output = tseytin_imply(b[0].clone(), a[0].clone(), clauses, symbols);
161
750
    }
162

            
163
    //TODO: There may be room for simplification, and constant optimization
164

            
165
852
    let bit_count = a.len();
166

            
167
    let mut lhs;
168
    let mut rhs;
169
    let mut iff;
170
1800
    for n in 1..(bit_count - 1) {
171
1800
        notb = tseytin_not(b[n].clone(), clauses, symbols);
172
1800
        lhs = tseytin_and(&vec![a[n].clone(), notb.clone()], clauses, symbols);
173
1800
        iff = tseytin_iff(a[n].clone(), b[n].clone(), clauses, symbols);
174
1800
        rhs = tseytin_and(&vec![iff.clone(), output.clone()], clauses, symbols);
175
1800
        output = tseytin_or(&vec![lhs.clone(), rhs.clone()], clauses, symbols);
176
1800
    }
177

            
178
    // final bool is the sign bit and should be handled inversely
179
852
    let nota = tseytin_not(a[bit_count - 1].clone(), clauses, symbols);
180
852
    lhs = tseytin_and(&vec![nota, b[bit_count - 1].clone()], clauses, symbols);
181
852
    iff = tseytin_iff(
182
852
        a[bit_count - 1].clone(),
183
852
        b[bit_count - 1].clone(),
184
852
        clauses,
185
852
        symbols,
186
    );
187
852
    rhs = tseytin_and(&vec![iff, output.clone()], clauses, symbols);
188
852
    output = tseytin_or(&vec![lhs, rhs], clauses, symbols);
189

            
190
852
    output
191
852
}
192

            
193
/// Converts sum of SATInts to a single SATInt
194
///
195
/// ```text
196
/// Sum(SATInt(a), SATInt(b), ...) ~> SATInt(c)
197
///
198
/// ```
199
#[register_rule(("SAT_Log", 4100))]
200
3024
fn cnf_int_sum(expr: &Expr, symbols: &SymbolTable) -> ApplicationResult {
201
3024
    let Expr::Sum(_, exprs) = expr else {
202
2964
        return Err(RuleNotApplicable);
203
    };
204

            
205
60
    let Expr::AbstractLiteral(_, Matrix(exprs_list, _)) = exprs.as_ref() else {
206
        return Err(RuleNotApplicable);
207
    };
208

            
209
60
    let ranges: Result<Vec<_>, _> = exprs_list
210
60
        .iter()
211
120
        .map(|e| match e {
212
90
            Expr::SATInt(_, _, _, x) => Ok(x),
213
30
            _ => Err(RuleNotApplicable),
214
120
        })
215
60
        .collect();
216

            
217
60
    let ranges = ranges?;
218

            
219
30
    let min = ranges.iter().map(|(a, _)| *a).sum();
220
30
    let max = ranges.iter().map(|(_, a)| *a).sum();
221

            
222
30
    let output_size = cmp::max(bit_magnitude(min), bit_magnitude(max));
223

            
224
    // Check operands are valid log ints
225
30
    let mut exprs_bits =
226
30
        validate_log_int_operands(exprs_list.clone(), Some(output_size.try_into().unwrap()))?;
227

            
228
30
    let mut new_symbols = symbols.clone();
229
    let mut values;
230
30
    let mut new_clauses = vec![];
231

            
232
66
    while exprs_bits.len() > 1 {
233
36
        let mut next = Vec::with_capacity(exprs_bits.len().div_ceil(2));
234
36
        let mut iter = exprs_bits.into_iter();
235

            
236
78
        while let Some(a) = iter.next() {
237
42
            if let Some(b) = iter.next() {
238
36
                values = tseytin_int_adder(&a, &b, output_size, &mut new_clauses, &mut new_symbols);
239
36
                next.push(values);
240
36
            } else {
241
6
                next.push(a);
242
6
            }
243
        }
244

            
245
36
        exprs_bits = next;
246
    }
247

            
248
30
    let result = exprs_bits.pop().unwrap();
249

            
250
30
    Ok(Reduction::cnf(
251
30
        Expr::SATInt(
252
30
            Metadata::new(),
253
30
            SATIntEncoding::Log,
254
30
            Moo::new(into_matrix_expr!(result)),
255
30
            (min, max),
256
30
        ),
257
30
        new_clauses,
258
30
        new_symbols,
259
30
    ))
260
3024
}
261

            
262
/// Returns result, new symbol table, new clauses
263
/// This function expects bits to match the lengths of x and y
264
66
fn tseytin_int_adder(
265
66
    x: &[Expr],
266
66
    y: &[Expr],
267
66
    bits: usize,
268
66
    clauses: &mut Vec<CnfClause>,
269
66
    symbols: &mut SymbolTable,
270
66
) -> Vec<Expr> {
271
    //TODO: Optimizing for constants
272
66
    let (mut result, mut carry) = tseytin_half_adder(x[0].clone(), y[0].clone(), clauses, symbols);
273

            
274
66
    let mut output = vec![result];
275
510
    for i in 1..bits {
276
510
        (result, carry) =
277
510
            tseytin_full_adder(x[i].clone(), y[i].clone(), carry.clone(), clauses, symbols);
278
510
        output.push(result);
279
510
    }
280

            
281
66
    output
282
66
}
283

            
284
/// This function adds two booleans and a carry boolean using the full-adder logic circuit, it is intended for use in a binary adder.
285
510
fn tseytin_full_adder(
286
510
    a: Expr,
287
510
    b: Expr,
288
510
    carry: Expr,
289
510
    clauses: &mut Vec<CnfClause>,
290
510
    symbols: &mut SymbolTable,
291
510
) -> (Expr, Expr) {
292
510
    let axorb = tseytin_xor(a.clone(), b.clone(), clauses, symbols);
293
510
    let result = tseytin_xor(axorb.clone(), carry.clone(), clauses, symbols);
294
510
    let aandb = tseytin_and(&vec![a, b], clauses, symbols);
295
510
    let carryandaxorb = tseytin_and(&vec![carry, axorb], clauses, symbols);
296
510
    let carryout = tseytin_or(&vec![aandb, carryandaxorb], clauses, symbols);
297

            
298
510
    (result, carryout)
299
510
}
300

            
301
/// This function adds two booleans using the half-adder logic circuit, it is intended for use in a binary adder.
302
66
fn tseytin_half_adder(
303
66
    a: Expr,
304
66
    b: Expr,
305
66
    clauses: &mut Vec<CnfClause>,
306
66
    symbols: &mut SymbolTable,
307
66
) -> (Expr, Expr) {
308
66
    let result = tseytin_xor(a.clone(), b.clone(), clauses, symbols);
309
66
    let carry = tseytin_and(&vec![a, b], clauses, symbols);
310

            
311
66
    (result, carry)
312
66
}
313

            
314
/// this function is for specifically adding a power of two constant to a cnf int.
315
54
fn tseytin_add_two_power(
316
54
    expr: &[Expr],
317
54
    exponent: usize,
318
54
    bits: usize,
319
54
    clauses: &mut Vec<CnfClause>,
320
54
    symbols: &mut SymbolTable,
321
54
) -> Vec<Expr> {
322
54
    let mut result = vec![];
323
54
    let mut product = expr[exponent].clone();
324

            
325
54
    for item in expr.iter().take(exponent) {
326
        result.push(item.clone());
327
    }
328

            
329
54
    result.push(tseytin_not(expr[exponent].clone(), clauses, symbols));
330

            
331
132
    for item in expr.iter().take(bits).skip(exponent + 1) {
332
132
        result.push(tseytin_xor(product.clone(), item.clone(), clauses, symbols));
333
132
        product = tseytin_and(&vec![product, item.clone()], clauses, symbols);
334
132
    }
335

            
336
54
    result
337
54
}
338

            
339
/// This function multiplies two binary values using the shift-add multiplication algorithm.
340
6
fn cnf_shift_add_multiply(
341
6
    x: &[Expr],
342
6
    y: &[Expr],
343
6
    bits: usize,
344
6
    clauses: &mut Vec<CnfClause>,
345
6
    symbols: &mut SymbolTable,
346
6
) -> Vec<Expr> {
347
6
    let mut x = x.to_owned();
348
6
    let mut y = y.to_owned();
349

            
350
    //TODO Optimizing for constants
351
    //TODO Optimize addition for i left shifted values - skip first i bits
352

            
353
    // extend sign bits of operands to 2*`bits`
354
6
    x.extend(std::iter::repeat_n(x[bits - 1].clone(), bits));
355
6
    y.extend(std::iter::repeat_n(y[bits - 1].clone(), bits));
356

            
357
6
    let mut s: Vec<Expr> = vec![];
358
    let mut x_0andy_i;
359

            
360
72
    for bit in &y {
361
72
        x_0andy_i = tseytin_and(&vec![x[0].clone(), bit.clone()], clauses, symbols);
362
72
        s.push(x_0andy_i);
363
72
    }
364

            
365
    let mut sum;
366
    let mut if_true;
367
    let mut not_x_n;
368
    let mut if_false;
369

            
370
30
    for item in x.iter().take(bits).skip(1) {
371
        // y << 1
372
330
        for i in (1..bits * 2).rev() {
373
330
            y[i] = y[i - 1].clone();
374
330
        }
375
30
        y[0] = false.into();
376

            
377
        // TODO switch to multiplexer
378
        // TODO Add negatives support once MUX is added
379
30
        sum = tseytin_int_adder(&s, &y, bits * 2, clauses, symbols);
380
30
        not_x_n = tseytin_not(item.clone(), clauses, symbols);
381

            
382
360
        for i in 0..(bits * 2) {
383
360
            if_true = tseytin_and(&vec![item.clone(), sum[i].clone()], clauses, symbols);
384
360
            if_false = tseytin_and(&vec![not_x_n.clone(), s[i].clone()], clauses, symbols);
385
360
            s[i] = tseytin_or(&vec![if_true.clone(), if_false.clone()], clauses, symbols);
386
360
        }
387
    }
388

            
389
6
    s
390
6
}
391

            
392
/// This function calculates the range of the product of multiple integers.
393
/// E.g.
394
/// a : [2, 5], b : [-1, 2], c : [-10, -6], d : [0, 3]
395
/// a * b * c *d : [-300, 150]
396
6
fn product_of_ranges(ranges: Vec<&(i32, i32)>) -> (i32, i32) {
397
6
    if ranges.is_empty() {
398
        return (1, 1); // product of zero numbers = 1
399
6
    }
400

            
401
6
    let &(mut min_prod, mut max_prod) = ranges[0];
402

            
403
6
    for &(a, b) in &ranges[1..] {
404
6
        let candidates = [min_prod * a, min_prod * b, max_prod * a, max_prod * b];
405
6
        min_prod = *candidates.iter().min().unwrap();
406
6
        max_prod = *candidates.iter().max().unwrap();
407
6
    }
408

            
409
6
    (min_prod, max_prod)
410
6
}
411

            
412
/// Converts product of SATInts to a single SATInt
413
///
414
/// ```text
415
/// Product(SATInt(a), SATInt(b), ...) ~> SATInt(c)
416
///
417
/// ```
418
#[register_rule(("SAT_Log", 9000))]
419
123003
fn cnf_int_product(expr: &Expr, symbols: &SymbolTable) -> ApplicationResult {
420
123003
    let Expr::Product(_, exprs) = expr else {
421
122997
        return Err(RuleNotApplicable);
422
    };
423

            
424
6
    let Expr::AbstractLiteral(_, Matrix(exprs_list, _)) = exprs.as_ref() else {
425
        return Err(RuleNotApplicable);
426
    };
427

            
428
6
    let ranges: Result<Vec<_>, _> = exprs_list
429
6
        .iter()
430
12
        .map(|e| match e {
431
12
            Expr::SATInt(_, _, _, x) => Ok(x),
432
            _ => Err(RuleNotApplicable),
433
12
        })
434
6
        .collect();
435

            
436
6
    let ranges = ranges?; // propagate error if any
437

            
438
6
    let (min, max) = product_of_ranges(ranges.clone());
439

            
440
6
    let exprs_bits = validate_log_int_operands(exprs_list.clone(), None)?;
441

            
442
6
    let mut new_symbols = symbols.clone();
443
6
    let mut new_clauses = vec![];
444

            
445
6
    let (result, _) = exprs_bits
446
6
        .iter()
447
6
        .cloned()
448
6
        .zip(ranges.into_iter().copied())
449
6
        .reduce(|lhs, rhs| {
450
            // Make both bit vectors the same length
451
6
            let (lhs_bits, rhs_bits) = match_bits_length(lhs.0.clone(), rhs.0.clone());
452

            
453
            // Multiply operands
454
6
            let mut values = cnf_shift_add_multiply(
455
6
                &lhs_bits,
456
6
                &rhs_bits,
457
6
                lhs_bits.len(),
458
6
                &mut new_clauses,
459
6
                &mut new_symbols,
460
            );
461

            
462
            // Determine new range of result
463
6
            let (mut cum_min, mut cum_max) = lhs.1;
464
6
            let candidates = [
465
6
                cum_min * rhs.1.0,
466
6
                cum_min * rhs.1.1,
467
6
                cum_max * rhs.1.0,
468
6
                cum_max * rhs.1.1,
469
6
            ];
470
6
            cum_min = *candidates.iter().min().unwrap();
471
6
            cum_max = *candidates.iter().max().unwrap();
472

            
473
6
            let new_bit_count = bit_magnitude(cum_min).max(bit_magnitude(cum_max));
474
6
            values.truncate(new_bit_count);
475

            
476
6
            (values, (cum_min, cum_max))
477
6
        })
478
6
        .unwrap();
479

            
480
6
    Ok(Reduction::cnf(
481
6
        Expr::SATInt(
482
6
            Metadata::new(),
483
6
            SATIntEncoding::Log,
484
6
            Moo::new(into_matrix_expr!(result)),
485
6
            (min, max),
486
6
        ),
487
6
        new_clauses,
488
6
        new_symbols,
489
6
    ))
490
123003
}
491

            
492
/// Converts negation of a SATInt to a SATInt
493
///
494
/// ```text
495
/// -SATInt(a) ~> SATInt(b)
496
///
497
/// ```
498
#[register_rule(("SAT_Log", 4100))]
499
3024
fn cnf_int_neg(expr: &Expr, symbols: &SymbolTable) -> ApplicationResult {
500
3024
    let Expr::Neg(_, expr) = expr else {
501
2982
        return Err(RuleNotApplicable);
502
    };
503

            
504
42
    let Expr::SATInt(_, _, _, (min, max)) = expr.as_ref() else {
505
        return Err(RuleNotApplicable);
506
    };
507

            
508
42
    let binding = validate_log_int_operands(vec![expr.as_ref().clone()], None)?;
509
42
    let [bits] = binding.as_slice() else {
510
        return Err(RuleNotApplicable);
511
    };
512

            
513
42
    let mut new_clauses = vec![];
514
42
    let mut new_symbols = symbols.clone();
515

            
516
42
    let result = tseytin_negate(bits, bits.len(), &mut new_clauses, &mut new_symbols);
517

            
518
42
    Ok(Reduction::cnf(
519
42
        Expr::SATInt(
520
42
            Metadata::new(),
521
42
            SATIntEncoding::Log,
522
42
            Moo::new(into_matrix_expr!(result)),
523
42
            (-max, -min),
524
42
        ),
525
42
        new_clauses,
526
42
        new_symbols,
527
42
    ))
528
3024
}
529

            
530
42
fn tseytin_negate(
531
42
    expr: &Vec<Expr>,
532
42
    bits: usize,
533
42
    clauses: &mut Vec<CnfClause>,
534
42
    symbols: &mut SymbolTable,
535
42
) -> Vec<Expr> {
536
42
    let mut result = vec![];
537
    // invert bits
538
126
    for bit in expr {
539
126
        result.push(tseytin_not(bit.clone(), clauses, symbols));
540
126
    }
541

            
542
    // add one
543
42
    result = tseytin_add_two_power(&result, 0, bits, clauses, symbols);
544

            
545
42
    result
546
42
}
547

            
548
/// Converts min of SATInts to a single SATInt
549
///
550
/// ```text
551
/// Min(SATInt(a), SATInt(b), ...) ~> SATInt(c)
552
///
553
/// ```
554
#[register_rule(("SAT_Log", 4100))]
555
3024
fn cnf_int_min(expr: &Expr, symbols: &SymbolTable) -> ApplicationResult {
556
3024
    let Expr::Min(_, exprs) = expr else {
557
3000
        return Err(RuleNotApplicable);
558
    };
559

            
560
24
    let Expr::AbstractLiteral(_, Matrix(exprs_list, _)) = exprs.as_ref() else {
561
        return Err(RuleNotApplicable);
562
    };
563

            
564
24
    let ranges: Result<Vec<_>, _> = exprs_list
565
24
        .iter()
566
48
        .map(|e| match e {
567
48
            Expr::SATInt(_, _, _, x) => Ok(x),
568
            _ => Err(RuleNotApplicable),
569
48
        })
570
24
        .collect();
571

            
572
24
    let ranges = ranges?; // propagate error if any
573

            
574
    // Is this optimal?
575
24
    let min = ranges.iter().map(|(a, _)| *a).min().unwrap();
576
24
    let max = ranges.iter().map(|(_, b)| *b).min().unwrap();
577

            
578
24
    let mut exprs_bits = validate_log_int_operands(exprs_list.clone(), None)?;
579

            
580
24
    let mut new_symbols = symbols.clone();
581
    let mut values;
582
24
    let mut new_clauses = vec![];
583

            
584
48
    while exprs_bits.len() > 1 {
585
24
        let mut next = Vec::with_capacity(exprs_bits.len().div_ceil(2));
586
24
        let mut iter = exprs_bits.into_iter();
587

            
588
48
        while let Some(a) = iter.next() {
589
24
            if let Some(b) = iter.next() {
590
24
                values = tseytin_binary_min_max(&a, &b, true, &mut new_clauses, &mut new_symbols);
591
24
                next.push(values);
592
24
            } else {
593
                next.push(a);
594
            }
595
        }
596

            
597
24
        exprs_bits = next;
598
    }
599

            
600
24
    let result = exprs_bits.pop().unwrap();
601

            
602
24
    Ok(Reduction::cnf(
603
24
        Expr::SATInt(
604
24
            Metadata::new(),
605
24
            SATIntEncoding::Log,
606
24
            Moo::new(into_matrix_expr!(result)),
607
24
            (min, max),
608
24
        ),
609
24
        new_clauses,
610
24
        new_symbols,
611
24
    ))
612
3024
}
613

            
614
/// General function for getting the min or max of two log integers.
615
48
fn tseytin_binary_min_max(
616
48
    x: &[Expr],
617
48
    y: &[Expr],
618
48
    min: bool,
619
48
    clauses: &mut Vec<CnfClause>,
620
48
    symbols: &mut SymbolTable,
621
48
) -> Vec<Expr> {
622
48
    let mask = if min {
623
        // mask is 1 if x > y
624
24
        inequality_boolean(x.to_owned(), y.to_owned(), true, clauses, symbols)
625
    } else {
626
        // flip the args if getting maximum x < y -> 1
627
24
        inequality_boolean(y.to_owned(), x.to_owned(), true, clauses, symbols)
628
    };
629

            
630
48
    tseytin_select_array(mask, x, y, clauses, symbols)
631
48
}
632

            
633
// Selects between two boolean vectors depending on a condition (both vectors must be the same length)
634
/// cond ? b : a
635
///
636
/// cond = 1 => b
637
/// cond = 0 => a
638
48
fn tseytin_select_array(
639
48
    cond: Expr,
640
48
    a: &[Expr],
641
48
    b: &[Expr],
642
48
    clauses: &mut Vec<CnfClause>,
643
48
    symbols: &mut SymbolTable,
644
48
) -> Vec<Expr> {
645
48
    assert_eq!(
646
48
        a.len(),
647
48
        b.len(),
648
        "Input vectors 'a' and 'b' must have the same length"
649
    );
650

            
651
48
    let mut out = vec![];
652

            
653
48
    let bit_count = a.len();
654

            
655
180
    for i in 0..bit_count {
656
180
        out.push(tseytin_mux(
657
180
            cond.clone(),
658
180
            a[i].clone(),
659
180
            b[i].clone(),
660
180
            clauses,
661
180
            symbols,
662
180
        ));
663
180
    }
664

            
665
48
    out
666
48
}
667

            
668
/// Converts max of SATInts to a single SATInt
669
///
670
/// ```text
671
/// Max(SATInt(a), SATInt(b), ...) ~> SATInt(c)
672
///
673
/// ```
674
#[register_rule(("SAT_Log", 4100))]
675
3024
fn cnf_int_max(expr: &Expr, symbols: &SymbolTable) -> ApplicationResult {
676
3024
    let Expr::Max(_, exprs) = expr else {
677
3000
        return Err(RuleNotApplicable);
678
    };
679

            
680
24
    let Expr::AbstractLiteral(_, Matrix(exprs_list, _)) = exprs.as_ref() else {
681
        return Err(RuleNotApplicable);
682
    };
683

            
684
24
    let ranges: Result<Vec<_>, _> = exprs_list
685
24
        .iter()
686
48
        .map(|e| match e {
687
48
            Expr::SATInt(_, _, _, x) => Ok(x),
688
            _ => Err(RuleNotApplicable),
689
48
        })
690
24
        .collect();
691

            
692
24
    let ranges = ranges?; // propagate error if any
693

            
694
    // Is this optimal?
695
24
    let min = ranges.iter().map(|(a, _)| *a).max().unwrap();
696
24
    let max = ranges.iter().map(|(_, b)| *b).max().unwrap();
697

            
698
24
    let mut exprs_bits = validate_log_int_operands(exprs_list.clone(), None)?;
699

            
700
24
    let mut new_symbols = symbols.clone();
701
    let mut values;
702
24
    let mut new_clauses = vec![];
703

            
704
48
    while exprs_bits.len() > 1 {
705
24
        let mut next = Vec::with_capacity(exprs_bits.len().div_ceil(2));
706
24
        let mut iter = exprs_bits.into_iter();
707

            
708
48
        while let Some(a) = iter.next() {
709
24
            if let Some(b) = iter.next() {
710
24
                values = tseytin_binary_min_max(&a, &b, false, &mut new_clauses, &mut new_symbols);
711
24
                next.push(values);
712
24
            } else {
713
                next.push(a);
714
            }
715
        }
716

            
717
24
        exprs_bits = next;
718
    }
719

            
720
24
    let result = exprs_bits.pop().unwrap();
721

            
722
24
    Ok(Reduction::cnf(
723
24
        Expr::SATInt(
724
24
            Metadata::new(),
725
24
            SATIntEncoding::Log,
726
24
            Moo::new(into_matrix_expr!(result)),
727
24
            (min, max),
728
24
        ),
729
24
        new_clauses,
730
24
        new_symbols,
731
24
    ))
732
3024
}
733

            
734
/// Converts Abs of a SATInt to a SATInt
735
///
736
/// ```text
737
/// |SATInt(a)| ~> SATInt(b)
738
///
739
/// ```
740
#[register_rule(("SAT_Log", 4100))]
741
3024
fn cnf_int_abs(expr: &Expr, symbols: &SymbolTable) -> ApplicationResult {
742
3024
    let Expr::Abs(_, expr) = expr else {
743
3000
        return Err(RuleNotApplicable);
744
    };
745

            
746
24
    let Expr::SATInt(_, _, _, (min, max)) = expr.as_ref() else {
747
12
        return Err(RuleNotApplicable);
748
    };
749

            
750
12
    let range = (
751
12
        cmp::max(0, cmp::max(*min, -*max)),
752
12
        cmp::max(min.abs(), max.abs()),
753
12
    );
754

            
755
12
    let binding = validate_log_int_operands(vec![expr.as_ref().clone()], None)?;
756
12
    let [bits] = binding.as_slice() else {
757
        return Err(RuleNotApplicable);
758
    };
759

            
760
12
    let mut new_clauses = vec![];
761
12
    let mut new_symbols = symbols.clone();
762

            
763
12
    let mut result = vec![];
764

            
765
    // How does this handle negatives edge cases: -(-8) = 8, an extra bit is needed
766

            
767
    // invert bits
768
60
    for bit in bits {
769
60
        result.push(tseytin_not(bit.clone(), &mut new_clauses, &mut new_symbols));
770
60
    }
771

            
772
12
    let bit_count = result.len();
773

            
774
    // add one
775
12
    result = tseytin_add_two_power(&result, 0, bit_count, &mut new_clauses, &mut new_symbols);
776

            
777
60
    for i in 0..bit_count {
778
60
        result[i] = tseytin_mux(
779
60
            bits[bit_count - 1].clone(),
780
60
            bits[i].clone(),
781
60
            result[i].clone(),
782
60
            &mut new_clauses,
783
60
            &mut new_symbols,
784
60
        )
785
    }
786

            
787
12
    Ok(Reduction::cnf(
788
12
        Expr::SATInt(
789
12
            Metadata::new(),
790
12
            SATIntEncoding::Log,
791
12
            Moo::new(into_matrix_expr!(result)),
792
12
            range,
793
12
        ),
794
12
        new_clauses,
795
12
        new_symbols,
796
12
    ))
797
3024
}
798

            
799
/// Converts SafeDiv of SATInts to a single SATInt
800
///
801
/// ```text
802
/// SafeDiv(SATInt(a), SATInt(b)) ~> SATInt(c)
803
///
804
/// ```
805
#[register_rule(("SAT_Log", 4100))]
806
3024
fn cnf_int_safediv(expr: &Expr, symbols: &SymbolTable) -> ApplicationResult {
807
    // Using "Restoring division" algorithm
808
    // https://en.wikipedia.org/wiki/Division_algorithm#Restoring_division
809
3024
    let Expr::SafeDiv(_, numer, denom) = expr else {
810
3024
        return Err(RuleNotApplicable);
811
    };
812

            
813
    let Expr::SATInt(_, _, _, (numer_min, numer_max)) = numer.as_ref() else {
814
        return Err(RuleNotApplicable);
815
    };
816

            
817
    let Expr::SATInt(_, _, _, (denom_min, denom_max)) = denom.as_ref() else {
818
        return Err(RuleNotApplicable);
819
    };
820

            
821
    let candidates = [
822
        numer_min / denom_min,
823
        numer_min / denom_max,
824
        numer_max / denom_min,
825
        numer_max / denom_max,
826
    ];
827

            
828
    let min = *candidates.iter().min().unwrap();
829
    let max = *candidates.iter().max().unwrap();
830

            
831
    let binding =
832
        validate_log_int_operands(vec![numer.as_ref().clone(), denom.as_ref().clone()], None)?;
833
    let [numer_bits, denom_bits] = binding.as_slice() else {
834
        return Err(RuleNotApplicable);
835
    };
836

            
837
    let bit_count = numer_bits.len();
838

            
839
    // TODO: Separate into division/mod function
840

            
841
    let mut new_symbols = symbols.clone();
842
    let mut new_clauses = vec![];
843
    let mut quotient = vec![false.into(); bit_count];
844

            
845
    let minus_numer = tseytin_negate(
846
        &numer_bits.clone(),
847
        bit_count,
848
        &mut new_clauses,
849
        &mut new_symbols,
850
    );
851
    let minus_denom = tseytin_negate(
852
        &denom_bits.clone(),
853
        bit_count,
854
        &mut new_clauses,
855
        &mut new_symbols,
856
    );
857

            
858
    let sign_bit = tseytin_xor(
859
        numer_bits[bit_count - 1].clone(),
860
        denom_bits[bit_count - 1].clone(),
861
        &mut new_clauses,
862
        &mut new_symbols,
863
    );
864

            
865
    let numer_bits = tseytin_select_array(
866
        numer_bits[bit_count - 1].clone(),
867
        &numer_bits.clone(),
868
        &minus_numer,
869
        &mut new_clauses,
870
        &mut new_symbols,
871
    );
872
    let denom_bits = tseytin_select_array(
873
        denom_bits[bit_count - 1].clone(),
874
        &denom_bits.clone(),
875
        &minus_denom,
876
        &mut new_clauses,
877
        &mut new_symbols,
878
    );
879

            
880
    let mut r = numer_bits;
881
    r.extend(std::iter::repeat_n(r[bit_count - 1].clone(), bit_count));
882
    let mut d = std::iter::repeat_n(false.into(), bit_count).collect_vec();
883
    d.extend(denom_bits);
884

            
885
    let minus_d = tseytin_negate(
886
        &d.clone(),
887
        2 * bit_count,
888
        &mut new_clauses,
889
        &mut new_symbols,
890
    );
891
    let mut rminusd;
892

            
893
    for i in (0..bit_count).rev() {
894
        // r << 1
895
        for j in (1..bit_count * 2).rev() {
896
            r[j] = r[j - 1].clone();
897
        }
898
        r[0] = false.into();
899

            
900
        rminusd = tseytin_int_adder(
901
            &r.clone(),
902
            &minus_d.clone(),
903
            2 * bit_count,
904
            &mut new_clauses,
905
            &mut new_symbols,
906
        );
907

            
908
        // TODO: For mod don't calculate on final iter
909
        quotient[i] = tseytin_not(
910
            // q[i] = inverse of sign bit - 1 if positive, 0 if negative
911
            rminusd[2 * bit_count - 1].clone(),
912
            &mut new_clauses,
913
            &mut new_symbols,
914
        );
915

            
916
        // TODO: For div don't calculate on final iter
917
        for j in 0..(2 * bit_count) {
918
            r[j] = tseytin_mux(
919
                quotient[i].clone(),
920
                r[j].clone(),       // use r if negative
921
                rminusd[j].clone(), // use r-d if positive
922
                &mut new_clauses,
923
                &mut new_symbols,
924
            );
925
        }
926
    }
927

            
928
    let minus_quotient = tseytin_negate(
929
        &quotient.clone(),
930
        bit_count,
931
        &mut new_clauses,
932
        &mut new_symbols,
933
    );
934

            
935
    let out = tseytin_select_array(
936
        sign_bit,
937
        &quotient,
938
        &minus_quotient,
939
        &mut new_clauses,
940
        &mut new_symbols,
941
    );
942

            
943
    Ok(Reduction::cnf(
944
        Expr::SATInt(
945
            Metadata::new(),
946
            SATIntEncoding::Log,
947
            Moo::new(into_matrix_expr!(out)),
948
            (min, max),
949
        ),
950
        new_clauses,
951
        new_symbols,
952
    ))
953
3024
}
954

            
955
/*
956
/// Converts SafeMod of SATInts to a single SATInt
957
///
958
/// ```text
959
/// SafeMod(SATInt(a), SATInt(b)) ~> SATInt(c)
960
///
961
/// ```
962
#[register_rule(("SAT_Log", 4100))]
963
fn cnf_int_safemod(expr: &Expr, _: &SymbolTable) -> ApplicationResult {}
964

            
965
/// Converts SafePow of SATInts to a single SATInt
966
///
967
/// ```text
968
/// SafePow(SATInt(a), SATInt(b)) ~> SATInt(c)
969
///
970
/// ```
971
#[register_rule(("SAT", 4100))]
972
fn cnf_int_safepow(expr: &Expr, _: &SymbolTable) -> ApplicationResult {
973
    // use 'Exponentiation by squaring'
974
}
975
*/