1
use std::collections::VecDeque;
2
use std::fmt::{Display, Formatter};
3
use std::sync::Arc;
4

            
5
use serde::{Deserialize, Serialize};
6

            
7
use enum_compatability_macro::document_compatibility;
8
use uniplate::derive::Uniplate;
9
use uniplate::{Biplate, Uniplate as _};
10

            
11
use crate::ast::literals::Literal;
12
use crate::ast::pretty::pretty_vec;
13
use crate::ast::symbol_table::{Name, SymbolTable};
14
use crate::ast::Atom;
15
use crate::ast::ReturnType;
16
use crate::metadata::Metadata;
17

            
18
use super::{Domain, Range};
19

            
20
/// Represents different types of expressions used to define rules and constraints in the model.
21
///
22
/// The `Expression` enum includes operations, constants, and variable references
23
/// used to build rules and conditions for the model.
24
#[document_compatibility]
25
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize, Uniplate)]
26
#[uniplate(walk_into=[Atom])]
27
#[biplate(to=Literal)]
28
#[biplate(to=Metadata)]
29
#[biplate(to=Atom)]
30
#[biplate(to=Name)]
31
#[biplate(to=Vec<Expression>)]
32
pub enum Expression {
33
    /// An expression representing "A is valid as long as B is true"
34
    /// Turns into a conjunction when it reaches a boolean context
35
    Bubble(Metadata, Box<Expression>, Box<Expression>),
36

            
37
    Atomic(Metadata, Atom),
38

            
39
    /// `|x|` - absolute value of `x`
40
    #[compatible(JsonInput)]
41
    Abs(Metadata, Box<Expression>),
42

            
43
    #[compatible(JsonInput)]
44
    Sum(Metadata, Vec<Expression>),
45

            
46
    #[compatible(JsonInput)]
47
    Product(Metadata, Vec<Expression>),
48

            
49
    #[compatible(JsonInput)]
50
    Min(Metadata, Vec<Expression>),
51

            
52
    #[compatible(JsonInput)]
53
    Max(Metadata, Vec<Expression>),
54

            
55
    #[compatible(JsonInput, SAT)]
56
    Not(Metadata, Box<Expression>),
57

            
58
    #[compatible(JsonInput, SAT)]
59
    Or(Metadata, Vec<Expression>),
60

            
61
    #[compatible(JsonInput, SAT)]
62
    And(Metadata, Vec<Expression>),
63

            
64
    /// Ensures that `a->b` (material implication).
65
    #[compatible(JsonInput)]
66
    Imply(Metadata, Box<Expression>, Box<Expression>),
67

            
68
    #[compatible(JsonInput)]
69
    Eq(Metadata, Box<Expression>, Box<Expression>),
70

            
71
    #[compatible(JsonInput)]
72
    Neq(Metadata, Box<Expression>, Box<Expression>),
73

            
74
    #[compatible(JsonInput)]
75
    Geq(Metadata, Box<Expression>, Box<Expression>),
76

            
77
    #[compatible(JsonInput)]
78
    Leq(Metadata, Box<Expression>, Box<Expression>),
79

            
80
    #[compatible(JsonInput)]
81
    Gt(Metadata, Box<Expression>, Box<Expression>),
82

            
83
    #[compatible(JsonInput)]
84
    Lt(Metadata, Box<Expression>, Box<Expression>),
85

            
86
    /// Division after preventing division by zero, usually with a bubble
87
    SafeDiv(Metadata, Box<Expression>, Box<Expression>),
88

            
89
    /// Division with a possibly undefined value (division by 0)
90
    #[compatible(JsonInput)]
91
    UnsafeDiv(Metadata, Box<Expression>, Box<Expression>),
92

            
93
    /// Modulo after preventing mod 0, usually with a bubble
94
    SafeMod(Metadata, Box<Expression>, Box<Expression>),
95

            
96
    /// Modulo with a possibly undefined value (mod 0)
97
    #[compatible(JsonInput)]
98
    UnsafeMod(Metadata, Box<Expression>, Box<Expression>),
99

            
100
    /// Negation: `-x`
101
    #[compatible(JsonInput)]
102
    Neg(Metadata, Box<Expression>),
103

            
104
    /// Unsafe power`x**y` (possibly undefined)
105
    ///
106
    /// Defined when (X!=0 \\/ Y!=0) /\ Y>=0
107
    #[compatible(JsonInput)]
108
    UnsafePow(Metadata, Box<Expression>, Box<Expression>),
109

            
110
    /// `UnsafePow` after preventing undefinedness
111
    SafePow(Metadata, Box<Expression>, Box<Expression>),
112

            
113
    #[compatible(JsonInput)]
114
    AllDiff(Metadata, Vec<Expression>),
115

            
116
    /// Binary subtraction operator
117
    ///
118
    /// This is a parser-level construct, and is immediately normalised to `Sum([a,-b])`.
119
    #[compatible(JsonInput)]
120
    Minus(Metadata, Box<Expression>, Box<Expression>),
121

            
122
    /// Ensures that x=|y| i.e. x is the absolute value of y.
123
    ///
124
    /// Low-level Minion constraint.
125
    ///
126
    /// # See also
127
    ///
128
    /// + [Minion documentation](https://minion-solver.readthedocs.io/en/stable/usage/constraints.html#abs)
129
    #[compatible(Minion)]
130
    FlatAbsEq(Metadata, Atom, Atom),
131

            
132
    /// Ensures that sum(vec) >= x.
133
    ///
134
    /// Low-level Minion constraint.
135
    ///
136
    /// # See also
137
    ///
138
    /// + [Minion documentation](https://minion-solver.readthedocs.io/en/stable/usage/constraints.html#sumgeq)
139
    #[compatible(Minion)]
140
    FlatSumGeq(Metadata, Vec<Atom>, Atom),
141

            
142
    /// Ensures that sum(vec) <= x.
143
    ///
144
    /// Low-level Minion constraint.
145
    ///
146
    /// # See also
147
    ///
148
    /// + [Minion documentation](https://minion-solver.readthedocs.io/en/stable/usage/constraints.html#sumleq)
149
    #[compatible(Minion)]
150
    FlatSumLeq(Metadata, Vec<Atom>, Atom),
151

            
152
    /// `ineq(x,y,k)` ensures that x <= y + k.
153
    ///
154
    /// Low-level Minion constraint.
155
    ///
156
    /// # See also
157
    ///
158
    /// + [Minion documentation](https://minion-solver.readthedocs.io/en/stable/usage/constraints.html#ineq)
159
    #[compatible(Minion)]
160
    FlatIneq(Metadata, Atom, Atom, Literal),
161

            
162
    /// `w-literal(x,k)` ensures that x == k, where x is a variable and k a constant.
163
    ///
164
    /// Low-level Minion constraint.
165
    ///
166
    /// This is a low-level Minion constraint and you should probably use Eq instead. The main use
167
    /// of w-literal is to convert boolean variables to constraints so that they can be used inside
168
    /// watched-and and watched-or.
169
    ///
170
    /// # See also
171
    ///
172
    /// + [Minion documentation](https://minion-solver.readthedocs.io/en/stable/usage/constraints.html#minuseq)
173
    /// + `rules::minion::boolean_literal_to_wliteral`.
174
    #[compatible(Minion)]
175
    FlatWatchedLiteral(Metadata, Name, Literal),
176

            
177
    /// `weightedsumleq(cs,xs,total)` ensures that cs.xs <= total, where cs.xs is the scalar dot
178
    /// product of cs and xs.
179
    ///
180
    /// Low-level Minion constraint.
181
    ///
182
    /// Represents a weighted sum of the form `ax + by + cz + ...`
183
    ///
184
    /// # See also
185
    ///
186
    /// + [Minion
187
    /// documentation](https://minion-solver.readthedocs.io/en/stable/usage/constraints.html#weightedsumleq)
188
    FlatWeightedSumLeq(Metadata, Vec<Literal>, Vec<Atom>, Atom),
189

            
190
    /// `weightedsumgeq(cs,xs,total)` ensures that cs.xs >= total, where cs.xs is the scalar dot
191
    /// product of cs and xs.
192
    ///
193
    /// Low-level Minion constraint.
194
    ///
195
    /// Represents a weighted sum of the form `ax + by + cz + ...`
196
    ///
197
    /// # See also
198
    ///
199
    /// + [Minion
200
    /// documentation](https://minion-solver.readthedocs.io/en/stable/usage/constraints.html#weightedsumleq)
201
    FlatWeightedSumGeq(Metadata, Vec<Literal>, Vec<Atom>, Atom),
202

            
203
    /// Ensures that x =-y, where x and y are atoms.
204
    ///
205
    /// Low-level Minion constraint.
206
    ///
207
    /// # See also
208
    ///
209
    /// + [Minion documentation](https://minion-solver.readthedocs.io/en/stable/usage/constraints.html#minuseq)
210
    #[compatible(Minion)]
211
    FlatMinusEq(Metadata, Atom, Atom),
212

            
213
    /// Ensures that x*y=z.
214
    ///
215
    /// Low-level Minion constraint.
216
    ///
217
    /// # See also
218
    ///
219
    /// + [Minion documentation](https://minion-solver.readthedocs.io/en/stable/usage/constraints.html#product)
220
    #[compatible(Minion)]
221
    FlatProductEq(Metadata, Atom, Atom, Atom),
222

            
223
    /// Ensures that floor(x/y)=z. Always true when y=0.
224
    ///
225
    /// Low-level Minion constraint.
226
    ///
227
    /// # See also
228
    ///
229
    /// + [Minion documentation](https://minion-solver.readthedocs.io/en/stable/usage/constraints.html#div_undefzero)
230
    #[compatible(Minion)]
231
    MinionDivEqUndefZero(Metadata, Atom, Atom, Atom),
232

            
233
    /// Ensures that x%y=z. Always true when y=0.
234
    ///
235
    /// Low-level Minion constraint.
236
    ///
237
    /// # See also
238
    ///
239
    /// + [Minion documentation](https://minion-solver.readthedocs.io/en/stable/usage/constraints.html#mod_undefzero)
240
    #[compatible(Minion)]
241
    MinionModuloEqUndefZero(Metadata, Atom, Atom, Atom),
242

            
243
    /// Ensures that `x**y = z`.
244
    ///
245
    /// Low-level Minion constraint.
246
    ///
247
    /// This constraint is false when `y<0` except for `1**y=1` and `(-1)**y=z` (where z is 1 if y
248
    /// is odd and z is -1 if y is even).
249
    ///
250
    /// # See also
251
    ///
252
    /// + [Github comment about `pow` semantics](https://github.com/minion/minion/issues/40#issuecomment-2595914891)
253
    /// + [Minion documentation](https://minion-solver.readthedocs.io/en/stable/usage/constraints.html#pow)
254
    MinionPow(Metadata, Atom, Atom, Atom),
255

            
256
    /// `reify(constraint,r)` ensures that r=1 iff `constraint` is satisfied, where r is a 0/1
257
    /// variable.
258
    ///
259
    /// Low-level Minion constraint.
260
    ///
261
    /// # See also
262
    ///
263
    ///  + [Minion documentation](https://minion-solver.readthedocs.io/en/stable/usage/constraints.html#reify)
264
    #[compatible(Minion)]
265
    MinionReify(Metadata, Box<Expression>, Atom),
266

            
267
    /// `reifyimply(constraint,r)` ensures that `r->constraint`, where r is a 0/1 variable.
268
    /// variable.
269
    ///
270
    /// Low-level Minion constraint.
271
    ///
272
    /// # See also
273
    ///
274
    ///  + [Minion documentation](https://minion-solver.readthedocs.io/en/stable/usage/constraints.html#reifyimply)
275
    #[compatible(Minion)]
276
    MinionReifyImply(Metadata, Box<Expression>, Atom),
277

            
278
    /// Declaration of an auxiliary variable.
279
    ///
280
    /// As with Savile Row, we semantically distinguish this from `Eq`.
281
    #[compatible(Minion)]
282
    AuxDeclaration(Metadata, Name, Box<Expression>),
283
}
284

            
285
413
fn expr_vec_to_domain_i32(
286
413
    exprs: &[Expression],
287
413
    op: fn(i32, i32) -> Option<i32>,
288
413
    vars: &SymbolTable,
289
413
) -> Option<Domain> {
290
943
    let domains: Vec<Option<_>> = exprs.iter().map(|e| e.domain_of(vars)).collect();
291
413
    domains
292
413
        .into_iter()
293
532
        .reduce(|a, b| a.and_then(|x| b.and_then(|y| x.apply_i32(op, &y))))
294
413
        .flatten()
295
413
}
296

            
297
1888
fn range_vec_bounds_i32(ranges: &Vec<Range<i32>>) -> (i32, i32) {
298
1888
    let mut min = i32::MAX;
299
1888
    let mut max = i32::MIN;
300
83203
    for r in ranges {
301
81315
        match r {
302
81315
            Range::Single(i) => {
303
81315
                if *i < min {
304
3401
                    min = *i;
305
77914
                }
306
81315
                if *i > max {
307
11019
                    max = *i;
308
70298
                }
309
            }
310
            Range::Bounded(i, j) => {
311
                if *i < min {
312
                    min = *i;
313
                }
314
                if *j > max {
315
                    max = *j;
316
                }
317
            }
318
        }
319
    }
320
1888
    (min, max)
321
1888
}
322

            
323
impl Expression {
324
    /// Returns the possible values of the expression, recursing to leaf expressions
325
6084
    pub fn domain_of(&self, syms: &SymbolTable) -> Option<Domain> {
326
6015
        let ret = match self {
327
3236
            Expression::Atomic(_, Atom::Reference(name)) => Some(syms.domain_of(name)?.clone()),
328
462
            Expression::Atomic(_, Atom::Literal(Literal::Int(n))) => {
329
462
                Some(Domain::IntDomain(vec![Range::Single(*n)]))
330
            }
331
1
            Expression::Atomic(_, Atom::Literal(Literal::Bool(_))) => Some(Domain::BoolDomain),
332
21035
            Expression::Sum(_, exprs) => expr_vec_to_domain_i32(exprs, |x, y| Some(x + y), syms),
333
85
            Expression::Product(_, exprs) => {
334
2550
                expr_vec_to_domain_i32(exprs, |x, y| Some(x * y), syms)
335
            }
336
102
            Expression::Min(_, exprs) => {
337
1683
                expr_vec_to_domain_i32(exprs, |x, y| Some(if x < y { x } else { y }), syms)
338
            }
339
51
            Expression::Max(_, exprs) => {
340
612
                expr_vec_to_domain_i32(exprs, |x, y| Some(if x > y { x } else { y }), syms)
341
            }
342
            Expression::UnsafeDiv(_, a, b) => a.domain_of(syms)?.apply_i32(
343
                // rust integer division is truncating; however, we want to always round down,
344
                // including for negative numbers.
345
                |x, y| {
346
                    if y != 0 {
347
                        Some((x as f32 / y as f32).floor() as i32)
348
                    } else {
349
                        None
350
                    }
351
                },
352
                &b.domain_of(syms)?,
353
            ),
354
799
            Expression::SafeDiv(_, a, b) => {
355
                // rust integer division is truncating; however, we want to always round down
356
                // including for negative numbers.
357
799
                let domain = a.domain_of(syms)?.apply_i32(
358
24157
                    |x, y| {
359
24157
                        if y != 0 {
360
21318
                            Some((x as f32 / y as f32).floor() as i32)
361
                        } else {
362
2839
                            None
363
                        }
364
24157
                    },
365
799
                    &b.domain_of(syms)?,
366
                );
367

            
368
799
                match domain {
369
799
                    Some(Domain::IntDomain(ranges)) => {
370
799
                        let mut ranges = ranges;
371
799
                        ranges.push(Range::Single(0));
372
799
                        Some(Domain::IntDomain(ranges))
373
                    }
374
                    None => Some(Domain::IntDomain(vec![Range::Single(0)])),
375
                    _ => None,
376
                }
377
            }
378
            Expression::UnsafeMod(_, a, b) => a.domain_of(syms)?.apply_i32(
379
                |x, y| if y != 0 { Some(x % y) } else { None },
380
                &b.domain_of(syms)?,
381
            ),
382

            
383
459
            Expression::SafeMod(_, a, b) => {
384
459
                let domain = a.domain_of(syms)?.apply_i32(
385
11067
                    |x, y| if y != 0 { Some(x % y) } else { None },
386
459
                    &b.domain_of(syms)?,
387
                );
388

            
389
459
                match domain {
390
459
                    Some(Domain::IntDomain(ranges)) => {
391
459
                        let mut ranges = ranges;
392
459
                        ranges.push(Range::Single(0));
393
459
                        Some(Domain::IntDomain(ranges))
394
                    }
395
                    None => Some(Domain::IntDomain(vec![Range::Single(0)])),
396
                    _ => None,
397
                }
398
            }
399

            
400
102
            Expression::SafePow(_, a, b) | Expression::UnsafePow(_, a, b) => {
401
102
                a.domain_of(syms)?.apply_i32(
402
8636
                    |x, y| {
403
8636
                        if (x != 0 || y != 0) && y >= 0 {
404
8296
                            Some(x ^ y)
405
                        } else {
406
340
                            None
407
                        }
408
8636
                    },
409
102
                    &b.domain_of(syms)?,
410
                )
411
            }
412

            
413
            Expression::Bubble(_, _, _) => None,
414
            Expression::AuxDeclaration(_, _, _) => Some(Domain::BoolDomain),
415
51
            Expression::And(_, _) => Some(Domain::BoolDomain),
416
            Expression::Not(_, _) => Some(Domain::BoolDomain),
417
            Expression::Or(_, _) => Some(Domain::BoolDomain),
418
            Expression::Imply(_, _, _) => Some(Domain::BoolDomain),
419
153
            Expression::Eq(_, _, _) => Some(Domain::BoolDomain),
420
            Expression::Neq(_, _, _) => Some(Domain::BoolDomain),
421
            Expression::Geq(_, _, _) => Some(Domain::BoolDomain),
422
119
            Expression::Leq(_, _, _) => Some(Domain::BoolDomain),
423
            Expression::Gt(_, _, _) => Some(Domain::BoolDomain),
424
            Expression::Lt(_, _, _) => Some(Domain::BoolDomain),
425
            Expression::FlatAbsEq(_, _, _) => Some(Domain::BoolDomain),
426
            Expression::FlatSumGeq(_, _, _) => Some(Domain::BoolDomain),
427
            Expression::FlatSumLeq(_, _, _) => Some(Domain::BoolDomain),
428
            Expression::MinionDivEqUndefZero(_, _, _, _) => Some(Domain::BoolDomain),
429
            Expression::MinionModuloEqUndefZero(_, _, _, _) => Some(Domain::BoolDomain),
430
            Expression::FlatIneq(_, _, _, _) => Some(Domain::BoolDomain),
431
            Expression::AllDiff(_, _) => Some(Domain::BoolDomain),
432
            Expression::FlatWatchedLiteral(_, _, _) => Some(Domain::BoolDomain),
433
            Expression::MinionReify(_, _, _) => Some(Domain::BoolDomain),
434
            Expression::MinionReifyImply(_, _, _) => Some(Domain::BoolDomain),
435
136
            Expression::Neg(_, x) => {
436
136
                let Some(Domain::IntDomain(mut ranges)) = x.domain_of(syms) else {
437
                    return None;
438
                };
439

            
440
136
                for range in ranges.iter_mut() {
441
136
                    *range = match range {
442
                        Range::Single(x) => Range::Single(-*x),
443
136
                        Range::Bounded(x, y) => Range::Bounded(-*y, -*x),
444
                    };
445
                }
446

            
447
136
                Some(Domain::IntDomain(ranges))
448
            }
449
            Expression::Minus(_, a, b) => a
450
                .domain_of(syms)?
451
                .apply_i32(|x, y| Some(x - y), &b.domain_of(syms)?),
452

            
453
            Expression::FlatMinusEq(_, _, _) => Some(Domain::BoolDomain),
454
            Expression::FlatProductEq(_, _, _, _) => Some(Domain::BoolDomain),
455
            Expression::FlatWeightedSumLeq(_, _, _, _) => Some(Domain::BoolDomain),
456
            Expression::FlatWeightedSumGeq(_, _, _, _) => Some(Domain::BoolDomain),
457
153
            Expression::Abs(_, a) => a
458
153
                .domain_of(syms)?
459
17765
                .apply_i32(|a, _| Some(a.abs()), &a.domain_of(syms)?),
460
            Expression::MinionPow(_, _, _, _) => Some(Domain::BoolDomain),
461
        };
462
5655
        match ret {
463
            // TODO: (flm8) the Minion bindings currently only support single ranges for domains, so we use the min/max bounds
464
            // Once they support a full domain as we define it, we can remove this conversion
465
5655
            Some(Domain::IntDomain(ranges)) if ranges.len() > 1 => {
466
1888
                let (min, max) = range_vec_bounds_i32(&ranges);
467
1888
                Some(Domain::IntDomain(vec![Range::Bounded(min, max)]))
468
            }
469
4127
            _ => ret,
470
        }
471
6084
    }
472

            
473
    pub fn get_meta(&self) -> Metadata {
474
        let metas: VecDeque<Metadata> = self.children_bi();
475
        metas[0].clone()
476
    }
477

            
478
    pub fn set_meta(&self, meta: Metadata) {
479
        self.transform_bi(Arc::new(move |_| meta.clone()));
480
    }
481

            
482
    /// Checks whether this expression is safe.
483
    ///
484
    /// An expression is unsafe if can be undefined, or if any of its children can be undefined.
485
    ///
486
    /// Unsafe expressions are (typically) prefixed with Unsafe in our AST, and can be made
487
    /// safe through the use of bubble rules.
488
4556
    pub fn is_safe(&self) -> bool {
489
        // TODO: memoise in Metadata
490
9877
        for expr in self.universe() {
491
9877
            match expr {
492
                Expression::UnsafeDiv(_, _, _)
493
                | Expression::UnsafeMod(_, _, _)
494
                | Expression::UnsafePow(_, _, _) => {
495
238
                    return false;
496
                }
497
9639
                _ => {}
498
            }
499
        }
500
4318
        true
501
4556
    }
502

            
503
17544
    pub fn return_type(&self) -> Option<ReturnType> {
504
4658
        match self {
505
4641
            Expression::Atomic(_, Atom::Literal(Literal::Int(_))) => Some(ReturnType::Int),
506
17
            Expression::Atomic(_, Atom::Literal(Literal::Bool(_))) => Some(ReturnType::Bool),
507
2363
            Expression::Atomic(_, Atom::Reference(_)) => None,
508
170
            Expression::Abs(_, _) => Some(ReturnType::Int),
509
119
            Expression::Sum(_, _) => Some(ReturnType::Int),
510
102
            Expression::Product(_, _) => Some(ReturnType::Int),
511
            Expression::Min(_, _) => Some(ReturnType::Int),
512
            Expression::Max(_, _) => Some(ReturnType::Int),
513
            Expression::Not(_, _) => Some(ReturnType::Bool),
514
            Expression::Or(_, _) => Some(ReturnType::Bool),
515
            Expression::Imply(_, _, _) => Some(ReturnType::Bool),
516
51
            Expression::And(_, _) => Some(ReturnType::Bool),
517
1258
            Expression::Eq(_, _, _) => Some(ReturnType::Bool),
518
136
            Expression::Neq(_, _, _) => Some(ReturnType::Bool),
519
            Expression::Geq(_, _, _) => Some(ReturnType::Bool),
520
221
            Expression::Leq(_, _, _) => Some(ReturnType::Bool),
521
            Expression::Gt(_, _, _) => Some(ReturnType::Bool),
522
            Expression::Lt(_, _, _) => Some(ReturnType::Bool),
523
4233
            Expression::SafeDiv(_, _, _) => Some(ReturnType::Int),
524
68
            Expression::UnsafeDiv(_, _, _) => Some(ReturnType::Int),
525
            Expression::FlatSumGeq(_, _, _) => Some(ReturnType::Bool),
526
            Expression::FlatSumLeq(_, _, _) => Some(ReturnType::Bool),
527
            Expression::MinionDivEqUndefZero(_, _, _, _) => Some(ReturnType::Bool),
528
            Expression::FlatIneq(_, _, _, _) => Some(ReturnType::Bool),
529
            Expression::AllDiff(_, _) => Some(ReturnType::Bool),
530
            Expression::Bubble(_, _, _) => None, // TODO: (flm8) should this be a bool?
531
            Expression::FlatWatchedLiteral(_, _, _) => Some(ReturnType::Bool),
532
            Expression::MinionReify(_, _, _) => Some(ReturnType::Bool),
533
            Expression::MinionReifyImply(_, _, _) => Some(ReturnType::Bool),
534
            Expression::AuxDeclaration(_, _, _) => Some(ReturnType::Bool),
535
51
            Expression::UnsafeMod(_, _, _) => Some(ReturnType::Int),
536
3740
            Expression::SafeMod(_, _, _) => Some(ReturnType::Int),
537
            Expression::MinionModuloEqUndefZero(_, _, _, _) => Some(ReturnType::Bool),
538
170
            Expression::Neg(_, _) => Some(ReturnType::Int),
539
17
            Expression::UnsafePow(_, _, _) => Some(ReturnType::Int),
540
187
            Expression::SafePow(_, _, _) => Some(ReturnType::Int),
541
            Expression::Minus(_, _, _) => Some(ReturnType::Int),
542
            Expression::FlatAbsEq(_, _, _) => Some(ReturnType::Bool),
543
            Expression::FlatMinusEq(_, _, _) => Some(ReturnType::Bool),
544
            Expression::FlatProductEq(_, _, _, _) => Some(ReturnType::Bool),
545
            Expression::FlatWeightedSumLeq(_, _, _, _) => Some(ReturnType::Bool),
546
            Expression::FlatWeightedSumGeq(_, _, _, _) => Some(ReturnType::Bool),
547
            Expression::MinionPow(_, _, _, _) => Some(ReturnType::Bool),
548
        }
549
17544
    }
550

            
551
    pub fn is_clean(&self) -> bool {
552
        let metadata = self.get_meta();
553
        metadata.clean
554
    }
555

            
556
    pub fn set_clean(&mut self, bool_value: bool) {
557
        let mut metadata = self.get_meta();
558
        metadata.clean = bool_value;
559
        self.set_meta(metadata);
560
    }
561

            
562
    /// True if the expression is an associative and commutative operator
563
339830
    pub fn is_associative_commutative_operator(&self) -> bool {
564
316353
        matches!(
565
339830
            self,
566
            Expression::Sum(_, _)
567
                | Expression::Or(_, _)
568
                | Expression::And(_, _)
569
                | Expression::Product(_, _)
570
        )
571
339830
    }
572

            
573
    /// True iff self and other are both atomic and identical.
574
    ///
575
    /// This method is useful to cheaply check equivalence. Assuming CSE is enabled, any unifiable
576
    /// expressions will be rewritten to a common variable. This is much cheaper than checking the
577
    /// entire subtrees of `self` and `other`.
578
49776
    pub fn identical_atom_to(&self, other: &Expression) -> bool {
579
49776
        let atom1: Result<&Atom, _> = self.try_into();
580
49776
        let atom2: Result<&Atom, _> = other.try_into();
581

            
582
49776
        if let (Ok(atom1), Ok(atom2)) = (atom1, atom2) {
583
714
            atom2 == atom1
584
        } else {
585
49062
            false
586
        }
587
49776
    }
588
}
589

            
590
impl From<i32> for Expression {
591
816
    fn from(i: i32) -> Self {
592
816
        Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Int(i)))
