1
use std::fmt::{Display, Formatter};
2

            
3
use serde::{Deserialize, Serialize};
4

            
5
use enum_compatability_macro::document_compatibility;
6
use uniplate::derive::Uniplate;
7
use uniplate::{Biplate, Uniplate};
8

            
9
use crate::ast::constants::Constant;
10
use crate::ast::symbol_table::{Name, SymbolTable};
11
use crate::ast::ReturnType;
12
use crate::metadata::Metadata;
13

            
14
use super::{Domain, Range};
15

            
16
#[document_compatibility]
17
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize, Uniplate)]
18
#[uniplate(walk_into=[])]
19
#[biplate(to=Constant)]
20
#[non_exhaustive]
21
pub enum Expression {
22
    /**
23
     * Represents an empty expression
24
     * NB: we only expect this at the top level of a model (if there is no constraints)
25
     */
26
    Nothing,
27

            
28
    /// An expression representing "A is valid as long as B is true"
29
    /// Turns into a conjunction when it reaches a boolean context
30
    Bubble(Metadata, Box<Expression>, Box<Expression>),
31

            
32
    #[compatible(Minion, JsonInput)]
33
    Constant(Metadata, Constant),
34

            
35
    #[compatible(Minion, JsonInput, SAT)]
36
    Reference(Metadata, Name),
37

            
38
    #[compatible(Minion, JsonInput)]
39
    Sum(Metadata, Vec<Expression>),
40

            
41
    // /// Division after preventing division by zero, usually with a top-level constraint
42
    // #[compatible(Minion)]
43
    // SafeDiv(Metadata, Box<Expression>, Box<Expression>),
44
    // /// Division with a possibly undefined value (division by 0)
45
    // #[compatible(Minion, JsonInput)]
46
    // Div(Metadata, Box<Expression>, Box<Expression>),
47
    #[compatible(JsonInput)]
48
    Min(Metadata, Vec<Expression>),
49

            
50
    #[compatible(JsonInput, SAT)]
51
    Not(Metadata, Box<Expression>),
52

            
53
    #[compatible(JsonInput, SAT)]
54
    Or(Metadata, Vec<Expression>),
55

            
56
    #[compatible(JsonInput, SAT)]
57
    And(Metadata, Vec<Expression>),
58

            
59
    #[compatible(JsonInput)]
60
    Eq(Metadata, Box<Expression>, Box<Expression>),
61

            
62
    #[compatible(JsonInput)]
63
    Neq(Metadata, Box<Expression>, Box<Expression>),
64

            
65
    #[compatible(JsonInput)]
66
    Geq(Metadata, Box<Expression>, Box<Expression>),
67

            
68
    #[compatible(JsonInput)]
69
    Leq(Metadata, Box<Expression>, Box<Expression>),
70

            
71
    #[compatible(JsonInput)]
72
    Gt(Metadata, Box<Expression>, Box<Expression>),
73

            
74
    #[compatible(JsonInput)]
75
    Lt(Metadata, Box<Expression>, Box<Expression>),
76

            
77
    /// Division after preventing division by zero, usually with a bubble
78
    SafeDiv(Metadata, Box<Expression>, Box<Expression>),
79

            
80
    /// Division with a possibly undefined value (division by 0)
81
    #[compatible(JsonInput)]
82
    UnsafeDiv(Metadata, Box<Expression>, Box<Expression>),
83

            
84
    /* Flattened SumEq.
85
     *
86
     * Note: this is an intermediary step that's used in the process of converting from conjure model to minion.
87
     * This is NOT a valid expression in either Essence or minion.
88
     *
89
     * ToDo: This is a stop gap solution. Eventually it may be better to have multiple constraints instead? (gs248)
90
     */
91
    SumEq(Metadata, Vec<Expression>, Box<Expression>),
92

            
93
    // Flattened Constraints
94
    #[compatible(Minion)]
95
    SumGeq(Metadata, Vec<Expression>, Box<Expression>),
96

            
97
    #[compatible(Minion)]
98
    SumLeq(Metadata, Vec<Expression>, Box<Expression>),
99

            
100
    #[compatible(Minion)]
101
    DivEq(Metadata, Box<Expression>, Box<Expression>, Box<Expression>),
102

            
103
    #[compatible(Minion)]
104
    Ineq(Metadata, Box<Expression>, Box<Expression>, Box<Expression>),
105

            
106
    #[compatible(Minion)]
107
    AllDiff(Metadata, Vec<Expression>),
108
}
109

            
110
fn expr_vec_to_domain_i32(
111
    exprs: &Vec<Expression>,
112
    op: fn(i32, i32) -> Option<i32>,
113
    vars: &SymbolTable,
114
) -> Option<Domain> {
115
    let domains: Vec<Option<_>> = exprs.iter().map(|e| e.domain_of(vars)).collect();
116
    domains
117
        .into_iter()
118
        .reduce(|a, b| a.and_then(|x| b.and_then(|y| x.apply_i32(op, &y))))
119
        .flatten()
120
}
121

            
122
fn range_vec_bounds_i32(ranges: &Vec<Range<i32>>) -> (i32, i32) {
123
    let mut min = i32::MAX;
124
    let mut max = i32::MIN;
125
    for r in ranges {
126
        match r {
127
            Range::Single(i) => {
128
                if *i < min {
129
                    min = *i;
130
                }
131
                if *i > max {
132
                    max = *i;
133
                }
134
            }
135
            Range::Bounded(i, j) => {
136
                if *i < min {
137
                    min = *i;
138
                }
139
                if *j > max {
140
                    max = *j;
141
                }
142
            }
143
        }
144
    }
145
    (min, max)
146
}
147

            
148
impl Expression {
149
    /// Returns the possible values of the expression, recursing to leaf expressions
150
    pub fn domain_of(&self, vars: &SymbolTable) -> Option<Domain> {
151
        let ret = match self {
152
            Expression::Reference(_, name) => Some(vars.get(name)?.domain.clone()),
153
            Expression::Constant(_, Constant::Int(n)) => {
154
                Some(Domain::IntDomain(vec![Range::Single(*n)]))
155
            }
156
            Expression::Constant(_, Constant::Bool(_)) => Some(Domain::BoolDomain),
157
            Expression::Sum(_, exprs) => expr_vec_to_domain_i32(exprs, |x, y| Some(x + y), vars),
158
            Expression::Min(_, exprs) => {
159
                expr_vec_to_domain_i32(exprs, |x, y| Some(if x < y { x } else { y }), vars)
160
            }
161
            Expression::UnsafeDiv(_, a, b) | Expression::SafeDiv(_, a, b) => {
162
                a.domain_of(vars)?.apply_i32(
163
                    |x, y| if y != 0 { Some(x / y) } else { None },
164
                    &b.domain_of(vars)?,
165
                )
166
            }
167
            _ => todo!("Calculate domain of {:?}", self),
168
            // TODO: (flm8) Add support for calculating the domains of more expression types
169
        };
170
        match ret {
171
            // TODO: (flm8) the Minion bindings currently only support single ranges for domains, so we use the min/max bounds
172
            // Once they support a full domain as we define it, we can remove this conversion
173
            Some(Domain::IntDomain(ranges)) if ranges.len() > 1 => {
174
                let (min, max) = range_vec_bounds_i32(&ranges);
175
                Some(Domain::IntDomain(vec![Range::Bounded(min, max)]))
176
            }
177
            _ => ret,
178
        }
179
    }
180

            
181
    pub fn can_be_undefined(&self) -> bool {
182
        // TODO: there will be more false cases but we are being conservative
183
        match self {
184
            Expression::Reference(_, _) => false,
185
            Expression::Constant(_, Constant::Bool(_)) => false,
186
            Expression::Constant(_, Constant::Int(_)) => false,
187
            _ => true,
188
        }
189
    }
190

            
191
    pub fn return_type(&self) -> Option<ReturnType> {
192
        match self {
193
            Expression::Constant(_, Constant::Int(_)) => Some(ReturnType::Int),
194
            Expression::Constant(_, Constant::Bool(_)) => Some(ReturnType::Bool),
195
            Expression::Reference(_, _) => None,
196
            Expression::Sum(_, _) => Some(ReturnType::Int),
197
            Expression::Min(_, _) => Some(ReturnType::Int),
198
            Expression::Not(_, _) => Some(ReturnType::Bool),
199
            Expression::Or(_, _) => Some(ReturnType::Bool),
200
            Expression::And(_, _) => Some(ReturnType::Bool),
201
            Expression::Eq(_, _, _) => Some(ReturnType::Bool),
202
            Expression::Neq(_, _, _) => Some(ReturnType::Bool),
203
            Expression::Geq(_, _, _) => Some(ReturnType::Bool),
204
            Expression::Leq(_, _, _) => Some(ReturnType::Bool),
205
            Expression::Gt(_, _, _) => Some(ReturnType::Bool),
206
            Expression::Lt(_, _, _) => Some(ReturnType::Bool),
207
            Expression::SafeDiv(_, _, _) => Some(ReturnType::Int),
208
            Expression::UnsafeDiv(_, _, _) => Some(ReturnType::Int),
209
            Expression::SumEq(_, _, _) => Some(ReturnType::Bool),
210
            Expression::SumGeq(_, _, _) => Some(ReturnType::Bool),
211
            Expression::SumLeq(_, _, _) => Some(ReturnType::Bool),
212
            Expression::DivEq(_, _, _, _) => Some(ReturnType::Bool),
213
            Expression::Ineq(_, _, _, _) => Some(ReturnType::Bool),
214
            Expression::AllDiff(_, _) => Some(ReturnType::Bool),
215
            Expression::Bubble(_, _, _) => None, // TODO: (flm8) should this be a bool?
216
            Expression::Nothing => None,
217
        }
218
    }
219

            
220
    pub fn is_clean(&self) -> bool {
221
        match self {
222
            Expression::Nothing => true,
223
            Expression::Constant(metadata, _) => metadata.clean,
224
            Expression::Reference(metadata, _) => metadata.clean,
225
            Expression::Sum(metadata, exprs) => metadata.clean,
226
            Expression::Min(metadata, exprs) => metadata.clean,
227
            Expression::Not(metadata, expr) => metadata.clean,
228
            Expression::Or(metadata, exprs) => metadata.clean,
229
            Expression::And(metadata, exprs) => metadata.clean,
230
            Expression::Eq(metadata, box1, box2) => metadata.clean,
231
            Expression::Neq(metadata, box1, box2) => metadata.clean,
232
            Expression::Geq(metadata, box1, box2) => metadata.clean,
233
            Expression::Leq(metadata, box1, box2) => metadata.clean,
234
            Expression::Gt(metadata, box1, box2) => metadata.clean,
235
            Expression::Lt(metadata, box1, box2) => metadata.clean,
236
            Expression::SumGeq(metadata, box1, box2) => metadata.clean,
237
            Expression::SumLeq(metadata, box1, box2) => metadata.clean,
238
            Expression::Ineq(metadata, box1, box2, box3) => metadata.clean,
239
            Expression::AllDiff(metadata, exprs) => metadata.clean,
240
            Expression::SumEq(metadata, exprs, expr) => metadata.clean,
241
            _ => false,
242
        }
243
    }
244

            
245
    pub fn set_clean(&mut self, bool_value: bool) {
246
        match self {
247
            Expression::Nothing => {}
248
            Expression::Constant(metadata, _) => metadata.clean = bool_value,
249
            Expression::Reference(metadata, _) => metadata.clean = bool_value,
250
            Expression::Sum(metadata, _) => {
251
                metadata.clean = bool_value;
252
            }
253
            Expression::Min(metadata, _) => {
254
                metadata.clean = bool_value;
255
            }
256
            Expression::Not(metadata, _) => {
257
                metadata.clean = bool_value;
258
            }
259
            Expression::Or(metadata, _) => {
260
                metadata.clean = bool_value;
261
            }
262
            Expression::And(metadata, _) => {
263
                metadata.clean = bool_value;
264
            }
265
            Expression::Eq(metadata, box1, box2) => {
266
                metadata.clean = bool_value;
267
            }
268
            Expression::Neq(metadata, _box1, _box2) => {
269
                metadata.clean = bool_value;
270
            }
271
            Expression::Geq(metadata, _box1, _box2) => {
272
                metadata.clean = bool_value;
273
            }
274
            Expression::Leq(metadata, _box1, _box2) => {
275
                metadata.clean = bool_value;
276
            }
277
            Expression::Gt(metadata, _box1, _box2) => {
278
                metadata.clean = bool_value;
279
            }
280
            Expression::Lt(metadata, _box1, _box2) => {
281
                metadata.clean = bool_value;
282
            }
283
            Expression::SumGeq(metadata, _box1, _box2) => {
284
                metadata.clean = bool_value;
285
            }
286
            Expression::SumLeq(metadata, _box1, _box2) => {
287
                metadata.clean = bool_value;
288
            }
289
            Expression::Ineq(metadata, _box1, _box2, _box3) => {
290
                metadata.clean = bool_value;
291
            }
292
            Expression::AllDiff(metadata, _exprs) => {
293
                metadata.clean = bool_value;
294
            }
295
            Expression::SumEq(metadata, _exprs, _expr) => {
296
                metadata.clean = bool_value;
297
            }
298
            Expression::Bubble(metadata, box1, box2) => {
299
                metadata.clean = bool_value;
300
            }
301
            Expression::SafeDiv(metadata, box1, box2) => {
302
                metadata.clean = bool_value;
303
            }
304
            Expression::UnsafeDiv(metadata, box1, box2) => {
305
                metadata.clean = bool_value;
306
            }
307
            Expression::DivEq(metadata, box1, box2, box3) => {
308
                metadata.clean = bool_value;
309
            }
310
        }
311
    }
312
}
313

            
314
fn display_expressions(expressions: &[Expression]) -> String {
315
    // if expressions.len() <= 3 {
316
    format!(
317
        "[{}]",
318
        expressions
319
            .iter()
320
            .map(|e| e.to_string())
321
            .collect::<Vec<String>>()
322
            .join(", ")
323
    )
324
    // } else {
325
    //     format!(
326
    //         "[{}..{}]",
327
    //         expressions[0],
328
    //         expressions[expressions.len() - 1]
329
    //     )
330
    // }
331
}
332

            
333
impl From<i32> for Expression {
334
    fn from(i: i32) -> Self {
335
        Expression::Constant(Metadata::new(), Constant::Int(i))
336
    }
337
}
338

            
339
impl From<bool> for Expression {
340
    fn from(b: bool) -> Self {
341
        Expression::Constant(Metadata::new(), Constant::Bool(b))
342
    }
343
}
344

            
345
impl Display for Expression {
346
    // TODO: (flm8) this will change once we implement a parser (two-way conversion)
347
    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
348
        match &self {
349
            Expression::Constant(_, c) => match c {
350
                Constant::Bool(b) => write!(f, "{}", b),
351
                Constant::Int(i) => write!(f, "{}", i),
352
            },
353
            Expression::Reference(_, name) => match name {
354
                Name::MachineName(n) => write!(f, "_{}", n),
355
                Name::UserName(s) => write!(f, "{}", s),
356
            },
357
            Expression::Nothing => write!(f, "Nothing"),
358
            Expression::Sum(_, expressions) => {
359
                write!(f, "Sum({})", display_expressions(expressions))
360
            }
361
            Expression::Min(_, expressions) => {
362
                write!(f, "Min({})", display_expressions(expressions))
363
            }
364
            Expression::Not(_, expr_box) => {
365
                write!(f, "Not({})", expr_box.clone())
366
            }
367
            Expression::Or(_, expressions) => {
368
                write!(f, "Or({})", display_expressions(expressions))
369
            }
370
            Expression::And(_, expressions) => {
371
                write!(f, "And({})", display_expressions(expressions))
372
            }
373
            Expression::Eq(_, box1, box2) => {
374
                write!(f, "({} = {})", box1.clone(), box2.clone())
375
            }
376
            Expression::Neq(_, box1, box2) => {
377
                write!(f, "({} != {})", box1.clone(), box2.clone())
378
            }
379
            Expression::Geq(_, box1, box2) => {
380
                write!(f, "({} >= {})", box1.clone(), box2.clone())
381
            }
382
            Expression::Leq(_, box1, box2) => {
383
                write!(f, "({} <= {})", box1.clone(), box2.clone())
384
            }
385
            Expression::Gt(_, box1, box2) => {
386
                write!(f, "({} > {})", box1.clone(), box2.clone())
387
            }
388
            Expression::Lt(_, box1, box2) => {
389
                write!(f, "({} < {})", box1.clone(), box2.clone())
390
            }
391
            Expression::SumEq(_, expressions, expr_box) => {
392
                write!(
393
                    f,
394
                    "SumEq({}, {})",
395
                    display_expressions(expressions),
396
                    expr_box.clone()
397
                )
398
            }
399
            Expression::SumGeq(_, box1, box2) => {
400
                write!(f, "SumGeq({}, {})", display_expressions(box1), box2.clone())
401
            }
402
            Expression::SumLeq(_, box1, box2) => {
403
                write!(f, "SumLeq({}, {})", display_expressions(box1), box2.clone())
404
            }
405
            Expression::Ineq(_, box1, box2, box3) => write!(
406
                f,
407
                "Ineq({}, {}, {})",
408
                box1.clone(),
409
                box2.clone(),
410
                box3.clone()
411
            ),
412
            Expression::AllDiff(_, expressions) => {
413
                write!(f, "AllDiff({})", display_expressions(expressions))
414
            }
415
            Expression::Bubble(_, box1, box2) => {
416
                write!(f, "{{{} @ {}}}", box1.clone(), box2.clone())
417
            }
418
            Expression::SafeDiv(_, box1, box2) => {
419
                write!(f, "SafeDiv({}, {})", box1.clone(), box2.clone())
420
            }
421
            Expression::UnsafeDiv(_, box1, box2) => {
422
                write!(f, "UnsafeDiv({}, {})", box1.clone(), box2.clone())
423
            }
424
            Expression::DivEq(_, box1, box2, box3) => {
425
                write!(
426
                    f,
427
                    "DivEq({}, {}, {})",
428
                    box1.clone(),
429
                    box2.clone(),
430
                    box3.clone()
431
                )
432
            }
433
            #[allow(unreachable_patterns)]
434
            other => todo!("Implement display for {:?}", other),
435
        }
436
    }
