1
use conjure_cp::ast::{Atom, Expression as Expr, Literal};
2
use conjure_cp::ast::{SATIntEncoding, SymbolTable};
3
use conjure_cp::rule_engine::ApplicationError;
4
use conjure_cp::rule_engine::{
5
    ApplicationError::RuleNotApplicable, ApplicationResult, Reduction, register_rule,
6
};
7

            
8
use crate::sat::boolean::{tseytin_and, tseytin_iff, tseytin_not, tseytin_or};
9
use conjure_cp::ast::Metadata;
10
use conjure_cp::ast::Moo;
11
use conjure_cp::into_matrix_expr;
12

            
13
/// This function confirms that all of the input expressions are order SATInts, and returns vectors for each input of their bits
14
/// This function also normalizes order SATInt operands to a common value range.
15
2508
pub fn validate_order_int_operands(
16
2508
    exprs: Vec<Expr>,
17
2508
) -> Result<(Vec<Vec<Expr>>, i32, i32), ApplicationError> {
18
    // Iterate over all inputs
19
    // Check they are order and calulate a lower and upper bound
20
2508
    let mut global_min: i32 = i32::MAX;
21
2508
    let mut global_max: i32 = i32::MIN;
22

            
23
4578
    for operand in &exprs {
24
3204
        let Expr::SATInt(_, SATIntEncoding::Order, _, (local_min, local_max)) = operand else {
25
1374
            return Err(RuleNotApplicable);
26
        };
27
3204
        global_min = global_min.min(*local_min);
28
3204
        global_max = global_max.max(*local_max);
29
    }
30

            
31
    // build out by iterating over each operand and expanding it to match the new bounds
32
1134
    let out: Vec<Vec<Expr>> = exprs
33
1134
        .into_iter()
34
2256
        .map(|expr| {
35
2256
            let Expr::SATInt(_, SATIntEncoding::Order, inner, (local_min, local_max)) = expr else {
36
                return Err(RuleNotApplicable);
37
            };
38

            
39
2256
            let Some(v) = inner.as_ref().clone().unwrap_list() else {
40
                return Err(RuleNotApplicable);
41
            };
42

            
43
            // calulcate how many trues/falses to prepend/append
44
2256
            let prefix_len = (local_min - global_min) as usize;
45
2256
            let postfix_len = (global_max - local_max) as usize;
46

            
47
2256
            let mut bits = Vec::with_capacity(v.len() + prefix_len + postfix_len);
48

            
49
            // add `true`s to start
50
2256
            bits.extend(std::iter::repeat_n(
51
2256
                Expr::Atomic(Metadata::new(), Atom::Literal(Literal::Bool(true))),
52
2256
                prefix_len,
53
            ));
54

            
55
2256
            bits.extend(v);
56

            
57
            // add `false`s to end
58
2256
            bits.extend(std::iter::repeat_n(
59
2256
                Expr::Atomic(Metadata::new(), Atom::Literal(Literal::Bool(false))),
60
2256
                postfix_len,
61
            ));
62

            
63
2256
            Ok(bits)
64
2256
        })
65
1134
        .collect::<Result<_, _>>()?;
66

            
67
1134
    Ok((out, global_min, global_max))
68
2508
}
69

            
70
/// Encodes a < b for order integers.
71
///
72
/// `x < y` iff `exists i . (NOT x_i AND y_i)`
73
915
fn sat_order_lt(
74
915
    a_bits: Vec<Expr>,
75
915
    b_bits: Vec<Expr>,
76
915
    clauses: &mut Vec<conjure_cp::ast::CnfClause>,
77
915
    symbols: &mut SymbolTable,
78
915
) -> Expr {
79
915
    let mut result = Expr::Atomic(Metadata::new(), Atom::Literal(Literal::Bool(false)));
80

            
81
5847
    for (a_i, b_i) in a_bits.iter().zip(b_bits.iter()) {
82
        // (NOT a_i AND b_i)
83
5847
        let not_a_i = tseytin_not(a_i.clone(), clauses, symbols);
84
5847
        let current_term = tseytin_and(&vec![not_a_i, b_i.clone()], clauses, symbols);
85
5847

            
86
        // accumulate (NOT a_i AND b_i) into OR term
87
5847
        result = tseytin_or(&vec![result, current_term], clauses, symbols);
88
5847
    }
89
915
    result
90
915
}
91

            
92
/// Converts an integer literal to SATInt form
93
///
94
/// ```text
95
///  3
96
///  ~~>
97
///  SATInt([true;int(1..), (3, 3)])
98
///
99
/// ```
100
#[register_rule(("SAT_Order", 9500))]
101
144750
fn literal_sat_order_int(expr: &Expr, _: &SymbolTable) -> ApplicationResult {
102
759
    let value = {
103
6870
        if let Expr::Atomic(_, Atom::Literal(Literal::Int(value))) = expr {
104
759
            *value
105
        } else {
106
143991
            return Err(RuleNotApplicable);
107
        }
108
    };
109

            
110
759
    Ok(Reduction::pure(Expr::SATInt(
111
759
        Metadata::new(),
112
759
        SATIntEncoding::Order,
113
759
        Moo::new(into_matrix_expr!(vec![Expr::Atomic(
114
759
            Metadata::new(),
115
759
            Atom::Literal(Literal::Bool(true)),
116
759
        )])),
117
759
        (value, value),
118
759
    )))
119
144750
}
120

            
121
/// Converts a = expression between two order SATInts to a boolean expression in cnf
122
///
123
/// ```text
124
/// SATInt(a) = SATInt(b) ~> Bool
125
/// ```
126
#[register_rule(("SAT_Order", 9100))]
127
46554
fn eq_sat_order(expr: &Expr, symbols: &SymbolTable) -> ApplicationResult {
128
46554
    let Expr::Eq(_, lhs, rhs) = expr else {
129
46326
        return Err(RuleNotApplicable);
130
    };
131

            
132
207
    let (binding, _, _) =
133
228
        validate_order_int_operands(vec![lhs.as_ref().clone(), rhs.as_ref().clone()])?;
134
207
    let [lhs_bits, rhs_bits] = binding.as_slice() else {
135
        return Err(RuleNotApplicable);
136
    };
137

            
138
207
    let bit_count = lhs_bits.len();
139

            
140
207
    let mut output = true.into();
141
207
    let mut new_symbols = symbols.clone();
142
207
    let mut new_clauses = vec![];
143
    let mut comparison;
144

            
145
894
    for i in 0..bit_count {
146
894
        comparison = tseytin_iff(
147
894
            lhs_bits[i].clone(),
148
894
            rhs_bits[i].clone(),
149
894
            &mut new_clauses,
150
894
            &mut new_symbols,
151
894
        );
152
894
        output = tseytin_and(
153
894
            &vec![comparison, output],
154
894
            &mut new_clauses,
155
894
            &mut new_symbols,
156
894
        );
157
894
    }
158

            
159
207
    Ok(Reduction::cnf(output, new_clauses, new_symbols))
160
46554
}
161

            
162
/// Converts a </>/<=/>= expression between two order SATInts to a boolean expression in cnf
163
///
164
/// ```text
165
/// SATInt(a) </>/<=/>= SATInt(b) ~> Bool
166
///
167
/// ```
168
/// Note: < and <= are rewritten by swapping operands to reuse lt logic.
169
#[register_rule(("SAT_Order", 9100))]
170
46554
fn ineq_sat_order(expr: &Expr, symbols: &SymbolTable) -> ApplicationResult {
171
46554
    let (lhs, rhs, negate) = match expr {
172
        // A < B -> sat_order_lt(A, B)
173
27
        Expr::Lt(_, x, y) => (x, y, false),
174
        // A > B -> sat_order_lt(B, A)
175
        Expr::Gt(_, x, y) => (y, x, false),
176
        // A <= B -> NOT (B < A)
177
1407
        Expr::Leq(_, x, y) => (y, x, true),
178
        // A >= B -> NOT (A < B)
179
831
        Expr::Geq(_, x, y) => (x, y, true),
180
44289
        _ => return Err(RuleNotApplicable),
181
    };
182

            
183
915
    let (binding, _, _) =
184
2265
        validate_order_int_operands(vec![lhs.as_ref().clone(), rhs.as_ref().clone()])?;
185
915
    let [lhs_bits, rhs_bits] = binding.as_slice() else {
186
        return Err(RuleNotApplicable);
187
    };
188

            
189
915
    let mut new_symbols = symbols.clone();
190
915
    let mut new_clauses = vec![];
191

            
192
915
    let mut output = sat_order_lt(
193
915
        lhs_bits.clone(),
194
915
        rhs_bits.clone(),
195
915
        &mut new_clauses,
196
915
        &mut new_symbols,
197
    );
198

            
199
915
    if negate {
200
888
        output = tseytin_not(output, &mut new_clauses, &mut new_symbols);
201
888
    }
202

            
203
915
    Ok(Reduction::cnf(output, new_clauses, new_symbols))
204
46554
}