1
use conjure_cp::ast::{Expression as Expr, GroundDomain};
2
use conjure_cp::ast::{SATIntEncoding, SymbolTable};
3
use conjure_cp::rule_engine::{
4
    ApplicationError, ApplicationError::RuleNotApplicable, ApplicationResult, Reduction,
5
    register_rule,
6
};
7

            
8
use conjure_cp::ast::Metadata;
9
use conjure_cp::ast::{Atom, Literal, Moo, Range};
10
use conjure_cp::into_matrix_expr;
11

            
12
use conjure_cp::{bug, essence_expr};
13

            
14
/// This function takes a target expression and a vector of ranges and creates an expression representing the ranges with the target expression as the subject
15
///
16
/// E.g. x : int(4), int(10..20), int(30..) ~~> Or(x=4, 10<=x<=20, x>=30)
17
3348
fn int_domain_to_expr(subject: Expr, ranges: &Vec<Range<i32>>) -> Expr {
18
3348
    let mut output = vec![];
19

            
20
3348
    let value = Moo::new(subject);
21

            
22
3588
    for range in ranges {
23
3588
        match range {
24
336
            Range::Single(x) => output.push(essence_expr!(&value = &x)),
25
3252
            Range::Bounded(x, y) => output.push(essence_expr!("&value >= &x /\\ &value <= &y")),
26
            _ => bug!("Unbounded domains not supported for SAT"),
27
        }
28
    }
29

            
30
3348
    Expr::Or(Metadata::new(), Moo::new(into_matrix_expr!(output)))
31
3348
}
32

            
33
/// This function confirms that all of the input expressions are log SATInts, and returns vectors for each input of their bits
34
/// This function also extends all vectors such that they have the same lengths
35
/// The vector lengths is either `n` for bit_count = Some(n), otherwise the length of the longest operand
36
4023
pub fn validate_log_int_operands(
37
4023
    exprs: Vec<Expr>,
38
4023
    bit_count: Option<u32>,
39
4023
) -> Result<Vec<Vec<Expr>>, ApplicationError> {
40
    // TODO: In the future it may be possible to optimize operations between integers with different bit sizes
41
    // Collect inner bit vectors from each SATInt
42

            
43
    // TODO: this file should be encoding agnostic so this needs to moved to the log_int_ops.rs file, do this once the direct ints have been merged to main though
44
4023
    let mut out: Vec<Vec<Expr>> = exprs
45
4023
        .into_iter()
46
7269
        .map(|expr| {
47
6228
            let Expr::SATInt(_, SATIntEncoding::Log, inner, _) = expr else {
48
1041
                return Err(RuleNotApplicable);
49
            };
50
6228
            let Some(bits) = inner.as_ref().clone().unwrap_list() else {
51
                return Err(RuleNotApplicable);
52
            };
53
6228
            Ok(bits)
54
7269
        })
55
4023
        .collect::<Result<_, _>>()?;
56

            
57
    // Determine target length
58
2982
    let max_len = bit_count
59
2982
        .map(|b| b as usize)
60
5538
        .unwrap_or_else(|| out.iter().map(|v| v.len()).max().unwrap_or(0));
61

            
62
    // Extend or crop each vector
63
5910
    for v in &mut out {
64
5910
        if v.len() < max_len {
65
            // pad with the last element
66
1554
            if let Some(last) = v.last().cloned() {
67
1554
                v.resize(max_len, last);
68
1554
            }
69
4356
        } else if v.len() > max_len {
70
            // crop extra elements
71
            v.truncate(max_len);
72
4356
        }
73
    }
74

            
75
2982
    Ok(out)
76
4023
}
77

            
78
/// Converts an integer decision variable to SATInt form, creating a new representation of boolean variables if
79
/// one does not yet exist
80
///
81
/// ```text
82
///  x
83
///  ~~>
84
///  SATInt([x#00, x#01, ...])
85
///
86
///  new variables:
87
///  find x#00: bool
88
///  find x#01: bool
89
///  ...
90
///
91
/// ```
92
#[register_rule("SAT_Direct", 9500, [Atomic])]
93
1374903
fn integer_decision_representation_direct(expr: &Expr, symbols: &SymbolTable) -> ApplicationResult {
94
    // thing we are representing must be a reference
95
837324
    let Expr::Atomic(_, Atom::Reference(name)) = expr else {
96
589824
        return Err(RuleNotApplicable);
97
    };
98

            
99
    // thing we are representing must be a variable
100
    // symbols
101
    //     .lookup(name)
102
    //     .ok_or(RuleNotApplicable)?
103
    //     .as_find()
104
    //     .ok_or(RuleNotApplicable)?;
105

            
106
    // thing we are representing must be an integer
107
785079
    let dom = name.resolved_domain().ok_or(RuleNotApplicable)?;
108
784791
    let GroundDomain::Int(ranges) = dom.as_ref() else {
109
781245
        return Err(RuleNotApplicable);
110
    };
111

            
112
3546
    let (min, max) = ranges
113
3546
        .iter()
114
3666
        .fold((i32::MAX, i32::MIN), |(min_a, max_b), range| {
115
3666
            (
116
3666
                min_a.min(*range.low().unwrap()),
117
3666
                max_b.max(*range.high().unwrap()),
118
3666
            )
119
3666
        });
120

            
121
3546
    let mut symbols = symbols.clone();
122

            
123
3546
    let new_name = &name.name().to_owned();
124

            
125
3546
    let repr_exists = symbols
126
3546
        .get_representation(new_name, &["sat_direct_int"])
127
3546
        .is_some();
128

            
129
3546
    let representation = symbols
130
3546
        .get_or_add_representation(new_name, &["sat_direct_int"])
131
3546
        .ok_or(RuleNotApplicable)?;
132

            
133
3186
    let bits: Vec<Expr> = representation[0]
134
3186
        .clone()
135
3186
        .expression_down(&symbols)?
136
3186
        .into_values()
137
3186
        .collect();
138

            
139
3186
    let cnf_int = Expr::SATInt(
140
3186
        Metadata::new(),
141
3186
        SATIntEncoding::Direct,
142
3186
        Moo::new(into_matrix_expr!(bits.clone())),
143
3186
        (min, max),
144
3186
    );
145

            
146
3186
    if !repr_exists {
147
        // Domain constraint: the integer must take one of its valid values
148
1962
        let constraints = vec![int_domain_to_expr(cnf_int.clone(), ranges)];
149

            
150
        // At-Most-One constraints: only one bit can be true.
151
1962
        let mut clauses = vec![];
152
15300
        for i in 0..bits.len() {
153
335211
            for j in i + 1..bits.len() {
154
335211
                clauses.push(conjure_cp::ast::CnfClause::new(vec![
155
335211
                    Expr::Not(Metadata::new(), Moo::new(bits[i].clone())),
156
335211
                    Expr::Not(Metadata::new(), Moo::new(bits[j].clone())),
157
335211
                ]));
158
335211
            }
159
        }
160

            
161
1962
        let mut reduction = Reduction::cnf(cnf_int, clauses, symbols);
162
1962
        reduction.new_top = constraints;
163
1962
        Ok(reduction)
164
    } else {
165
1224
        Ok(Reduction::pure(cnf_int))
166
    }
167
1374903
}
168

            
169
#[register_rule("SAT_Order", 9500, [Atomic])]
170
233553
fn integer_decision_representation_order(expr: &Expr, symbols: &SymbolTable) -> ApplicationResult {
171
    // thing we are representing must be a reference
172
122610
    let Expr::Atomic(_, Atom::Reference(name)) = expr else {
173
121215
        return Err(RuleNotApplicable);
174
    };
175

            
176
    // thing we are representing must be an integer
177
112338
    let dom = name.resolved_domain().ok_or(RuleNotApplicable)?;
178
112338
    let GroundDomain::Int(ranges) = dom.as_ref() else {
179
111342
        return Err(RuleNotApplicable);
180
    };
181

            
182
996
    let (min, max) = ranges
183
996
        .iter()
184
996
        .fold((i32::MAX, i32::MIN), |(min_a, max_b), range| {
185
996
            (
186
996
                min_a.min(*range.low().unwrap()),
187
996
                max_b.max(*range.high().unwrap()),
188
996
            )
189
996
        });
190

            
191
996
    let mut symbols = symbols.clone();
192

            
193
996
    let new_name = &name.name().to_owned();
194

            
195
996
    let repr_exists = symbols
196
996
        .get_representation(new_name, &["sat_order_int"])
197
996
        .is_some();
198

            
199
996
    let representation = symbols
200
996
        .get_or_add_representation(new_name, &["sat_order_int"])
201
996
        .ok_or(RuleNotApplicable)?;
202

            
203
996
    let bits: Vec<Expr> = representation[0]
204
996
        .clone()
205
996
        .expression_down(&symbols)?
206
996
        .into_values()
207
996
        .collect();
208

            
209
996
    let cnf_int = Expr::SATInt(
210
996
        Metadata::new(),
211
996
        SATIntEncoding::Order,
212
996
        Moo::new(into_matrix_expr!(bits.clone())),
213
996
        (min, max),
214
996
    );
215

            
216
996
    if !repr_exists {
217
        // Domain constraint: the integer must take one of its valid values
218
462
        let constraints = vec![int_domain_to_expr(cnf_int.clone(), ranges)];
219

            
220
        // Ordering constraints: b_i -> b_{i-1} which is !b_i or b_{i-1}
221
462
        let mut clauses = vec![];
222
1824
        for i in 1..bits.len() {
223
1824
            clauses.push(conjure_cp::ast::CnfClause::new(vec![
224
1824
                Expr::Not(Metadata::new(), Moo::new(bits[i].clone())),
225
1824
                bits[i - 1].clone(),
226
1824
            ]));
227
1824
        }
228

            
229
462
        if !bits.is_empty() {
230
            // Domain constraint: a >= min, which is b_min.
231
462
            clauses.push(conjure_cp::ast::CnfClause::new(vec![bits[0].clone()]));
232
462
        }
233

            
234
462
        let mut reduction = Reduction::cnf(cnf_int, clauses, symbols);
235
462
        reduction.new_top = constraints;
236
462
        Ok(reduction)
237
    } else {
238
534
        Ok(Reduction::pure(cnf_int))
239
    }
240
233553
}
241

            
242
/// Converts an integer decision variable to SATInt form (Log encoding)
243
#[register_rule("SAT_Log", 9500, [Atomic])]
244
606174
fn integer_decision_representation_log(expr: &Expr, symbols: &SymbolTable) -> ApplicationResult {
245
    // thing we are representing must be a reference
246
315207
    let Expr::Atomic(_, Atom::Reference(name)) = expr else {
247
359673
        return Err(RuleNotApplicable);
248
    };
249

            
250
    // thing we are representing must be a variable
251
    // symbols
252
    //     .lookup(name)
253
    //     .ok_or(RuleNotApplicable)?
254
    //     .as_find()
255
    //     .ok_or(RuleNotApplicable)?;
256

            
257
    // thing we are representing must be an integer
258
246501
    let dom = name.resolved_domain().ok_or(RuleNotApplicable)?;
259
246501
    let GroundDomain::Int(ranges) = dom.as_ref() else {
260
245277
        return Err(RuleNotApplicable);
261
    };
262

            
263
1224
    let (min, max) = ranges
264
1224
        .iter()
265
1344
        .fold((i32::MAX, i32::MIN), |(min_a, max_b), range| {
266
1344
            (
267
1344
                min_a.min(*range.low().unwrap()),
268
1344
                max_b.max(*range.high().unwrap()),
269
1344
            )
270
1344
        });
271

            
272
1224
    let mut symbols = symbols.clone();
273

            
274
1224
    let new_name = &name.name().to_owned();
275

            
276
1224
    let repr_exists = symbols
277
1224
        .get_representation(new_name, &["sat_log_int"])
278
1224
        .is_some();
279

            
280
1224
    let representation = symbols
281
1224
        .get_or_add_representation(new_name, &["sat_log_int"])
282
1224
        .ok_or(RuleNotApplicable)?;
283

            
284
1224
    let bits = representation[0]
285
1224
        .clone()
286
1224
        .expression_down(&symbols)?
287
1224
        .into_values()
288
1224
        .collect();
289

            
290
1224
    let cnf_int = Expr::SATInt(
291
1224
        Metadata::new(),
292
1224
        SATIntEncoding::Log,
293
1224
        Moo::new(into_matrix_expr!(bits)),
294
1224
        (min, max),
295
1224
    );
296

            
297
1224
    if !repr_exists {
298
        // add domain ranges as constraints if this is the first time the representation is added
299
924
        Ok(Reduction::new(
300
924
            cnf_int.clone(),
301
924
            vec![int_domain_to_expr(cnf_int, ranges)], // contains domain rules
302
924
            symbols,
303
924
        ))
304
    } else {
305
300
        Ok(Reduction::pure(cnf_int))
306
    }
307
606174
}
308

            
309
/// Converts an integer literal to SATInt form
310
///
311
/// ```text
312
///  3
313
///  ~~>
314
///  SATInt([true,true,false,false,false,false,false,false;int(1..)])
315
///
316
/// ```
317
#[register_rule("SAT_Log", 9500, [Atomic])]
318
606174
fn literal_cnf_int(expr: &Expr, _: &SymbolTable) -> ApplicationResult {
319
2400
    let value = {
320
68706
        if let Expr::Atomic(_, Atom::Literal(Literal::Int(v))) = expr {
321
2400
            *v
322
        } else {
323
603774
            return Err(RuleNotApplicable);
324
        }
325
    };
326
    //TODO: Adding constant optimization to all int operations should hopefully make this rule redundant
327

            
328
2400
    let mut binary_encoding = vec![];
329

            
330
2400
    let bit_count = bit_magnitude(value);
331

            
332
2400
    let mut value_mut = value as u32;
333

            
334
7998
    for _ in 0..bit_count {
335
7998
        binary_encoding.push(Expr::Atomic(
336
7998
            Metadata::new(),
337
7998
            Atom::Literal(Literal::Bool((value_mut & 1) != 0)),
338
7998
        ));
339
7998
        value_mut >>= 1;
340
7998
    }
341

            
342
2400
    Ok(Reduction::pure(Expr::SATInt(
343
2400
        Metadata::new(),
344
2400
        SATIntEncoding::Log,
345
2400
        Moo::new(into_matrix_expr!(binary_encoding)),
346
2400
        (value, value),
347
2400
    )))
348
606174
}
349

            
350
/// Determine the number of bits required to encode an i32 in 2s complement
351
2784
pub fn bit_magnitude(x: i32) -> usize {
352
2784
    if x >= 0 {
353
        // positive: bits = highest set bit + 1 sign bit
354
2538
        (1 + (32 - x.leading_zeros())).try_into().unwrap()
355
    } else {
356
        // negative: bits = highest set bit in magnitude
357
246
        (33 - (!x).leading_zeros()).try_into().unwrap()
358
    }
359
2784
}
360

            
361
/// Given two vectors of expressions, extend the shorter one by repeating its last element until both are the same length
362
12
pub fn match_bits_length(a: Vec<Expr>, b: Vec<Expr>) -> (Vec<Expr>, Vec<Expr>) {
363
12
    let len_a = a.len();
364
12
    let len_b = b.len();
365

            
366
12
    if len_a < len_b {
367
        let last_a = a.last().cloned().unwrap();
368
        let mut a_extended = a;
369
        a_extended.resize(len_b, last_a);
370
        (a_extended, b)
371
12
    } else if len_b < len_a {
372
        let last_b = b.last().cloned().unwrap();
373
        let mut b_extended = b;
374
        b_extended.resize(len_a, last_b);
375
        (a, b_extended)
376
    } else {
377
12
        (a, b)
378
    }
379
12
}