437
}
438

            
439
#[cfg(test)]
440
mod tests {
441
    use crate::ast::DecisionVariable;
442

            
443
    use super::*;
444

            
445
    #[test]
446
    fn test_domain_of_constant_sum() {
447
        let c1 = Expression::Constant(Metadata::new(), Constant::Int(1));
448
        let c2 = Expression::Constant(Metadata::new(), Constant::Int(2));
449
        let sum = Expression::Sum(Metadata::new(), vec![c1.clone(), c2.clone()]);
450
        assert_eq!(
451
            sum.domain_of(&SymbolTable::new()),
452
            Some(Domain::IntDomain(vec![Range::Single(3)]))
453
        );
454
    }
455

            
456
    #[test]
457
    fn test_domain_of_constant_invalid_type() {
458
        let c1 = Expression::Constant(Metadata::new(), Constant::Int(1));
459
        let c2 = Expression::Constant(Metadata::new(), Constant::Bool(true));
460
        let sum = Expression::Sum(Metadata::new(), vec![c1.clone(), c2.clone()]);
461
        assert_eq!(sum.domain_of(&SymbolTable::new()), None);
462
    }
463

            
464
    #[test]
465
    fn test_domain_of_empty_sum() {
466
        let sum = Expression::Sum(Metadata::new(), vec![]);
467
        assert_eq!(sum.domain_of(&SymbolTable::new()), None);
468
    }
469

            
470
    #[test]
471
    fn test_domain_of_reference() {
472
        let reference = Expression::Reference(Metadata::new(), Name::MachineName(0));
473
        let mut vars = SymbolTable::new();
474
        vars.insert(
475
            Name::MachineName(0),
476
            DecisionVariable::new(Domain::IntDomain(vec![Range::Single(1)])),
477
        );
478
        assert_eq!(
479
            reference.domain_of(&vars),
480
            Some(Domain::IntDomain(vec![Range::Single(1)]))
481
        );
482
    }
483

            
484
    #[test]
485
    fn test_domain_of_reference_not_found() {
486
        let reference = Expression::Reference(Metadata::new(), Name::MachineName(0));
487
        assert_eq!(reference.domain_of(&SymbolTable::new()), None);
488
    }
489

            
490
    #[test]
491
    fn test_domain_of_reference_sum_single() {
492
        let reference = Expression::Reference(Metadata::new(), Name::MachineName(0));
493
        let mut vars = SymbolTable::new();
494
        vars.insert(
495
            Name::MachineName(0),
496
            DecisionVariable::new(Domain::IntDomain(vec![Range::Single(1)])),
497
        );
498
        let sum = Expression::Sum(Metadata::new(), vec![reference.clone(), reference.clone()]);
499
        assert_eq!(
500
            sum.domain_of(&vars),
501
            Some(Domain::IntDomain(vec![Range::Single(2)]))
502
        );
503
    }
504

            
505
    #[test]
506
    fn test_domain_of_reference_sum_bounded() {
507
        let reference = Expression::Reference(Metadata::new(), Name::MachineName(0));
508
        let mut vars = SymbolTable::new();
509
        vars.insert(
510
            Name::MachineName(0),
511
            DecisionVariable::new(Domain::IntDomain(vec![Range::Bounded(1, 2)])),
512
        );
513
        let sum = Expression::Sum(Metadata::new(), vec![reference.clone(), reference.clone()]);
514
        assert_eq!(
515
            sum.domain_of(&vars),
516
            Some(Domain::IntDomain(vec![Range::Bounded(2, 4)]))
517
        );
518
    }
519
}