1
use super::{FromConjureModel, SolverError};
2
use crate::Solver;
3
use std::collections::HashMap;
4
use thiserror::Error;
5

            
6
use crate::ast::{
7
    Domain as ConjureDomain, Expression as ConjureExpression, Model as ConjureModel,
8
    Name as ConjureName,
9
};
10

            
11
const SOLVER: Solver = Solver::KissSAT;
12

            
13
struct CNF {
14
    pub clauses: Vec<Vec<i32>>,
15
    variables: HashMap<ConjureName, i32>,
16
    next_ind: i32,
17
}
18

            
19
/**
20
 * Error type for CNF adapter
21
 */
22
#[derive(Error, Debug)]
23
pub enum CNFError {
24
    #[error("Variable with name `{0}` not found")]
25
    VariableNameNotFound(ConjureName),
26

            
27
    #[error("Clause with index `{0}` not found")]
28
    ClauseIndexNotFound(i32),
29

            
30
    #[error("Unexpected Expression `{0}` inside Not(). Only Not(Reference) allowed!")]
31
    UnexpectedExpressionInsideNot(ConjureExpression),
32

            
33
    #[error(
34
        "Unexpected Expression `{0}` found. Only Reference, Not(Reference) and Or(...) allowed!"
35
    )]
36
    UnexpectedExpression(ConjureExpression),
37

            
38
    #[error("Unexpected nested And: {0}")]
39
    NestedAnd(ConjureExpression),
