1
use std::fmt::{Display, Formatter};
2
use std::sync::Arc;
3

            
4
use serde::{Deserialize, Serialize};
5

            
6
use enum_compatability_macro::document_compatibility;
7
use uniplate::derive::Uniplate;
8
use uniplate::Biplate;
9

            
10
use crate::ast::literals::Literal;
11
use crate::ast::symbol_table::{Name, SymbolTable};
12
use crate::ast::Atom;
13
use crate::ast::ReturnType;
14
use crate::bug;
15
use crate::metadata::Metadata;
16

            
17
use super::{Domain, Range};
18

            
19
/// Represents different types of expressions used to define rules and constraints in the model.
20
///
21
/// The `Expression` enum includes operations, constants, and variable references
22
/// used to build rules and conditions for the model.
23
#[document_compatibility]
24
3864
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize, Uniplate)]
25
#[uniplate(walk_into=[Atom])]
26
#[biplate(to=Literal)]
27
#[biplate(to=Metadata)]
28
#[biplate(to=Atom)]
29
#[biplate(to=Name)]
30
pub enum Expression {
31
    /// An expression representing "A is valid as long as B is true"
32
    /// Turns into a conjunction when it reaches a boolean context
33
    Bubble(Metadata, Box<Expression>, Box<Expression>),
34

            
35
    Atomic(Metadata, Atom),
36

            
37
    #[compatible(Minion, JsonInput)]
38
    Sum(Metadata, Vec<Expression>),
39

            
40
    // /// Division after preventing division by zero, usually with a top-level constraint
41
    // #[compatible(Minion)]
42
    // SafeDiv(Metadata, Box<Expression>, Box<Expression>),
43
    // /// Division with a possibly undefined value (division by 0)
44
    // #[compatible(Minion, JsonInput)]
45
    // Div(Metadata, Box<Expression>, Box<Expression>),
46
    #[compatible(JsonInput)]
47
    Min(Metadata, Vec<Expression>),
48

            
49
    #[compatible(JsonInput)]
50
    Max(Metadata, Vec<Expression>),
51

            
52
    #[compatible(JsonInput, SAT)]
53
    Not(Metadata, Box<Expression>),
54

            
55
    #[compatible(JsonInput, SAT)]
56
    Or(Metadata, Vec<Expression>),
57

            
58
    #[compatible(JsonInput, SAT)]
59
    And(Metadata, Vec<Expression>),
60

            
61
    #[compatible(JsonInput)]
62
    Eq(Metadata, Box<Expression>, Box<Expression>),
63

            
64
    #[compatible(JsonInput)]
65
    Neq(Metadata, Box<Expression>, Box<Expression>),
66

            
67
    #[compatible(JsonInput)]
68
    Geq(Metadata, Box<Expression>, Box<Expression>),
69

            
70
    #[compatible(JsonInput)]
71
    Leq(Metadata, Box<Expression>, Box<Expression>),
72

            
73
    #[compatible(JsonInput)]
74
    Gt(Metadata, Box<Expression>, Box<Expression>),
75

            
76
    #[compatible(JsonInput)]
77
    Lt(Metadata, Box<Expression>, Box<Expression>),
78

            
79
    /// Division after preventing division by zero, usually with a bubble
80
    SafeDiv(Metadata, Box<Expression>, Box<Expression>),
81

            
82
    /// Division with a possibly undefined value (division by 0)
83
    #[compatible(JsonInput)]
84
    UnsafeDiv(Metadata, Box<Expression>, Box<Expression>),
85

            
86
    /* Flattened SumEq.
87
     *
88
     * Note: this is an intermediary step that's used in the process of converting from conjure model to minion.
89
     * This is NOT a valid expression in either Essence or minion.
90
     *
91
     * ToDo: This is a stop gap solution. Eventually it may be better to have multiple constraints instead? (gs248)
92
     */
93
    SumEq(Metadata, Vec<Expression>, Box<Expression>),
94

            
95
    // Flattened Constraints
96
    #[compatible(Minion)]
97
    SumGeq(Metadata, Vec<Expression>, Box<Expression>),
98

            
99
    #[compatible(Minion)]
100
    SumLeq(Metadata, Vec<Expression>, Box<Expression>),
101

            
102
    /// `a / b = c`
103
    #[compatible(Minion)]
104
    DivEq(Metadata, Atom, Atom, Atom),
105

            
106
    #[compatible(Minion)]
107
    Ineq(Metadata, Box<Expression>, Box<Expression>, Box<Expression>),
108

            
109
    #[compatible(Minion)]
110
    AllDiff(Metadata, Vec<Expression>),
111

            
112
    /// w-literal(x,k) is SAT iff x == k, where x is a variable and k a constant.
113
    ///
114
    /// This is a low-level Minion constraint and you should (probably) use Eq instead. The main
115
    /// use of w-literal is to convert boolean variables to constraints so that they can be used
116
    /// inside watched-and and watched-or.
117
    ///
118
    /// See `rules::minion::boolean_literal_to_wliteral`.
119
    ///
120
    ///
121
    #[compatible(Minion)]
122
    WatchedLiteral(Metadata, Name, Literal),
123

            
124
    #[compatible(Minion)]
125
    Reify(Metadata, Box<Expression>, Box<Expression>),
126

            
127
    /// Declaration of an auxiliary variable.
128
    ///
129
    /// As with Savile Row, we semantically distinguish this from `Eq`.
130
    #[compatible(Minion)]
131
    AuxDeclaration(Metadata, Name, Box<Expression>),
132
}
133

            
134
192
fn expr_vec_to_domain_i32(
135
192
    exprs: &[Expression],
136
192
    op: fn(i32, i32) -> Option<i32>,
137
192
    vars: &SymbolTable,
138
192
) -> Option<Domain> {
139
382
    let domains: Vec<Option<_>> = exprs.iter().map(|e| e.domain_of(vars)).collect();
140
192
    domains
141
192
        .into_iter()
142
192
        .reduce(|a, b| a.and_then(|x| b.and_then(|y| x.apply_i32(op, &y))))
143
192
        .flatten()
144
192
}
145

            
146
511
fn range_vec_bounds_i32(ranges: &Vec<Range<i32>>) -> (i32, i32) {
147
511
    let mut min = i32::MAX;
148
511
    let mut max = i32::MIN;
149
16019
    for r in ranges {
150
15508
        match r {
151
15508
            Range::Single(i) => {
152
15508
                if *i < min {
153
800
                    min = *i;
154
14708
                }
155
15508
                if *i > max {
156
3097
                    max = *i;
157
12413
                }
158
            }
159
            Range::Bounded(i, j) => {
160
                if *i < min {
161
                    min = *i;
162
                }
163
                if *j > max {
164
                    max = *j;
165
                }
166
            }
167
        }
168
    }
169
511
    (min, max)
170
511
}
171

            
172
impl Expression {
173
    /// Returns the possible values of the expression, recursing to leaf expressions
174
1545
    pub fn domain_of(&self, vars: &SymbolTable) -> Option<Domain> {
175
1544
        let ret = match self {
176
941
            Expression::Atomic(_, Atom::Reference(name)) => Some(vars.get(name)?.domain.clone()),
177
71
            Expression::Atomic(_, Atom::Literal(Literal::Int(n))) => {
178
71
                Some(Domain::IntDomain(vec![Range::Single(*n)]))
179
            }
180
1
            Expression::Atomic(_, Atom::Literal(Literal::Bool(_))) => Some(Domain::BoolDomain),
181
6
            Expression::Sum(_, exprs) => expr_vec_to_domain_i32(exprs, |x, y| Some(x + y), vars),
182
119
            Expression::Min(_, exprs) => {
183
1700
                expr_vec_to_domain_i32(exprs, |x, y| Some(if x < y { x } else { y }), vars)
184
            }
185
68
            Expression::Max(_, exprs) => {
186
884
                expr_vec_to_domain_i32(exprs, |x, y| Some(if x > y { x } else { y }), vars)
187
            }
188
            Expression::UnsafeDiv(_, a, b) => a.domain_of(vars)?.apply_i32(
189
                |x, y| if y != 0 { Some(x / y) } else { None },
190
                &b.domain_of(vars)?,
191
            ),
192
340
            Expression::SafeDiv(_, a, b) => {
193
340
                let domain = a.domain_of(vars)?.apply_i32(
194
14977
                    |x, y| if y != 0 { Some(x / y) } else { None },
195
340
                    &b.domain_of(vars)?,
196
                );
197

            
198
340
                match domain {
199
340
                    Some(Domain::IntDomain(ranges)) => {
200
340
                        let mut ranges = ranges;
201
340
                        ranges.push(Range::Single(0));
202
340
                        Some(Domain::IntDomain(ranges))
203
                    }
204
                    None => Some(Domain::IntDomain(vec![Range::Single(0)])),
205
                    _ => None,
206
                }
207
            }
208
            Expression::SafeDiv(_, a, b) => {
209
                let domain = a.domain_of(vars)?.apply_i32(
210
                    |x, y| if y != 0 { Some(x / y) } else { None },
211
                    &b.domain_of(vars)?,
212
                );
213
                match domain {
214
                    Some(Domain::IntDomain(v)) if v.len() == 1 => {
215
                        let range = match v[0] {
216
                            Range::Single(a) if a > 0 => Range::Bounded(0, a),
217
                            Range::Single(a) if a < 0 => Range::Bounded(a, 0),
218
                            Range::Single(0) => Range::Single(0),
219
                            Range::Bounded(a, b) if a < 0 => Range::Bounded(a, b),
220
                            Range::Bounded(_, b) => Range::Bounded(0, b),
221
                            _ => unreachable!(),
222
                        };
223

            
224
                        Some(Domain::IntDomain(vec![range]))
225
                    }
226
                    _ => None,
227
                }
228
            }
229
            Expression::Bubble(_, _, _) => None,
230
            Expression::AuxDeclaration(_, _, _) => Some(Domain::BoolDomain),
231
            Expression::And(_, _) => Some(Domain::BoolDomain),
232
            _ => bug!("Cannot calculate domain of {:?}", self),
233
            // TODO: (flm8) Add support for calculating the domains of more expression types
234
        };
235
1541
        match ret {
236
            // TODO: (flm8) the Minion bindings currently only support single ranges for domains, so we use the min/max bounds
237
            // Once they support a full domain as we define it, we can remove this conversion
238
1541
            Some(Domain::IntDomain(ranges)) if ranges.len() > 1 => {
239
511
                let (min, max) = range_vec_bounds_i32(&ranges);
240
511
                Some(Domain::IntDomain(vec![Range::Bounded(min, max)]))
241
            }
242
1033
            _ => ret,
243
        }
244
1545
    }
245

            
246
17
    pub fn get_meta(&self) -> Metadata {
247
17
        <Expression as Biplate<Metadata>>::children_bi(self)[0].clone()
248
17
    }
249

            
250
    pub fn set_meta(&self, meta: Metadata) {
251
        <Expression as Biplate<Metadata>>::transform_bi(self, Arc::new(move |_| meta.clone()));
252
    }
253

            
254
5355
    pub fn can_be_undefined(&self) -> bool {
255
5355
        // TODO: there will be more false cases but we are being conservative
256
5355
        match self {
257
731
            Expression::Atomic(_, _) => false,
258
408
            Expression::SafeDiv(_, _, _) => false,
259
4216
            _ => true,
260
        }
261
5355
    }
262

            
263
765
    pub fn return_type(&self) -> Option<ReturnType> {
264
        match self {
265
            Expression::Atomic(_, Atom::Literal(Literal::Int(_))) => Some(ReturnType::Int),
266
            Expression::Atomic(_, Atom::Literal(Literal::Bool(_))) => Some(ReturnType::Bool),
267
            Expression::Atomic(_, Atom::Reference(_)) => None,
268
            Expression::Sum(_, _) => Some(ReturnType::Int),
269
            Expression::Min(_, _) => Some(ReturnType::Int),
270
            Expression::Max(_, _) => Some(ReturnType::Int),
271
            Expression::Not(_, _) => Some(ReturnType::Bool),
272
            Expression::Or(_, _) => Some(ReturnType::Bool),
273
            Expression::And(_, _) => Some(ReturnType::Bool),
274
255
            Expression::Eq(_, _, _) => Some(ReturnType::Bool),
275
68
            Expression::Neq(_, _, _) => Some(ReturnType::Bool),
276
            Expression::Geq(_, _, _) => Some(ReturnType::Bool),
277
            Expression::Leq(_, _, _) => Some(ReturnType::Bool),
278
            Expression::Gt(_, _, _) => Some(ReturnType::Bool),
279
            Expression::Lt(_, _, _) => Some(ReturnType::Bool),
280
221
            Expression::SafeDiv(_, _, _) => Some(ReturnType::Int),
281
51
            Expression::UnsafeDiv(_, _, _) => Some(ReturnType::Int),
282
            Expression::SumEq(_, _, _) => Some(ReturnType::Bool),
283
            Expression::SumGeq(_, _, _) => Some(ReturnType::Bool),
284
            Expression::SumLeq(_, _, _) => Some(ReturnType::Bool),
285
            Expression::DivEq(_, _, _, _) => Some(ReturnType::Bool),
286
            Expression::Ineq(_, _, _, _) => Some(ReturnType::Bool),
287
            Expression::AllDiff(_, _) => Some(ReturnType::Bool),
288
            Expression::Bubble(_, _, _) => None, // TODO: (flm8) should this be a bool?
289
            Expression::WatchedLiteral(_, _, _) => Some(ReturnType::Bool),
290
            Expression::Reify(_, _, _) => Some(ReturnType::Bool),
291
170
            Expression::AuxDeclaration(_, _, _) => Some(ReturnType::Bool),
292
        }
293
765
    }
294

            
295
    pub fn is_clean(&self) -> bool {
296
        let metadata = self.get_meta();
297
        metadata.clean
298
    }
299

            
300
    pub fn set_clean(&mut self, bool_value: bool) {
301
        let mut metadata = self.get_meta();
302
        metadata.clean = bool_value;
303
        self.set_meta(metadata);
304
    }
305

            
306
8874
    pub fn as_atom(&self) -> Option<Atom> {
307
8874
        if let Expression::Atomic(_m, f) = self {
308
5134
            Some(f.clone())
309
        } else {
310
3740
            None
311
        }
312
8874
    }
313
}
314

            
315
fn display_expressions(expressions: &[Expression]) -> String {
316
    // if expressions.len() <= 3 {
317
    format!(
318
        "[{}]",
319
        expressions
320
            .iter()
321
            .map(|e| e.to_string())
322
            .collect::<Vec<String>>()
323
            .join(", ")
324
    )
325
    // } else {
326
    //     format!(
327
    //         "[{}..{}]",
328
    //         expressions[0],
329
    //         expressions[expressions.len() - 1]
330
    //     )
331
    // }
332
}
333

            
334
impl From<i32> for Expression {
335
221
    fn from(i: i32) -> Self {
336
221
        Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Int(i)))
