1
use std::fmt::{Display, Formatter};
2
use std::sync::Arc;
3

            
4
use serde::{Deserialize, Serialize};
5

            
6
use enum_compatability_macro::document_compatibility;
7
use uniplate::derive::Uniplate;
8
use uniplate::Biplate;
9

            
10
use crate::ast::literals::Literal;
11
use crate::ast::symbol_table::{Name, SymbolTable};
12
use crate::ast::Factor;
13
use crate::ast::ReturnType;
14
use crate::metadata::Metadata;
15

            
16
use super::{Domain, Range};
17

            
18
/// Represents different types of expressions used to define rules and constraints in the model.
19
///
20
/// The `Expression` enum includes operations, constants, and variable references
21
/// used to build rules and conditions for the model.
22
#[document_compatibility]
23
2940
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize, Uniplate)]
24
#[uniplate(walk_into=[Factor])]
25
#[biplate(to=Literal)]
26
#[biplate(to=Metadata)]
27
#[biplate(to=Factor)]
28
#[biplate(to=Name)]
29
pub enum Expression {
30
    /// An expression representing "A is valid as long as B is true"
31
    /// Turns into a conjunction when it reaches a boolean context
32
    Bubble(Metadata, Box<Expression>, Box<Expression>),
33

            
34
    FactorE(Metadata, Factor),
35

            
36
    #[compatible(Minion, JsonInput)]
37
    Sum(Metadata, Vec<Expression>),
38

            
39
    // /// Division after preventing division by zero, usually with a top-level constraint
40
    // #[compatible(Minion)]
41
    // SafeDiv(Metadata, Box<Expression>, Box<Expression>),
42
    // /// Division with a possibly undefined value (division by 0)
43
    // #[compatible(Minion, JsonInput)]
44
    // Div(Metadata, Box<Expression>, Box<Expression>),
45
    #[compatible(JsonInput)]
46
    Min(Metadata, Vec<Expression>),
47

            
48
    #[compatible(JsonInput)]
49
    Max(Metadata, Vec<Expression>),
50

            
51
    #[compatible(JsonInput, SAT)]
52
    Not(Metadata, Box<Expression>),
53

            
54
    #[compatible(JsonInput, SAT)]
55
    Or(Metadata, Vec<Expression>),
56

            
57
    #[compatible(JsonInput, SAT)]
58
    And(Metadata, Vec<Expression>),
59

            
60
    #[compatible(JsonInput)]
61
    Eq(Metadata, Box<Expression>, Box<Expression>),
62

            
63
    #[compatible(JsonInput)]
64
    Neq(Metadata, Box<Expression>, Box<Expression>),
65

            
66
    #[compatible(JsonInput)]
67
    Geq(Metadata, Box<Expression>, Box<Expression>),
68

            
69
    #[compatible(JsonInput)]
70
    Leq(Metadata, Box<Expression>, Box<Expression>),
71

            
72
    #[compatible(JsonInput)]
73
    Gt(Metadata, Box<Expression>, Box<Expression>),
74

            
75
    #[compatible(JsonInput)]
76
    Lt(Metadata, Box<Expression>, Box<Expression>),
77

            
78
    /// Division after preventing division by zero, usually with a bubble
79
    SafeDiv(Metadata, Box<Expression>, Box<Expression>),
80

            
81
    /// Division with a possibly undefined value (division by 0)
82
    #[compatible(JsonInput)]
83
    UnsafeDiv(Metadata, Box<Expression>, Box<Expression>),
84

            
85
    /* Flattened SumEq.
86
     *
87
     * Note: this is an intermediary step that's used in the process of converting from conjure model to minion.
88
     * This is NOT a valid expression in either Essence or minion.
89
     *
90
     * ToDo: This is a stop gap solution. Eventually it may be better to have multiple constraints instead? (gs248)
91
     */
92
    SumEq(Metadata, Vec<Expression>, Box<Expression>),
93

            
94
    // Flattened Constraints
95
    #[compatible(Minion)]
96
    SumGeq(Metadata, Vec<Expression>, Box<Expression>),
97

            
98
    #[compatible(Minion)]
99
    SumLeq(Metadata, Vec<Expression>, Box<Expression>),
100

            
101
    #[compatible(Minion)]
102
    DivEq(Metadata, Box<Expression>, Box<Expression>, Box<Expression>),
103

            
104
    #[compatible(Minion)]
105
    Ineq(Metadata, Box<Expression>, Box<Expression>, Box<Expression>),
106

            
107
    #[compatible(Minion)]
108
    AllDiff(Metadata, Vec<Expression>),
109

            
110
    /// w-literal(x,k) is SAT iff x == k, where x is a variable and k a constant.
111
    ///
112
    /// This is a low-level Minion constraint and you should (probably) use Eq instead. The main
113
    /// use of w-literal is to convert boolean variables to constraints so that they can be used
114
    /// inside watched-and and watched-or.
115
    ///
116
    /// See `rules::minion::boolean_literal_to_wliteral`.
117
    ///
118
    ///
119
    #[compatible(Minion)]
120
    WatchedLiteral(Metadata, Name, Literal),
121

            
122
    #[compatible(Minion)]
123
    Reify(Metadata, Box<Expression>, Box<Expression>),
124

            
125
    /// Declaration of an auxiliary variable.
126
    ///
127
    /// As with Savile Row, we semantically distinguish this from `Eq`.
128
    #[compatible(Minion)]
129
    AuxDeclaration(Metadata, Name, Box<Expression>),
130
}
131

            
132
170
fn expr_vec_to_domain_i32(
133
170
    exprs: &[Expression],
134
170
    op: fn(i32, i32) -> Option<i32>,
135
170
    vars: &SymbolTable,
136
170
) -> Option<Domain> {
137
338
    let domains: Vec<Option<_>> = exprs.iter().map(|e| e.domain_of(vars)).collect();
138
170
    domains
139
170
        .into_iter()
140
170
        .reduce(|a, b| a.and_then(|x| b.and_then(|y| x.apply_i32(op, &y))))
141
170
        .flatten()
142
170
}
143

            
144
241
fn range_vec_bounds_i32(ranges: &Vec<Range<i32>>) -> (i32, i32) {
145
241
    let mut min = i32::MAX;
146
241
    let mut max = i32::MIN;
147
3785
    for r in ranges {
148
3544
        match r {
149
3544
            Range::Single(i) => {
150
3544
                if *i < min {
151
316
                    min = *i;
152
3228
                }
153
3544
                if *i > max {
154
903
                    max = *i;
155
2643
                }
156
            }
157
            Range::Bounded(i, j) => {
158
                if *i < min {
159
                    min = *i;
160
                }
161
                if *j > max {
162
                    max = *j;
163
                }
164
            }
165
        }
166
    }
167
241
    (min, max)
168
241
}
169

            
170
impl Expression {
171
    /// Returns the possible values of the expression, recursing to leaf expressions
172
780
    pub fn domain_of(&self, vars: &SymbolTable) -> Option<Domain> {
173
779
        let ret = match self {
174
456
            Expression::FactorE(_, Factor::Reference(name)) => Some(vars.get(name)?.domain.clone()),
175
63
            Expression::FactorE(_, Factor::Literal(Literal::Int(n))) => {
176
63
                Some(Domain::IntDomain(vec![Range::Single(*n)]))
177
            }
178
1
            Expression::FactorE(_, Factor::Literal(Literal::Bool(_))) => Some(Domain::BoolDomain),
179
6
            Expression::Sum(_, exprs) => expr_vec_to_domain_i32(exprs, |x, y| Some(x + y), vars),
180
105
            Expression::Min(_, exprs) => {
181
1500
                expr_vec_to_domain_i32(exprs, |x, y| Some(if x < y { x } else { y }), vars)
182
            }
183
60
            Expression::Max(_, exprs) => {
184
780
                expr_vec_to_domain_i32(exprs, |x, y| Some(if x > y { x } else { y }), vars)
185
            }
186
90
            Expression::UnsafeDiv(_, a, b) | Expression::SafeDiv(_, a, b) => {
187
90
                a.domain_of(vars)?.apply_i32(
188
1545
                    |x, y| if y != 0 { Some(x / y) } else { None },
189
90
                    &b.domain_of(vars)?,
190
                )
191
            }
192
            _ => todo!("Calculate domain of {:?}", self),
193
            // TODO: (flm8) Add support for calculating the domains of more expression types
194
        };
195
776
        match ret {
196
            // TODO: (flm8) the Minion bindings currently only support single ranges for domains, so we use the min/max bounds
197
            // Once they support a full domain as we define it, we can remove this conversion
198
776
            Some(Domain::IntDomain(ranges)) if ranges.len() > 1 => {
199
241
                let (min, max) = range_vec_bounds_i32(&ranges);
200
241
                Some(Domain::IntDomain(vec![Range::Bounded(min, max)]))
201
            }
202
538
            _ => ret,
203
        }
204
780
    }
205

            
206
3885
    pub fn get_meta(&self) -> Metadata {
207
3885
        <Expression as Biplate<Metadata>>::children_bi(self)[0].clone()
208
3885
    }
209

            
210
    pub fn set_meta(&self, meta: Metadata) {
211
        <Expression as Biplate<Metadata>>::transform_bi(self, Arc::new(move |_| meta.clone()));
212
    }
213

            
214
30
    pub fn can_be_undefined(&self) -> bool {
215
30
        // TODO: there will be more false cases but we are being conservative
216
30
        match self {
217
15
            Expression::FactorE(_, _) => false,
218
15
            _ => true,
219
        }
220
30
    }
221

            
222
195
    pub fn return_type(&self) -> Option<ReturnType> {
223
        match self {
224
            Expression::FactorE(_, Factor::Literal(Literal::Int(_))) => Some(ReturnType::Int),
225
            Expression::FactorE(_, Factor::Literal(Literal::Bool(_))) => Some(ReturnType::Bool),
226
            Expression::FactorE(_, Factor::Reference(_)) => None,
227
            Expression::Sum(_, _) => Some(ReturnType::Int),
228
            Expression::Min(_, _) => Some(ReturnType::Int),
229
            Expression::Max(_, _) => Some(ReturnType::Int),
230
            Expression::Not(_, _) => Some(ReturnType::Bool),
231
            Expression::Or(_, _) => Some(ReturnType::Bool),
232
            Expression::And(_, _) => Some(ReturnType::Bool),
233
90
            Expression::Eq(_, _, _) => Some(ReturnType::Bool),
234
15
            Expression::Neq(_, _, _) => Some(ReturnType::Bool),
235
            Expression::Geq(_, _, _) => Some(ReturnType::Bool),
236
            Expression::Leq(_, _, _) => Some(ReturnType::Bool),
237
            Expression::Gt(_, _, _) => Some(ReturnType::Bool),
238
            Expression::Lt(_, _, _) => Some(ReturnType::Bool),
239
90
            Expression::SafeDiv(_, _, _) => Some(ReturnType::Int),
240
            Expression::UnsafeDiv(_, _, _) => Some(ReturnType::Int),
241
            Expression::SumEq(_, _, _) => Some(ReturnType::Bool),
242
            Expression::SumGeq(_, _, _) => Some(ReturnType::Bool),
243
            Expression::SumLeq(_, _, _) => Some(ReturnType::Bool),
244
            Expression::DivEq(_, _, _, _) => Some(ReturnType::Bool),
245
            Expression::Ineq(_, _, _, _) => Some(ReturnType::Bool),
246
            Expression::AllDiff(_, _) => Some(ReturnType::Bool),
247
            Expression::Bubble(_, _, _) => None, // TODO: (flm8) should this be a bool?
248
            Expression::WatchedLiteral(_, _, _) => Some(ReturnType::Bool),
249
            Expression::Reify(_, _, _) => Some(ReturnType::Bool),
250
            Expression::AuxDeclaration(_, _, _) => Some(ReturnType::Bool),
251
        }
252
195
    }
253

            
254
    pub fn is_clean(&self) -> bool {
255
        let metadata = self.get_meta();
256
        metadata.clean
257
    }
258

            
259
    pub fn set_clean(&mut self, bool_value: bool) {
260
        let mut metadata = self.get_meta();
261
        metadata.clean = bool_value;
262
        self.set_meta(metadata);
263
    }
264

            
265
    pub fn as_factor(&self) -> Option<Factor> {
266
        if let Expression::FactorE(_m, f) = self {
267
            Some(f.clone())
268
        } else {
269
            None
270
        }
271
    }
272
}
273

            
274
fn display_expressions(expressions: &[Expression]) -> String {
275
    // if expressions.len() <= 3 {
276
    format!(
277
        "[{}]",
278
        expressions
279
            .iter()
280
            .map(|e| e.to_string())
281
            .collect::<Vec<String>>()
282
            .join(", ")
283
    )
284
    // } else {
285
    //     format!(
286
    //         "[{}..{}]",
287
    //         expressions[0],
288
    //         expressions[expressions.len() - 1]
289
    //     )
290
    // }
291
}
292

            
293
impl From<i32> for Expression {
294
90
    fn from(i: i32) -> Self {
295
90
        Expression::FactorE(Metadata::new(), Factor::Literal(Literal::Int(i)))
296
90
    }
297
}
298

            
299
impl From<bool> for Expression {
300
    fn from(b: bool) -> Self {
301
        Expression::FactorE(Metadata::new(), Factor::Literal(Literal::Bool(b)))
302
    }
303
}
304

            
305
impl From<Factor> for Expression {
306
    fn from(value: Factor) -> Self {
307
        Expression::FactorE(Metadata::new(), value)
308
    }
309
}
310
impl Display for Expression {
311
    // TODO: (flm8) this will change once we implement a parser (two-way conversion)
312
    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
313
        match &self {
314
            Expression::FactorE(_, Factor::Literal(c)) => match c {
315
                Literal::Bool(b) => write!(f, "{}", b),
316
                Literal::Int(i) => write!(f, "{}", i),
317
            },
318
            Expression::FactorE(_, Factor::Reference(name)) => match name {
319
                Name::MachineName(n) => write!(f, "_{}", n),
320
                Name::UserName(s) => write!(f, "{}", s),
321
            },
322
            Expression::Sum(_, expressions) => {
323
                write!(f, "Sum({})", display_expressions(expressions))
324
            }
325
            Expression::Min(_, expressions) => {
326
                write!(f, "Min({})", display_expressions(expressions))
327
            }
328
            Expression::Max(_, expressions) => {
329
                write!(f, "Max({})", display_expressions(expressions))
330
            }
331
            Expression::Not(_, expr_box) => {
332
                write!(f, "Not({})", expr_box.clone())
333
            }
334
            Expression::Or(_, expressions) => {
335
                write!(f, "Or({})", display_expressions(expressions))
336
            }
337
            Expression::And(_, expressions) => {
338
                write!(f, "And({})", display_expressions(expressions))
339
            }
340
            Expression::Eq(_, box1, box2) => {
341
                write!(f, "({} = {})", box1.clone(), box2.clone())
342
            }
343
            Expression::Neq(_, box1, box2) => {
344
                write!(f, "({} != {})", box1.clone(), box2.clone())
345
            }
346
            Expression::Geq(_, box1, box2) => {
347
                write!(f, "({} >= {})", box1.clone(), box2.clone())
348
            }
349
            Expression::Leq(_, box1, box2) => {
350
                write!(f, "({} <= {})", box1.clone(), box2.clone())
351
            }
352
            Expression::Gt(_, box1, box2) => {
353
                write!(f, "({} > {})", box1.clone(), box2.clone())
354
            }
355
            Expression::Lt(_, box1, box2) => {
356
                write!(f, "({} < {})", box1.clone(), box2.clone())
357
            }
358
            Expression::SumEq(_, expressions, expr_box) => {
359
                write!(
360
                    f,
361
                    "SumEq({}, {})",
362
                    display_expressions(expressions),
363
                    expr_box.clone()
364
                )
365
            }
366
            Expression::SumGeq(_, box1, box2) => {
367
                write!(f, "SumGeq({}, {})", display_expressions(box1), box2.clone())
368
            }
369
            Expression::SumLeq(_, box1, box2) => {
370
                write!(f, "SumLeq({}, {})", display_expressions(box1), box2.clone())
371
            }
372
            Expression::Ineq(_, box1, box2, box3) => write!(
373
                f,
374
                "Ineq({}, {}, {})",
375
                box1.clone(),
376
                box2.clone(),
377
                box3.clone()
378
            ),
379
            Expression::AllDiff(_, expressions) => {
380
                write!(f, "AllDiff({})", display_expressions(expressions))
381
            }
382
            Expression::Bubble(_, box1, box2) => {
383
                write!(f, "{{{} @ {}}}", box1.clone(), box2.clone())
384
            }
385
            Expression::SafeDiv(_, box1, box2) => {
386
                write!(f, "SafeDiv({}, {})", box1.clone(), box2.clone())
387
            }
388
            Expression::UnsafeDiv(_, box1, box2) => {
389
                write!(f, "UnsafeDiv({}, {})", box1.clone(), box2.clone())
390
            }
391
            Expression::DivEq(_, box1, box2, box3) => {
392
                write!(
393
                    f,
394
                    "DivEq({}, {}, {})",
395
                    box1.clone(),
396
                    box2.clone(),
397
                    box3.clone()
398
                )
399
            }
400
            Expression::WatchedLiteral(_, x, l) => {
401
                write!(f, "WatchedLiteral({},{})", x, l)
402
            }
403
            Expression::Reify(_, box1, box2) => {
404
                write!(f, "Reify({}, {})", box1.clone(), box2.clone())
405
            }
406
            Expression::AuxDeclaration(_, n, e) => {
407
                write!(f, "{} =aux {}", n, e.clone())
408
            }
409
        }
410
    }
