1
use std::collections::VecDeque;
2
use std::fmt::{Display, Formatter};
3
use std::sync::Arc;
4

            
5
use serde::{Deserialize, Serialize};
6

            
7
use enum_compatability_macro::document_compatibility;
8
use uniplate::derive::Uniplate;
9
use uniplate::{Biplate, Uniplate as _};
10

            
11
use crate::ast::literals::Literal;
12
use crate::ast::pretty::{pretty_expressions_as_top_level, pretty_vec};
13
use crate::ast::symbol_table::SymbolTable;
14
use crate::ast::Atom;
15
use crate::ast::Name;
16
use crate::ast::ReturnType;
17
use crate::metadata::Metadata;
18

            
19
use super::{Domain, Range};
20

            
21
/// Represents different types of expressions used to define rules and constraints in the model.
22
///
23
/// The `Expression` enum includes operations, constants, and variable references
24
/// used to build rules and conditions for the model.
25
#[document_compatibility]
26
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize, Uniplate)]
27
#[uniplate(walk_into=[Atom])]
28
#[biplate(to=Literal)]
29
#[biplate(to=Metadata)]
30
#[biplate(to=Atom)]
31
#[biplate(to=Name)]
32
#[biplate(to=Vec<Expression>)]
33
pub enum Expression {
34
    /// The top of the model
35
    Root(Metadata, Vec<Expression>),
36

            
37
    /// An expression representing "A is valid as long as B is true"
38
    /// Turns into a conjunction when it reaches a boolean context
39
    Bubble(Metadata, Box<Expression>, Box<Expression>),
40

            
41
    Atomic(Metadata, Atom),
42

            
43
    /// `|x|` - absolute value of `x`
44
    #[compatible(JsonInput)]
45
    Abs(Metadata, Box<Expression>),
46

            
47
    #[compatible(JsonInput)]
48
    Sum(Metadata, Vec<Expression>),
49

            
50
    #[compatible(JsonInput)]
51
    Product(Metadata, Vec<Expression>),
52

            
53
    #[compatible(JsonInput)]
54
    Min(Metadata, Vec<Expression>),
55

            
56
    #[compatible(JsonInput)]
57
    Max(Metadata, Vec<Expression>),
58

            
59
    #[compatible(JsonInput, SAT)]
60
    Not(Metadata, Box<Expression>),
61

            
62
    #[compatible(JsonInput, SAT)]
63
    Or(Metadata, Vec<Expression>),
64

            
65
    #[compatible(JsonInput, SAT)]
66
    And(Metadata, Vec<Expression>),
67

            
68
    /// Ensures that `a->b` (material implication).
69
    #[compatible(JsonInput)]
70
    Imply(Metadata, Box<Expression>, Box<Expression>),
71

            
72
    #[compatible(JsonInput)]
73
    Eq(Metadata, Box<Expression>, Box<Expression>),
74

            
75
    #[compatible(JsonInput)]
76
    Neq(Metadata, Box<Expression>, Box<Expression>),
77

            
78
    #[compatible(JsonInput)]
79
    Geq(Metadata, Box<Expression>, Box<Expression>),
80

            
81
    #[compatible(JsonInput)]
82
    Leq(Metadata, Box<Expression>, Box<Expression>),
83

            
84
    #[compatible(JsonInput)]
85
    Gt(Metadata, Box<Expression>, Box<Expression>),
86

            
87
    #[compatible(JsonInput)]
88
    Lt(Metadata, Box<Expression>, Box<Expression>),
89

            
90
    /// Division after preventing division by zero, usually with a bubble
91
    SafeDiv(Metadata, Box<Expression>, Box<Expression>),
92

            
93
    /// Division with a possibly undefined value (division by 0)
94
    #[compatible(JsonInput)]
95
    UnsafeDiv(Metadata, Box<Expression>, Box<Expression>),
96

            
97
    /// Modulo after preventing mod 0, usually with a bubble
98
    SafeMod(Metadata, Box<Expression>, Box<Expression>),
99

            
100
    /// Modulo with a possibly undefined value (mod 0)
101
    #[compatible(JsonInput)]
102
    UnsafeMod(Metadata, Box<Expression>, Box<Expression>),
103

            
104
    /// Negation: `-x`
105
    #[compatible(JsonInput)]
106
    Neg(Metadata, Box<Expression>),
107

            
108
    /// Unsafe power`x**y` (possibly undefined)
109
    ///
110
    /// Defined when (X!=0 \\/ Y!=0) /\ Y>=0
111
    #[compatible(JsonInput)]
112
    UnsafePow(Metadata, Box<Expression>, Box<Expression>),
113

            
114
    /// `UnsafePow` after preventing undefinedness
115
    SafePow(Metadata, Box<Expression>, Box<Expression>),
116

            
117
    #[compatible(JsonInput)]
118
    AllDiff(Metadata, Vec<Expression>),
119

            
120
    /// Binary subtraction operator
121
    ///
122
    /// This is a parser-level construct, and is immediately normalised to `Sum([a,-b])`.
123
    #[compatible(JsonInput)]
124
    Minus(Metadata, Box<Expression>, Box<Expression>),
125

            
126
    /// Ensures that x=|y| i.e. x is the absolute value of y.
127
    ///
128
    /// Low-level Minion constraint.
129
    ///
130
    /// # See also
131
    ///
132
    /// + [Minion documentation](https://minion-solver.readthedocs.io/en/stable/usage/constraints.html#abs)
133
    #[compatible(Minion)]
134
    FlatAbsEq(Metadata, Atom, Atom),
135

            
136
    /// Ensures that sum(vec) >= x.
137
    ///
138
    /// Low-level Minion constraint.
139
    ///
140
    /// # See also
141
    ///
142
    /// + [Minion documentation](https://minion-solver.readthedocs.io/en/stable/usage/constraints.html#sumgeq)
143
    #[compatible(Minion)]
144
    FlatSumGeq(Metadata, Vec<Atom>, Atom),
145

            
146
    /// Ensures that sum(vec) <= x.
147
    ///
148
    /// Low-level Minion constraint.
149
    ///
150
    /// # See also
151
    ///
152
    /// + [Minion documentation](https://minion-solver.readthedocs.io/en/stable/usage/constraints.html#sumleq)
153
    #[compatible(Minion)]
154
    FlatSumLeq(Metadata, Vec<Atom>, Atom),
155

            
156
    /// `ineq(x,y,k)` ensures that x <= y + k.
157
    ///
158
    /// Low-level Minion constraint.
159
    ///
160
    /// # See also
161
    ///
162
    /// + [Minion documentation](https://minion-solver.readthedocs.io/en/stable/usage/constraints.html#ineq)
163
    #[compatible(Minion)]
164
    FlatIneq(Metadata, Atom, Atom, Literal),
165

            
166
    /// `w-literal(x,k)` ensures that x == k, where x is a variable and k a constant.
167
    ///
168
    /// Low-level Minion constraint.
169
    ///
170
    /// This is a low-level Minion constraint and you should probably use Eq instead. The main use
171
    /// of w-literal is to convert boolean variables to constraints so that they can be used inside
172
    /// watched-and and watched-or.
173
    ///
174
    /// # See also
175
    ///
176
    /// + [Minion documentation](https://minion-solver.readthedocs.io/en/stable/usage/constraints.html#minuseq)
177
    /// + `rules::minion::boolean_literal_to_wliteral`.
178
    #[compatible(Minion)]
179
    FlatWatchedLiteral(Metadata, Name, Literal),
180

            
181
    /// `weightedsumleq(cs,xs,total)` ensures that cs.xs <= total, where cs.xs is the scalar dot
182
    /// product of cs and xs.
183
    ///
184
    /// Low-level Minion constraint.
185
    ///
186
    /// Represents a weighted sum of the form `ax + by + cz + ...`
187
    ///
188
    /// # See also
189
    ///
190
    /// + [Minion
191
    /// documentation](https://minion-solver.readthedocs.io/en/stable/usage/constraints.html#weightedsumleq)
192
    FlatWeightedSumLeq(Metadata, Vec<Literal>, Vec<Atom>, Atom),
193

            
194
    /// `weightedsumgeq(cs,xs,total)` ensures that cs.xs >= total, where cs.xs is the scalar dot
195
    /// product of cs and xs.
196
    ///
197
    /// Low-level Minion constraint.
198
    ///
199
    /// Represents a weighted sum of the form `ax + by + cz + ...`
200
    ///
201
    /// # See also
202
    ///
203
    /// + [Minion
204
    /// documentation](https://minion-solver.readthedocs.io/en/stable/usage/constraints.html#weightedsumleq)
205
    FlatWeightedSumGeq(Metadata, Vec<Literal>, Vec<Atom>, Atom),
206

            
207
    /// Ensures that x =-y, where x and y are atoms.
208
    ///
209
    /// Low-level Minion constraint.
210
    ///
211
    /// # See also
212
    ///
213
    /// + [Minion documentation](https://minion-solver.readthedocs.io/en/stable/usage/constraints.html#minuseq)
214
    #[compatible(Minion)]
215
    FlatMinusEq(Metadata, Atom, Atom),
216

            
217
    /// Ensures that x*y=z.
218
    ///
219
    /// Low-level Minion constraint.
220
    ///
221
    /// # See also
222
    ///
223
    /// + [Minion documentation](https://minion-solver.readthedocs.io/en/stable/usage/constraints.html#product)
224
    #[compatible(Minion)]
225
    FlatProductEq(Metadata, Atom, Atom, Atom),
226

            
227
    /// Ensures that floor(x/y)=z. Always true when y=0.
228
    ///
229
    /// Low-level Minion constraint.
230
    ///
231
    /// # See also
232
    ///
233
    /// + [Minion documentation](https://minion-solver.readthedocs.io/en/stable/usage/constraints.html#div_undefzero)
234
    #[compatible(Minion)]
235
    MinionDivEqUndefZero(Metadata, Atom, Atom, Atom),
236

            
237
    /// Ensures that x%y=z. Always true when y=0.
238
    ///
239
    /// Low-level Minion constraint.
240
    ///
241
    /// # See also
242
    ///
243
    /// + [Minion documentation](https://minion-solver.readthedocs.io/en/stable/usage/constraints.html#mod_undefzero)
244
    #[compatible(Minion)]
245
    MinionModuloEqUndefZero(Metadata, Atom, Atom, Atom),
246

            
247
    /// Ensures that `x**y = z`.
248
    ///
249
    /// Low-level Minion constraint.
250
    ///
251
    /// This constraint is false when `y<0` except for `1**y=1` and `(-1)**y=z` (where z is 1 if y
252
    /// is odd and z is -1 if y is even).
253
    ///
254
    /// # See also
255
    ///
256
    /// + [Github comment about `pow` semantics](https://github.com/minion/minion/issues/40#issuecomment-2595914891)
257
    /// + [Minion documentation](https://minion-solver.readthedocs.io/en/stable/usage/constraints.html#pow)
258
    MinionPow(Metadata, Atom, Atom, Atom),
259

            
260
    /// `reify(constraint,r)` ensures that r=1 iff `constraint` is satisfied, where r is a 0/1
261
    /// variable.
262
    ///
263
    /// Low-level Minion constraint.
264
    ///
265
    /// # See also
266
    ///
267
    ///  + [Minion documentation](https://minion-solver.readthedocs.io/en/stable/usage/constraints.html#reify)
268
    #[compatible(Minion)]
269
    MinionReify(Metadata, Box<Expression>, Atom),
270

            
271
    /// `reifyimply(constraint,r)` ensures that `r->constraint`, where r is a 0/1 variable.
272
    /// variable.
273
    ///
274
    /// Low-level Minion constraint.
275
    ///
276
    /// # See also
277
    ///
278
    ///  + [Minion documentation](https://minion-solver.readthedocs.io/en/stable/usage/constraints.html#reifyimply)
279
    #[compatible(Minion)]
280
    MinionReifyImply(Metadata, Box<Expression>, Atom),
281

            
282
    /// Declaration of an auxiliary variable.
283
    ///
284
    /// As with Savile Row, we semantically distinguish this from `Eq`.
285
    #[compatible(Minion)]
286
    AuxDeclaration(Metadata, Name, Box<Expression>),
287
}
288

            
289
379
fn expr_vec_to_domain_i32(
290
379
    exprs: &[Expression],
291
379
    op: fn(i32, i32) -> Option<i32>,
292
379
    vars: &SymbolTable,
293
379
) -> Option<Domain> {
294
858
    let domains: Vec<Option<_>> = exprs.iter().map(|e| e.domain_of(vars)).collect();
295
379
    domains
296
379
        .into_iter()
297
481
        .reduce(|a, b| a.and_then(|x| b.and_then(|y| x.apply_i32(op, &y))))
298
379
        .flatten()
299
379
}
300

            
301
1888
fn range_vec_bounds_i32(ranges: &Vec<Range<i32>>) -> (i32, i32) {
302
1888
    let mut min = i32::MAX;
303
1888
    let mut max = i32::MIN;
304
83203
    for r in ranges {
305
81315
        match r {
306
81315
            Range::Single(i) => {
307
81315
                if *i < min {
308
3401
                    min = *i;
309
77914
                }
310
81315
                if *i > max {
311
11019
                    max = *i;
312
70298
                }
313
            }
314
            Range::Bounded(i, j) => {
315
                if *i < min {
316
                    min = *i;
317
                }
318
                if *j > max {
319
                    max = *j;
320
                }
321
            }
322
        }
323
    }
324
1888
    (min, max)
325
1888
}
326

            
327
impl Expression {
328
    /// Returns the possible values of the expression, recursing to leaf expressions
329
5999
    pub fn domain_of(&self, syms: &SymbolTable) -> Option<Domain> {
330
5998
        let ret = match self {
331
3168
            Expression::Atomic(_, Atom::Reference(name)) => Some(syms.domain(name)?),
332
445
            Expression::Atomic(_, Atom::Literal(Literal::Int(n))) => {
333
445
                Some(Domain::IntDomain(vec![Range::Single(*n)]))
334
            }
335
1
            Expression::Atomic(_, Atom::Literal(Literal::Bool(_))) => Some(Domain::BoolDomain),
336
21035
            Expression::Sum(_, exprs) => expr_vec_to_domain_i32(exprs, |x, y| Some(x + y), syms),
337
85
            Expression::Product(_, exprs) => {
338
2550
                expr_vec_to_domain_i32(exprs, |x, y| Some(x * y), syms)
339
            }
340
102
            Expression::Min(_, exprs) => {
341
1683
                expr_vec_to_domain_i32(exprs, |x, y| Some(if x < y { x } else { y }), syms)
342
            }
343
51
            Expression::Max(_, exprs) => {
344
612
                expr_vec_to_domain_i32(exprs, |x, y| Some(if x > y { x } else { y }), syms)
345
            }
346
            Expression::UnsafeDiv(_, a, b) => a.domain_of(syms)?.apply_i32(
347
                // rust integer division is truncating; however, we want to always round down,
348
                // including for negative numbers.
349
                |x, y| {
350
                    if y != 0 {
351
                        Some((x as f32 / y as f32).floor() as i32)
352
                    } else {
353
                        None
354
                    }
355
                },
356
                &b.domain_of(syms)?,
357
            ),
358
799
            Expression::SafeDiv(_, a, b) => {
359
                // rust integer division is truncating; however, we want to always round down
360
                // including for negative numbers.
361
799
                let domain = a.domain_of(syms)?.apply_i32(
362
24157
                    |x, y| {
363
24157
                        if y != 0 {
364
21318
                            Some((x as f32 / y as f32).floor() as i32)
365
                        } else {
366
2839
                            None
367
                        }
368
24157
                    },
369
799
                    &b.domain_of(syms)?,
370
                );
371

            
372
799
                match domain {
373
799
                    Some(Domain::IntDomain(ranges)) => {
374
799
                        let mut ranges = ranges;
375
799
                        ranges.push(Range::Single(0));
376
799
                        Some(Domain::IntDomain(ranges))
377
                    }
378
                    None => Some(Domain::IntDomain(vec![Range::Single(0)])),
379
                    _ => None,
380
                }
381
            }
382
            Expression::UnsafeMod(_, a, b) => a.domain_of(syms)?.apply_i32(
383
                |x, y| if y != 0 { Some(x % y) } else { None },
384
                &b.domain_of(syms)?,
385
            ),
386

            
387
459
            Expression::SafeMod(_, a, b) => {
388
459
                let domain = a.domain_of(syms)?.apply_i32(
389
11067
                    |x, y| if y != 0 { Some(x % y) } else { None },
390
459
                    &b.domain_of(syms)?,
391
                );
392

            
393
459
                match domain {
394
459
                    Some(Domain::IntDomain(ranges)) => {
395
459
                        let mut ranges = ranges;
396
459
                        ranges.push(Range::Single(0));
397
459
                        Some(Domain::IntDomain(ranges))
398
                    }
399
                    None => Some(Domain::IntDomain(vec![Range::Single(0)])),
400
                    _ => None,
401
                }
402
            }
403

            
404
102
            Expression::SafePow(_, a, b) | Expression::UnsafePow(_, a, b) => {
405
102
                a.domain_of(syms)?.apply_i32(
406
8636
                    |x, y| {
407
8636
                        if (x != 0 || y != 0) && y >= 0 {
408
8296
                            Some(x ^ y)
409
                        } else {
410
340
                            None
411
                        }
412
8636
                    },
413
102
                    &b.domain_of(syms)?,
414
                )
415
            }
416

            
417
            Expression::Root(_, _) => None,
418
            Expression::Bubble(_, _, _) => None,
419
            Expression::AuxDeclaration(_, _, _) => Some(Domain::BoolDomain),
420
68
            Expression::And(_, _) => Some(Domain::BoolDomain),
421
            Expression::Not(_, _) => Some(Domain::BoolDomain),
422
            Expression::Or(_, _) => Some(Domain::BoolDomain),
423
            Expression::Imply(_, _, _) => Some(Domain::BoolDomain),
424
153
            Expression::Eq(_, _, _) => Some(Domain::BoolDomain),
425
            Expression::Neq(_, _, _) => Some(Domain::BoolDomain),
426
            Expression::Geq(_, _, _) => Some(Domain::BoolDomain),
427
136
            Expression::Leq(_, _, _) => Some(Domain::BoolDomain),
428
            Expression::Gt(_, _, _) => Some(Domain::BoolDomain),
429
            Expression::Lt(_, _, _) => Some(Domain::BoolDomain),
430
            Expression::FlatAbsEq(_, _, _) => Some(Domain::BoolDomain),
431
            Expression::FlatSumGeq(_, _, _) => Some(Domain::BoolDomain),
432
            Expression::FlatSumLeq(_, _, _) => Some(Domain::BoolDomain),
433
            Expression::MinionDivEqUndefZero(_, _, _, _) => Some(Domain::BoolDomain),
434
            Expression::MinionModuloEqUndefZero(_, _, _, _) => Some(Domain::BoolDomain),
435
            Expression::FlatIneq(_, _, _, _) => Some(Domain::BoolDomain),
436
            Expression::AllDiff(_, _) => Some(Domain::BoolDomain),
437
            Expression::FlatWatchedLiteral(_, _, _) => Some(Domain::BoolDomain),
438
            Expression::MinionReify(_, _, _) => Some(Domain::BoolDomain),
439
            Expression::MinionReifyImply(_, _, _) => Some(Domain::BoolDomain),
440
136
            Expression::Neg(_, x) => {
441
136
                let Some(Domain::IntDomain(mut ranges)) = x.domain_of(syms) else {
442
                    return None;
443
                };
444

            
445
136
                for range in ranges.iter_mut() {
446
136
                    *range = match range {
447
                        Range::Single(x) => Range::Single(-*x),
448
136
                        Range::Bounded(x, y) => Range::Bounded(-*y, -*x),
449
                    };
450
                }
451

            
452
136
                Some(Domain::IntDomain(ranges))
453
            }
454
            Expression::Minus(_, a, b) => a
455
                .domain_of(syms)?
456
                .apply_i32(|x, y| Some(x - y), &b.domain_of(syms)?),
457

            
458
            Expression::FlatMinusEq(_, _, _) => Some(Domain::BoolDomain),
459
            Expression::FlatProductEq(_, _, _, _) => Some(Domain::BoolDomain),
460
            Expression::FlatWeightedSumLeq(_, _, _, _) => Some(Domain::BoolDomain),
461
            Expression::FlatWeightedSumGeq(_, _, _, _) => Some(Domain::BoolDomain),
462
153
            Expression::Abs(_, a) => a
463
153
                .domain_of(syms)?
464
17765
                .apply_i32(|a, _| Some(a.abs()), &a.domain_of(syms)?),
465
            Expression::MinionPow(_, _, _, _) => Some(Domain::BoolDomain),
466
        };
467
5638
        match ret {
468
            // TODO: (flm8) the Minion bindings currently only support single ranges for domains, so we use the min/max bounds
469
            // Once they support a full domain as we define it, we can remove this conversion
470
5638
            Some(Domain::IntDomain(ranges)) if ranges.len() > 1 => {
471
1888
                let (min, max) = range_vec_bounds_i32(&ranges);
472
1888
                Some(Domain::IntDomain(vec![Range::Bounded(min, max)]))
473
            }
474
4110
            _ => ret,
475
        }
476
5999
    }
477

            
478
    pub fn get_meta(&self) -> Metadata {
479
        let metas: VecDeque<Metadata> = self.children_bi();
480
        metas[0].clone()
481
    }
482

            
483
    pub fn set_meta(&self, meta: Metadata) {
484
        self.transform_bi(Arc::new(move |_| meta.clone()));
485
    }
486

            
487
    /// Checks whether this expression is safe.
488
    ///
489
    /// An expression is unsafe if can be undefined, or if any of its children can be undefined.
490
    ///
491
    /// Unsafe expressions are (typically) prefixed with Unsafe in our AST, and can be made
492
    /// safe through the use of bubble rules.
493
4556
    pub fn is_safe(&self) -> bool {
494
        // TODO: memoise in Metadata
495
9860
        for expr in self.universe() {
496
9860
            match expr {
497
                Expression::UnsafeDiv(_, _, _)
498
                | Expression::UnsafeMod(_, _, _)
499
                | Expression::UnsafePow(_, _, _) => {
500
238
                    return false;
501
                }
502
9622
                _ => {}
503
            }
504
        }
505
4318
        true
506
4556
    }
507

            
508
18037
    pub fn return_type(&self) -> Option<ReturnType> {
509
4658
        match self {
510
            Expression::Root(_, _) => Some(ReturnType::Bool),
511
4641
            Expression::Atomic(_, Atom::Literal(Literal::Int(_))) => Some(ReturnType::Int),
512
17
            Expression::Atomic(_, Atom::Literal(Literal::Bool(_))) => Some(ReturnType::Bool),
513
2363
            Expression::Atomic(_, Atom::Reference(_)) => None,
514
170
            Expression::Abs(_, _) => Some(ReturnType::Int),
515
102
            Expression::Sum(_, _) => Some(ReturnType::Int),
516
102
            Expression::Product(_, _) => Some(ReturnType::Int),
517
            Expression::Min(_, _) => Some(ReturnType::Int),
518
            Expression::Max(_, _) => Some(ReturnType::Int),
519
            Expression::Not(_, _) => Some(ReturnType::Bool),
520
            Expression::Or(_, _) => Some(ReturnType::Bool),
521
            Expression::Imply(_, _, _) => Some(ReturnType::Bool),
522
68
            Expression::And(_, _) => Some(ReturnType::Bool),
523
1649
            Expression::Eq(_, _, _) => Some(ReturnType::Bool),
524
204
            Expression::Neq(_, _, _) => Some(ReturnType::Bool),
525
            Expression::Geq(_, _, _) => Some(ReturnType::Bool),
526
255
            Expression::Leq(_, _, _) => Some(ReturnType::Bool),
527
            Expression::Gt(_, _, _) => Some(ReturnType::Bool),
528
            Expression::Lt(_, _, _) => Some(ReturnType::Bool),
529
4233
            Expression::SafeDiv(_, _, _) => Some(ReturnType::Int),
530
68
            Expression::UnsafeDiv(_, _, _) => Some(ReturnType::Int),
531
            Expression::FlatSumGeq(_, _, _) => Some(ReturnType::Bool),
532
            Expression::FlatSumLeq(_, _, _) => Some(ReturnType::Bool),
533
            Expression::MinionDivEqUndefZero(_, _, _, _) => Some(ReturnType::Bool),
534
            Expression::FlatIneq(_, _, _, _) => Some(ReturnType::Bool),
535
            Expression::AllDiff(_, _) => Some(ReturnType::Bool),
536
            Expression::Bubble(_, _, _) => None, // TODO: (flm8) should this be a bool?
537
            Expression::FlatWatchedLiteral(_, _, _) => Some(ReturnType::Bool),
538
            Expression::MinionReify(_, _, _) => Some(ReturnType::Bool),
539
            Expression::MinionReifyImply(_, _, _) => Some(ReturnType::Bool),
540
            Expression::AuxDeclaration(_, _, _) => Some(ReturnType::Bool),
541
51
            Expression::UnsafeMod(_, _, _) => Some(ReturnType::Int),
542
3740
            Expression::SafeMod(_, _, _) => Some(ReturnType::Int),
543
            Expression::MinionModuloEqUndefZero(_, _, _, _) => Some(ReturnType::Bool),
544
170
            Expression::Neg(_, _) => Some(ReturnType::Int),
545
17
            Expression::UnsafePow(_, _, _) => Some(ReturnType::Int),
546
187
            Expression::SafePow(_, _, _) => Some(ReturnType::Int),
547
            Expression::Minus(_, _, _) => Some(ReturnType::Int),
548
            Expression::FlatAbsEq(_, _, _) => Some(ReturnType::Bool),
549
            Expression::FlatMinusEq(_, _, _) => Some(ReturnType::Bool),
550
            Expression::FlatProductEq(_, _, _, _) => Some(ReturnType::Bool),
551
            Expression::FlatWeightedSumLeq(_, _, _, _) => Some(ReturnType::Bool),
552
            Expression::FlatWeightedSumGeq(_, _, _, _) => Some(ReturnType::Bool),
553
            Expression::MinionPow(_, _, _, _) => Some(ReturnType::Bool),
554
        }
555
18037
    }
556

            
557
    pub fn is_clean(&self) -> bool {
558
        let metadata = self.get_meta();
559
        metadata.clean
560
    }
561

            
562
    pub fn set_clean(&mut self, bool_value: bool) {
563
        let mut metadata = self.get_meta();
564
        metadata.clean = bool_value;
565
        self.set_meta(metadata);
566
    }
567

            
568
    /// True if the expression is an associative and commutative operator
569
349775
    pub fn is_associative_commutative_operator(&self) -> bool {
570
325805
        matches!(
571
349775
            self,
572
            Expression::Sum(_, _)
573
                | Expression::Or(_, _)
574
                | Expression::And(_, _)
575
                | Expression::Product(_, _)
576
        )
577
349775
    }
578

            
579
    /// True iff self and other are both atomic and identical.
580
    ///
581
    /// This method is useful to cheaply check equivalence. Assuming CSE is enabled, any unifiable
582
    /// expressions will be rewritten to a common variable. This is much cheaper than checking the
583
    /// entire subtrees of `self` and `other`.
584
50150
    pub fn identical_atom_to(&self, other: &Expression) -> bool {
585
50150
        let atom1: Result<&Atom, _> = self.try_into();
586
50150
        let atom2: Result<&Atom, _> = other.try_into();
587

            
588
50150
        if let (Ok(atom1), Ok(atom2)) = (atom1, atom2) {
589
714
            atom2 == atom1
590
        } else {
591
49436
            false
592
        }
593
50150
    }
594
}
595

            
596
impl From<i32> for Expression {
597
816
    fn from(i: i32) -> Self {
598
816
        Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Int(i)))
