1
use std::fmt::{Display, Formatter};
2
use std::sync::Arc;
3

            
4
use serde::{Deserialize, Serialize};
5

            
6
use enum_compatability_macro::document_compatibility;
7
use uniplate::derive::Uniplate;
8
use uniplate::Biplate;
9

            
10
use crate::ast::literals::Literal;
11
use crate::ast::symbol_table::{Name, SymbolTable};
12
use crate::ast::Factor;
13
use crate::ast::ReturnType;
14
use crate::bug;
15
use crate::metadata::Metadata;
16

            
17
use super::{Domain, Range};
18

            
19
/// Represents different types of expressions used to define rules and constraints in the model.
20
///
21
/// The `Expression` enum includes operations, constants, and variable references
22
/// used to build rules and conditions for the model.
23
#[document_compatibility]
24
3864
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize, Uniplate)]
25
#[uniplate(walk_into=[Factor])]
26
#[biplate(to=Literal)]
27
#[biplate(to=Metadata)]
28
#[biplate(to=Factor)]
29
#[biplate(to=Name)]
30
pub enum Expression {
31
    /// An expression representing "A is valid as long as B is true"
32
    /// Turns into a conjunction when it reaches a boolean context
33
    Bubble(Metadata, Box<Expression>, Box<Expression>),
34

            
35
    FactorE(Metadata, Factor),
36

            
37
    #[compatible(Minion, JsonInput)]
38
    Sum(Metadata, Vec<Expression>),
39

            
40
    // /// Division after preventing division by zero, usually with a top-level constraint
41
    // #[compatible(Minion)]
42
    // SafeDiv(Metadata, Box<Expression>, Box<Expression>),
43
    // /// Division with a possibly undefined value (division by 0)
44
    // #[compatible(Minion, JsonInput)]
45
    // Div(Metadata, Box<Expression>, Box<Expression>),
46
    #[compatible(JsonInput)]
47
    Min(Metadata, Vec<Expression>),
48

            
49
    #[compatible(JsonInput)]
50
    Max(Metadata, Vec<Expression>),
51

            
52
    #[compatible(JsonInput, SAT)]
53
    Not(Metadata, Box<Expression>),
54

            
55
    #[compatible(JsonInput, SAT)]
56
    Or(Metadata, Vec<Expression>),
57

            
58
    #[compatible(JsonInput, SAT)]
59
    And(Metadata, Vec<Expression>),
60

            
61
    #[compatible(JsonInput)]
62
    Eq(Metadata, Box<Expression>, Box<Expression>),
63

            
64
    #[compatible(JsonInput)]
65
    Neq(Metadata, Box<Expression>, Box<Expression>),
66

            
67
    #[compatible(JsonInput)]
68
    Geq(Metadata, Box<Expression>, Box<Expression>),
69

            
70
    #[compatible(JsonInput)]
71
    Leq(Metadata, Box<Expression>, Box<Expression>),
72

            
73
    #[compatible(JsonInput)]
74
    Gt(Metadata, Box<Expression>, Box<Expression>),
75

            
76
    #[compatible(JsonInput)]
77
    Lt(Metadata, Box<Expression>, Box<Expression>),
78

            
79
    /// Division after preventing division by zero, usually with a bubble
80
    SafeDiv(Metadata, Box<Expression>, Box<Expression>),
81

            
82
    /// Division with a possibly undefined value (division by 0)
83
    #[compatible(JsonInput)]
84
    UnsafeDiv(Metadata, Box<Expression>, Box<Expression>),
85

            
86
    /* Flattened SumEq.
87
     *
88
     * Note: this is an intermediary step that's used in the process of converting from conjure model to minion.
89
     * This is NOT a valid expression in either Essence or minion.
90
     *
91
     * ToDo: This is a stop gap solution. Eventually it may be better to have multiple constraints instead? (gs248)
92
     */
93
    SumEq(Metadata, Vec<Expression>, Box<Expression>),
94

            
95
    // Flattened Constraints
96
    #[compatible(Minion)]
97
    SumGeq(Metadata, Vec<Expression>, Box<Expression>),
98

            
99
    #[compatible(Minion)]
100
    SumLeq(Metadata, Vec<Expression>, Box<Expression>),
101

            
102
    /// `a / b = c`
103
    #[compatible(Minion)]
104
    DivEq(Metadata, Factor, Factor, Factor),
105

            
106
    #[compatible(Minion)]
107
    Ineq(Metadata, Box<Expression>, Box<Expression>, Box<Expression>),
108

            
109
    #[compatible(Minion)]
110
    AllDiff(Metadata, Vec<Expression>),
111

            
112
    /// w-literal(x,k) is SAT iff x == k, where x is a variable and k a constant.
113
    ///
114
    /// This is a low-level Minion constraint and you should (probably) use Eq instead. The main
115
    /// use of w-literal is to convert boolean variables to constraints so that they can be used
116
    /// inside watched-and and watched-or.
117
    ///
118
    /// See `rules::minion::boolean_literal_to_wliteral`.
119
    ///
120
    ///
121
    #[compatible(Minion)]
122
    WatchedLiteral(Metadata, Name, Literal),
123

            
124
    #[compatible(Minion)]
125
    Reify(Metadata, Box<Expression>, Box<Expression>),
126

            
127
    /// Declaration of an auxiliary variable.
128
    ///
129
    /// As with Savile Row, we semantically distinguish this from `Eq`.
130
    #[compatible(Minion)]
131
    AuxDeclaration(Metadata, Name, Box<Expression>),
132
}
133

            
134
192
fn expr_vec_to_domain_i32(
135
192
    exprs: &[Expression],
136
192
    op: fn(i32, i32) -> Option<i32>,
137
192
    vars: &SymbolTable,
138
192
) -> Option<Domain> {
139
382
    let domains: Vec<Option<_>> = exprs.iter().map(|e| e.domain_of(vars)).collect();
140
192
    domains
141
192
        .into_iter()
142
192
        .reduce(|a, b| a.and_then(|x| b.and_then(|y| x.apply_i32(op, &y))))
143
192
        .flatten()
144
192
}
145

            
146
511
fn range_vec_bounds_i32(ranges: &Vec<Range<i32>>) -> (i32, i32) {
147
511
    let mut min = i32::MAX;
148
511
    let mut max = i32::MIN;
149
16019
    for r in ranges {
150
15508
        match r {
151
15508
            Range::Single(i) => {
152
15508
                if *i < min {
153
800
                    min = *i;
154
14708
                }
155
15508
                if *i > max {
156
3097
                    max = *i;
157
12413
                }
158
            }
159
            Range::Bounded(i, j) => {
160
                if *i < min {
161
                    min = *i;
162
                }
163
                if *j > max {
164
                    max = *j;
165
                }
166
            }
167
        }
168
    }
169
511
    (min, max)
170
511
}
171

            
172
impl Expression {
173
    /// Returns the possible values of the expression, recursing to leaf expressions
174
1545
    pub fn domain_of(&self, vars: &SymbolTable) -> Option<Domain> {
175
1544
        let ret = match self {
176
941
            Expression::FactorE(_, Factor::Reference(name)) => Some(vars.get(name)?.domain.clone()),
177
71
            Expression::FactorE(_, Factor::Literal(Literal::Int(n))) => {
178
71
                Some(Domain::IntDomain(vec![Range::Single(*n)]))
179
            }
180
1
            Expression::FactorE(_, Factor::Literal(Literal::Bool(_))) => Some(Domain::BoolDomain),
181
6
            Expression::Sum(_, exprs) => expr_vec_to_domain_i32(exprs, |x, y| Some(x + y), vars),
182
119
            Expression::Min(_, exprs) => {
183
1700
                expr_vec_to_domain_i32(exprs, |x, y| Some(if x < y { x } else { y }), vars)
184
            }
185
68
            Expression::Max(_, exprs) => {
186
884
                expr_vec_to_domain_i32(exprs, |x, y| Some(if x > y { x } else { y }), vars)
187
            }
188
            Expression::UnsafeDiv(_, a, b) => a.domain_of(vars)?.apply_i32(
189
                |x, y| if y != 0 { Some(x / y) } else { None },
190
                &b.domain_of(vars)?,
191
            ),
192
340
            Expression::SafeDiv(_, a, b) => {
193
340
                let domain = a.domain_of(vars)?.apply_i32(
194
14977
                    |x, y| if y != 0 { Some(x / y) } else { None },
195
340
                    &b.domain_of(vars)?,
196
                );
197

            
198
340
                match domain {
199
340
                    Some(Domain::IntDomain(ranges)) => {
200
340
                        let mut ranges = ranges;
201
340
                        ranges.push(Range::Single(0));
202
340
                        Some(Domain::IntDomain(ranges))
203
                    }
204
                    None => Some(Domain::IntDomain(vec![Range::Single(0)])),
205
                    _ => None,
206
                }
207
            }
208

            
209
            Expression::Bubble(_, _, _) => None,
210
            Expression::AuxDeclaration(_, _, _) => Some(Domain::BoolDomain),
211
            Expression::And(_, _) => Some(Domain::BoolDomain),
212
            _ => bug!("Cannot calculate domain of {:?}", self),
213
            // TODO: (flm8) Add support for calculating the domains of more expression types
214
        };
215
1541
        match ret {
216
            // TODO: (flm8) the Minion bindings currently only support single ranges for domains, so we use the min/max bounds
217
            // Once they support a full domain as we define it, we can remove this conversion
218
1541
            Some(Domain::IntDomain(ranges)) if ranges.len() > 1 => {
219
511
                let (min, max) = range_vec_bounds_i32(&ranges);
220
511
                Some(Domain::IntDomain(vec![Range::Bounded(min, max)]))
221
            }
222
1033
            _ => ret,
223
        }
224
1545
    }
225

            
226
17
    pub fn get_meta(&self) -> Metadata {
227
17
        <Expression as Biplate<Metadata>>::children_bi(self)[0].clone()
228
17
    }
229

            
230
    pub fn set_meta(&self, meta: Metadata) {
231
        <Expression as Biplate<Metadata>>::transform_bi(self, Arc::new(move |_| meta.clone()));
232
    }
233

            
234
5355
    pub fn can_be_undefined(&self) -> bool {
235
5355
        // TODO: there will be more false cases but we are being conservative
236
5355
        match self {
237
731
            Expression::FactorE(_, _) => false,
238
408
            Expression::SafeDiv(_, _, _) => false,
239
4216
            _ => true,
240
        }
241
5355
    }
242

            
243
765
    pub fn return_type(&self) -> Option<ReturnType> {
244
        match self {
245
            Expression::FactorE(_, Factor::Literal(Literal::Int(_))) => Some(ReturnType::Int),
246
            Expression::FactorE(_, Factor::Literal(Literal::Bool(_))) => Some(ReturnType::Bool),
247
            Expression::FactorE(_, Factor::Reference(_)) => None,
248
            Expression::Sum(_, _) => Some(ReturnType::Int),
249
            Expression::Min(_, _) => Some(ReturnType::Int),
250
            Expression::Max(_, _) => Some(ReturnType::Int),
251
            Expression::Not(_, _) => Some(ReturnType::Bool),
252
            Expression::Or(_, _) => Some(ReturnType::Bool),
253
            Expression::And(_, _) => Some(ReturnType::Bool),
254
255
            Expression::Eq(_, _, _) => Some(ReturnType::Bool),
255
68
            Expression::Neq(_, _, _) => Some(ReturnType::Bool),
256
            Expression::Geq(_, _, _) => Some(ReturnType::Bool),
257
            Expression::Leq(_, _, _) => Some(ReturnType::Bool),
258
            Expression::Gt(_, _, _) => Some(ReturnType::Bool),
259
            Expression::Lt(_, _, _) => Some(ReturnType::Bool),
260
221
            Expression::SafeDiv(_, _, _) => Some(ReturnType::Int),
261
51
            Expression::UnsafeDiv(_, _, _) => Some(ReturnType::Int),
262
            Expression::SumEq(_, _, _) => Some(ReturnType::Bool),
263
            Expression::SumGeq(_, _, _) => Some(ReturnType::Bool),
264
            Expression::SumLeq(_, _, _) => Some(ReturnType::Bool),
265
            Expression::DivEq(_, _, _, _) => Some(ReturnType::Bool),
266
            Expression::Ineq(_, _, _, _) => Some(ReturnType::Bool),
267
            Expression::AllDiff(_, _) => Some(ReturnType::Bool),
268
            Expression::Bubble(_, _, _) => None, // TODO: (flm8) should this be a bool?
269
            Expression::WatchedLiteral(_, _, _) => Some(ReturnType::Bool),
270
            Expression::Reify(_, _, _) => Some(ReturnType::Bool),
271
170
            Expression::AuxDeclaration(_, _, _) => Some(ReturnType::Bool),
272
        }
273
765
    }
274

            
275
    pub fn is_clean(&self) -> bool {
276
        let metadata = self.get_meta();
277
        metadata.clean
278
    }
279

            
280
    pub fn set_clean(&mut self, bool_value: bool) {
281
        let mut metadata = self.get_meta();
282
        metadata.clean = bool_value;
283
        self.set_meta(metadata);
284
    }
285

            
286
8874
    pub fn as_factor(&self) -> Option<Factor> {
287
8874
        if let Expression::FactorE(_m, f) = self {
288
5134
            Some(f.clone())
289
        } else {
290
3740
            None
291
        }
292
8874
    }
293
}
294

            
295
fn display_expressions(expressions: &[Expression]) -> String {
296
    // if expressions.len() <= 3 {
297
    format!(
298
        "[{}]",
299
        expressions
300
            .iter()
301
            .map(|e| e.to_string())
302
            .collect::<Vec<String>>()
303
            .join(", ")
304
    )
305
    // } else {
306
    //     format!(
307
    //         "[{}..{}]",
308
    //         expressions[0],
309
    //         expressions[expressions.len() - 1]
310
    //     )
311
    // }
312
}
313

            
314
impl From<i32> for Expression {
315
221
    fn from(i: i32) -> Self {
316
221
        Expression::FactorE(Metadata::new(), Factor::Literal(Literal::Int(i)))
317
221
    }
318
}
319

            
320
impl From<bool> for Expression {
321
    fn from(b: bool) -> Self {
322
        Expression::FactorE(Metadata::new(), Factor::Literal(Literal::Bool(b)))
323
    }
324
}
325

            
326
impl From<Factor> for Expression {
327
816
    fn from(value: Factor) -> Self {
328
816
        Expression::FactorE(Metadata::new(), value)
329
816
    }
330
}
331
impl Display for Expression {
332
    // TODO: (flm8) this will change once we implement a parser (two-way conversion)
333
    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
334
        match &self {
335
            Expression::FactorE(_, factor) => factor.fmt(f),
336
            Expression::Sum(_, expressions) => {
337
                write!(f, "Sum({})", display_expressions(expressions))
338
            }
339
            Expression::Min(_, expressions) => {
340
                write!(f, "Min({})", display_expressions(expressions))
341
            }
342
            Expression::Max(_, expressions) => {
343
                write!(f, "Max({})", display_expressions(expressions))
344
            }
345
            Expression::Not(_, expr_box) => {
346
                write!(f, "Not({})", expr_box.clone())
347
            }
348
            Expression::Or(_, expressions) => {
349
                write!(f, "Or({})", display_expressions(expressions))
350
            }
351
            Expression::And(_, expressions) => {
352
                write!(f, "And({})", display_expressions(expressions))
353
            }
354
            Expression::Eq(_, box1, box2) => {
355
                write!(f, "({} = {})", box1.clone(), box2.clone())
356
            }
357
            Expression::Neq(_, box1, box2) => {
358
                write!(f, "({} != {})", box1.clone(), box2.clone())
359
            }
360
            Expression::Geq(_, box1, box2) => {
361
                write!(f, "({} >= {})", box1.clone(), box2.clone())
362
            }
363
            Expression::Leq(_, box1, box2) => {
364
                write!(f, "({} <= {})", box1.clone(), box2.clone())
365
            }
366
            Expression::Gt(_, box1, box2) => {
367
                write!(f, "({} > {})", box1.clone(), box2.clone())
368
            }
369
            Expression::Lt(_, box1, box2) => {
370
                write!(f, "({} < {})", box1.clone(), box2.clone())
371
            }
372
            Expression::SumEq(_, expressions, expr_box) => {
373
                write!(
374
                    f,
375
                    "SumEq({}, {})",
376
                    display_expressions(expressions),
377
                    expr_box.clone()
378
                )
379
            }
380
            Expression::SumGeq(_, box1, box2) => {
381
                write!(f, "SumGeq({}, {})", display_expressions(box1), box2.clone())
382
            }
383
            Expression::SumLeq(_, box1, box2) => {
384
                write!(f, "SumLeq({}, {})", display_expressions(box1), box2.clone())
385
            }
386
            Expression::Ineq(_, box1, box2, box3) => write!(
387
                f,
388
                "Ineq({}, {}, {})",
389
                box1.clone(),
390
                box2.clone(),
391
                box3.clone()
392
            ),
393
            Expression::AllDiff(_, expressions) => {
394
                write!(f, "AllDiff({})", display_expressions(expressions))
395
            }
396
            Expression::Bubble(_, box1, box2) => {
397
                write!(f, "{{{} @ {}}}", box1.clone(), box2.clone())
398
            }
399
            Expression::SafeDiv(_, box1, box2) => {
400
                write!(f, "SafeDiv({}, {})", box1.clone(), box2.clone())
401
            }
402
            Expression::UnsafeDiv(_, box1, box2) => {
403
                write!(f, "UnsafeDiv({}, {})", box1.clone(), box2.clone())
404
            }
405
            Expression::DivEq(_, box1, box2, box3) => {
406
                write!(
407
                    f,
408
                    "DivEq({}, {}, {})",
409
                    box1.clone(),
410
                    box2.clone(),
411
                    box3.clone()
412
                )
413
            }
414
            Expression::WatchedLiteral(_, x, l) => {
415
                write!(f, "WatchedLiteral({},{})", x, l)
416
            }
417
            Expression::Reify(_, box1, box2) => {
418
                write!(f, "Reify({}, {})", box1.clone(), box2.clone())
419
            }
420
            Expression::AuxDeclaration(_, n, e) => {
421
                write!(f, "{} =aux {}", n, e.clone())
422
            }
423
        }
424
    }
