1
use std::fmt::{Display, Formatter};
2
use std::sync::Arc;
3

            
4
use serde::{Deserialize, Serialize};
5

            
6
use enum_compatability_macro::document_compatibility;
7
use uniplate::derive::Uniplate;
8
use uniplate::Biplate;
9

            
10
use crate::ast::literals::Literal;
11
use crate::ast::symbol_table::{Name, SymbolTable};
12
use crate::ast::Atom;
13
use crate::ast::ReturnType;
14
use crate::bug;
15
use crate::metadata::Metadata;
16

            
17
use super::{Domain, Range};
18

            
19
/// Represents different types of expressions used to define rules and constraints in the model.
20
///
21
/// The `Expression` enum includes operations, constants, and variable references
22
/// used to build rules and conditions for the model.
23
#[document_compatibility]
24
6
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize, Uniplate)]
25
#[uniplate(walk_into=[Atom])]
26
#[biplate(to=Literal)]
27
#[biplate(to=Metadata)]
28
#[biplate(to=Atom)]
29
#[biplate(to=Name)]
30
pub enum Expression {
31
    /// An expression representing "A is valid as long as B is true"
32
    /// Turns into a conjunction when it reaches a boolean context
33
    Bubble(Metadata, Box<Expression>, Box<Expression>),
34

            
35
    Atomic(Metadata, Atom),
36

            
37
    #[compatible(Minion, JsonInput)]
38
    Sum(Metadata, Vec<Expression>),
39

            
40
    // /// Division after preventing division by zero, usually with a top-level constraint
41
    // #[compatible(Minion)]
42
    // SafeDiv(Metadata, Box<Expression>, Box<Expression>),
43
    // /// Division with a possibly undefined value (division by 0)
44
    // #[compatible(Minion, JsonInput)]
45
    // Div(Metadata, Box<Expression>, Box<Expression>),
46
    #[compatible(JsonInput)]
47
    Min(Metadata, Vec<Expression>),
48

            
49
    #[compatible(JsonInput)]
50
    Max(Metadata, Vec<Expression>),
51

            
52
    #[compatible(JsonInput, SAT)]
53
    Not(Metadata, Box<Expression>),
54

            
55
    #[compatible(JsonInput, SAT)]
56
    Or(Metadata, Vec<Expression>),
57

            
58
    #[compatible(JsonInput, SAT)]
59
    And(Metadata, Vec<Expression>),
60

            
61
    #[compatible(JsonInput)]
62
    Eq(Metadata, Box<Expression>, Box<Expression>),
63

            
64
    #[compatible(JsonInput)]
65
    Neq(Metadata, Box<Expression>, Box<Expression>),
66

            
67
    #[compatible(JsonInput)]
68
    Geq(Metadata, Box<Expression>, Box<Expression>),
69

            
70
    #[compatible(JsonInput)]
71
    Leq(Metadata, Box<Expression>, Box<Expression>),
72

            
73
    #[compatible(JsonInput)]
74
    Gt(Metadata, Box<Expression>, Box<Expression>),
75

            
76
    #[compatible(JsonInput)]
77
    Lt(Metadata, Box<Expression>, Box<Expression>),
78

            
79
    /// Division after preventing division by zero, usually with a bubble
80
    SafeDiv(Metadata, Box<Expression>, Box<Expression>),
81

            
82
    /// Division with a possibly undefined value (division by 0)
83
    #[compatible(JsonInput)]
84
    UnsafeDiv(Metadata, Box<Expression>, Box<Expression>),
85

            
86
    /// Modulo after preventing mod 0, usually with a bubble
87
    SafeMod(Metadata, Box<Expression>, Box<Expression>),
88

            
89
    /// Modulo with a possibly undefined value (mod 0)
90
    #[compatible(JsonInput)]
91
    UnsafeMod(Metadata, Box<Expression>, Box<Expression>),
92

            
93
    /* Flattened SumEq.
94
     *
95
     * Note: this is an intermediary step that's used in the process of converting from conjure model to minion.
96
     * This is NOT a valid expression in either Essence or minion.
97
     *
98
     * ToDo: This is a stop gap solution. Eventually it may be better to have multiple constraints instead? (gs248)
99
     */
100
    SumEq(Metadata, Vec<Expression>, Box<Expression>),
101

            
102
    // Flattened Constraints
103
    #[compatible(Minion)]
104
    SumGeq(Metadata, Vec<Expression>, Box<Expression>),
105

            
106
    #[compatible(Minion)]
107
    SumLeq(Metadata, Vec<Expression>, Box<Expression>),
108

            
109
    /// `a / b = c`
110
    #[compatible(Minion)]
111
    DivEqUndefZero(Metadata, Atom, Atom, Atom),
112

            
113
    /// `a % b = c`
114
    #[compatible(Minion)]
115
    ModuloEqUndefZero(Metadata, Atom, Atom, Atom),
116

            
117
    #[compatible(Minion)]
118
    Ineq(Metadata, Box<Expression>, Box<Expression>, Box<Expression>),
119

            
120
    #[compatible(Minion)]
121
    AllDiff(Metadata, Vec<Expression>),
122

            
123
    /// w-literal(x,k) is SAT iff x == k, where x is a variable and k a constant.
124
    ///
125
    /// This is a low-level Minion constraint and you should (probably) use Eq instead. The main
126
    /// use of w-literal is to convert boolean variables to constraints so that they can be used
127
    /// inside watched-and and watched-or.
128
    ///
129
    /// See `rules::minion::boolean_literal_to_wliteral`.
130
    ///
131
    ///
132
    #[compatible(Minion)]
133
    WatchedLiteral(Metadata, Name, Literal),
134

            
135
    #[compatible(Minion)]
136
    Reify(Metadata, Box<Expression>, Box<Expression>),
137

            
138
    /// Declaration of an auxiliary variable.
139
    ///
140
    /// As with Savile Row, we semantically distinguish this from `Eq`.
141
    #[compatible(Minion)]
142
    AuxDeclaration(Metadata, Name, Box<Expression>),
143
}
144

            
145
5
fn expr_vec_to_domain_i32(
146
5
    exprs: &[Expression],
147
5
    op: fn(i32, i32) -> Option<i32>,
148
5
    vars: &SymbolTable,
149
5
) -> Option<Domain> {
150
8
    let domains: Vec<Option<_>> = exprs.iter().map(|e| e.domain_of(vars)).collect();
151
5
    domains
152
5
        .into_iter()
153
5
        .reduce(|a, b| a.and_then(|x| b.and_then(|y| x.apply_i32(op, &y))))
154
5
        .flatten()
155
5
}
156

            
157
1
fn range_vec_bounds_i32(ranges: &Vec<Range<i32>>) -> (i32, i32) {
158
1
    let mut min = i32::MAX;
159
1
    let mut max = i32::MIN;
160
5
    for r in ranges {
161
4
        match r {
162
4
            Range::Single(i) => {
163
4
                if *i < min {
164
1
                    min = *i;
165
3
                }
166
4
                if *i > max {
167
3
                    max = *i;
168
3
                }
169
            }
170
            Range::Bounded(i, j) => {
171
                if *i < min {
172
                    min = *i;
173
                }
174
                if *j > max {
175
                    max = *j;
176
                }
177
            }
178
        }
179
    }
180
1
    (min, max)
181
1
}
182

            
183
impl Expression {
184
    /// Returns the possible values of the expression, recursing to leaf expressions
185
15
    pub fn domain_of(&self, vars: &SymbolTable) -> Option<Domain> {
186
14
        let ret = match self {
187
6
            Expression::Atomic(_, Atom::Reference(name)) => Some(vars.get(name)?.domain.clone()),
188
3
            Expression::Atomic(_, Atom::Literal(Literal::Int(n))) => {
189
3
                Some(Domain::IntDomain(vec![Range::Single(*n)]))
190
            }
191
1
            Expression::Atomic(_, Atom::Literal(Literal::Bool(_))) => Some(Domain::BoolDomain),
192
6
            Expression::Sum(_, exprs) => expr_vec_to_domain_i32(exprs, |x, y| Some(x + y), vars),
193
            Expression::Min(_, exprs) => {
194
                expr_vec_to_domain_i32(exprs, |x, y| Some(if x < y { x } else { y }), vars)
195
            }
196
            Expression::Max(_, exprs) => {
197
                expr_vec_to_domain_i32(exprs, |x, y| Some(if x > y { x } else { y }), vars)
198
            }
199
            Expression::UnsafeDiv(_, a, b) => a.domain_of(vars)?.apply_i32(
200
                |x, y| if y != 0 { Some(x / y) } else { None },
201
                &b.domain_of(vars)?,
202
            ),
203
            Expression::SafeDiv(_, a, b) => {
204
                let domain = a.domain_of(vars)?.apply_i32(
205
                    |x, y| if y != 0 { Some(x / y) } else { None },
206
                    &b.domain_of(vars)?,
207
                );
208

            
209
                match domain {
210
                    Some(Domain::IntDomain(ranges)) => {
211
                        let mut ranges = ranges;
212
                        ranges.push(Range::Single(0));
213
                        Some(Domain::IntDomain(ranges))
214
                    }
215
                    None => Some(Domain::IntDomain(vec![Range::Single(0)])),
216
                    _ => None,
217
                }
218
            }
219
            Expression::UnsafeMod(_, a, b) => a.domain_of(vars)?.apply_i32(
220
                |x, y| if y != 0 { Some(x % y) } else { None },
221
                &b.domain_of(vars)?,
222
            ),
223

            
224
            Expression::SafeMod(_, a, b) => {
225
                let domain = a.domain_of(vars)?.apply_i32(
226
                    |x, y| if y != 0 { Some(x % y) } else { None },
227
                    &b.domain_of(vars)?,
228
                );
229

            
230
                match domain {
231
                    Some(Domain::IntDomain(ranges)) => {
232
                        let mut ranges = ranges;
233
                        ranges.push(Range::Single(0));
234
                        Some(Domain::IntDomain(ranges))
235
                    }
236
                    None => Some(Domain::IntDomain(vec![Range::Single(0)])),
237
                    _ => None,
238
                }
239
            }
240
            Expression::Bubble(_, _, _) => None,
241
            Expression::AuxDeclaration(_, _, _) => Some(Domain::BoolDomain),
242
            Expression::And(_, _) => Some(Domain::BoolDomain),
243
            _ => bug!("Cannot calculate domain of {:?}", self),
244
            // TODO: (flm8) Add support for calculating the domains of more expression types
245
        };
246
11
        match ret {
247
            // TODO: (flm8) the Minion bindings currently only support single ranges for domains, so we use the min/max bounds
248
            // Once they support a full domain as we define it, we can remove this conversion
249
11
            Some(Domain::IntDomain(ranges)) if ranges.len() > 1 => {
250
1
                let (min, max) = range_vec_bounds_i32(&ranges);
251
1
                Some(Domain::IntDomain(vec![Range::Bounded(min, max)]))
252
            }
253
13
            _ => ret,
254
        }
255
15
    }
256

            
257
    pub fn get_meta(&self) -> Metadata {
258
        <Expression as Biplate<Metadata>>::children_bi(self)[0].clone()
259
    }
260

            
261
    pub fn set_meta(&self, meta: Metadata) {
262
        <Expression as Biplate<Metadata>>::transform_bi(self, Arc::new(move |_| meta.clone()));
263
    }
264

            
265
    pub fn can_be_undefined(&self) -> bool {
266
        // TODO: there will be more false cases but we are being conservative
267
        match self {
268
            Expression::Atomic(_, _) => false,
269
            Expression::SafeDiv(_, _, _) => false,
270
            Expression::SafeMod(_, _, _) => false,
271
            _ => true,
272
        }
273
    }
274

            
275
    pub fn return_type(&self) -> Option<ReturnType> {
276
        match self {
277
            Expression::Atomic(_, Atom::Literal(Literal::Int(_))) => Some(ReturnType::Int),
278
            Expression::Atomic(_, Atom::Literal(Literal::Bool(_))) => Some(ReturnType::Bool),
279
            Expression::Atomic(_, Atom::Reference(_)) => None,
280
            Expression::Sum(_, _) => Some(ReturnType::Int),
281
            Expression::Min(_, _) => Some(ReturnType::Int),
282
            Expression::Max(_, _) => Some(ReturnType::Int),
283
            Expression::Not(_, _) => Some(ReturnType::Bool),
284
            Expression::Or(_, _) => Some(ReturnType::Bool),
285
            Expression::And(_, _) => Some(ReturnType::Bool),
286
            Expression::Eq(_, _, _) => Some(ReturnType::Bool),
287
            Expression::Neq(_, _, _) => Some(ReturnType::Bool),
288
            Expression::Geq(_, _, _) => Some(ReturnType::Bool),
289
            Expression::Leq(_, _, _) => Some(ReturnType::Bool),
290
            Expression::Gt(_, _, _) => Some(ReturnType::Bool),
291
            Expression::Lt(_, _, _) => Some(ReturnType::Bool),
292
            Expression::SafeDiv(_, _, _) => Some(ReturnType::Int),
293
            Expression::UnsafeDiv(_, _, _) => Some(ReturnType::Int),
294
            Expression::SumEq(_, _, _) => Some(ReturnType::Bool),
295
            Expression::SumGeq(_, _, _) => Some(ReturnType::Bool),
296
            Expression::SumLeq(_, _, _) => Some(ReturnType::Bool),
297
            Expression::DivEqUndefZero(_, _, _, _) => Some(ReturnType::Bool),
298
            Expression::Ineq(_, _, _, _) => Some(ReturnType::Bool),
299
            Expression::AllDiff(_, _) => Some(ReturnType::Bool),
300
            Expression::Bubble(_, _, _) => None, // TODO: (flm8) should this be a bool?
301
            Expression::WatchedLiteral(_, _, _) => Some(ReturnType::Bool),
302
            Expression::Reify(_, _, _) => Some(ReturnType::Bool),
303
            Expression::AuxDeclaration(_, _, _) => Some(ReturnType::Bool),
304
            Expression::UnsafeMod(_, _, _) => Some(ReturnType::Int),
305
            Expression::SafeMod(_, _, _) => Some(ReturnType::Int),
306
            Expression::ModuloEqUndefZero(_, _, _, _) => Some(ReturnType::Bool),
307
        }
308
    }
309

            
310
    pub fn is_clean(&self) -> bool {
311
        let metadata = self.get_meta();
312
        metadata.clean
313
    }
314

            
315
    pub fn set_clean(&mut self, bool_value: bool) {
316
        let mut metadata = self.get_meta();
317
        metadata.clean = bool_value;
318
        self.set_meta(metadata);
319
    }
320

            
321
    pub fn as_atom(&self) -> Option<Atom> {
322
        if let Expression::Atomic(_m, f) = self {
323
            Some(f.clone())
324
        } else {
325
            None
326
        }
327
    }
328
}
329

            
330
9
fn display_expressions(expressions: &[Expression]) -> String {
331
9
    // if expressions.len() <= 3 {
332
9
    format!(
333
9
        "[{}]",
334
9
        expressions
335
9
            .iter()
336
9
            .map(|e| e.to_string())
337
9
            .collect::<Vec<String>>()
338
9
            .join(", ")
339
9
    )
340
9
    // } else {
341
9
    //     format!(
342
9
    //         "[{}..{}]",
343
9
    //         expressions[0],
344
9
    //         expressions[expressions.len() - 1]
345
9
    //     )
346
9
    // }
347
9
}
348

            
349
impl From<i32> for Expression {
350
    fn from(i: i32) -> Self {
351
        Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Int(i)))