599
816
    }
600
}
601

            
602
impl From<bool> for Expression {
603
170
    fn from(b: bool) -> Self {
604
170
        Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Bool(b)))
605
170
    }
606
}
607

            
608
impl From<Atom> for Expression {
609
272
    fn from(value: Atom) -> Self {
610
272
        Expression::Atomic(Metadata::new(), value)
611
272
    }
612
}
613
impl Display for Expression {
614
    // TODO: (flm8) this will change once we implement a parser (two-way conversion)
615
720800
    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
616
720800
        match &self {
617
408
            Expression::Root(_, exprs) => {
618
408
                write!(f, "{}", pretty_expressions_as_top_level(exprs))
619
            }
620
370311
            Expression::Atomic(_, atom) => atom.fmt(f),
621
2669
            Expression::Abs(_, a) => write!(f, "|{}|", a),
622
8755
            Expression::Sum(_, expressions) => {
623
8755
                write!(f, "Sum({})", pretty_vec(expressions))
624
            }
625
6358
            Expression::Product(_, expressions) => {
626
6358
                write!(f, "Product({})", pretty_vec(expressions))
627
            }
628
238
            Expression::Min(_, expressions) => {
629
238
                write!(f, "Min({})", pretty_vec(expressions))
630
            }
631
119
            Expression::Max(_, expressions) => {
632
119
                write!(f, "Max({})", pretty_vec(expressions))
633
            }
634
9775
            Expression::Not(_, expr_box) => {
635
9775
                write!(f, "Not({})", expr_box.clone())
636
            }
637
25619
            Expression::Or(_, expressions) => {
638
25619
                write!(f, "Or({})", pretty_vec(expressions))
639
            }
640
19023
            Expression::And(_, expressions) => {
641
19023
                write!(f, "And({})", pretty_vec(expressions))
642
            }
643
34578
            Expression::Imply(_, box1, box2) => {
644
34578
                write!(f, "({}) -> ({})", box1, box2)
645
            }
646
61557
            Expression::Eq(_, box1, box2) => {
647
61557
                write!(f, "({} = {})", box1.clone(), box2.clone())
648
            }
649
17561
            Expression::Neq(_, box1, box2) => {
650
17561
                write!(f, "({} != {})", box1.clone(), box2.clone())
651
            }
652
3332
            Expression::Geq(_, box1, box2) => {
653
3332
                write!(f, "({} >= {})", box1.clone(), box2.clone())
654
            }
655
33388
            Expression::Leq(_, box1, box2) => {
656
33388
                write!(f, "({} <= {})", box1.clone(), box2.clone())
657
            }
658
68
            Expression::Gt(_, box1, box2) => {
659
68
                write!(f, "({} > {})", box1.clone(), box2.clone())
660
            }
661
5712
            Expression::Lt(_, box1, box2) => {
662
5712
                write!(f, "({} < {})", box1.clone(), box2.clone())
663
            }
664
2618
            Expression::FlatSumGeq(_, box1, box2) => {
665
2618
                write!(f, "SumGeq({}, {})", pretty_vec(box1), box2.clone())
666
            }
667
2516
            Expression::FlatSumLeq(_, box1, box2) => {
668
2516
                write!(f, "SumLeq({}, {})", pretty_vec(box1), box2.clone())
669
            }
670
6528
            Expression::FlatIneq(_, box1, box2, box3) => write!(
671
6528
                f,
672
6528
                "Ineq({}, {}, {})",
673
6528
                box1.clone(),
674
6528
                box2.clone(),
675
6528
                box3.clone()
676
6528
            ),
677
68
            Expression::AllDiff(_, expressions) => {
678
68
                write!(f, "AllDiff({})", pretty_vec(expressions))
679
            }
680
6936
            Expression::Bubble(_, box1, box2) => {
681
6936
                write!(f, "{{{} @ {}}}", box1.clone(), box2.clone())
682
            }
683
18173
            Expression::SafeDiv(_, box1, box2) => {
684
18173
                write!(f, "SafeDiv({}, {})", box1.clone(), box2.clone())
685
            }
686
10132
            Expression::UnsafeDiv(_, box1, box2) => {
687
10132
                write!(f, "UnsafeDiv({}, {})", box1.clone(), box2.clone())
688
            }
689
493
            Expression::UnsafePow(_, box1, box2) => {
690
493
                write!(f, "UnsafePow({}, {})", box1.clone(), box2.clone())
691
            }
692
1088
            Expression::SafePow(_, box1, box2) => {
693
1088
                write!(f, "SafePow({}, {})", box1.clone(), box2.clone())
694
            }
695
3077
            Expression::MinionDivEqUndefZero(_, box1, box2, box3) => {
696
3077
                write!(
697
3077
                    f,
698
3077
                    "DivEq({}, {}, {})",
699
3077
                    box1.clone(),
700
3077
                    box2.clone(),
701
3077
                    box3.clone()
702
3077
                )
703
            }
704
5372
            Expression::MinionModuloEqUndefZero(_, box1, box2, box3) => {
705
5372
                write!(
706
5372
                    f,
707
5372
                    "ModEq({}, {}, {})",
708
5372
                    box1.clone(),
709
5372
                    box2.clone(),
710
5372
                    box3.clone()
711
5372
                )
712
            }
713

            
714
510
            Expression::FlatWatchedLiteral(_, x, l) => {
715
510
                write!(f, "WatchedLiteral({},{})", x, l)
716
            }
717
13107
            Expression::MinionReify(_, box1, box2) => {
718
13107
                write!(f, "Reify({}, {})", box1.clone(), box2.clone())
719
            }
720
9588
            Expression::MinionReifyImply(_, box1, box2) => {
721
9588
                write!(f, "ReifyImply({}, {})", box1.clone(), box2.clone())
722
            }
723
5780
            Expression::AuxDeclaration(_, n, e) => {
724
5780
                write!(f, "{} =aux {}", n, e.clone())
725
            }
726
8602
            Expression::UnsafeMod(_, a, b) => {
727
8602
                write!(f, "{} % {}", a.clone(), b.clone())
728
            }
729
17918
            Expression::SafeMod(_, a, b) => {
730
17918
                write!(f, "SafeMod({},{})", a.clone(), b.clone())
731
            }
732
6171
            Expression::Neg(_, a) => {
733
6171
                write!(f, "-({})", a.clone())
734
            }
735
221
            Expression::Minus(_, a, b) => {
736
221
                write!(f, "({} - {})", a.clone(), b.clone())
737
            }
738
561
            Expression::FlatAbsEq(_, a, b) => {
739
561
                write!(f, "AbsEq({},{})", a.clone(), b.clone())
740
            }
741
255
            Expression::FlatMinusEq(_, a, b) => {
742
255
                write!(f, "MinusEq({},{})", a.clone(), b.clone())
743
            }
744
306
            Expression::FlatProductEq(_, a, b, c) => {
745
306
                write!(
746
306
                    f,
747
306
                    "FlatProductEq({},{},{})",
748
306
                    a.clone(),
749
306
                    b.clone(),
750
306
                    c.clone()
751
306
                )
752
            }
753
510
            Expression::FlatWeightedSumLeq(_, cs, vs, total) => {
754
510
                write!(
755
510
                    f,
756
510
                    "FlatWeightedSumLeq({},{},{})",
757
510
                    pretty_vec(cs),
758
510
                    pretty_vec(vs),
759
510
                    total.clone()
760
510
                )
761
            }
762

            
763
340
            Expression::FlatWeightedSumGeq(_, cs, vs, total) => {
764
340
                write!(
765
340
                    f,
766
340
                    "FlatWeightedSumGeq({},{},{})",
767
340
                    pretty_vec(cs),
768
340
                    pretty_vec(vs),
769
340
                    total.clone()
770
340
                )
771
            }
772
459
            Expression::MinionPow(_, atom, atom1, atom2) => {
773
459
                write!(f, "MinionPow({},{},{})", atom, atom1, atom2)
774
            }
775
        }
776
720800
    }
