1
use conjure_cp::ast::{Atom, Expression as Expr, Literal};
2
use conjure_cp::ast::{SATIntEncoding, SymbolTable};
3
use conjure_cp::rule_engine::ApplicationError;
4
use conjure_cp::rule_engine::{
5
    ApplicationError::RuleNotApplicable, ApplicationResult, Reduction, register_rule,
6
};
7

            
8
use conjure_cp::ast::Metadata;
9
use conjure_cp::ast::Moo;
10
use conjure_cp::into_matrix_expr;
11

            
12
use super::boolean::{tseytin_and, tseytin_iff, tseytin_not, tseytin_or, tseytin_xor};
13

            
14
use conjure_cp::ast::CnfClause;
15

            
16
/// Converts an integer literal to SATInt form
17
///
18
/// ```text
19
///  3
20
///  ~~>
21
///  SATInt([true;int(1..), (3, 3)])
22
///
23
/// ```
24
#[register_rule(("SAT_Direct", 9500))]
25
9
fn literal_sat_direct_int(expr: &Expr, _: &SymbolTable) -> ApplicationResult {
26
    let value = {
27
        if let Expr::Atomic(_, Atom::Literal(Literal::Int(value))) = expr {
28
            *value
29
        } else {
30
9
            return Err(RuleNotApplicable);
31
        }
32
    };
33

            
34
    Ok(Reduction::pure(Expr::SATInt(
35
        Metadata::new(),
36
        SATIntEncoding::Direct,
37
        Moo::new(into_matrix_expr!(vec![Expr::Atomic(
38
            Metadata::new(),
39
            Atom::Literal(Literal::Bool(true)),
40
        )])),
41
        (value, value),
42
    )))
43
9
}
44

            
45
/// This function confirms that all of the input expressions are direct SATInts, and returns vectors for each input of their bits
46
/// This function also normalizes direct SATInt operands to a common value range by zero-padding.
47
pub fn validate_direct_int_operands(
48
    exprs: Vec<Expr>,
49
) -> Result<(Vec<Vec<Expr>>, i32, i32), ApplicationError> {
50
    // TODO: In the future it may be possible to optimize operations between integers with different bit sizes
51
    // Collect inner bit vectors from each SATInt
52

            
53
    // Iterate over all inputs
54
    // Check they are direct and calulate a lower and upper bound
55
    let mut global_min: i32 = i32::MAX;
56
    let mut global_max: i32 = i32::MIN;
57

            
58
    for operand in &exprs {
59
        let Expr::SATInt(_, SATIntEncoding::Direct, _, (local_min, local_max)) = operand else {
60
            return Err(RuleNotApplicable);
61
        };
62
        global_min = global_min.min(*local_min);
63
        global_max = global_max.max(*local_max);
64
    }
65

            
66
    // build out by iterating over each operand and expanding it to match the new bounds
67

            
68
    let out: Vec<Vec<Expr>> = exprs
69
        .into_iter()
70
        .map(|expr| {
71
            let Expr::SATInt(_, SATIntEncoding::Direct, inner, (local_min, local_max)) = expr
72
            else {
73
                return Err(RuleNotApplicable);
74
            };
75

            
76
            let Some(v) = inner.as_ref().clone().unwrap_list() else {
77
                return Err(RuleNotApplicable);
78
            };
79

            
80
            // calulcate how many zeroes to prepend/append
81
            let prefix_len = (local_min - global_min) as usize;
82
            let postfix_len = (global_max - local_max) as usize;
83

            
84
            let mut bits = Vec::with_capacity(v.len() + prefix_len + postfix_len);
85

            
86
            // add 0s to start
87
            bits.extend(std::iter::repeat_n(
88
                Expr::Atomic(Metadata::new(), Atom::Literal(Literal::Bool(false))),
89
                prefix_len,
90
            ));
91

            
92
            bits.extend(v);
93

            
94
            // add 0s to end
95
            bits.extend(std::iter::repeat_n(
96
                Expr::Atomic(Metadata::new(), Atom::Literal(Literal::Bool(false))),
97
                postfix_len,
98
            ));
99

            
100
            Ok(bits)
101
        })
102
        .collect::<Result<_, _>>()?;
103

            
104
    Ok((out, global_min, global_max))
105
}
106

            
107
/// Converts a = expression between two direct SATInts to a boolean expression in cnf
108
///
109
/// ```text
110
/// SATInt(a) = SATInt(b) ~> Bool
111
/// ```
112
/// NOTE: This rule reduces to AND_i (a[i] ≡ b[i]) and does not enforce one-hotness.
113
#[register_rule(("SAT_Direct", 9100))]
114
9
fn eq_sat_direct(expr: &Expr, symbols: &SymbolTable) -> ApplicationResult {
115
    // TODO: this could be optimized by just going over the sections of both vectors where the ranges intersect
116
    // this does require enforcing structure separately
117
9
    let Expr::Eq(_, lhs, rhs) = expr else {
118
9
        return Err(RuleNotApplicable);
119
    };
120

            
121
    let (binding, _, _) =
122
        validate_direct_int_operands(vec![lhs.as_ref().clone(), rhs.as_ref().clone()])?;
123
    let [lhs_bits, rhs_bits] = binding.as_slice() else {
124
        return Err(RuleNotApplicable);
125
    };
126

            
127
    let bit_count = lhs_bits.len();
128

            
129
    let mut output = true.into();
130
    let mut new_symbols = symbols.clone();
131
    let mut new_clauses = vec![];
132
    let mut comparison;
133

            
134
    for i in 0..bit_count {
135
        comparison = tseytin_iff(
136
            lhs_bits[i].clone(),
137
            rhs_bits[i].clone(),
138
            &mut new_clauses,
139
            &mut new_symbols,
140
        );
141
        output = tseytin_and(
142
            &vec![comparison, output],
143
            &mut new_clauses,
144
            &mut new_symbols,
145
        );
146
    }
147

            
148
    Ok(Reduction::cnf(output, new_clauses, new_symbols))
149
9
}
150

            
151
/// Converts a != expression between two direct SATInts to a boolean expression in cnf
152
///
153
/// ```text
154
/// SATInt(a) != SATInt(b) ~> Bool
155
///
156
/// ```
157
///
158
/// True iff at least one value position differs.
159
#[register_rule(("SAT_Direct", 9100))]
160
9
fn neq_sat_direct(expr: &Expr, symbols: &SymbolTable) -> ApplicationResult {
161
9
    let Expr::Neq(_, lhs, rhs) = expr else {
162
9
        return Err(RuleNotApplicable);
163
    };
164

            
165
    let (binding, _, _) =
166
        validate_direct_int_operands(vec![lhs.as_ref().clone(), rhs.as_ref().clone()])?;
167
    let [lhs_bits, rhs_bits] = binding.as_slice() else {
168
        return Err(RuleNotApplicable);
169
    };
170

            
171
    let bit_count = lhs_bits.len();
172

            
173
    let mut output = false.into();
174
    let mut new_symbols = symbols.clone();
175
    let mut new_clauses = vec![];
176
    let mut comparison;
177

            
178
    for i in 0..bit_count {
179
        comparison = tseytin_xor(
180
            lhs_bits[i].clone(),
181
            rhs_bits[i].clone(),
182
            &mut new_clauses,
183
            &mut new_symbols,
184
        );
185
        output = tseytin_or(
186
            &vec![comparison, output],
187
            &mut new_clauses,
188
            &mut new_symbols,
189
        );
190
    }
191

            
192
    Ok(Reduction::cnf(output, new_clauses, new_symbols))
193
9
}
194

            
195
/// Converts a </>/<=/>= expression between two direct SATInts to a boolean expression in cnf
196
///
197
/// ```text
198
/// SATInt(a) </>/<=/>= SATInt(b) ~> Bool
199
///
200
/// ```
201
/// Note: < and <= are rewritten by swapping operands to reuse lt logic.
202
#[register_rule(("SAT", 9100))]
203
9
fn ineq_sat_direct(expr: &Expr, symbols: &SymbolTable) -> ApplicationResult {
204
9
    let (lhs, rhs, negate) = match expr {
205
        // A < B -> sat_direct_lt(A, B)
206
        Expr::Lt(_, x, y) => (x, y, false),
207
        // A > B -> sat_direct_lt(B, A)
208
        Expr::Gt(_, x, y) => (y, x, false),
209
        // A <= B -> NOT (B < A)
210
        Expr::Leq(_, x, y) => (y, x, true),
211
        // A >= B -> NOT (A < B)
212
        Expr::Geq(_, x, y) => (x, y, true),
213
9
        _ => return Err(RuleNotApplicable),
214
    };
215

            
216
    let (binding, _, _) =
217
        validate_direct_int_operands(vec![lhs.as_ref().clone(), rhs.as_ref().clone()])?;
218
    let [lhs_bits, rhs_bits] = binding.as_slice() else {
219
        return Err(RuleNotApplicable);
220
    };
221

            
222
    let mut new_symbols = symbols.clone();
223
    let mut new_clauses = vec![];
224

            
225
    let mut output = sat_direct_lt(
226
        lhs_bits.clone(),
227
        rhs_bits.clone(),
228
        &mut new_clauses,
229
        &mut new_symbols,
230
    );
231

            
232
    if negate {
233
        output = tseytin_not(output, &mut new_clauses, &mut new_symbols);
234
    }
235

            
236
    Ok(Reduction::cnf(output, new_clauses, new_symbols))
237
9
}
238

            
239
/// Encodes a < b for one-hot direct integers using prefix OR logic.
240
fn sat_direct_lt(
241
    a: Vec<Expr>,
242
    b: Vec<Expr>,
243
    clauses: &mut Vec<CnfClause>,
244
    symbols: &mut SymbolTable,
245
) -> Expr {
246
    let mut b_or = Expr::Atomic(Metadata::new(), Atom::Literal(Literal::Bool(false)));
247
    let mut cum_result = Expr::Atomic(Metadata::new(), Atom::Literal(Literal::Bool(false)));
248

            
249
    for (a_i, b_i) in a.iter().zip(b.iter()) {
250
        // b_or is prefix_or of b up to index i: B_i = b_0 | ... | b_i
251
        b_or = tseytin_or(&vec![b_or, b_i.clone()], clauses, symbols);
252

            
253
        // a < b if there exists i such that a=i and b > i.
254
        // b > i is equivalent to NOT(B_i) assuming one-hotness.
255
        let not_b_or = tseytin_not(b_or.clone(), clauses, symbols);
256
        let a_i_and_not_b_i = tseytin_and(&vec![a_i.clone(), not_b_or], clauses, symbols);
257

            
258
        cum_result = tseytin_or(&vec![cum_result, a_i_and_not_b_i], clauses, symbols);
259
    }
260

            
261
    cum_result
262
}
263

            
264
/// Converts a - expression for a SATInt to a new SATInt
265
///
266
/// ```text
267
/// -SATInt(a) ~> SATInt(b)
268
///
269
/// ```
270
#[register_rule(("SAT_Direct", 9100))]
271
9
fn neg_sat_direct(expr: &Expr, _: &SymbolTable) -> ApplicationResult {
272
9
    let Expr::Neg(_, value) = expr else {
273
9
        return Err(RuleNotApplicable);
274
    };
275

            
276
    let (binding, old_min, old_max) = validate_direct_int_operands(vec![value.as_ref().clone()])?;
277
    let [val_bits] = binding.as_slice() else {
278
        return Err(RuleNotApplicable);
279
    };
280

            
281
    let new_min = -old_max;
282
    let new_max = -old_min;
283

            
284
    let mut out = val_bits.clone();
285
    out.reverse();
286

            
287
    Ok(Reduction::pure(Expr::SATInt(
288
        Metadata::new(),
289
        SATIntEncoding::Direct,
290
        Moo::new(into_matrix_expr!(out)),
291
        (new_min, new_max),
292
    )))
293
9
}