411
}
412

            
413
#[cfg(test)]
414
mod tests {
415
    use crate::ast::DecisionVariable;
416

            
417
    use super::*;
418

            
419
    #[test]
420
1
    fn test_domain_of_constant_sum() {
421
1
        let c1 = Expression::FactorE(Metadata::new(), Factor::Literal(Literal::Int(1)));
422
1
        let c2 = Expression::FactorE(Metadata::new(), Factor::Literal(Literal::Int(2)));
423
1
        let sum = Expression::Sum(Metadata::new(), vec![c1.clone(), c2.clone()]);
424
1
        assert_eq!(
425
1
            sum.domain_of(&SymbolTable::new()),
426
1
            Some(Domain::IntDomain(vec![Range::Single(3)]))
427
1
        );
428
1
    }
429

            
430
    #[test]
431
1
    fn test_domain_of_constant_invalid_type() {
432
1
        let c1 = Expression::FactorE(Metadata::new(), Factor::Literal(Literal::Int(1)));
433
1
        let c2 = Expression::FactorE(Metadata::new(), Factor::Literal(Literal::Bool(true)));
434
1
        let sum = Expression::Sum(Metadata::new(), vec![c1.clone(), c2.clone()]);
435
1
        assert_eq!(sum.domain_of(&SymbolTable::new()), None);
436
1
    }
437

            
438
    #[test]
439
1
    fn test_domain_of_empty_sum() {
440
1
        let sum = Expression::Sum(Metadata::new(), vec![]);
441
1
        assert_eq!(sum.domain_of(&SymbolTable::new()), None);
442
1
    }
443

            
444
    #[test]
445
1
    fn test_domain_of_reference() {
446
1
        let reference =
447
1
            Expression::FactorE(Metadata::new(), Factor::Reference(Name::MachineName(0)));
448
1
        let mut vars = SymbolTable::new();
449
1
        vars.insert(
450
1
            Name::MachineName(0),
451
1
            DecisionVariable::new(Domain::IntDomain(vec![Range::Single(1)])),
452
1
        );
453
1
        assert_eq!(
454
1
            reference.domain_of(&vars),
455
1
            Some(Domain::IntDomain(vec![Range::Single(1)]))
456
1
        );
457
1
    }
458

            
459
    #[test]
460
1
    fn test_domain_of_reference_not_found() {
461
1
        let reference =
462
1
            Expression::FactorE(Metadata::new(), Factor::Reference(Name::MachineName(0)));
463
1
        assert_eq!(reference.domain_of(&SymbolTable::new()), None);
464
1
    }
465

            
466
    #[test]
467
1
    fn test_domain_of_reference_sum_single() {
468
1
        let reference =
469
1
            Expression::FactorE(Metadata::new(), Factor::Reference(Name::MachineName(0)));
470
1
        let mut vars = SymbolTable::new();
471
1
        vars.insert(
472
1
            Name::MachineName(0),
473
1
            DecisionVariable::new(Domain::IntDomain(vec![Range::Single(1)])),
474
1
        );
475
1
        let sum = Expression::Sum(Metadata::new(), vec![reference.clone(), reference.clone()]);
476
1
        assert_eq!(
477
1
            sum.domain_of(&vars),
478
1
            Some(Domain::IntDomain(vec![Range::Single(2)]))
479
1
        );
480
1
    }
481

            
482
    #[test]
483
1
    fn test_domain_of_reference_sum_bounded() {
484
1
        let reference =
485
1
            Expression::FactorE(Metadata::new(), Factor::Reference(Name::MachineName(0)));
486
1
        let mut vars = SymbolTable::new();
487
1
        vars.insert(
488
1
            Name::MachineName(0),
489
1
            DecisionVariable::new(Domain::IntDomain(vec![Range::Bounded(1, 2)])),
490
1
        );
491
1
        let sum = Expression::Sum(Metadata::new(), vec![reference.clone(), reference.clone()]);
492
1
        assert_eq!(
493
1
            sum.domain_of(&vars),
494
1
            Some(Domain::IntDomain(vec![Range::Bounded(2, 4)]))
495
1
        );
496
1
    }
497
}