593
816
    }
594
}
595

            
596
impl From<bool> for Expression {
597
102
    fn from(b: bool) -> Self {
598
102
        Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Bool(b)))
599
102
    }
600
}
601

            
602
impl From<Atom> for Expression {
603
272
    fn from(value: Atom) -> Self {
604
272
        Expression::Atomic(Metadata::new(), value)
605
272
    }
606
}
607
impl Display for Expression {
608
    // TODO: (flm8) this will change once we implement a parser (two-way conversion)
609
183770
    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
610
183770
        match &self {
611
95455
            Expression::Atomic(_, atom) => atom.fmt(f),
612
1530
            Expression::Abs(_, a) => write!(f, "|{}|", a),
613
5151
            Expression::Sum(_, expressions) => {
614
5151
                write!(f, "Sum({})", pretty_vec(expressions))
615
            }
616
3570
            Expression::Product(_, expressions) => {
617
3570
                write!(f, "Product({})", pretty_vec(expressions))
618
            }
619
238
            Expression::Min(_, expressions) => {
620
238
                write!(f, "Min({})", pretty_vec(expressions))
621
            }
622
102
            Expression::Max(_, expressions) => {
623
102
                write!(f, "Max({})", pretty_vec(expressions))
624
            }
625
1428
            Expression::Not(_, expr_box) => {
626
1428
                write!(f, "Not({})", expr_box.clone())
627
            }
628
2601
            Expression::Or(_, expressions) => {
629
2601
                write!(f, "Or({})", pretty_vec(expressions))
630
            }
631
9707
            Expression::And(_, expressions) => {
632
9707
                write!(f, "And({})", pretty_vec(expressions))
633
            }
634
2074
            Expression::Imply(_, box1, box2) => {
635
2074
                write!(f, "({}) -> ({})", box1, box2)
636
            }
637
9605
            Expression::Eq(_, box1, box2) => {
638
9605
                write!(f, "({} = {})", box1.clone(), box2.clone())
639
            }
640
9350
            Expression::Neq(_, box1, box2) => {
641
9350
                write!(f, "({} != {})", box1.clone(), box2.clone())
642
            }
643
1530
            Expression::Geq(_, box1, box2) => {
644
1530
                write!(f, "({} >= {})", box1.clone(), box2.clone())
645
            }
646
2907
            Expression::Leq(_, box1, box2) => {
647
2907
                write!(f, "({} <= {})", box1.clone(), box2.clone())
648
            }
649
68
            Expression::Gt(_, box1, box2) => {
650
68
                write!(f, "({} > {})", box1.clone(), box2.clone())
651
            }
652
714
            Expression::Lt(_, box1, box2) => {
653
714
                write!(f, "({} < {})", box1.clone(), box2.clone())
654
            }
655
799
            Expression::FlatSumGeq(_, box1, box2) => {
656
799
                write!(f, "SumGeq({}, {})", pretty_vec(box1), box2.clone())
657
            }
658
731
            Expression::FlatSumLeq(_, box1, box2) => {
659
731
                write!(f, "SumLeq({}, {})", pretty_vec(box1), box2.clone())
660
            }
661
1972
            Expression::FlatIneq(_, box1, box2, box3) => write!(
662
1972
                f,
663
1972
                "Ineq({}, {}, {})",
664
1972
                box1.clone(),
665
1972
                box2.clone(),
666
1972
                box3.clone()
667
1972
            ),
668
68
            Expression::AllDiff(_, expressions) => {
669
68
                write!(f, "AllDiff({})", pretty_vec(expressions))
670
            }
671
4250
            Expression::Bubble(_, box1, box2) => {
672
4250
                write!(f, "{{{} @ {}}}", box1.clone(), box2.clone())
673
            }
674
8211
            Expression::SafeDiv(_, box1, box2) => {
675
8211
                write!(f, "SafeDiv({}, {})", box1.clone(), box2.clone())
676
            }
677
2414
            Expression::UnsafeDiv(_, box1, box2) => {
678
2414
                write!(f, "UnsafeDiv({}, {})", box1.clone(), box2.clone())
679
            }
680
340
            Expression::UnsafePow(_, box1, box2) => {
681
340
                write!(f, "UnsafePow({}, {})", box1.clone(), box2.clone())
682
            }
683
714
            Expression::SafePow(_, box1, box2) => {
684
714
                write!(f, "SafePow({}, {})", box1.clone(), box2.clone())
685
            }
686
1122
            Expression::MinionDivEqUndefZero(_, box1, box2, box3) => {
687
1122
                write!(
688
1122
                    f,
689
1122
                    "DivEq({}, {}, {})",
690
1122
                    box1.clone(),
691
1122
                    box2.clone(),
692
1122
                    box3.clone()
693
1122
                )
694
            }
695
782
            Expression::MinionModuloEqUndefZero(_, box1, box2, box3) => {
696
782
                write!(
697
782
                    f,
698
782
                    "ModEq({}, {}, {})",
699
782
                    box1.clone(),
700
782
                    box2.clone(),
701
782
                    box3.clone()
702
782
                )
703
            }
704

            
705
170
            Expression::FlatWatchedLiteral(_, x, l) => {
706
170
                write!(f, "WatchedLiteral({},{})", x, l)
707
            }
708
714
            Expression::MinionReify(_, box1, box2) => {
709
714
                write!(f, "Reify({}, {})", box1.clone(), box2.clone())
710
            }
711
442
            Expression::MinionReifyImply(_, box1, box2) => {
712
442
                write!(f, "ReifyImply({}, {})", box1.clone(), box2.clone())
713
            }
714
3264
            Expression::AuxDeclaration(_, n, e) => {
715
3264
                write!(f, "{} =aux {}", n, e.clone())
716
            }
717
1394
            Expression::UnsafeMod(_, a, b) => {
718
1394
                write!(f, "{} % {}", a.clone(), b.clone())
719
            }
720
5066
            Expression::SafeMod(_, a, b) => {
721
5066
                write!(f, "SafeMod({},{})", a.clone(), b.clone())
722
            }
723
3978
            Expression::Neg(_, a) => {
724
3978
                write!(f, "-({})", a.clone())
725
            }
726
170
            Expression::Minus(_, a, b) => {
727
170
                write!(f, "({} - {})", a.clone(), b.clone())
728
            }
729
272
            Expression::FlatAbsEq(_, a, b) => {
730
272
                write!(f, "AbsEq({},{})", a.clone(), b.clone())
731
            }
732
136
            Expression::FlatMinusEq(_, a, b) => {
733
136
                write!(f, "MinusEq({},{})", a.clone(), b.clone())
734
            }
735
153
            Expression::FlatProductEq(_, a, b, c) => {
736
153
                write!(
737
153
                    f,
738
153
                    "FlatProductEq({},{},{})",
739
153
                    a.clone(),
740
153
                    b.clone(),
741
153
                    c.clone()
742
153
                )
743
            }
744
238
            Expression::FlatWeightedSumLeq(_, cs, vs, total) => {
745
238
                write!(
746
238
                    f,
747
238
                    "FlatWeightedSumLeq({},{},{})",
748
238
                    pretty_vec(cs),
749
238
                    pretty_vec(vs),
750
238
                    total.clone()
751
238
                )
752
            }
753

            
754
170
            Expression::FlatWeightedSumGeq(_, cs, vs, total) => {
755
170
                write!(
756
170
                    f,
757
170
                    "FlatWeightedSumGeq({},{},{})",
758
170
                    pretty_vec(cs),
759
170
                    pretty_vec(vs),
760
170
                    total.clone()
761
170
                )
762
            }
763
170
            Expression::MinionPow(_, atom, atom1, atom2) => {
764
170
                write!(f, "MinionPow({},{},{})", atom, atom1, atom2)
765
            }
766
        }
767
183770
    }