40
}
41

            
42
impl CNF {
43
9
    pub fn new() -> CNF {
44
9
        CNF {
45
9
            clauses: Vec::new(),
46
9
            variables: HashMap::new(),
47
9
            next_ind: 1,
48
9
        }
49
9
    }
50

            
51
    /**
52
     * Get all the Conjure variables in the CNF
53
     */
54
    pub fn get_variables(&self) -> Vec<&ConjureName> {
55
        let mut ans: Vec<&ConjureName> = Vec::new();
56

            
57
        for key in self.variables.keys() {
58
            ans.push(key);
59
        }
60

            
61
        ans
62
    }
63

            
64
    /**
65
     * Get the index of a Conjure variable
66
     */
67
67
    pub fn get_index(&self, var: &ConjureName) -> Option<i32> {
68
67
        return self.variables.get(var).copied();
69
67
    }
70

            
71
    /**
72
     * Get the Conjure variable from its index
73
     */
74
27
    pub fn get_name(&self, ind: i32) -> Option<&ConjureName> {
75
41
        for key in self.variables.keys() {
76
41
            let idx = self.get_index(key)?;
77
41
            if idx == ind {
78
27
                return Some(key);
79
14
            }
80
        }
81

            
82
        None
83
27
    }
84

            
85
    /**
86
     * Add a new Conjure variable to the CNF
87
     */
88
16
    pub fn add_variable(&mut self, var: &ConjureName) {
89
16
        self.variables.insert(var.clone(), self.next_ind);
90
16
        self.next_ind += 1;
91
16
    }
92

            
93
    /**
94
     * Check if a Conjure variable or index is present in the CNF
95
     */
96
13
    pub fn has_variable<T: HasVariable>(&self, value: T) -> bool {
97
13
        value.has_variable(self)
98
13
    }
99

            
100
    /**
101
     * Add a new clause to the CNF. Must be a vector of indices in CNF format
102
     */
103
9
    pub fn add_clause(&mut self, vec: &Vec<i32>) -> Result<(), CNFError> {
104
22
        for idx in vec {
105
13
            if !self.has_variable(idx.abs()) {
106
                return Err(CNFError::ClauseIndexNotFound(*idx));
107
13
            }
108
        }
109
9
        self.clauses.push(vec.clone());
110
9
        Ok(())
111
9
    }
112

            
113
    /**
114
     * Add a new Conjure expression to the CNF. Must be a logical expression in CNF form
115
     */
116
10
    pub fn add_expression(&mut self, expr: &ConjureExpression) -> Result<(), CNFError> {
117
10
        for row in self.handle_expression(expr)? {
118
9
            self.add_clause(&row)?;
119
        }
120
9
        Ok(())
121
10
    }
122

            
123
    /**
124
     * Convert the CNF to a Conjure expression
125
     */
126
6
    pub fn as_expression(&self) -> Result<ConjureExpression, CNFError> {
127
6
        let mut expr_clauses: Vec<ConjureExpression> = Vec::new();
128

            
129
14
        for clause in &self.clauses {
130
8
            expr_clauses.push(self.clause_to_expression(clause)?);
131
        }
132

            
133
6
        Ok(ConjureExpression::And(expr_clauses))
134
6
    }
135

            
136
    /**
137
     * Convert a single clause to a Conjure expression
138
     */
139
8
    fn clause_to_expression(&self, clause: &Vec<i32>) -> Result<ConjureExpression, CNFError> {
140
8
        let mut ans: Vec<ConjureExpression> = Vec::new();
141

            
142
20
        for idx in clause {
143
12
            match self.get_name(idx.abs()) {
144
                None => return Err(CNFError::ClauseIndexNotFound(*idx)),
145
12
                Some(name) => {
146
12
                    if *idx > 0 {
147
10
                        ans.push(ConjureExpression::Reference(name.clone()))
148
                    } else {
149
2
                        let expression: ConjureExpression =
150
2
                            ConjureExpression::Reference(name.clone());
151
2
                        ans.push(ConjureExpression::Not(Box::from(expression)))
152
                    }
153
                }
154
            }
155
        }
156

            
157
8
        Ok(ConjureExpression::Or(ans))
158
8
    }
159

            
160
    /**
161
     * Get the index for a Conjure Reference or return an error
162
     * @see get_index
163
     * @see ConjureExpression::Reference
164
     */
165
13
    fn get_reference_index(&self, name: &ConjureName) -> Result<i32, CNFError> {
166
13
        match self.get_index(name) {
167
            None => Err(CNFError::VariableNameNotFound(name.clone())),
168
13
            Some(ind) => Ok(ind),
169
        }
170
13
    }
171

            
172
    /**
173
     * Convert the contents of a single Reference to a row of the CNF format
174
     * @see get_reference_index
175
     * @see ConjureExpression::Reference
176
     */
177
11
    fn handle_reference(&self, name: &ConjureName) -> Result<Vec<i32>, CNFError> {
178
11
        Ok(vec![self.get_reference_index(name)?])
179
11
    }
180

            
181
    /**
182
     * Convert the contents of a single Not() to CNF
183
     */
184
2
    fn handle_not(&self, expr_box: &Box<ConjureExpression>) -> Result<Vec<i32>, CNFError> {
185
2
        let expr = expr_box.as_ref();
186
2
        match expr {
187
            // Expression inside the Not()
188
2
            ConjureExpression::Reference(name) => Ok(vec![-self.get_reference_index(name)?]),
189
            _ => Err(CNFError::UnexpectedExpressionInsideNot(expr.clone())),
190
        }
191
2
    }
192

            
193
    /**
194
     * Convert the contents of a single Or() to a row of the CNF format
195
     */
196
4
    fn handle_or(&self, expressions: &Vec<ConjureExpression>) -> Result<Vec<i32>, CNFError> {
197
4
        let mut ans: Vec<i32> = Vec::new();
198

            
199
12
        for expr in expressions {
200
8
            let ret = self.handle_flat_expression(expr)?;
201
17
            for ind in ret {
202
9
                ans.push(ind);
203
9
            }
204
        }
205

            
206
4
        Ok(ans)
207
4
    }
208

            
209
    /**
210
     * Convert a single Reference, Not or Or into a clause of the CNF format
211
     */
212
18
    fn handle_flat_expression(&self, expression: &ConjureExpression) -> Result<Vec<i32>, CNFError> {
213
18
        match expression {
214
11
            ConjureExpression::Reference(name) => self.handle_reference(name),
215
2
            ConjureExpression::Not(var_box) => self.handle_not(var_box),
216
4
            ConjureExpression::Or(expressions) => self.handle_or(expressions),
217
1
            _ => Err(CNFError::UnexpectedExpression(expression.clone())),
218
        }
219
18
    }
220

            
221
    /**
222
     * Convert a single And() into a vector of clauses in the CNF format
223
     */
224
    fn handle_and(&self, expressions: &Vec<ConjureExpression>) -> Result<Vec<Vec<i32>>, CNFError> {
225
        let mut ans: Vec<Vec<i32>> = Vec::new();
226

            
227
        for expression in expressions {
228
            match expression {
229
                ConjureExpression::And(_expressions) => {
230
                    return Err(CNFError::NestedAnd(expression.clone()));
231
                }
232
                _ => {
233
                    ans.push(self.handle_flat_expression(expression)?);
234
                }
235
            }
236
        }
237

            
238
        Ok(ans)
239
    }
240

            
241
    /**
242
     * Convert a single Conjure expression into a vector of clauses of the CNF format
243
     */
244
10
    fn handle_expression(&self, expression: &ConjureExpression) -> Result<Vec<Vec<i32>>, CNFError> {
245
10
        match expression {
246
            ConjureExpression::And(expressions) => self.handle_and(expressions),
247
10
            _ => Ok(vec![self.handle_flat_expression(expression)?]),
248
        }
249
10
    }
250
}
251

            
252
/**
253
 * Helper trait for checking if a variable is present in the CNF polymorphically (i32 or ConjureName)
254
 */