337
221
    }
338
}
339

            
340
impl From<bool> for Expression {
341
    fn from(b: bool) -> Self {
342
        Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Bool(b)))
343
    }
344
}
345

            
346
impl From<Atom> for Expression {
347
816
    fn from(value: Atom) -> Self {
348
816
        Expression::Atomic(Metadata::new(), value)
349
816
    }
350
}
351
impl Display for Expression {
352
    // TODO: (flm8) this will change once we implement a parser (two-way conversion)
353
    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
354
        match &self {
355
            Expression::Atomic(_, atom) => atom.fmt(f),
356
            Expression::Sum(_, expressions) => {
357
                write!(f, "Sum({})", display_expressions(expressions))
358
            }
359
            Expression::Min(_, expressions) => {
360
                write!(f, "Min({})", display_expressions(expressions))
361
            }
362
            Expression::Max(_, expressions) => {
363
                write!(f, "Max({})", display_expressions(expressions))
364
            }
365
            Expression::Not(_, expr_box) => {
366
                write!(f, "Not({})", expr_box.clone())
367
            }
368
            Expression::Or(_, expressions) => {
369
                write!(f, "Or({})", display_expressions(expressions))
370
            }
371
            Expression::And(_, expressions) => {
372
                write!(f, "And({})", display_expressions(expressions))
373
            }
374
            Expression::Eq(_, box1, box2) => {
375
                write!(f, "({} = {})", box1.clone(), box2.clone())
376
            }
377
            Expression::Neq(_, box1, box2) => {
378
                write!(f, "({} != {})", box1.clone(), box2.clone())
379
            }
380
            Expression::Geq(_, box1, box2) => {
381
                write!(f, "({} >= {})", box1.clone(), box2.clone())
382
            }
383
            Expression::Leq(_, box1, box2) => {
384
                write!(f, "({} <= {})", box1.clone(), box2.clone())
385
            }
386
            Expression::Gt(_, box1, box2) => {
387
                write!(f, "({} > {})", box1.clone(), box2.clone())
388
            }
389
            Expression::Lt(_, box1, box2) => {
390
                write!(f, "({} < {})", box1.clone(), box2.clone())
391
            }
392
            Expression::SumEq(_, expressions, expr_box) => {
393
                write!(
394
                    f,
395
                    "SumEq({}, {})",
396
                    display_expressions(expressions),
397
                    expr_box.clone()
398
                )
399
            }
400
            Expression::SumGeq(_, box1, box2) => {
401
                write!(f, "SumGeq({}, {})", display_expressions(box1), box2.clone())
402
            }
403
            Expression::SumLeq(_, box1, box2) => {
404
                write!(f, "SumLeq({}, {})", display_expressions(box1), box2.clone())
405
            }
406
            Expression::Ineq(_, box1, box2, box3) => write!(
407
                f,
408
                "Ineq({}, {}, {})",
409
                box1.clone(),
410
                box2.clone(),
411
                box3.clone()
412
            ),
413
            Expression::AllDiff(_, expressions) => {
414
                write!(f, "AllDiff({})", display_expressions(expressions))
415
            }
416
            Expression::Bubble(_, box1, box2) => {
417
                write!(f, "{{{} @ {}}}", box1.clone(), box2.clone())
418
            }
419
            Expression::SafeDiv(_, box1, box2) => {
420
                write!(f, "SafeDiv({}, {})", box1.clone(), box2.clone())
421
            }
422
            Expression::UnsafeDiv(_, box1, box2) => {
423
                write!(f, "UnsafeDiv({}, {})", box1.clone(), box2.clone())
424
            }
425
            Expression::DivEq(_, box1, box2, box3) => {
426
                write!(
427
                    f,
428
                    "DivEq({}, {}, {})",
429
                    box1.clone(),
430
                    box2.clone(),
431
                    box3.clone()
432
                )
433
            }
434
            Expression::WatchedLiteral(_, x, l) => {
435
                write!(f, "WatchedLiteral({},{})", x, l)
436
            }
437
            Expression::Reify(_, box1, box2) => {
438
                write!(f, "Reify({}, {})", box1.clone(), box2.clone())
439
            }
440
            Expression::AuxDeclaration(_, n, e) => {
441
                write!(f, "{} =aux {}", n, e.clone())
442
            }
443
        }
444
    }