768
}
769

            
770
#[cfg(test)]
771
mod tests {
772
    use crate::ast::DecisionVariable;
773

            
774
    use super::*;
775

            
776
    #[test]
777
1
    fn test_domain_of_constant_sum() {
778
1
        let c1 = Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Int(1)));
779
1
        let c2 = Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Int(2)));
780
1
        let sum = Expression::Sum(Metadata::new(), vec![c1.clone(), c2.clone()]);
781
1
        assert_eq!(
782
1
            sum.domain_of(&SymbolTable::new()),
783
1
            Some(Domain::IntDomain(vec![Range::Single(3)]))
784
1
        );
785
1
    }
786

            
787
    #[test]
788
1
    fn test_domain_of_constant_invalid_type() {
789
1
        let c1 = Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Int(1)));
790
1
        let c2 = Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Bool(true)));
791
1
        let sum = Expression::Sum(Metadata::new(), vec![c1.clone(), c2.clone()]);
792
1
        assert_eq!(sum.domain_of(&SymbolTable::new()), None);
793
1
    }
794

            
795
    #[test]
796
1
    fn test_domain_of_empty_sum() {
797
1
        let sum = Expression::Sum(Metadata::new(), vec![]);
798
1
        assert_eq!(sum.domain_of(&SymbolTable::new()), None);
799
1
    }