352
    }
353
}
354

            
355
impl From<bool> for Expression {
356
    fn from(b: bool) -> Self {
357
        Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Bool(b)))
358
    }
359
}
360

            
361
impl From<Atom> for Expression {
362
    fn from(value: Atom) -> Self {
363
        Expression::Atomic(Metadata::new(), value)
364
    }
365
}
366
impl Display for Expression {
367
    // TODO: (flm8) this will change once we implement a parser (two-way conversion)
368
18
    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
369
18
        match &self {
370
9
            Expression::Atomic(_, atom) => atom.fmt(f),
371
            Expression::Sum(_, expressions) => {
372
                write!(f, "Sum({})", display_expressions(expressions))
373
            }
374
            Expression::Min(_, expressions) => {
375
                write!(f, "Min({})", display_expressions(expressions))
376
            }
377
            Expression::Max(_, expressions) => {
378
                write!(f, "Max({})", display_expressions(expressions))
379
            }
380
            Expression::Not(_, expr_box) => {
381
                write!(f, "Not({})", expr_box.clone())
382
            }
383
9
            Expression::Or(_, expressions) => {
384
9
                write!(f, "Or({})", display_expressions(expressions))
385
            }
386
            Expression::And(_, expressions) => {
387
                write!(f, "And({})", display_expressions(expressions))
388
            }
389
            Expression::Eq(_, box1, box2) => {
390
                write!(f, "({} = {})", box1.clone(), box2.clone())
391
            }
392
            Expression::Neq(_, box1, box2) => {
393
                write!(f, "({} != {})", box1.clone(), box2.clone())
394
            }
395
            Expression::Geq(_, box1, box2) => {
396
                write!(f, "({} >= {})", box1.clone(), box2.clone())
397
            }
398
            Expression::Leq(_, box1, box2) => {
399
                write!(f, "({} <= {})", box1.clone(), box2.clone())
400
            }
401
            Expression::Gt(_, box1, box2) => {
402
                write!(f, "({} > {})", box1.clone(), box2.clone())
403
            }
404
            Expression::Lt(_, box1, box2) => {
405
                write!(f, "({} < {})", box1.clone(), box2.clone())
406
            }
407
            Expression::SumEq(_, expressions, expr_box) => {
408
                write!(
409
                    f,
410
                    "SumEq({}, {})",
411
                    display_expressions(expressions),
412
                    expr_box.clone()
413
                )
414
            }
415
            Expression::SumGeq(_, box1, box2) => {
416
                write!(f, "SumGeq({}, {})", display_expressions(box1), box2.clone())
417
            }
418
            Expression::SumLeq(_, box1, box2) => {
419
                write!(f, "SumLeq({}, {})", display_expressions(box1), box2.clone())
420
            }
421
            Expression::Ineq(_, box1, box2, box3) => write!(
422
                f,
423
                "Ineq({}, {}, {})",
424
                box1.clone(),
425
                box2.clone(),
426
                box3.clone()
427
            ),
428
            Expression::AllDiff(_, expressions) => {
429
                write!(f, "AllDiff({})", display_expressions(expressions))
430
            }
431
            Expression::Bubble(_, box1, box2) => {
432
                write!(f, "{{{} @ {}}}", box1.clone(), box2.clone())
433
            }
434
            Expression::SafeDiv(_, box1, box2) => {
435
                write!(f, "SafeDiv({}, {})", box1.clone(), box2.clone())
436
            }
437
            Expression::UnsafeDiv(_, box1, box2) => {
438
                write!(f, "UnsafeDiv({}, {})", box1.clone(), box2.clone())
439
            }
440
            Expression::DivEqUndefZero(_, box1, box2, box3) => {
441
                write!(
442
                    f,
443
                    "DivEq({}, {}, {})",
444
                    box1.clone(),
445
                    box2.clone(),
446
                    box3.clone()
447
                )
448
            }
449
            Expression::ModuloEqUndefZero(_, box1, box2, box3) => {
450
                write!(
451
                    f,
452
                    "ModEq({}, {}, {})",
453
                    box1.clone(),
454
                    box2.clone(),
455
                    box3.clone()
456
                )
457
            }
458

            
459
            Expression::WatchedLiteral(_, x, l) => {
460
                write!(f, "WatchedLiteral({},{})", x, l)
461
            }
462
            Expression::Reify(_, box1, box2) => {
463
                write!(f, "Reify({}, {})", box1.clone(), box2.clone())
464
            }
465
            Expression::AuxDeclaration(_, n, e) => {
466
                write!(f, "{} =aux {}", n, e.clone())
467
            }
468
            Expression::UnsafeMod(_, a, b) => {
469
                write!(f, "{} % {}", a.clone(), b.clone())
470
            }
471
            Expression::SafeMod(_, a, b) => {
472
                write!(f, "SafeMod({},{})", a.clone(), b.clone())
473
            }
474
        }
475
18
    }