255
trait HasVariable {
256
    fn has_variable(self, cnf: &CNF) -> bool;
257
}
258

            
259
impl HasVariable for i32 {
260
13
    fn has_variable(self, cnf: &CNF) -> bool {
261
13
        return cnf.get_name(self).is_some();
262
13
    }
263
}
264

            
265
impl HasVariable for &ConjureName {
266
    fn has_variable(self, cnf: &CNF) -> bool {
267
        cnf.get_index(self).is_some()
268
    }
269
}
270

            
271
/**
272
* Expects Model to be in the Conjunctive Normal Form:
273
* - All variables must be boolean
274
* - Expressions must be Reference, Not(Reference), or Or(Reference1, Not(Reference2), ...)
275
* - The top level And() may contain nested Or()s. Any other nested expressions are not allowed.
276
*/
277
impl FromConjureModel for CNF {
278
    /**
279
     * Convert a Conjure model to a CNF
280
     */
281
9
    fn from_conjure(conjure_model: ConjureModel) -> Result<Self, SolverError> {
282
9
        let mut ans: CNF = CNF::new();
283

            
284
17
        for var in conjure_model.variables.keys() {
285
            // Check that domain has the correct type
286
17
            let decision_var = conjure_model.variables.get(var).unwrap();
287
17
            if decision_var.domain != ConjureDomain::BoolDomain {
288
1
                return Err(SolverError::NotSupported(
289
1
                    SOLVER,
290
1
                    format!("variable {:?} is not BoolDomain", decision_var),
291
1
                ));
292
16
            }
293
16

            
294
16
            ans.add_variable(var);
295
        }
296

            
297
10
        for expr in conjure_model.get_constraints_vec() {
298
10
            match ans.add_expression(&expr) {
299
9
                Ok(_) => {}
300
1
                Err(error) => {
301
1
                    let message = format!("{:?}", error);
302
1
                    return Err(SolverError::NotSupported(SOLVER, message));
303
                }
304
            }
305
        }
306

            
307
7
        Ok(ans)
308
9
    }
309
}
310

            
311
#[cfg(test)]
312
mod tests {
313
    use crate::ast::Domain::{BoolDomain, IntDomain};
314
    use crate::ast::Expression::{And, Not, Or, Reference};
315
    use crate::ast::{DecisionVariable, Model};
316
    use crate::ast::{Expression, Name};
317
    use crate::solvers::kissat::CNF;
318
    use crate::solvers::{FromConjureModel, SolverError};
319
    use crate::utils::{assert_eq_any_order, if_ok};
320

            
321
    #[test]
322
1
    fn test_single_var() {
323
1
        // x -> [[1]]
324
1

            
325
1
        let mut model: Model = Model::new();
326
1

            
327
1
        let x: Name = Name::UserName(String::from('x'));
328
1
        model.add_variable(x.clone(), DecisionVariable { domain: BoolDomain });
329
1
        model.add_constraint(Reference(x.clone()));
330
1

            
331
1
        let res: Result<CNF, SolverError> = CNF::from_conjure(model);
332
1
        assert!(res.is_ok());
333

            
334
1
        let cnf = res.unwrap();
335
1

            
336
1
        assert_eq!(cnf.get_index(&x), Some(1));
337
1
        assert!(cnf.get_name(1).is_some());
338
1
        assert_eq!(cnf.get_name(1).unwrap(), &x);
339

            
340
1
        assert_eq!(cnf.clauses, vec![vec![1]]);
341
1
    }
342

            
343
    #[test]
344
1
    fn test_single_not() {
345
1
        // Not(x) -> [[-1]]
346
1

            
347
1
        let mut model: Model = Model::new();
348
1

            
349
1
        let x: Name = Name::UserName(String::from('x'));
350
1
        model.add_variable(x.clone(), DecisionVariable { domain: BoolDomain });
351
1
        model.add_constraint(Not(Box::from(Reference(x.clone()))));
352
1

            
353
1
        let cnf: CNF = CNF::from_conjure(model).unwrap();
354
1
        assert_eq!(cnf.get_index(&x), Some(1));
355
1
        assert_eq!(cnf.clauses, vec![vec![-1]]);
356

            
357
1
        assert_eq!(
358
1
            if_ok(cnf.as_expression()),
359
1
            And(vec![Or(vec![Not(Box::from(Reference(x.clone())))])])
360
1
        )
361
1
    }
362

            
363
    #[test]
364
1
    fn test_single_or() {
365
1
        // Or(x, y) -> [[1, 2]]
366
1

            
367
1
        let mut model: Model = Model::new();
368
1

            
369
1
        let x: Name = Name::UserName(String::from('x'));
370
1
        let y: Name = Name::UserName(String::from('y'));
371
1

            
372
1
        model.add_variable(x.clone(), DecisionVariable { domain: BoolDomain });
373
1
        model.add_variable(y.clone(), DecisionVariable { domain: BoolDomain });
374
1

            
375
1
        model.add_constraint(Or(vec![Reference(x.clone()), Reference(y.clone())]));
376
1

            
377
1
        let cnf: CNF = CNF::from_conjure(model).unwrap();
378
1

            
379
1
        let xi = cnf.get_index(&x).unwrap();
380
1
        let yi = cnf.get_index(&y).unwrap();
381
1
        assert_eq_any_order(&cnf.clauses, &vec![vec![xi, yi]]);
382
1

            
383
1
        assert_eq!(
384
1
            if_ok(cnf.as_expression()),
385
1
            And(vec![Or(vec![Reference(x.clone()), Reference(y.clone())])])
386
1
        )
387
1
    }
388

            
389
    #[test]
390
1
    fn test_or_not() {
391
1
        // Or(x, Not(y)) -> [[1, -2]]
392
1

            
393
1
        let mut model: Model = Model::new();
394
1

            
395
1
        let x: Name = Name::UserName(String::from('x'));
396
1
        let y: Name = Name::UserName(String::from('y'));
397
1

            
398
1
        model.add_variable(x.clone(), DecisionVariable { domain: BoolDomain });
399
1
        model.add_variable(y.clone(), DecisionVariable { domain: BoolDomain });
400
1

            
401
1
        model.add_constraint(Or(vec![
402
1
            Reference(x.clone()),
403
1
            Not(Box::from(Reference(y.clone()))),
404
1
        ]));
405
1

            
406
1
        let cnf: CNF = CNF::from_conjure(model).unwrap();
407
1

            
408
1
        let xi = cnf.get_index(&x).unwrap();
409
1
        let yi = cnf.get_index(&y).unwrap();
410
1
        assert_eq_any_order(&cnf.clauses, &vec![vec![xi, -yi]]);
411
1

            
412
1
        assert_eq!(
413
1
            if_ok(cnf.as_expression()),
414
1
            And(vec![Or(vec![
415
1
                Reference(x.clone()),
416
1
                Not(Box::from(Reference(y.clone())))
417
1
            ])])
418
1
        )
419
1
    }
420

            
421
    #[test]
422
1
    fn test_multiple() {
423
1
        // [x, y] - equivalent to And(x, y) -> [[1], [2]]
424
1

            
425
1
        let mut model: Model = Model::new();
426
1

            
427
1
        let x: Name = Name::UserName(String::from('x'));
428
1
        let y: Name = Name::UserName(String::from('y'));
429
1

            
430
1
        model.add_variable(x.clone(), DecisionVariable { domain: BoolDomain });
431
1
        model.add_variable(y.clone(), DecisionVariable { domain: BoolDomain });
432
1

            
433
1
        model.add_constraint(Reference(x.clone()));
434
1
        model.add_constraint(Reference(y.clone()));
435
1

            
436
1
        let cnf: CNF = CNF::from_conjure(model).unwrap();
437
1

            
438
1
        let xi = cnf.get_index(&x).unwrap();
439
1
        let yi = cnf.get_index(&y).unwrap();
440
1
        assert_eq_any_order(&cnf.clauses, &vec![vec![xi], vec![yi]]);
441
1

            
442
1
        assert_eq!(
443
1
            if_ok(cnf.as_expression()),
444
1
            And(vec![
445
1
                Or(vec![Reference(x.clone())]),
446
1
                Or(vec![Reference(y.clone())])
447
1
            ])
448
1
        )
449
1
    }
450

            
451
    #[test]
452
1
    fn test_and() {
453
1
        // And(x, y) -> [[1], [2]]
454
1

            
455
1
        let mut model: Model = Model::new();
456
1

            
457
1
        let x: Name = Name::UserName(String::from('x'));
458
1
        let y: Name = Name::UserName(String::from('y'));
459
1

            
460
1
        model.add_variable(x.clone(), DecisionVariable { domain: BoolDomain });
461
1
        model.add_variable(y.clone(), DecisionVariable { domain: BoolDomain });
462
1

            
463
1
        model.add_constraint(And(vec![Reference(x.clone()), Reference(y.clone())]));
464
1

            
465
1
        let cnf: CNF = CNF::from_conjure(model).unwrap();
466
1

            
467
1
        let xi = cnf.get_index(&x).unwrap();
468
1
        let yi = cnf.get_index(&y).unwrap();
469
1
        assert_eq_any_order(&cnf.clauses, &vec![vec![xi], vec![yi]]);
470
1

            
471
1
        assert_eq!(
472
1
            if_ok(cnf.as_expression()),
473
1
            And(vec![
474
1
                Or(vec![Reference(x.clone())]),
475
1
                Or(vec![Reference(y.clone())])
476
1
            ])
477
1
        )
478
1
    }
479

            
480
    #[test]
481
1
    fn test_nested_ors() {
482
1
        // Or(x, Or(y, z)) -> [[1, 2, 3]]
483
1

            
484
1
        let mut model: Model = Model::new();
485
1

            
486
1
        let x: Name = Name::UserName(String::from('x'));
487
1
        let y: Name = Name::UserName(String::from('y'));
488
1
        let z: Name = Name::UserName(String::from('z'));
489
1

            
490
1
        model.add_variable(x.clone(), DecisionVariable { domain: BoolDomain });
491
1
        model.add_variable(y.clone(), DecisionVariable { domain: BoolDomain });
492
1
        model.add_variable(z.clone(), DecisionVariable { domain: BoolDomain });
493
1

            
494
1
        model.add_constraint(Or(vec![
495
1
            Reference(x.clone()),
496
1
            Or(vec![Reference(y.clone()), Reference(z.clone())]),
497
1
        ]));
498
1

            
499
1
        let cnf: CNF = CNF::from_conjure(model).unwrap();
500
1

            
501
1
        let xi = cnf.get_index(&x).unwrap();
502
1
        let yi = cnf.get_index(&y).unwrap();
503
1
        let zi = cnf.get_index(&z).unwrap();
504
1
        assert_eq_any_order(&cnf.clauses, &vec![vec![xi, yi, zi]]);
505
1

            
506
1
        assert_eq!(
507
1
            if_ok(cnf.as_expression()),
508
1
            And(vec![Or(vec![
509
1
                Reference(x.clone()),
510
1
                Reference(y.clone()),
511
1
                Reference(z.clone())
512
1
            ])])
513
1
        )
514
1
    }
515

            
516
    #[test]
517
1
    fn test_int() {
518
1
        // y is an IntDomain - only booleans should be allowed
519
1

            
520
1
        let mut model: Model = Model::new();
521
1

            
522
1
        let x: Name = Name::UserName(String::from('x'));
523
1
        let y: Name = Name::UserName(String::from('y'));
524
1

            
525
1
        model.add_variable(x.clone(), DecisionVariable { domain: BoolDomain });
526
1
        model.add_variable(
527
1
            y.clone(),
528
1
            DecisionVariable {
529
1
                domain: IntDomain(vec![]),
530
1
            },
531
1
        );
532
1

            
533
1
        model.add_constraint(Reference(x.clone()));
534
1
        model.add_constraint(Reference(y.clone()));
535
1

            
536
1
        let cnf: Result<CNF, SolverError> = CNF::from_conjure(model);
537
1
        assert!(cnf.is_err());
538
1
    }
539

            
540
    #[test]
541
1
    fn test_eq() {
542
1
        // Eq(x, y) - this operation is not allowed
543
1

            
544
1
        let mut model: Model = Model::new();
545
1

            
546
1
        let x: Name = Name::UserName(String::from('x'));
547
1
        let y: Name = Name::UserName(String::from('y'));
548
1

            
549
1
        model.add_variable(x.clone(), DecisionVariable { domain: BoolDomain });
550
1
        model.add_variable(y.clone(), DecisionVariable { domain: BoolDomain });
551
1

            
552
1
        model.add_constraint(Expression::Eq(
553
1
            Box::from(Reference(x.clone())),
554
1
            Box::from(Reference(y.clone())),
555
1
        ));
556
1

            
557
1
        let cnf: Result<CNF, SolverError> = CNF::from_conjure(model);
558
1
        assert!(cnf.is_err());
559
1
    }
560
}