800

            
801
    #[test]
802
1
    fn test_domain_of_reference() {
803
1
        let reference = Expression::Atomic(Metadata::new(), Atom::Reference(Name::MachineName(0)));
804
1
        let mut vars = SymbolTable::new();
805
1
        vars.add_var(
806
1
            Name::MachineName(0),
807
1
            DecisionVariable::new(Domain::IntDomain(vec![Range::Single(1)])),
808
1
        );
809
1
        assert_eq!(
810
1
            reference.domain_of(&vars),
811
1
            Some(Domain::IntDomain(vec![Range::Single(1)]))
812
1
        );
813
1
    }
814

            
815
    #[test]
816
1
    fn test_domain_of_reference_not_found() {
817
1
        let reference = Expression::Atomic(Metadata::new(), Atom::Reference(Name::MachineName(0)));
818
1
        assert_eq!(reference.domain_of(&SymbolTable::new()), None);
819
1
    }
820

            
821
    #[test]
822
1
    fn test_domain_of_reference_sum_single() {
823
1
        let reference = Expression::Atomic(Metadata::new(), Atom::Reference(Name::MachineName(0)));
824
1
        let mut vars = SymbolTable::new();
825
1
        vars.add_var(
826
1
            Name::MachineName(0),
827
1
            DecisionVariable::new(Domain::IntDomain(vec![Range::Single(1)])),
828
1
        );
829
1
        let sum = Expression::Sum(Metadata::new(), vec![reference.clone(), reference.clone()]);
830
1
        assert_eq!(
831
1
            sum.domain_of(&vars),
832
1
            Some(Domain::IntDomain(vec![Range::Single(2)]))
833
1
        );
834
1
    }
835

            
836
    #[test]
837
1
    fn test_domain_of_reference_sum_bounded() {
838
1
        let reference = Expression::Atomic(Metadata::new(), Atom::Reference(Name::MachineName(0)));
839
1
        let mut vars = SymbolTable::new();
840
1
        vars.add_var(
841
1
            Name::MachineName(0),
842
1
            DecisionVariable::new(Domain::IntDomain(vec![Range::Bounded(1, 2)])),
843
1
        );
844
1
        let sum = Expression::Sum(Metadata::new(), vec![reference.clone(), reference.clone()]);
845
1
        assert_eq!(
846
1
            sum.domain_of(&vars),
847
1
            Some(Domain::IntDomain(vec![Range::Bounded(2, 4)]))
848
1
        );
849
1
    }
850
}