476
}
477

            
478
#[cfg(test)]
479
mod tests {
480
    use crate::ast::DecisionVariable;
481

            
482
    use super::*;
483

            
484
    #[test]
485
1
    fn test_domain_of_constant_sum() {
486
1
        let c1 = Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Int(1)));
487
1
        let c2 = Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Int(2)));
488
1
        let sum = Expression::Sum(Metadata::new(), vec![c1.clone(), c2.clone()]);
489
1
        assert_eq!(
490
1
            sum.domain_of(&SymbolTable::new()),
491
1
            Some(Domain::IntDomain(vec![Range::Single(3)]))
492
1
        );
493
1
    }
494

            
495
    #[test]
496
1
    fn test_domain_of_constant_invalid_type() {
497
1
        let c1 = Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Int(1)));
498
1
        let c2 = Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Bool(true)));
499
1
        let sum = Expression::Sum(Metadata::new(), vec![c1.clone(), c2.clone()]);
500
1
        assert_eq!(sum.domain_of(&SymbolTable::new()), None);
501
1
    }
502

            
503
    #[test]
504
1
    fn test_domain_of_empty_sum() {
505
1
        let sum = Expression::Sum(Metadata::new(), vec![]);
506
1
        assert_eq!(sum.domain_of(&SymbolTable::new()), None);
507
1
    }