445
}
446

            
447
#[cfg(test)]
448
mod tests {
449
    use crate::ast::DecisionVariable;
450

            
451
    use super::*;
452

            
453
    #[test]
454
1
    fn test_domain_of_constant_sum() {
455
1
        let c1 = Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Int(1)));
456
1
        let c2 = Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Int(2)));
457
1
        let sum = Expression::Sum(Metadata::new(), vec![c1.clone(), c2.clone()]);
458
1
        assert_eq!(
459
1
            sum.domain_of(&SymbolTable::new()),
460
1
            Some(Domain::IntDomain(vec![Range::Single(3)]))
461
1
        );
462
1
    }
463

            
464
    #[test]
465
1
    fn test_domain_of_constant_invalid_type() {
466
1
        let c1 = Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Int(1)));
467
1
        let c2 = Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Bool(true)));
468
1
        let sum = Expression::Sum(Metadata::new(), vec![c1.clone(), c2.clone()]);
469
1
        assert_eq!(sum.domain_of(&SymbolTable::new()), None);
470
1
    }
471

            
472
    #[test]
473
1
    fn test_domain_of_empty_sum() {
474
1
        let sum = Expression::Sum(Metadata::new(), vec![]);
475
1
        assert_eq!(sum.domain_of(&SymbolTable::new()), None);
476
1
    }
