1
use std::fmt::{Display, Formatter};
2
use std::sync::Arc;
3

            
4
use serde::{Deserialize, Serialize};
5

            
6
use enum_compatability_macro::document_compatibility;
7
use uniplate::derive::Uniplate;
8
use uniplate::Biplate;
9

            
10
use crate::ast::literals::Literal;
11
use crate::ast::symbol_table::{Name, SymbolTable};
12
use crate::ast::Factor;
13
use crate::ast::ReturnType;
14
use crate::metadata::Metadata;
15

            
16
use super::{Domain, Range};
17

            
18
/// Represents different types of expressions used to define rules and constraints in the model.
19
///
20
/// The `Expression` enum includes operations, constants, and variable references
21
/// used to build rules and conditions for the model.
22
#[document_compatibility]
23
2940
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize, Uniplate)]
24
#[uniplate(walk_into=[Factor])]
25
#[biplate(to=Literal)]
26
#[biplate(to=Metadata)]
27
#[biplate(to=Factor)]
28
#[biplate(to=Name)]
29
pub enum Expression {
30
    /// An expression representing "A is valid as long as B is true"
31
    /// Turns into a conjunction when it reaches a boolean context
32
    Bubble(Metadata, Box<Expression>, Box<Expression>),
33

            
34
    FactorE(Metadata, Factor),
35

            
36
    #[compatible(Minion, JsonInput)]
37
    Sum(Metadata, Vec<Expression>),
38

            
39
    // /// Division after preventing division by zero, usually with a top-level constraint
40
    // #[compatible(Minion)]
41
    // SafeDiv(Metadata, Box<Expression>, Box<Expression>),
42
    // /// Division with a possibly undefined value (division by 0)
43
    // #[compatible(Minion, JsonInput)]
44
    // Div(Metadata, Box<Expression>, Box<Expression>),
45
    #[compatible(JsonInput)]
46
    Min(Metadata, Vec<Expression>),
47

            
48
    #[compatible(JsonInput)]
49
    Max(Metadata, Vec<Expression>),
50

            
51
    #[compatible(JsonInput, SAT)]
52
    Not(Metadata, Box<Expression>),
53

            
54
    #[compatible(JsonInput, SAT)]
55
    Or(Metadata, Vec<Expression>),
56

            
57
    #[compatible(JsonInput, SAT)]
58
    And(Metadata, Vec<Expression>),
59

            
60
    #[compatible(JsonInput)]
61
    Eq(Metadata, Box<Expression>, Box<Expression>),
62

            
63
    #[compatible(JsonInput)]
64
    Neq(Metadata, Box<Expression>, Box<Expression>),
65

            
66
    #[compatible(JsonInput)]
67
    Geq(Metadata, Box<Expression>, Box<Expression>),
68

            
69
    #[compatible(JsonInput)]
70
    Leq(Metadata, Box<Expression>, Box<Expression>),
71

            
72
    #[compatible(JsonInput)]
73
    Gt(Metadata, Box<Expression>, Box<Expression>),
74

            
75
    #[compatible(JsonInput)]
76
    Lt(Metadata, Box<Expression>, Box<Expression>),
77

            
78
    /// Division after preventing division by zero, usually with a bubble
79
    SafeDiv(Metadata, Box<Expression>, Box<Expression>),
80

            
81
    /// Division with a possibly undefined value (division by 0)
82
    #[compatible(JsonInput)]
83
    UnsafeDiv(Metadata, Box<Expression>, Box<Expression>),
84

            
85
    /* Flattened SumEq.
86
     *
87
     * Note: this is an intermediary step that's used in the process of converting from conjure model to minion.
88
     * This is NOT a valid expression in either Essence or minion.
89
     *
90
     * ToDo: This is a stop gap solution. Eventually it may be better to have multiple constraints instead? (gs248)
91
     */
92
    SumEq(Metadata, Vec<Expression>, Box<Expression>),
93

            
94
    // Flattened Constraints
95
    #[compatible(Minion)]
96
    SumGeq(Metadata, Vec<Expression>, Box<Expression>),
97

            
98
    #[compatible(Minion)]
99
    SumLeq(Metadata, Vec<Expression>, Box<Expression>),
100

            
101
    #[compatible(Minion)]
102
    DivEq(Metadata, Box<Expression>, Box<Expression>, Box<Expression>),
103

            
104
    #[compatible(Minion)]
105
    Ineq(Metadata, Box<Expression>, Box<Expression>, Box<Expression>),
106

            
107
    #[compatible(Minion)]
108
    AllDiff(Metadata, Vec<Expression>),
109

            
110
    /// w-literal(x,k) is SAT iff x == k, where x is a variable and k a constant.
111
    ///
112
    /// This is a low-level Minion constraint and you should (probably) use Eq instead. The main
113
    /// use of w-literal is to convert boolean variables to constraints so that they can be used
114
    /// inside watched-and and watched-or.
115
    ///
116
    /// See `rules::minion::boolean_literal_to_wliteral`.
117
    ///
118
    ///
119
    #[compatible(Minion)]
120
    WatchedLiteral(Metadata, Name, Literal),
121

            
122
    #[compatible(Minion)]
123
    Reify(Metadata, Box<Expression>, Box<Expression>),
124
}
125

            
126
170
fn expr_vec_to_domain_i32(
127
170
    exprs: &[Expression],
128
170
    op: fn(i32, i32) -> Option<i32>,
129
170
    vars: &SymbolTable,
130
170
) -> Option<Domain> {
131
338
    let domains: Vec<Option<_>> = exprs.iter().map(|e| e.domain_of(vars)).collect();
132
170
    domains
133
170
        .into_iter()
134
170
        .reduce(|a, b| a.and_then(|x| b.and_then(|y| x.apply_i32(op, &y))))
135
170
        .flatten()
136
170
}
137

            
138
241
fn range_vec_bounds_i32(ranges: &Vec<Range<i32>>) -> (i32, i32) {
139
241
    let mut min = i32::MAX;
140
241
    let mut max = i32::MIN;
141
3785
    for r in ranges {
142
3544
        match r {
143
3544
            Range::Single(i) => {
144
3544
                if *i < min {
145
316
                    min = *i;
146
3228
                }
147
3544
                if *i > max {
148
903
                    max = *i;
149
2643
                }
150
            }
151
            Range::Bounded(i, j) => {
152
                if *i < min {
153
                    min = *i;
154
                }
155
                if *j > max {
156
                    max = *j;
157
                }
158
            }
159
        }
160
    }
161
241
    (min, max)
162
241
}
163

            
164
impl Expression {
165
    /// Returns the possible values of the expression, recursing to leaf expressions
166
780
    pub fn domain_of(&self, vars: &SymbolTable) -> Option<Domain> {
167
779
        let ret = match self {
168
456
            Expression::FactorE(_, Factor::Reference(name)) => Some(vars.get(name)?.domain.clone()),
169
63
            Expression::FactorE(_, Factor::Literal(Literal::Int(n))) => {
170
63
                Some(Domain::IntDomain(vec![Range::Single(*n)]))
171
            }
172
1
            Expression::FactorE(_, Factor::Literal(Literal::Bool(_))) => Some(Domain::BoolDomain),
173
6
            Expression::Sum(_, exprs) => expr_vec_to_domain_i32(exprs, |x, y| Some(x + y), vars),
174
105
            Expression::Min(_, exprs) => {
175
1500
                expr_vec_to_domain_i32(exprs, |x, y| Some(if x < y { x } else { y }), vars)
176
            }
177
60
            Expression::Max(_, exprs) => {
178
780
                expr_vec_to_domain_i32(exprs, |x, y| Some(if x > y { x } else { y }), vars)
179
            }
180
90
            Expression::UnsafeDiv(_, a, b) | Expression::SafeDiv(_, a, b) => {
181
90
                a.domain_of(vars)?.apply_i32(
182
1545
                    |x, y| if y != 0 { Some(x / y) } else { None },
183
90
                    &b.domain_of(vars)?,
184
                )
185
            }
186
            _ => todo!("Calculate domain of {:?}", self),
187
            // TODO: (flm8) Add support for calculating the domains of more expression types
188
        };
189
776
        match ret {
190
            // TODO: (flm8) the Minion bindings currently only support single ranges for domains, so we use the min/max bounds
191
            // Once they support a full domain as we define it, we can remove this conversion
192
776
            Some(Domain::IntDomain(ranges)) if ranges.len() > 1 => {
193
241
                let (min, max) = range_vec_bounds_i32(&ranges);
194
241
                Some(Domain::IntDomain(vec![Range::Bounded(min, max)]))
195
            }
196
538
            _ => ret,
197
        }
198
780
    }
199

            
200
3885
    pub fn get_meta(&self) -> Metadata {
201
3885
        <Expression as Biplate<Metadata>>::children_bi(self)[0].clone()
202
3885
    }
203

            
204
    pub fn set_meta(&self, meta: Metadata) {
205
        <Expression as Biplate<Metadata>>::transform_bi(self, Arc::new(move |_| meta.clone()));
206
    }
207

            
208
30
    pub fn can_be_undefined(&self) -> bool {
209
30
        // TODO: there will be more false cases but we are being conservative
210
30
        match self {
211
15
            Expression::FactorE(_, _) => false,
212
15
            _ => true,
213
        }
214
30
    }
215

            
216
195
    pub fn return_type(&self) -> Option<ReturnType> {
217
        match self {
218
            Expression::FactorE(_, Factor::Literal(Literal::Int(_))) => Some(ReturnType::Int),
219
            Expression::FactorE(_, Factor::Literal(Literal::Bool(_))) => Some(ReturnType::Bool),
220
            Expression::FactorE(_, Factor::Reference(_)) => None,
221
            Expression::Sum(_, _) => Some(ReturnType::Int),
222
            Expression::Min(_, _) => Some(ReturnType::Int),
223
            Expression::Max(_, _) => Some(ReturnType::Int),
224
            Expression::Not(_, _) => Some(ReturnType::Bool),
225
            Expression::Or(_, _) => Some(ReturnType::Bool),
226
            Expression::And(_, _) => Some(ReturnType::Bool),
227
90
            Expression::Eq(_, _, _) => Some(ReturnType::Bool),
228
15
            Expression::Neq(_, _, _) => Some(ReturnType::Bool),
229
            Expression::Geq(_, _, _) => Some(ReturnType::Bool),
230
            Expression::Leq(_, _, _) => Some(ReturnType::Bool),
231
            Expression::Gt(_, _, _) => Some(ReturnType::Bool),
232
            Expression::Lt(_, _, _) => Some(ReturnType::Bool),
233
90
            Expression::SafeDiv(_, _, _) => Some(ReturnType::Int),
234
            Expression::UnsafeDiv(_, _, _) => Some(ReturnType::Int),
235
            Expression::SumEq(_, _, _) => Some(ReturnType::Bool),
236
            Expression::SumGeq(_, _, _) => Some(ReturnType::Bool),
237
            Expression::SumLeq(_, _, _) => Some(ReturnType::Bool),
238
            Expression::DivEq(_, _, _, _) => Some(ReturnType::Bool),
239
            Expression::Ineq(_, _, _, _) => Some(ReturnType::Bool),
240
            Expression::AllDiff(_, _) => Some(ReturnType::Bool),
241
            Expression::Bubble(_, _, _) => None, // TODO: (flm8) should this be a bool?
242
            Expression::WatchedLiteral(_, _, _) => Some(ReturnType::Bool),
243
            Expression::Reify(_, _, _) => Some(ReturnType::Bool),
244
        }
245
195
    }
246

            
247
    pub fn is_clean(&self) -> bool {
248
        let metadata = self.get_meta();
249
        metadata.clean
250
    }
251

            
252
    pub fn set_clean(&mut self, bool_value: bool) {
253
        let mut metadata = self.get_meta();
254
        metadata.clean = bool_value;
255
        self.set_meta(metadata);
256
    }
257

            
258
    pub fn as_factor(&self) -> Option<Factor> {
259
        if let Expression::FactorE(_m, f) = self {
260
            Some(f.clone())
261
        } else {
262
            None
263
        }
264
    }
265
}
266

            
267
fn display_expressions(expressions: &[Expression]) -> String {
268
    // if expressions.len() <= 3 {
269
    format!(
270
        "[{}]",
271
        expressions
272
            .iter()
273
            .map(|e| e.to_string())
274
            .collect::<Vec<String>>()
275
            .join(", ")
276
    )
277
    // } else {
278
    //     format!(
279
    //         "[{}..{}]",
280
    //         expressions[0],
281
    //         expressions[expressions.len() - 1]
282
    //     )
283
    // }
284
}
285

            
286
impl From<i32> for Expression {
287
90
    fn from(i: i32) -> Self {
288
90
        Expression::FactorE(Metadata::new(), Factor::Literal(Literal::Int(i)))
289
90
    }
290
}
291

            
292
impl From<bool> for Expression {
293
    fn from(b: bool) -> Self {
294
        Expression::FactorE(Metadata::new(), Factor::Literal(Literal::Bool(b)))
295
    }
296
}
297

            
298
impl From<Factor> for Expression {
299
    fn from(value: Factor) -> Self {
300
        Expression::FactorE(Metadata::new(), value)
301
    }
302
}
303
impl Display for Expression {
304
    // TODO: (flm8) this will change once we implement a parser (two-way conversion)
305
    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
306
        match &self {
307
            Expression::FactorE(_, Factor::Literal(c)) => match c {
308
                Literal::Bool(b) => write!(f, "{}", b),
309
                Literal::Int(i) => write!(f, "{}", i),
310
            },
311
            Expression::FactorE(_, Factor::Reference(name)) => match name {
312
                Name::MachineName(n) => write!(f, "_{}", n),
313
                Name::UserName(s) => write!(f, "{}", s),
314
            },
315
            Expression::Sum(_, expressions) => {
316
                write!(f, "Sum({})", display_expressions(expressions))
317
            }
318
            Expression::Min(_, expressions) => {
319
                write!(f, "Min({})", display_expressions(expressions))
320
            }
321
            Expression::Max(_, expressions) => {
322
                write!(f, "Max({})", display_expressions(expressions))
323
            }
324
            Expression::Not(_, expr_box) => {
325
                write!(f, "Not({})", expr_box.clone())
326
            }
327
            Expression::Or(_, expressions) => {
328
                write!(f, "Or({})", display_expressions(expressions))
329
            }
330
            Expression::And(_, expressions) => {
331
                write!(f, "And({})", display_expressions(expressions))
332
            }
333
            Expression::Eq(_, box1, box2) => {
334
                write!(f, "({} = {})", box1.clone(), box2.clone())
335
            }
336
            Expression::Neq(_, box1, box2) => {
337
                write!(f, "({} != {})", box1.clone(), box2.clone())
338
            }
339
            Expression::Geq(_, box1, box2) => {
340
                write!(f, "({} >= {})", box1.clone(), box2.clone())
341
            }
342
            Expression::Leq(_, box1, box2) => {
343
                write!(f, "({} <= {})", box1.clone(), box2.clone())
344
            }
345
            Expression::Gt(_, box1, box2) => {
346
                write!(f, "({} > {})", box1.clone(), box2.clone())
347
            }
348
            Expression::Lt(_, box1, box2) => {
349
                write!(f, "({} < {})", box1.clone(), box2.clone())
350
            }
351
            Expression::SumEq(_, expressions, expr_box) => {
352
                write!(
353
                    f,
354
                    "SumEq({}, {})",
355
                    display_expressions(expressions),
356
                    expr_box.clone()
357
                )
358
            }
359
            Expression::SumGeq(_, box1, box2) => {
360
                write!(f, "SumGeq({}, {})", display_expressions(box1), box2.clone())
361
            }
362
            Expression::SumLeq(_, box1, box2) => {
363
                write!(f, "SumLeq({}, {})", display_expressions(box1), box2.clone())
364
            }
365
            Expression::Ineq(_, box1, box2, box3) => write!(
366
                f,
367
                "Ineq({}, {}, {})",
368
                box1.clone(),
369
                box2.clone(),
370
                box3.clone()
371
            ),
372
            Expression::AllDiff(_, expressions) => {
373
                write!(f, "AllDiff({})", display_expressions(expressions))
374
            }
375
            Expression::Bubble(_, box1, box2) => {
376
                write!(f, "{{{} @ {}}}", box1.clone(), box2.clone())
377
            }
378
            Expression::SafeDiv(_, box1, box2) => {
379
                write!(f, "SafeDiv({}, {})", box1.clone(), box2.clone())
380
            }
381
            Expression::UnsafeDiv(_, box1, box2) => {
382
                write!(f, "UnsafeDiv({}, {})", box1.clone(), box2.clone())
383
            }
384
            Expression::DivEq(_, box1, box2, box3) => {
385
                write!(
386
                    f,
387
                    "DivEq({}, {}, {})",
388
                    box1.clone(),
389
                    box2.clone(),
390
                    box3.clone()
391
                )
392
            }
393
            #[allow(unreachable_patterns)]
394
            other => todo!("Implement display for {:?}", other),
395
        }
396
    }