508

            
509
    #[test]
510
1
    fn test_domain_of_reference() {
511
1
        let reference = Expression::Atomic(Metadata::new(), Atom::Reference(Name::MachineName(0)));
512
1
        let mut vars = SymbolTable::new();
513
1
        vars.insert(
514
1
            Name::MachineName(0),
515
1
            DecisionVariable::new(Domain::IntDomain(vec![Range::Single(1)])),
516
1
        );
517
1
        assert_eq!(
518
1
            reference.domain_of(&vars),
519
1
            Some(Domain::IntDomain(vec![Range::Single(1)]))
520
1
        );
521
1
    }
522

            
523
    #[test]
524
1
    fn test_domain_of_reference_not_found() {
525
1
        let reference = Expression::Atomic(Metadata::new(), Atom::Reference(Name::MachineName(0)));
526
1
        assert_eq!(reference.domain_of(&SymbolTable::new()), None);
527
1
    }
528

            
529
    #[test]
530
1
    fn test_domain_of_reference_sum_single() {
531
1
        let reference = Expression::Atomic(Metadata::new(), Atom::Reference(Name::MachineName(0)));
532
1
        let mut vars = SymbolTable::new();
533
1
        vars.insert(
534
1
            Name::MachineName(0),
535
1
            DecisionVariable::new(Domain::IntDomain(vec![Range::Single(1)])),
536
1
        );
537
1
        let sum = Expression::Sum(Metadata::new(), vec![reference.clone(), reference.clone()]);
538
1
        assert_eq!(
539
1
            sum.domain_of(&vars),
540
1
            Some(Domain::IntDomain(vec![Range::Single(2)]))
541
1
        );
542
1
    }
543

            
544
    #[test]
545
1
    fn test_domain_of_reference_sum_bounded() {
546
1
        let reference = Expression::Atomic(Metadata::new(), Atom::Reference(Name::MachineName(0)));
547
1
        let mut vars = SymbolTable::new();
548
1
        vars.insert(
549
1
            Name::MachineName(0),
550
1
            DecisionVariable::new(Domain::IntDomain(vec![Range::Bounded(1, 2)])),
551
1
        );
552
1
        let sum = Expression::Sum(Metadata::new(), vec![reference.clone(), reference.clone()]);
553
1
        assert_eq!(
554
1
            sum.domain_of(&vars),
555
1
            Some(Domain::IntDomain(vec![Range::Bounded(2, 4)]))
556
1
        );
557
1
    }
558
}