477

            
478
    #[test]
479
1
    fn test_domain_of_reference() {
480
1
        let reference = Expression::Atomic(Metadata::new(), Atom::Reference(Name::MachineName(0)));
481
1
        let mut vars = SymbolTable::new();
482
1
        vars.insert(
483
1
            Name::MachineName(0),
484
1
            DecisionVariable::new(Domain::IntDomain(vec![Range::Single(1)])),
485
1
        );
486
1
        assert_eq!(
487
1
            reference.domain_of(&vars),
488
1
            Some(Domain::IntDomain(vec![Range::Single(1)]))
489
1
        );
490
1
    }
491

            
492
    #[test]
493
1
    fn test_domain_of_reference_not_found() {
494
1
        let reference = Expression::Atomic(Metadata::new(), Atom::Reference(Name::MachineName(0)));
495
1
        assert_eq!(reference.domain_of(&SymbolTable::new()), None);
496
1
    }
497

            
498
    #[test]
499
1
    fn test_domain_of_reference_sum_single() {
500
1
        let reference = Expression::Atomic(Metadata::new(), Atom::Reference(Name::MachineName(0)));
501
1
        let mut vars = SymbolTable::new();
502
1
        vars.insert(
503
1
            Name::MachineName(0),
504
1
            DecisionVariable::new(Domain::IntDomain(vec![Range::Single(1)])),
505
1
        );
506
1
        let sum = Expression::Sum(Metadata::new(), vec![reference.clone(), reference.clone()]);
507
1
        assert_eq!(
508
1
            sum.domain_of(&vars),
509
1
            Some(Domain::IntDomain(vec![Range::Single(2)]))
510
1
        );
511
1
    }
512

            
513
    #[test]
514
1
    fn test_domain_of_reference_sum_bounded() {
515
1
        let reference = Expression::Atomic(Metadata::new(), Atom::Reference(Name::MachineName(0)));
516
1
        let mut vars = SymbolTable::new();
517
1
        vars.insert(
518
1
            Name::MachineName(0),
519
1
            DecisionVariable::new(Domain::IntDomain(vec![Range::Bounded(1, 2)])),
520
1
        );
521
1
        let sum = Expression::Sum(Metadata::new(), vec![reference.clone(), reference.clone()]);
522
1
        assert_eq!(
523
1
            sum.domain_of(&vars),
524
1
            Some(Domain::IntDomain(vec![Range::Bounded(2, 4)]))
525
1
        );
526
1
    }
527
}