777
}
778

            
779
#[cfg(test)]
780
mod tests {
781
    use std::rc::Rc;
782

            
783
    use crate::ast::declaration::Declaration;
784

            
785
    use super::*;
786

            
787
    #[test]
788
1
    fn test_domain_of_constant_sum() {
789
1
        let c1 = Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Int(1)));
790
1
        let c2 = Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Int(2)));
791
1
        let sum = Expression::Sum(Metadata::new(), vec![c1.clone(), c2.clone()]);
792
1
        assert_eq!(
793
1
            sum.domain_of(&SymbolTable::new()),
794
1
            Some(Domain::IntDomain(vec![Range::Single(3)]))
795
1
        );
796
1
    }
797

            
798
    #[test]
799
1
    fn test_domain_of_constant_invalid_type() {
800
1
        let c1 = Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Int(1)));
801
1
        let c2 = Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Bool(true)));
802
1
        let sum = Expression::Sum(Metadata::new(), vec![c1.clone(), c2.clone()]);
803
1
        assert_eq!(sum.domain_of(&SymbolTable::new()), None);
804
1
    }
805

            
806
    #[test]
807
1
    fn test_domain_of_empty_sum() {
808
1
        let sum = Expression::Sum(Metadata::new(), vec![]);
809
1
        assert_eq!(sum.domain_of(&SymbolTable::new()), None);
810
1
    }