397
}
398

            
399
#[cfg(test)]
400
mod tests {
401
    use crate::ast::DecisionVariable;
402

            
403
    use super::*;
404

            
405
    #[test]
406
1
    fn test_domain_of_constant_sum() {
407
1
        let c1 = Expression::FactorE(Metadata::new(), Factor::Literal(Literal::Int(1)));
408
1
        let c2 = Expression::FactorE(Metadata::new(), Factor::Literal(Literal::Int(2)));
409
1
        let sum = Expression::Sum(Metadata::new(), vec![c1.clone(), c2.clone()]);
410
1
        assert_eq!(
411
1
            sum.domain_of(&SymbolTable::new()),
412
1
            Some(Domain::IntDomain(vec![Range::Single(3)]))
413
1
        );
414
1
    }
415

            
416
    #[test]
417
1
    fn test_domain_of_constant_invalid_type() {
418
1
        let c1 = Expression::FactorE(Metadata::new(), Factor::Literal(Literal::Int(1)));
419
1
        let c2 = Expression::FactorE(Metadata::new(), Factor::Literal(Literal::Bool(true)));
420
1
        let sum = Expression::Sum(Metadata::new(), vec![c1.clone(), c2.clone()]);
421
1
        assert_eq!(sum.domain_of(&SymbolTable::new()), None);
422
1
    }
423

            
424
    #[test]
425
1
    fn test_domain_of_empty_sum() {
426
1
        let sum = Expression::Sum(Metadata::new(), vec![]);
427
1
        assert_eq!(sum.domain_of(&SymbolTable::new()), None);
428
1
    }
429

            
430
    #[test]
431
1
    fn test_domain_of_reference() {
432
1
        let reference =
433
1
            Expression::FactorE(Metadata::new(), Factor::Reference(Name::MachineName(0)));
434
1
        let mut vars = SymbolTable::new();
435
1
        vars.insert(
436
1
            Name::MachineName(0),
437
1
            DecisionVariable::new(Domain::IntDomain(vec![Range::Single(1)])),
438
1
        );
439
1
        assert_eq!(
440
1
            reference.domain_of(&vars),
441
1
            Some(Domain::IntDomain(vec![Range::Single(1)]))
442
1
        );
443
1
    }
444

            
445
    #[test]
446
1
    fn test_domain_of_reference_not_found() {
447
1
        let reference =
448
1
            Expression::FactorE(Metadata::new(), Factor::Reference(Name::MachineName(0)));
449
1
        assert_eq!(reference.domain_of(&SymbolTable::new()), None);
450
1
    }
451

            
452
    #[test]
453
1
    fn test_domain_of_reference_sum_single() {
454
1
        let reference =
455
1
            Expression::FactorE(Metadata::new(), Factor::Reference(Name::MachineName(0)));
456
1
        let mut vars = SymbolTable::new();
457
1
        vars.insert(
458
1
            Name::MachineName(0),
459
1
            DecisionVariable::new(Domain::IntDomain(vec![Range::Single(1)])),
460
1
        );
461
1
        let sum = Expression::Sum(Metadata::new(), vec![reference.clone(), reference.clone()]);
462
1
        assert_eq!(
463
1
            sum.domain_of(&vars),
464
1
            Some(Domain::IntDomain(vec![Range::Single(2)]))
465
1
        );
466
1
    }
467

            
468
    #[test]
469
1
    fn test_domain_of_reference_sum_bounded() {
470
1
        let reference =
471
1
            Expression::FactorE(Metadata::new(), Factor::Reference(Name::MachineName(0)));
472
1
        let mut vars = SymbolTable::new();
473
1
        vars.insert(
474
1
            Name::MachineName(0),
475
1
            DecisionVariable::new(Domain::IntDomain(vec![Range::Bounded(1, 2)])),
476
1
        );
477
1
        let sum = Expression::Sum(Metadata::new(), vec![reference.clone(), reference.clone()]);
478
1
        assert_eq!(
479
1
            sum.domain_of(&vars),
480
1
            Some(Domain::IntDomain(vec![Range::Bounded(2, 4)]))
481
1
        );
482
1
    }
483
}