1
use std::fmt::{Display, Formatter};
2
use std::sync::Arc;
3

            
4
use serde::{Deserialize, Serialize};
5

            
6
use enum_compatability_macro::document_compatibility;
7
use uniplate::derive::Uniplate;
8
use uniplate::Biplate;
9

            
10
use crate::ast::literals::Literal;
11
use crate::ast::symbol_table::{Name, SymbolTable};
12
use crate::ast::Factor;
13
use crate::ast::ReturnType;
14
use crate::metadata::Metadata;
15

            
16
use super::{Domain, Range};
17

            
18
/// Represents different types of expressions used to define rules and constraints in the model.
19
///
20
/// The `Expression` enum includes operations, constants, and variable references
21
/// used to build rules and conditions for the model.
22
#[document_compatibility]
23
4116
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize, Uniplate)]
24
#[uniplate(walk_into=[Factor])]
25
#[biplate(to=Literal)]
26
#[biplate(to=Metadata)]
27
#[biplate(to=Factor)]
28
#[biplate(to=Name)]
29
pub enum Expression {
30
    /// An expression representing "A is valid as long as B is true"
31
    /// Turns into a conjunction when it reaches a boolean context
32
    Bubble(Metadata, Box<Expression>, Box<Expression>),
33

            
34
    FactorE(Metadata, Factor),
35

            
36
    #[compatible(Minion, JsonInput)]
37
    Sum(Metadata, Vec<Expression>),
38

            
39
    // /// Division after preventing division by zero, usually with a top-level constraint
40
    // #[compatible(Minion)]
41
    // SafeDiv(Metadata, Box<Expression>, Box<Expression>),
42
    // /// Division with a possibly undefined value (division by 0)
43
    // #[compatible(Minion, JsonInput)]
44
    // Div(Metadata, Box<Expression>, Box<Expression>),
45
    #[compatible(JsonInput)]
46
    Min(Metadata, Vec<Expression>),
47

            
48
    #[compatible(JsonInput)]
49
    Max(Metadata, Vec<Expression>),
50

            
51
    #[compatible(JsonInput, SAT)]
52
    Not(Metadata, Box<Expression>),
53

            
54
    #[compatible(JsonInput, SAT)]
55
    Or(Metadata, Vec<Expression>),
56

            
57
    #[compatible(JsonInput, SAT)]
58
    And(Metadata, Vec<Expression>),
59

            
60
    #[compatible(JsonInput)]
61
    Eq(Metadata, Box<Expression>, Box<Expression>),
62

            
63
    #[compatible(JsonInput)]
64
    Neq(Metadata, Box<Expression>, Box<Expression>),
65

            
66
    #[compatible(JsonInput)]
67
    Geq(Metadata, Box<Expression>, Box<Expression>),
68

            
69
    #[compatible(JsonInput)]
70
    Leq(Metadata, Box<Expression>, Box<Expression>),
71

            
72
    #[compatible(JsonInput)]
73
    Gt(Metadata, Box<Expression>, Box<Expression>),
74

            
75
    #[compatible(JsonInput)]
76
    Lt(Metadata, Box<Expression>, Box<Expression>),
77

            
78
    /// Division after preventing division by zero, usually with a bubble
79
    SafeDiv(Metadata, Box<Expression>, Box<Expression>),
80

            
81
    /// Division with a possibly undefined value (division by 0)
82
    #[compatible(JsonInput)]
83
    UnsafeDiv(Metadata, Box<Expression>, Box<Expression>),
84

            
85
    /* Flattened SumEq.
86
     *
87
     * Note: this is an intermediary step that's used in the process of converting from conjure model to minion.
88
     * This is NOT a valid expression in either Essence or minion.
89
     *
90
     * ToDo: This is a stop gap solution. Eventually it may be better to have multiple constraints instead? (gs248)
91
     */
92
    SumEq(Metadata, Vec<Expression>, Box<Expression>),
93

            
94
    // Flattened Constraints
95
    #[compatible(Minion)]
96
    SumGeq(Metadata, Vec<Expression>, Box<Expression>),
97

            
98
    #[compatible(Minion)]
99
    SumLeq(Metadata, Vec<Expression>, Box<Expression>),
100

            
101
    #[compatible(Minion)]
102
    DivEq(Metadata, Box<Expression>, Box<Expression>, Box<Expression>),
103

            
104
    #[compatible(Minion)]
105
    Ineq(Metadata, Box<Expression>, Box<Expression>, Box<Expression>),
106

            
107
    #[compatible(Minion)]
108
    AllDiff(Metadata, Vec<Expression>),
109

            
110
    /// w-literal(x,k) is SAT iff x == k, where x is a variable and k a constant.
111
    ///
112
    /// This is a low-level Minion constraint and you should (probably) use Eq instead. The main
113
    /// use of w-literal is to convert boolean variables to constraints so that they can be used
114
    /// inside watched-and and watched-or.
115
    ///
116
    /// See `rules::minion::boolean_literal_to_wliteral`.
117
    ///
118
    ///
119
    #[compatible(Minion)]
120
    WatchedLiteral(Metadata, Name, Literal),
121

            
122
    #[compatible(Minion)]
123
    Reify(Metadata, Box<Expression>, Box<Expression>),
124

            
125
    /// Declaration of an auxiliary variable.
126
    ///
127
    /// As with Savile Row, we semantically distinguish this from `Eq`.
128
    #[compatible(Minion)]
129
    AuxDeclaration(Metadata, Name, Box<Expression>),
130
}
131

            
132
192
fn expr_vec_to_domain_i32(
133
192
    exprs: &[Expression],
134
192
    op: fn(i32, i32) -> Option<i32>,
135
192
    vars: &SymbolTable,
136
192
) -> Option<Domain> {
137
382
    let domains: Vec<Option<_>> = exprs.iter().map(|e| e.domain_of(vars)).collect();
138
192
    domains
139
192
        .into_iter()
140
192
        .reduce(|a, b| a.and_then(|x| b.and_then(|y| x.apply_i32(op, &y))))
141
192
        .flatten()
142
192
}
143

            
144
273
fn range_vec_bounds_i32(ranges: &Vec<Range<i32>>) -> (i32, i32) {
145
273
    let mut min = i32::MAX;
146
273
    let mut max = i32::MIN;
147
4289
    for r in ranges {
148
4016
        match r {
149
4016
            Range::Single(i) => {
150
4016
                if *i < min {
151
358
                    min = *i;
152
3658
                }
153
4016
                if *i > max {
154
1023
                    max = *i;
155
2995
                }
156
            }
157
            Range::Bounded(i, j) => {
158
                if *i < min {
159
                    min = *i;
160
                }
161
                if *j > max {
162
                    max = *j;
163
                }
164
            }
165
        }
166
    }
167
273
    (min, max)
168
273
}
169

            
170
impl Expression {
171
    /// Returns the possible values of the expression, recursing to leaf expressions
172
882
    pub fn domain_of(&self, vars: &SymbolTable) -> Option<Domain> {
173
881
        let ret = match self {
174
516
            Expression::FactorE(_, Factor::Reference(name)) => Some(vars.get(name)?.domain.clone()),
175
71
            Expression::FactorE(_, Factor::Literal(Literal::Int(n))) => {
176
71
                Some(Domain::IntDomain(vec![Range::Single(*n)]))
177
            }
178
1
            Expression::FactorE(_, Factor::Literal(Literal::Bool(_))) => Some(Domain::BoolDomain),
179
6
            Expression::Sum(_, exprs) => expr_vec_to_domain_i32(exprs, |x, y| Some(x + y), vars),
180
119
            Expression::Min(_, exprs) => {
181
1700
                expr_vec_to_domain_i32(exprs, |x, y| Some(if x < y { x } else { y }), vars)
182
            }
183
68
            Expression::Max(_, exprs) => {
184
884
                expr_vec_to_domain_i32(exprs, |x, y| Some(if x > y { x } else { y }), vars)
185
            }
186
102
            Expression::UnsafeDiv(_, a, b) | Expression::SafeDiv(_, a, b) => {
187
102
                a.domain_of(vars)?.apply_i32(
188
1751
                    |x, y| if y != 0 { Some(x / y) } else { None },
189
102
                    &b.domain_of(vars)?,
190
                )
191
            }
192
            _ => todo!("Calculate domain of {:?}", self),
193
            // TODO: (flm8) Add support for calculating the domains of more expression types
194
        };
195
878
        match ret {
196
            // TODO: (flm8) the Minion bindings currently only support single ranges for domains, so we use the min/max bounds
197
            // Once they support a full domain as we define it, we can remove this conversion
198
878
            Some(Domain::IntDomain(ranges)) if ranges.len() > 1 => {
199
273
                let (min, max) = range_vec_bounds_i32(&ranges);
200
273
                Some(Domain::IntDomain(vec![Range::Bounded(min, max)]))
201
            }
202
608
            _ => ret,
203
        }
204
882
    }
205

            
206
4403
    pub fn get_meta(&self) -> Metadata {
207
4403
        <Expression as Biplate<Metadata>>::children_bi(self)[0].clone()
208
4403
    }
209

            
210
    pub fn set_meta(&self, meta: Metadata) {
211
        <Expression as Biplate<Metadata>>::transform_bi(self, Arc::new(move |_| meta.clone()));
212
    }
213

            
214
34
    pub fn can_be_undefined(&self) -> bool {
215
34
        // TODO: there will be more false cases but we are being conservative
216
34
        match self {
217
17
            Expression::FactorE(_, _) => false,
218
17
            _ => true,
219
        }
220
34
    }
221

            
222
221
    pub fn return_type(&self) -> Option<ReturnType> {
223
        match self {
224
            Expression::FactorE(_, Factor::Literal(Literal::Int(_))) => Some(ReturnType::Int),
225
            Expression::FactorE(_, Factor::Literal(Literal::Bool(_))) => Some(ReturnType::Bool),
226
            Expression::FactorE(_, Factor::Reference(_)) => None,
227
            Expression::Sum(_, _) => Some(ReturnType::Int),
228
            Expression::Min(_, _) => Some(ReturnType::Int),
229
            Expression::Max(_, _) => Some(ReturnType::Int),
230
            Expression::Not(_, _) => Some(ReturnType::Bool),
231
            Expression::Or(_, _) => Some(ReturnType::Bool),
232
            Expression::And(_, _) => Some(ReturnType::Bool),
233
102
            Expression::Eq(_, _, _) => Some(ReturnType::Bool),
234
17
            Expression::Neq(_, _, _) => Some(ReturnType::Bool),
235
            Expression::Geq(_, _, _) => Some(ReturnType::Bool),
236
            Expression::Leq(_, _, _) => Some(ReturnType::Bool),
237
            Expression::Gt(_, _, _) => Some(ReturnType::Bool),
238
            Expression::Lt(_, _, _) => Some(ReturnType::Bool),
239
102
            Expression::SafeDiv(_, _, _) => Some(ReturnType::Int),
240
            Expression::UnsafeDiv(_, _, _) => Some(ReturnType::Int),
241
            Expression::SumEq(_, _, _) => Some(ReturnType::Bool),
242
            Expression::SumGeq(_, _, _) => Some(ReturnType::Bool),
243
            Expression::SumLeq(_, _, _) => Some(ReturnType::Bool),
244
            Expression::DivEq(_, _, _, _) => Some(ReturnType::Bool),
245
            Expression::Ineq(_, _, _, _) => Some(ReturnType::Bool),
246
            Expression::AllDiff(_, _) => Some(ReturnType::Bool),
247
            Expression::Bubble(_, _, _) => None, // TODO: (flm8) should this be a bool?
248
            Expression::WatchedLiteral(_, _, _) => Some(ReturnType::Bool),
249
            Expression::Reify(_, _, _) => Some(ReturnType::Bool),
250
            Expression::AuxDeclaration(_, _, _) => Some(ReturnType::Bool),
251
        }
252
221
    }
253

            
254
    pub fn is_clean(&self) -> bool {
255
        let metadata = self.get_meta();
256
        metadata.clean
257
    }
258

            
259
    pub fn set_clean(&mut self, bool_value: bool) {
260
        let mut metadata = self.get_meta();
261
        metadata.clean = bool_value;
262
        self.set_meta(metadata);
263
    }
264

            
265
    pub fn as_factor(&self) -> Option<Factor> {
266
        if let Expression::FactorE(_m, f) = self {
267
            Some(f.clone())
268
        } else {
269
            None
270
        }
271
    }
272
}
273

            
274
fn display_expressions(expressions: &[Expression]) -> String {
275
    // if expressions.len() <= 3 {
276
    format!(
277
        "[{}]",
278
        expressions
279
            .iter()
280
            .map(|e| e.to_string())
281
            .collect::<Vec<String>>()
282
            .join(", ")
283
    )
284
    // } else {
285
    //     format!(
286
    //         "[{}..{}]",
287
    //         expressions[0],
288
    //         expressions[expressions.len() - 1]
289
    //     )
290
    // }
291
}
292

            
293
impl From<i32> for Expression {
294
102
    fn from(i: i32) -> Self {
295
102
        Expression::FactorE(Metadata::new(), Factor::Literal(Literal::Int(i)))
296
102
    }
297
}
298

            
299
impl From<bool> for Expression {
300
    fn from(b: bool) -> Self {
301
        Expression::FactorE(Metadata::new(), Factor::Literal(Literal::Bool(b)))
302
    }
303
}
304

            
305
impl From<Factor> for Expression {
306
    fn from(value: Factor) -> Self {
307
        Expression::FactorE(Metadata::new(), value)
308
    }
309
}
310
impl Display for Expression {
311
    // TODO: (flm8) this will change once we implement a parser (two-way conversion)
312
    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
313
        match &self {
314
            Expression::FactorE(_, factor) => factor.fmt(f),
315
            Expression::Sum(_, expressions) => {
316
                write!(f, "Sum({})", display_expressions(expressions))
317
            }
318
            Expression::Min(_, expressions) => {
319
                write!(f, "Min({})", display_expressions(expressions))
320
            }
321
            Expression::Max(_, expressions) => {
322
                write!(f, "Max({})", display_expressions(expressions))
323
            }
324
            Expression::Not(_, expr_box) => {
325
                write!(f, "Not({})", expr_box.clone())
326
            }
327
            Expression::Or(_, expressions) => {
328
                write!(f, "Or({})", display_expressions(expressions))
329
            }
330
            Expression::And(_, expressions) => {
331
                write!(f, "And({})", display_expressions(expressions))
332
            }
333
            Expression::Eq(_, box1, box2) => {
334
                write!(f, "({} = {})", box1.clone(), box2.clone())
335
            }
336
            Expression::Neq(_, box1, box2) => {
337
                write!(f, "({} != {})", box1.clone(), box2.clone())
338
            }
339
            Expression::Geq(_, box1, box2) => {
340
                write!(f, "({} >= {})", box1.clone(), box2.clone())
341
            }
342
            Expression::Leq(_, box1, box2) => {
343
                write!(f, "({} <= {})", box1.clone(), box2.clone())
344
            }
345
            Expression::Gt(_, box1, box2) => {
346
                write!(f, "({} > {})", box1.clone(), box2.clone())
347
            }
348
            Expression::Lt(_, box1, box2) => {
349
                write!(f, "({} < {})", box1.clone(), box2.clone())
350
            }
351
            Expression::SumEq(_, expressions, expr_box) => {
352
                write!(
353
                    f,
354
                    "SumEq({}, {})",
355
                    display_expressions(expressions),
356
                    expr_box.clone()
357
                )
358
            }
359
            Expression::SumGeq(_, box1, box2) => {
360
                write!(f, "SumGeq({}, {})", display_expressions(box1), box2.clone())
361
            }
362
            Expression::SumLeq(_, box1, box2) => {
363
                write!(f, "SumLeq({}, {})", display_expressions(box1), box2.clone())
364
            }
365
            Expression::Ineq(_, box1, box2, box3) => write!(
366
                f,
367
                "Ineq({}, {}, {})",
368
                box1.clone(),
369
                box2.clone(),
370
                box3.clone()
371
            ),
372
            Expression::AllDiff(_, expressions) => {
373
                write!(f, "AllDiff({})", display_expressions(expressions))
374
            }
375
            Expression::Bubble(_, box1, box2) => {
376
                write!(f, "{{{} @ {}}}", box1.clone(), box2.clone())
377
            }
378
            Expression::SafeDiv(_, box1, box2) => {
379
                write!(f, "SafeDiv({}, {})", box1.clone(), box2.clone())
380
            }
381
            Expression::UnsafeDiv(_, box1, box2) => {
382
                write!(f, "UnsafeDiv({}, {})", box1.clone(), box2.clone())
383
            }
384
            Expression::DivEq(_, box1, box2, box3) => {
385
                write!(
386
                    f,
387
                    "DivEq({}, {}, {})",
388
                    box1.clone(),
389
                    box2.clone(),
390
                    box3.clone()
391
                )
392
            }
393
            Expression::WatchedLiteral(_, x, l) => {
394
                write!(f, "WatchedLiteral({},{})", x, l)
395
            }
396
            Expression::Reify(_, box1, box2) => {
397
                write!(f, "Reify({}, {})", box1.clone(), box2.clone())
398
            }
399
            Expression::AuxDeclaration(_, n, e) => {
400
                write!(f, "{} =aux {}", n, e.clone())
401
            }
402
        }
403
    }