811

            
812
    #[test]
813
1
    fn test_domain_of_reference() {
814
1
        let reference = Expression::Atomic(Metadata::new(), Atom::Reference(Name::MachineName(0)));
815
1
        let mut vars = SymbolTable::new();
816
1
        vars.insert(Rc::new(Declaration::new_var(
817
1
            Name::MachineName(0),
818
1
            Domain::IntDomain(vec![Range::Single(1)]),
819
1
        )))
820
1
        .unwrap();
821
1
        assert_eq!(
822
1
            reference.domain_of(&vars),
823
1
            Some(Domain::IntDomain(vec![Range::Single(1)]))
824
1
        );
825
1
    }
826

            
827
    #[test]
828
1
    fn test_domain_of_reference_not_found() {
829
1
        let reference = Expression::Atomic(Metadata::new(), Atom::Reference(Name::MachineName(0)));
830
1
        assert_eq!(reference.domain_of(&SymbolTable::new()), None);
831
1
    }
832

            
833
    #[test]
834
1
    fn test_domain_of_reference_sum_single() {
835
1
        let reference = Expression::Atomic(Metadata::new(), Atom::Reference(Name::MachineName(0)));
836
1
        let mut vars = SymbolTable::new();
837
1
        vars.insert(Rc::new(Declaration::new_var(
838
1
            Name::MachineName(0),
839
1
            Domain::IntDomain(vec![Range::Single(1)]),
840
1
        )))
841
1
        .unwrap();
842
1
        let sum = Expression::Sum(Metadata::new(), vec![reference.clone(), reference.clone()]);
843
1
        assert_eq!(
844
1
            sum.domain_of(&vars),
845
1
            Some(Domain::IntDomain(vec![Range::Single(2)]))
846
1
        );
847
1
    }
848

            
849
    #[test]
850
1
    fn test_domain_of_reference_sum_bounded() {
851
1
        let reference = Expression::Atomic(Metadata::new(), Atom::Reference(Name::MachineName(0)));
852
1
        let mut vars = SymbolTable::new();
853
1
        vars.insert(Rc::new(Declaration::new_var(
854
1
            Name::MachineName(0),
855
1
            Domain::IntDomain(vec![Range::Bounded(1, 2)]),
856
1
        )));
857
1
        let sum = Expression::Sum(Metadata::new(), vec![reference.clone(), reference.clone()]);
858
1
        assert_eq!(
859
1
            sum.domain_of(&vars),
860
1
            Some(Domain::IntDomain(vec![Range::Bounded(2, 4)]))
861
1
        );
862
1
    }
863
}