425
}
426

            
427
#[cfg(test)]
428
mod tests {
429
    use crate::ast::DecisionVariable;
430

            
431
    use super::*;
432

            
433
    #[test]
434
1
    fn test_domain_of_constant_sum() {
435
1
        let c1 = Expression::FactorE(Metadata::new(), Factor::Literal(Literal::Int(1)));
436
1
        let c2 = Expression::FactorE(Metadata::new(), Factor::Literal(Literal::Int(2)));
437
1
        let sum = Expression::Sum(Metadata::new(), vec![c1.clone(), c2.clone()]);
438
1
        assert_eq!(
439
1
            sum.domain_of(&SymbolTable::new()),
440
1
            Some(Domain::IntDomain(vec![Range::Single(3)]))
441
1
        );
442
1
    }
443

            
444
    #[test]
445
1
    fn test_domain_of_constant_invalid_type() {
446
1
        let c1 = Expression::FactorE(Metadata::new(), Factor::Literal(Literal::Int(1)));
447
1
        let c2 = Expression::FactorE(Metadata::new(), Factor::Literal(Literal::Bool(true)));
448
1
        let sum = Expression::Sum(Metadata::new(), vec![c1.clone(), c2.clone()]);
449
1
        assert_eq!(sum.domain_of(&SymbolTable::new()), None);
450
1
    }
451

            
452
    #[test]
453
1
    fn test_domain_of_empty_sum() {
454
1
        let sum = Expression::Sum(Metadata::new(), vec![]);
455
1
        assert_eq!(sum.domain_of(&SymbolTable::new()), None);
456
1
    }
457

            
458
    #[test]
459
1
    fn test_domain_of_reference() {
460
1
        let reference =
461
1
            Expression::FactorE(Metadata::new(), Factor::Reference(Name::MachineName(0)));
462
1
        let mut vars = SymbolTable::new();
463
1
        vars.insert(
464
1
            Name::MachineName(0),
465
1
            DecisionVariable::new(Domain::IntDomain(vec![Range::Single(1)])),
466
1
        );
467
1
        assert_eq!(
468
1
            reference.domain_of(&vars),
469
1
            Some(Domain::IntDomain(vec![Range::Single(1)]))
470
1
        );
471
1
    }
472

            
473
    #[test]
474
1
    fn test_domain_of_reference_not_found() {
475
1
        let reference =
476
1
            Expression::FactorE(Metadata::new(), Factor::Reference(Name::MachineName(0)));
477
1
        assert_eq!(reference.domain_of(&SymbolTable::new()), None);
478
1
    }
479

            
480
    #[test]
481
1
    fn test_domain_of_reference_sum_single() {
482
1
        let reference =
483
1
            Expression::FactorE(Metadata::new(), Factor::Reference(Name::MachineName(0)));
484
1
        let mut vars = SymbolTable::new();
485
1
        vars.insert(
486
1
            Name::MachineName(0),
487
1
            DecisionVariable::new(Domain::IntDomain(vec![Range::Single(1)])),
488
1
        );
489
1
        let sum = Expression::Sum(Metadata::new(), vec![reference.clone(), reference.clone()]);
490
1
        assert_eq!(
491
1
            sum.domain_of(&vars),
492
1
            Some(Domain::IntDomain(vec![Range::Single(2)]))
493
1
        );
494
1
    }
495

            
496
    #[test]
497
1
    fn test_domain_of_reference_sum_bounded() {
498
1
        let reference =
499
1
            Expression::FactorE(Metadata::new(), Factor::Reference(Name::MachineName(0)));
500
1
        let mut vars = SymbolTable::new();
501
1
        vars.insert(
502
1
            Name::MachineName(0),
503
1
            DecisionVariable::new(Domain::IntDomain(vec![Range::Bounded(1, 2)])),
504
1
        );
505
1
        let sum = Expression::Sum(Metadata::new(), vec![reference.clone(), reference.clone()]);
506
1
        assert_eq!(
507
1
            sum.domain_of(&vars),
508
1
            Some(Domain::IntDomain(vec![Range::Bounded(2, 4)]))
509
1
        );
510
1
    }
511
}