404
}
405

            
406
#[cfg(test)]
407
mod tests {
408
    use crate::ast::DecisionVariable;
409

            
410
    use super::*;
411

            
412
    #[test]
413
1
    fn test_domain_of_constant_sum() {
414
1
        let c1 = Expression::FactorE(Metadata::new(), Factor::Literal(Literal::Int(1)));
415
1
        let c2 = Expression::FactorE(Metadata::new(), Factor::Literal(Literal::Int(2)));
416
1
        let sum = Expression::Sum(Metadata::new(), vec![c1.clone(), c2.clone()]);
417
1
        assert_eq!(
418
1
            sum.domain_of(&SymbolTable::new()),
419
1
            Some(Domain::IntDomain(vec![Range::Single(3)]))
420
1
        );
421
1
    }
422

            
423
    #[test]
424
1
    fn test_domain_of_constant_invalid_type() {
425
1
        let c1 = Expression::FactorE(Metadata::new(), Factor::Literal(Literal::Int(1)));
426
1
        let c2 = Expression::FactorE(Metadata::new(), Factor::Literal(Literal::Bool(true)));
427
1
        let sum = Expression::Sum(Metadata::new(), vec![c1.clone(), c2.clone()]);
428
1
        assert_eq!(sum.domain_of(&SymbolTable::new()), None);
429
1
    }
430

            
431
    #[test]
432
1
    fn test_domain_of_empty_sum() {
433
1
        let sum = Expression::Sum(Metadata::new(), vec![]);
434
1
        assert_eq!(sum.domain_of(&SymbolTable::new()), None);
435
1
    }
436

            
437
    #[test]
438
1
    fn test_domain_of_reference() {
439
1
        let reference =
440
1
            Expression::FactorE(Metadata::new(), Factor::Reference(Name::MachineName(0)));
441
1
        let mut vars = SymbolTable::new();
442
1
        vars.insert(
443
1
            Name::MachineName(0),
444
1
            DecisionVariable::new(Domain::IntDomain(vec![Range::Single(1)])),
445
1
        );
446
1
        assert_eq!(
447
1
            reference.domain_of(&vars),
448
1
            Some(Domain::IntDomain(vec![Range::Single(1)]))
449
1
        );
450
1
    }
451

            
452
    #[test]
453
1
    fn test_domain_of_reference_not_found() {
454
1
        let reference =
455
1
            Expression::FactorE(Metadata::new(), Factor::Reference(Name::MachineName(0)));
456
1
        assert_eq!(reference.domain_of(&SymbolTable::new()), None);
457
1
    }
458

            
459
    #[test]
460
1
    fn test_domain_of_reference_sum_single() {
461
1
        let reference =
462
1
            Expression::FactorE(Metadata::new(), Factor::Reference(Name::MachineName(0)));
463
1
        let mut vars = SymbolTable::new();
464
1
        vars.insert(
465
1
            Name::MachineName(0),
466
1
            DecisionVariable::new(Domain::IntDomain(vec![Range::Single(1)])),
467
1
        );
468
1
        let sum = Expression::Sum(Metadata::new(), vec![reference.clone(), reference.clone()]);
469
1
        assert_eq!(
470
1
            sum.domain_of(&vars),
471
1
            Some(Domain::IntDomain(vec![Range::Single(2)]))
472
1
        );
473
1
    }
474

            
475
    #[test]
476
1
    fn test_domain_of_reference_sum_bounded() {
477
1
        let reference =
478
1
            Expression::FactorE(Metadata::new(), Factor::Reference(Name::MachineName(0)));
479
1
        let mut vars = SymbolTable::new();
480
1
        vars.insert(
481
1
            Name::MachineName(0),
482
1
            DecisionVariable::new(Domain::IntDomain(vec![Range::Bounded(1, 2)])),
483
1
        );
484
1
        let sum = Expression::Sum(Metadata::new(), vec![reference.clone(), reference.clone()]);
485
1
        assert_eq!(
486
1
            sum.domain_of(&vars),
487
1
            Some(Domain::IntDomain(vec![Range::Bounded(2, 4)]))
488
1
        );
489
1
    }
490
}