1
use std::collections::VecDeque;
2
use std::fmt::{Display, Formatter};
3
use std::sync::Arc;
4

            
5
use itertools::Itertools;
6
use serde::{Deserialize, Serialize};
7

            
8
use crate::ast::literals::AbstractLiteral;
9
use crate::ast::literals::Literal;
10
use crate::ast::pretty::{pretty_expressions_as_top_level, pretty_vec};
11
use crate::ast::symbol_table::SymbolTable;
12
use crate::ast::Atom;
13
use crate::ast::Name;
14
use crate::ast::ReturnType;
15
use crate::bug;
16
use crate::metadata::Metadata;
17
use enum_compatability_macro::document_compatibility;
18
use uniplate::derive::Uniplate;
19
use uniplate::{Biplate, Uniplate as _};
20

            
21
use super::{Domain, Range, SubModel, Typeable};
22

            
23
/// Represents different types of expressions used to define rules and constraints in the model.
24
///
25
/// The `Expression` enum includes operations, constants, and variable references
26
/// used to build rules and conditions for the model.
27
#[document_compatibility]
28
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize, Uniplate)]
29
#[uniplate(walk_into=[Atom,SubModel,AbstractLiteral<Expression>])]
30
#[biplate(to=Metadata)]
31
#[biplate(to=Atom)]
32
#[biplate(to=Name)]
33
#[biplate(to=Vec<Expression>)]
34
#[biplate(to=Option<Expression>)]
35
#[biplate(to=SubModel)]
36
#[biplate(to=AbstractLiteral<Expression>)]
37
#[biplate(to=AbstractLiteral<Literal>,walk_into=[Atom])]
38
#[biplate(to=Literal,walk_into=[Atom])]
39
pub enum Expression {
40
    AbstractLiteral(Metadata, AbstractLiteral<Expression>),
41
    /// The top of the model
42
    Root(Metadata, Vec<Expression>),
43

            
44
    /// An expression representing "A is valid as long as B is true"
45
    /// Turns into a conjunction when it reaches a boolean context
46
    Bubble(Metadata, Box<Expression>, Box<Expression>),
47

            
48
    /// Defines dominance ("Solution A is preferred over Solution B")
49
    DominanceRelation(Metadata, Box<Expression>),
50
    /// `fromSolution(name)` - Used in dominance relation definitions
51
    FromSolution(Metadata, Box<Expression>),
52

            
53
    Atomic(Metadata, Atom),
54

            
55
    /// A matrix index.
56
    ///
57
    /// Defined iff the indices are within their respective index domains.
58
    #[compatible(JsonInput)]
59
    UnsafeIndex(Metadata, Box<Expression>, Vec<Expression>),
60

            
61
    /// A safe matrix index.
62
    ///
63
    /// See [`Expression::UnsafeIndex`]
64
    SafeIndex(Metadata, Box<Expression>, Vec<Expression>),
65

            
66
    /// A matrix slice: `a[indices]`.
67
    ///
68
    /// One of the indicies may be `None`, representing the dimension of the matrix we want to take
69
    /// a slice of. For example, for some 3d matrix a, `a[1,..,2]` has the indices
70
    /// `Some(1),None,Some(2)`.
71
    ///
72
    /// It is assumed that the slice only has one "wild-card" dimension and thus is 1 dimensional.
73
    ///
74
    /// Defined iff the defined indices are within their respective index domains.
75
    #[compatible(JsonInput)]
76
    UnsafeSlice(Metadata, Box<Expression>, Vec<Option<Expression>>),
77

            
78
    /// A safe matrix slice: `a[indices]`.
79
    ///
80
    /// See [`Expression::UnsafeSlice`].
81
    SafeSlice(Metadata, Box<Expression>, Vec<Option<Expression>>),
82

            
83
    /// `inDomain(x,domain)` iff `x` is in the domain `domain`.
84
    ///
85
    /// This cannot be constructed from Essence input, nor passed to a solver: this expression is
86
    /// mainly used during the conversion of `UnsafeIndex` and `UnsafeSlice` to `SafeIndex` and
87
    /// `SafeSlice` respectively.
88
    InDomain(Metadata, Box<Expression>, Domain),
89

            
90
    Scope(Metadata, Box<SubModel>),
91

            
92
    /// `|x|` - absolute value of `x`
93
    #[compatible(JsonInput)]
94
    Abs(Metadata, Box<Expression>),
95

            
96
    /// `a + b + c + ...`
97
    #[compatible(JsonInput)]
98
    Sum(Metadata, Vec<Expression>),
99

            
100
    /// `a * b * c * ...`
101
    #[compatible(JsonInput)]
102
    Product(Metadata, Vec<Expression>),
103

            
104
    /// `min(<vec_expr>)`
105
    #[compatible(JsonInput)]
106
    Min(Metadata, Box<Expression>),
107

            
108
    /// `max(<vec_expr>)`
109
    #[compatible(JsonInput)]
110
    Max(Metadata, Box<Expression>),
111

            
112
    /// `not(a)`
113
    #[compatible(JsonInput, SAT)]
114
    Not(Metadata, Box<Expression>),
115

            
116
    /// `or(<vec_expr>)`
117
    #[compatible(JsonInput, SAT)]
118
    Or(Metadata, Box<Expression>),
119

            
120
    /// `and(<vec_expr>)`
121
    #[compatible(JsonInput, SAT)]
122
    And(Metadata, Box<Expression>),
123

            
124
    /// Ensures that `a->b` (material implication).
125
    #[compatible(JsonInput)]
126
    Imply(Metadata, Box<Expression>, Box<Expression>),
127

            
128
    #[compatible(JsonInput)]
129
    Eq(Metadata, Box<Expression>, Box<Expression>),
130

            
131
    #[compatible(JsonInput)]
132
    Neq(Metadata, Box<Expression>, Box<Expression>),
133

            
134
    #[compatible(JsonInput)]
135
    Geq(Metadata, Box<Expression>, Box<Expression>),
136

            
137
    #[compatible(JsonInput)]
138
    Leq(Metadata, Box<Expression>, Box<Expression>),
139

            
140
    #[compatible(JsonInput)]
141
    Gt(Metadata, Box<Expression>, Box<Expression>),
142

            
143
    #[compatible(JsonInput)]
144
    Lt(Metadata, Box<Expression>, Box<Expression>),
145

            
146
    /// Division after preventing division by zero, usually with a bubble
147
    SafeDiv(Metadata, Box<Expression>, Box<Expression>),
148

            
149
    /// Division with a possibly undefined value (division by 0)
150
    #[compatible(JsonInput)]
151
    UnsafeDiv(Metadata, Box<Expression>, Box<Expression>),
152

            
153
    /// Modulo after preventing mod 0, usually with a bubble
154
    SafeMod(Metadata, Box<Expression>, Box<Expression>),
155

            
156
    /// Modulo with a possibly undefined value (mod 0)
157
    #[compatible(JsonInput)]
158
    UnsafeMod(Metadata, Box<Expression>, Box<Expression>),
159

            
160
    /// Negation: `-x`
161
    #[compatible(JsonInput)]
162
    Neg(Metadata, Box<Expression>),
163

            
164
    /// Unsafe power`x**y` (possibly undefined)
165
    ///
166
    /// Defined when (X!=0 \\/ Y!=0) /\ Y>=0
167
    #[compatible(JsonInput)]
168
    UnsafePow(Metadata, Box<Expression>, Box<Expression>),
169

            
170
    /// `UnsafePow` after preventing undefinedness
171
    SafePow(Metadata, Box<Expression>, Box<Expression>),
172

            
173
    /// `allDiff(<vec_expr>)`
174
    #[compatible(JsonInput)]
175
    AllDiff(Metadata, Box<Expression>),
176

            
177
    /// Binary subtraction operator
178
    ///
179
    /// This is a parser-level construct, and is immediately normalised to `Sum([a,-b])`.
180
    #[compatible(JsonInput)]
181
    Minus(Metadata, Box<Expression>, Box<Expression>),
182

            
183
    /// Ensures that x=|y| i.e. x is the absolute value of y.
184
    ///
185
    /// Low-level Minion constraint.
186
    ///
187
    /// # See also
188
    ///
189
    /// + [Minion documentation](https://minion-solver.readthedocs.io/en/stable/usage/constraints.html#abs)
190
    #[compatible(Minion)]
191
    FlatAbsEq(Metadata, Atom, Atom),
192

            
193
    /// Ensures that `alldiff([a,b,...])`.
194
    ///
195
    /// Low-level Minion constraint.
196
    ///
197
    /// # See also
198
    ///
199
    /// + [Minion documentation](https://minion-solver.readthedocs.io/en/stable/usage/constraints.html#alldiff)
200
    #[compatible(Minion)]
201
    FlatAllDiff(Metadata, Vec<Atom>),
202

            
203
    /// Ensures that sum(vec) >= x.
204
    ///
205
    /// Low-level Minion constraint.
206
    ///
207
    /// # See also
208
    ///
209
    /// + [Minion documentation](https://minion-solver.readthedocs.io/en/stable/usage/constraints.html#sumgeq)
210
    #[compatible(Minion)]
211
    FlatSumGeq(Metadata, Vec<Atom>, Atom),
212

            
213
    /// Ensures that sum(vec) <= x.
214
    ///
215
    /// Low-level Minion constraint.
216
    ///
217
    /// # See also
218
    ///
219
    /// + [Minion documentation](https://minion-solver.readthedocs.io/en/stable/usage/constraints.html#sumleq)
220
    #[compatible(Minion)]
221
    FlatSumLeq(Metadata, Vec<Atom>, Atom),
222

            
223
    /// `ineq(x,y,k)` ensures that x <= y + k.
224
    ///
225
    /// Low-level Minion constraint.
226
    ///
227
    /// # See also
228
    ///
229
    /// + [Minion documentation](https://minion-solver.readthedocs.io/en/stable/usage/constraints.html#ineq)
230
    #[compatible(Minion)]
231
    FlatIneq(Metadata, Atom, Atom, Literal),
232

            
233
    /// `w-literal(x,k)` ensures that x == k, where x is a variable and k a constant.
234
    ///
235
    /// Low-level Minion constraint.
236
    ///
237
    /// This is a low-level Minion constraint and you should probably use Eq instead. The main use
238
    /// of w-literal is to convert boolean variables to constraints so that they can be used inside
239
    /// watched-and and watched-or.
240
    ///
241
    /// # See also
242
    ///
243
    /// + [Minion documentation](https://minion-solver.readthedocs.io/en/stable/usage/constraints.html#minuseq)
244
    /// + `rules::minion::boolean_literal_to_wliteral`.
245
    #[compatible(Minion)]
246
    FlatWatchedLiteral(Metadata, Name, Literal),
247

            
248
    /// `weightedsumleq(cs,xs,total)` ensures that cs.xs <= total, where cs.xs is the scalar dot
249
    /// product of cs and xs.
250
    ///
251
    /// Low-level Minion constraint.
252
    ///
253
    /// Represents a weighted sum of the form `ax + by + cz + ...`
254
    ///
255
    /// # See also
256
    ///
257
    /// + [Minion
258
    /// documentation](https://minion-solver.readthedocs.io/en/stable/usage/constraints.html#weightedsumleq)
259
    FlatWeightedSumLeq(Metadata, Vec<Literal>, Vec<Atom>, Atom),
260

            
261
    /// `weightedsumgeq(cs,xs,total)` ensures that cs.xs >= total, where cs.xs is the scalar dot
262
    /// product of cs and xs.
263
    ///
264
    /// Low-level Minion constraint.
265
    ///
266
    /// Represents a weighted sum of the form `ax + by + cz + ...`
267
    ///
268
    /// # See also
269
    ///
270
    /// + [Minion
271
    /// documentation](https://minion-solver.readthedocs.io/en/stable/usage/constraints.html#weightedsumleq)
272
    FlatWeightedSumGeq(Metadata, Vec<Literal>, Vec<Atom>, Atom),
273

            
274
    /// Ensures that x =-y, where x and y are atoms.
275
    ///
276
    /// Low-level Minion constraint.
277
    ///
278
    /// # See also
279
    ///
280
    /// + [Minion documentation](https://minion-solver.readthedocs.io/en/stable/usage/constraints.html#minuseq)
281
    #[compatible(Minion)]
282
    FlatMinusEq(Metadata, Atom, Atom),
283

            
284
    /// Ensures that x*y=z.
285
    ///
286
    /// Low-level Minion constraint.
287
    ///
288
    /// # See also
289
    ///
290
    /// + [Minion documentation](https://minion-solver.readthedocs.io/en/stable/usage/constraints.html#product)
291
    #[compatible(Minion)]
292
    FlatProductEq(Metadata, Atom, Atom, Atom),
293

            
294
    /// Ensures that floor(x/y)=z. Always true when y=0.
295
    ///
296
    /// Low-level Minion constraint.
297
    ///
298
    /// # See also
299
    ///
300
    /// + [Minion documentation](https://minion-solver.readthedocs.io/en/stable/usage/constraints.html#div_undefzero)
301
    #[compatible(Minion)]
302
    MinionDivEqUndefZero(Metadata, Atom, Atom, Atom),
303

            
304
    /// Ensures that x%y=z. Always true when y=0.
305
    ///
306
    /// Low-level Minion constraint.
307
    ///
308
    /// # See also
309
    ///
310
    /// + [Minion documentation](https://minion-solver.readthedocs.io/en/stable/usage/constraints.html#mod_undefzero)
311
    #[compatible(Minion)]
312
    MinionModuloEqUndefZero(Metadata, Atom, Atom, Atom),
313

            
314
    /// Ensures that `x**y = z`.
315
    ///
316
    /// Low-level Minion constraint.
317
    ///
318
    /// This constraint is false when `y<0` except for `1**y=1` and `(-1)**y=z` (where z is 1 if y
319
    /// is odd and z is -1 if y is even).
320
    ///
321
    /// # See also
322
    ///
323
    /// + [Github comment about `pow` semantics](https://github.com/minion/minion/issues/40#issuecomment-2595914891)
324
    /// + [Minion documentation](https://minion-solver.readthedocs.io/en/stable/usage/constraints.html#pow)
325
    MinionPow(Metadata, Atom, Atom, Atom),
326

            
327
    /// `reify(constraint,r)` ensures that r=1 iff `constraint` is satisfied, where r is a 0/1
328
    /// variable.
329
    ///
330
    /// Low-level Minion constraint.
331
    ///
332
    /// # See also
333
    ///
334
    ///  + [Minion documentation](https://minion-solver.readthedocs.io/en/stable/usage/constraints.html#reify)
335
    #[compatible(Minion)]
336
    MinionReify(Metadata, Box<Expression>, Atom),
337

            
338
    /// `reifyimply(constraint,r)` ensures that `r->constraint`, where r is a 0/1 variable.
339
    /// variable.
340
    ///
341
    /// Low-level Minion constraint.
342
    ///
343
    /// # See also
344
    ///
345
    ///  + [Minion documentation](https://minion-solver.readthedocs.io/en/stable/usage/constraints.html#reifyimply)
346
    #[compatible(Minion)]
347
    MinionReifyImply(Metadata, Box<Expression>, Atom),
348

            
349
    /// Declaration of an auxiliary variable.
350
    ///
351
    /// As with Savile Row, we semantically distinguish this from `Eq`.
352
    #[compatible(Minion)]
353
    AuxDeclaration(Metadata, Name, Box<Expression>),
354
}
355

            
356
545
fn expr_vec_to_domain_i32(
357
545
    exprs: &[Expression],
358
545
    op: fn(i32, i32) -> Option<i32>,
359
545
    vars: &SymbolTable,
360
545
) -> Option<Domain> {
361
1196
    let domains: Vec<Option<_>> = exprs.iter().map(|e| e.domain_of(vars)).collect();
362
545
    domains
363
545
        .into_iter()
364
653
        .reduce(|a, b| a.and_then(|x| b.and_then(|y| x.apply_i32(op, &y))))
365
545
        .flatten()
366
545
}
367
540
fn expr_vec_lit_to_domain_i32(
368
540
    e: &Expression,
369
540
    op: fn(i32, i32) -> Option<i32>,
370
540
    vars: &SymbolTable,
371
540
) -> Option<Domain> {
372
540
    let exprs = e.clone().unwrap_list()?;
373
162
    expr_vec_to_domain_i32(&exprs, op, vars)
374
540
}
375

            
376
// Returns none if unbounded
377
2017
fn range_vec_bounds_i32(ranges: &Vec<Range<i32>>) -> Option<(i32, i32)> {
378
2017
    let mut min = i32::MAX;
379
2017
    let mut max = i32::MIN;
380
88439
    for r in ranges {
381
86422
        match r {
382
86422
            Range::Single(i) => {
383
86422
                if *i < min {
384
3619
                    min = *i;
385
82803
                }
386
86422
                if *i > max {
387
11865
                    max = *i;
388
74559
                }
389
            }
390
            Range::Bounded(i, j) => {
391
                if *i < min {
392
                    min = *i;
393
                }
394
                if *j > max {
395
                    max = *j;
396
                }
397
            }
398
            Range::UnboundedR(_) | Range::UnboundedL(_) => return None,
399
        }
400
    }
401
2017
    Some((min, max))
402
2017
}
403

            
404
impl Expression {
405
    /// Returns the possible values of the expression, recursing to leaf expressions
406
8493
    pub fn domain_of(&self, syms: &SymbolTable) -> Option<Domain> {
407
8456
        let ret = match self {
408
            //todo
409
90
            Expression::AbstractLiteral(_, _) => None,
410
            Expression::DominanceRelation(_, _) => Some(Domain::BoolDomain),
411
            Expression::FromSolution(_, expr) => expr.domain_of(syms),
412
396
            Expression::UnsafeIndex(_, matrix, _) | Expression::SafeIndex(_, matrix, _) => {
413
396
                let Domain::DomainMatrix(elem_domain, _) = matrix.domain_of(syms)? else {
414
                    bug!("subject of an index operation should be a matrix");
415
                };
416

            
417
396
                Some(*elem_domain)
418
            }
419
            Expression::UnsafeSlice(_, matrix, indices)
420
            | Expression::SafeSlice(_, matrix, indices) => {
421
                let sliced_dimension = indices.iter().position(Option::is_none);
422

            
423
                let Domain::DomainMatrix(elem_domain, index_domains) = matrix.domain_of(syms)?
424
                else {
425
                    bug!("subject of an index operation should be a matrix");
426
                };
427

            
428
                match sliced_dimension {
429
                    Some(dimension) => Some(Domain::DomainMatrix(
430
                        elem_domain,
431
                        vec![index_domains[dimension].clone()],
432
                    )),
433

            
434
                    // same as index
435
                    None => Some(*elem_domain),
436
                }
437
            }
438
            Expression::InDomain(_, _, _) => Some(Domain::BoolDomain),
439
4344
            Expression::Atomic(_, Atom::Reference(name)) => Some(syms.resolve_domain(name)?),
440
597
            Expression::Atomic(_, Atom::Literal(Literal::Int(n))) => {
441
597
                Some(Domain::IntDomain(vec![Range::Single(*n)]))
442
            }
443
1
            Expression::Atomic(_, Atom::Literal(Literal::Bool(_))) => Some(Domain::BoolDomain),
444
            Expression::Atomic(_, Atom::Literal(Literal::AbstractLiteral(_))) => None,
445
            Expression::Scope(_, _) => Some(Domain::BoolDomain),
446
22452
            Expression::Sum(_, exprs) => expr_vec_to_domain_i32(exprs, |x, y| Some(x + y), syms),
447
90
            Expression::Product(_, exprs) => {
448
2700
                expr_vec_to_domain_i32(exprs, |x, y| Some(x * y), syms)
449
            }
450
270
            Expression::Min(_, e) => {
451
1782
                expr_vec_lit_to_domain_i32(e, |x, y| Some(if x < y { x } else { y }), syms)
452
            }
453
270
            Expression::Max(_, e) => {
454
648
                expr_vec_lit_to_domain_i32(e, |x, y| Some(if x > y { x } else { y }), syms)
455
            }
456
            Expression::UnsafeDiv(_, a, b) => a.domain_of(syms)?.apply_i32(
457
                // rust integer division is truncating; however, we want to always round down,
458
                // including for negative numbers.
459
                |x, y| {
460
                    if y != 0 {
461
                        Some((x as f32 / y as f32).floor() as i32)
462
                    } else {
463
                        None
464
                    }
465
                },
466
                &b.domain_of(syms)?,
467
            ),
468
846
            Expression::SafeDiv(_, a, b) => {
469
                // rust integer division is truncating; however, we want to always round down
470
                // including for negative numbers.
471
846
                let domain = a.domain_of(syms)?.apply_i32(
472
25740
                    |x, y| {
473
25740
                        if y != 0 {
474
22716
                            Some((x as f32 / y as f32).floor() as i32)
475
                        } else {
476
3024
                            None
477
                        }
478
25740
                    },
479
846
                    &b.domain_of(syms)?,
480
                );
481

            
482
846
                match domain {
483
846
                    Some(Domain::IntDomain(ranges)) => {
484
846
                        let mut ranges = ranges;
485
846
                        ranges.push(Range::Single(0));
486
846
                        Some(Domain::IntDomain(ranges))
487
                    }
488
                    None => Some(Domain::IntDomain(vec![Range::Single(0)])),
489
                    _ => None,
490
                }
491
            }
492
            Expression::UnsafeMod(_, a, b) => a.domain_of(syms)?.apply_i32(
493
                |x, y| if y != 0 { Some(x % y) } else { None },
494
                &b.domain_of(syms)?,
495
            ),
496

            
497
486
            Expression::SafeMod(_, a, b) => {
498
486
                let domain = a.domain_of(syms)?.apply_i32(
499
11718
                    |x, y| if y != 0 { Some(x % y) } else { None },
500
486
                    &b.domain_of(syms)?,
501
                );
502

            
503
486
                match domain {
504
486
                    Some(Domain::IntDomain(ranges)) => {
505
486
                        let mut ranges = ranges;
506
486
                        ranges.push(Range::Single(0));
507
486
                        Some(Domain::IntDomain(ranges))
508
                    }
509
                    None => Some(Domain::IntDomain(vec![Range::Single(0)])),
510
                    _ => None,
511
                }
512
            }
513

            
514
108
            Expression::SafePow(_, a, b) | Expression::UnsafePow(_, a, b) => {
515
108
                a.domain_of(syms)?.apply_i32(
516
9144
                    |x, y| {
517
9144
                        if (x != 0 || y != 0) && y >= 0 {
518
8784
                            Some(x ^ y)
519
                        } else {
520
360
                            None
521
                        }
522
9144
                    },
523
108
                    &b.domain_of(syms)?,
524
                )
525
            }
526

            
527
            Expression::Root(_, _) => None,
528
            Expression::Bubble(_, _, _) => None,
529
            Expression::AuxDeclaration(_, _, _) => Some(Domain::BoolDomain),
530
72
            Expression::And(_, _) => Some(Domain::BoolDomain),
531
18
            Expression::Not(_, _) => Some(Domain::BoolDomain),
532
            Expression::Or(_, _) => Some(Domain::BoolDomain),
533
            Expression::Imply(_, _, _) => Some(Domain::BoolDomain),
534
162
            Expression::Eq(_, _, _) => Some(Domain::BoolDomain),
535
            Expression::Neq(_, _, _) => Some(Domain::BoolDomain),
536
            Expression::Geq(_, _, _) => Some(Domain::BoolDomain),
537
144
            Expression::Leq(_, _, _) => Some(Domain::BoolDomain),
538
            Expression::Gt(_, _, _) => Some(Domain::BoolDomain),
539
            Expression::Lt(_, _, _) => Some(Domain::BoolDomain),
540
            Expression::FlatAbsEq(_, _, _) => Some(Domain::BoolDomain),
541
            Expression::FlatSumGeq(_, _, _) => Some(Domain::BoolDomain),
542
            Expression::FlatSumLeq(_, _, _) => Some(Domain::BoolDomain),
543
            Expression::MinionDivEqUndefZero(_, _, _, _) => Some(Domain::BoolDomain),
544
            Expression::MinionModuloEqUndefZero(_, _, _, _) => Some(Domain::BoolDomain),
545
            Expression::FlatIneq(_, _, _, _) => Some(Domain::BoolDomain),
546
            Expression::AllDiff(_, _) => Some(Domain::BoolDomain),
547
            Expression::FlatWatchedLiteral(_, _, _) => Some(Domain::BoolDomain),
548
            Expression::MinionReify(_, _, _) => Some(Domain::BoolDomain),
549
            Expression::MinionReifyImply(_, _, _) => Some(Domain::BoolDomain),
550
144
            Expression::Neg(_, x) => {
551
144
                let Some(Domain::IntDomain(mut ranges)) = x.domain_of(syms) else {
552
                    return None;
553
                };
554

            
555
144
                for range in ranges.iter_mut() {
556
144
                    *range = match range {
557
                        Range::Single(x) => Range::Single(-*x),
558
144
                        Range::Bounded(x, y) => Range::Bounded(-*y, -*x),
559
                        Range::UnboundedR(i) => Range::UnboundedL(-*i),
560
                        Range::UnboundedL(i) => Range::UnboundedR(-*i),
561
                    };
562
                }
563

            
564
144
                Some(Domain::IntDomain(ranges))
565
            }
566
            Expression::Minus(_, a, b) => a
567
                .domain_of(syms)?
568
                .apply_i32(|x, y| Some(x - y), &b.domain_of(syms)?),
569

            
570
            Expression::FlatAllDiff(_, _) => Some(Domain::BoolDomain),
571
            Expression::FlatMinusEq(_, _, _) => Some(Domain::BoolDomain),
572
            Expression::FlatProductEq(_, _, _, _) => Some(Domain::BoolDomain),
573
            Expression::FlatWeightedSumLeq(_, _, _, _) => Some(Domain::BoolDomain),
574
            Expression::FlatWeightedSumGeq(_, _, _, _) => Some(Domain::BoolDomain),
575
162
            Expression::Abs(_, a) => a
576
162
                .domain_of(syms)?
577
18810
                .apply_i32(|a, _| Some(a.abs()), &a.domain_of(syms)?),
578
            Expression::MinionPow(_, _, _, _) => Some(Domain::BoolDomain),
579
        };
580
6221
        match ret {
581
            // TODO: (flm8) the Minion bindings currently only support single ranges for domains, so we use the min/max bounds
582
            // Once they support a full domain as we define it, we can remove this conversion
583
6221
            Some(Domain::IntDomain(ranges)) if ranges.len() > 1 => {
584
2017
                let (min, max) = range_vec_bounds_i32(&ranges)?;
585
2017
                Some(Domain::IntDomain(vec![Range::Bounded(min, max)]))
586
            }
587
6439
            _ => ret,
588
        }
589
8493
    }
590

            
591
    pub fn get_meta(&self) -> Metadata {
592
        let metas: VecDeque<Metadata> = self.children_bi();
593
        metas[0].clone()
594
    }
595

            
596
    pub fn set_meta(&self, meta: Metadata) {
597
        self.transform_bi(Arc::new(move |_| meta.clone()));
598
    }
599

            
600
    /// Checks whether this expression is safe.
601
    ///
602
    /// An expression is unsafe if can be undefined, or if any of its children can be undefined.
603
    ///
604
    /// Unsafe expressions are (typically) prefixed with Unsafe in our AST, and can be made
605
    /// safe through the use of bubble rules.
606
5724
    pub fn is_safe(&self) -> bool {
607
        // TODO: memoise in Metadata
608
13770
        for expr in self.universe() {
609
13770
            match expr {
610
                Expression::UnsafeDiv(_, _, _)
611
                | Expression::UnsafeMod(_, _, _)
612
                | Expression::UnsafePow(_, _, _)
613
                | Expression::UnsafeIndex(_, _, _)
614
                | Expression::UnsafeSlice(_, _, _) => {
615
342
                    return false;
616
                }
617
13428
                _ => {}
618
            }
619
        }
620
5382
        true
621
5724
    }
622

            
623
23184
    pub fn return_type(&self) -> Option<ReturnType> {
624
5580
        match self {
625
            Expression::AbstractLiteral(_, _) => None,
626
756
            Expression::UnsafeIndex(_, subject, _) | Expression::SafeIndex(_, subject, _) => {
627
756
                Some(subject.return_type()?)
628
            }
629
306
            Expression::UnsafeSlice(_, subject, _) | Expression::SafeSlice(_, subject, _) => {
630
306
                Some(ReturnType::Matrix(Box::new(subject.return_type()?)))
631
            }
632
            Expression::InDomain(_, _, _) => Some(ReturnType::Bool),
633
            Expression::Root(_, _) => Some(ReturnType::Bool),
634
            Expression::DominanceRelation(_, _) => Some(ReturnType::Bool),
635
            Expression::FromSolution(_, expr) => expr.return_type(),
636
5562
            Expression::Atomic(_, Atom::Literal(Literal::Int(_))) => Some(ReturnType::Int),
637
18
            Expression::Atomic(_, Atom::Literal(Literal::Bool(_))) => Some(ReturnType::Bool),
638
            Expression::Atomic(_, Atom::Literal(Literal::AbstractLiteral(_))) => None,
639
3726
            Expression::Atomic(_, Atom::Reference(_)) => None,
640
            Expression::Scope(_, scope) => scope.return_type(),
641
180
            Expression::Abs(_, _) => Some(ReturnType::Int),
642
198
            Expression::Sum(_, _) => Some(ReturnType::Int),
643
108
            Expression::Product(_, _) => Some(ReturnType::Int),
644
            Expression::Min(_, _) => Some(ReturnType::Int),
645
            Expression::Max(_, _) => Some(ReturnType::Int),
646
54
            Expression::Not(_, _) => Some(ReturnType::Bool),
647
            Expression::Or(_, _) => Some(ReturnType::Bool),
648
            Expression::Imply(_, _, _) => Some(ReturnType::Bool),
649
72
            Expression::And(_, _) => Some(ReturnType::Bool),
650
2142
            Expression::Eq(_, _, _) => Some(ReturnType::Bool),
651
216
            Expression::Neq(_, _, _) => Some(ReturnType::Bool),
652
            Expression::Geq(_, _, _) => Some(ReturnType::Bool),
653
270
            Expression::Leq(_, _, _) => Some(ReturnType::Bool),
654
            Expression::Gt(_, _, _) => Some(ReturnType::Bool),
655
            Expression::Lt(_, _, _) => Some(ReturnType::Bool),
656
4482
            Expression::SafeDiv(_, _, _) => Some(ReturnType::Int),
657
72
            Expression::UnsafeDiv(_, _, _) => Some(ReturnType::Int),
658
            Expression::FlatAllDiff(_, _) => Some(ReturnType::Bool),
659
            Expression::FlatSumGeq(_, _, _) => Some(ReturnType::Bool),
660
            Expression::FlatSumLeq(_, _, _) => Some(ReturnType::Bool),
661
            Expression::MinionDivEqUndefZero(_, _, _, _) => Some(ReturnType::Bool),
662
            Expression::FlatIneq(_, _, _, _) => Some(ReturnType::Bool),
663
612
            Expression::AllDiff(_, _) => Some(ReturnType::Bool),
664
            Expression::Bubble(_, _, _) => None, // TODO: (flm8) should this be a bool?
665
            Expression::FlatWatchedLiteral(_, _, _) => Some(ReturnType::Bool),
666
            Expression::MinionReify(_, _, _) => Some(ReturnType::Bool),
667
            Expression::MinionReifyImply(_, _, _) => Some(ReturnType::Bool),
668
            Expression::AuxDeclaration(_, _, _) => Some(ReturnType::Bool),
669
54
            Expression::UnsafeMod(_, _, _) => Some(ReturnType::Int),
670
3960
            Expression::SafeMod(_, _, _) => Some(ReturnType::Int),
671
            Expression::MinionModuloEqUndefZero(_, _, _, _) => Some(ReturnType::Bool),
672
180
            Expression::Neg(_, _) => Some(ReturnType::Int),
673
18
            Expression::UnsafePow(_, _, _) => Some(ReturnType::Int),
674
198
            Expression::SafePow(_, _, _) => Some(ReturnType::Int),
675
            Expression::Minus(_, _, _) => Some(ReturnType::Int),
676
            Expression::FlatAbsEq(_, _, _) => Some(ReturnType::Bool),
677
            Expression::FlatMinusEq(_, _, _) => Some(ReturnType::Bool),
678
            Expression::FlatProductEq(_, _, _, _) => Some(ReturnType::Bool),
679
            Expression::FlatWeightedSumLeq(_, _, _, _) => Some(ReturnType::Bool),
680
            Expression::FlatWeightedSumGeq(_, _, _, _) => Some(ReturnType::Bool),
681
            Expression::MinionPow(_, _, _, _) => Some(ReturnType::Bool),
682
        }
683
23184
    }
684

            
685
    pub fn is_clean(&self) -> bool {
686
        let metadata = self.get_meta();
687
        metadata.clean
688
    }
689

            
690
    pub fn set_clean(&mut self, bool_value: bool) {
691
        let mut metadata = self.get_meta();
692
        metadata.clean = bool_value;
693
        self.set_meta(metadata);
694
    }
695

            
696
    /// True if the expression is an associative and commutative operator
697
511758
    pub fn is_associative_commutative_operator(&self) -> bool {
698
478314
        matches!(
699
511758
            self,
700
            Expression::Sum(_, _)
701
                | Expression::Or(_, _)
702
                | Expression::And(_, _)
703
                | Expression::Product(_, _)
704
        )
705
511758
    }
706

            
707
    /// True iff self and other are both atomic and identical.
708
    ///
709
    /// This method is useful to cheaply check equivalence. Assuming CSE is enabled, any unifiable
710
    /// expressions will be rewritten to a common variable. This is much cheaper than checking the
711
    /// entire subtrees of `self` and `other`.
712
28242
    pub fn identical_atom_to(&self, other: &Expression) -> bool {
713
28242
        let atom1: Result<&Atom, _> = self.try_into();
714
28242
        let atom2: Result<&Atom, _> = other.try_into();
715

            
716
28242
        if let (Ok(atom1), Ok(atom2)) = (atom1, atom2) {
717
4410
            atom2 == atom1
718
        } else {
719
23832
            false
720
        }
721
28242
    }
722

            
723
    /// If the expression is a list, returns the inner expressions.
724
    ///
725
    /// A list is any a matrix with the domain `int(1..)`. This includes matrix literals without
726
    /// any explicitly specified domain.
727
178020
    pub fn unwrap_list(self) -> Option<Vec<Expression>> {
728
134406
        match self {
729
134406
            Expression::AbstractLiteral(_, matrix @ AbstractLiteral::Matrix(_, _)) => {
730
134406
                matrix.unwrap_list().cloned()
731
            }
732
            Expression::Atomic(
733
                _,
734
                Atom::Literal(Literal::AbstractLiteral(matrix @ AbstractLiteral::Matrix(_, _))),
735
            ) => matrix.unwrap_list().map(|elems| {
736
                elems
737
                    .clone()
738
                    .into_iter()
739
                    .map(|x: Literal| Expression::Atomic(Metadata::new(), Atom::Literal(x)))
740
                    .collect_vec()
741
            }),
742
43614
            _ => None,
743
        }
744
178020
    }
745

            
746
    /// If the expression is a matrix, gets it elements and index domain.
747
    ///
748
    /// **Consider using the safer [`Expression::unwrap_list`] instead.**
749
    ///
750
    /// It is generally undefined to edit the length of a matrix unless it is a list (as defined by
751
    /// [`Expression::unwrap_list`]). Users of this function should ensure that, if the matrix is
752
    /// reconstructed, the index domain and the number of elements in the matrix remain the same.
753
18450
    pub fn unwrap_matrix_unchecked(self) -> Option<(Vec<Expression>, Domain)> {
754
2052
        match self {
755
2052
            Expression::AbstractLiteral(_, AbstractLiteral::Matrix(elems, domain)) => {
756
2052
                Some((elems.clone(), domain))
757
            }
758
            Expression::Atomic(
759
                _,
760
                Atom::Literal(Literal::AbstractLiteral(AbstractLiteral::Matrix(elems, domain))),
761
            ) => Some((
762
                elems
763
                    .clone()
764
                    .into_iter()
765
                    .map(|x: Literal| Expression::Atomic(Metadata::new(), Atom::Literal(x)))
766
                    .collect_vec(),
767
                domain,
768
            )),
769

            
770
16398
            _ => None,
771
        }
772
18450
    }
773
}
774

            
775
impl From<i32> for Expression {
776
864
    fn from(i: i32) -> Self {
777
864
        Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Int(i)))
778
864
    }
779
}
780

            
781
impl From<bool> for Expression {
782
144
    fn from(b: bool) -> Self {
783
144
        Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Bool(b)))
784
144
    }
785
}
786

            
787
impl From<Atom> for Expression {
788
288
    fn from(value: Atom) -> Self {
789
288
        Expression::Atomic(Metadata::new(), value)
790
288
    }
791
}
792
impl Display for Expression {
793
280296
    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
794
280296
        match &self {
795
20790
            Expression::AbstractLiteral(_, l) => l.fmt(f),
796
2358
            Expression::UnsafeIndex(_, e1, e2) | Expression::SafeIndex(_, e1, e2) => {
797
3222
                write!(f, "{e1}{}", pretty_vec(e2))
798
            }
799
3168
            Expression::UnsafeSlice(_, e1, es) | Expression::SafeSlice(_, e1, es) => {
800
3888
                let args = es
801
3888
                    .iter()
802
7776
                    .map(|x| match x {
803
3888
                        Some(x) => format!("{}", x),
804
3888
                        None => "..".into(),
805
7776
                    })
806
3888
                    .join(",");
807
3888

            
808
3888
                write!(f, "{e1}[{args}]")
809
            }
810

            
811
1260
            Expression::InDomain(_, e, domain) => {
812
1260
                write!(f, "__inDomain({e},{domain})")
813
            }
814
504
            Expression::Root(_, exprs) => {
815
504
                write!(f, "{}", pretty_expressions_as_top_level(exprs))
816
            }
817
            Expression::DominanceRelation(_, expr) => write!(f, "DominanceRelation({})", expr),
818
            Expression::FromSolution(_, expr) => write!(f, "FromSolution({})", expr),
819
135270
            Expression::Atomic(_, atom) => atom.fmt(f),
820
            Expression::Scope(_, submodel) => write!(f, "{{\n{submodel}\n}}"),
821
1620
            Expression::Abs(_, a) => write!(f, "|{}|", a),
822
5832
            Expression::Sum(_, expressions) => {
823
5832
                write!(f, "Sum({})", pretty_vec(expressions))
824
            }
825
3780
            Expression::Product(_, expressions) => {
826
3780
                write!(f, "Product({})", pretty_vec(expressions))
827
            }
828
504
            Expression::Min(_, e) => {
829
504
                write!(f, "min({e})")
830
            }
831
216
            Expression::Max(_, e) => {
832
216
                write!(f, "max({e})")
833
            }
834
2358
            Expression::Not(_, expr_box) => {
835
2358
                write!(f, "Not({})", expr_box.clone())
836
            }
837
3978
            Expression::Or(_, e) => {
838
3978
                write!(f, "or({e})")
839
            }
840
14868
            Expression::And(_, e) => {
841
14868
                write!(f, "and({e})")
842
            }
843
3240
            Expression::Imply(_, box1, box2) => {
844
3240
                write!(f, "({}) -> ({})", box1, box2)
845
            }
846
13014
            Expression::Eq(_, box1, box2) => {
847
13014
                write!(f, "({} = {})", box1.clone(), box2.clone())
848
            }
849
9900
            Expression::Neq(_, box1, box2) => {
850
9900
                write!(f, "({} != {})", box1.clone(), box2.clone())
851
            }
852
1800
            Expression::Geq(_, box1, box2) => {
853
1800
                write!(f, "({} >= {})", box1.clone(), box2.clone())
854
            }
855
3672
            Expression::Leq(_, box1, box2) => {
856
3672
                write!(f, "({} <= {})", box1.clone(), box2.clone())
857
            }
858
108
            Expression::Gt(_, box1, box2) => {
859
108
                write!(f, "({} > {})", box1.clone(), box2.clone())
860
            }
861
1368
            Expression::Lt(_, box1, box2) => {
862
1368
                write!(f, "({} < {})", box1.clone(), box2.clone())
863
            }
864
1116
            Expression::FlatSumGeq(_, box1, box2) => {
865
1116
                write!(f, "SumGeq({}, {})", pretty_vec(box1), box2.clone())
866
            }
867
1044
            Expression::FlatSumLeq(_, box1, box2) => {
868
1044
                write!(f, "SumLeq({}, {})", pretty_vec(box1), box2.clone())
869
            }
870
3096
            Expression::FlatIneq(_, box1, box2, box3) => write!(
871
3096
                f,
872
3096
                "Ineq({}, {}, {})",
873
3096
                box1.clone(),
874
3096
                box2.clone(),
875
3096
                box3.clone()
876
3096
            ),
877
3636
            Expression::AllDiff(_, e) => {
878
3636
                write!(f, "allDiff({e})")
879
            }
880
6588
            Expression::Bubble(_, box1, box2) => {
881
6588
                write!(f, "{{{} @ {}}}", box1.clone(), box2.clone())
882
            }
883
8694
            Expression::SafeDiv(_, box1, box2) => {
884
8694
                write!(f, "SafeDiv({}, {})", box1.clone(), box2.clone())
885
            }
886
2700
            Expression::UnsafeDiv(_, box1, box2) => {
887
2700
                write!(f, "UnsafeDiv({}, {})", box1.clone(), box2.clone())
888
            }
889
360
            Expression::UnsafePow(_, box1, box2) => {
890
360
                write!(f, "UnsafePow({}, {})", box1.clone(), box2.clone())
891
            }
892
756
            Expression::SafePow(_, box1, box2) => {
893
756
                write!(f, "SafePow({}, {})", box1.clone(), box2.clone())
894
            }
895
1188
            Expression::MinionDivEqUndefZero(_, box1, box2, box3) => {
896
1188
                write!(
897
1188
                    f,
898
1188
                    "DivEq({}, {}, {})",
899
1188
                    box1.clone(),
900
1188
                    box2.clone(),
901
1188
                    box3.clone()
902
1188
                )
903
            }
904
900
            Expression::MinionModuloEqUndefZero(_, box1, box2, box3) => {
905
900
                write!(
906
900
                    f,
907
900
                    "ModEq({}, {}, {})",
908
900
                    box1.clone(),
909
900
                    box2.clone(),
910
900
                    box3.clone()
911
900
                )
912
            }
913

            
914
504
            Expression::FlatWatchedLiteral(_, x, l) => {
915
504
                write!(f, "WatchedLiteral({},{})", x, l)
916
            }
917
864
            Expression::MinionReify(_, box1, box2) => {
918
864
                write!(f, "Reify({}, {})", box1.clone(), box2.clone())
919
            }
920
1080
            Expression::MinionReifyImply(_, box1, box2) => {
921
1080
                write!(f, "ReifyImply({}, {})", box1.clone(), box2.clone())
922
            }
923
3780
            Expression::AuxDeclaration(_, n, e) => {
924
3780
                write!(f, "{} =aux {}", n, e.clone())
925
            }
926
1728
            Expression::UnsafeMod(_, a, b) => {
927
1728
                write!(f, "{} % {}", a.clone(), b.clone())
928
            }
929
5364
            Expression::SafeMod(_, a, b) => {
930
5364
                write!(f, "SafeMod({},{})", a.clone(), b.clone())
931
            }
932
4212
            Expression::Neg(_, a) => {
933
4212
                write!(f, "-({})", a.clone())
934
            }
935
180
            Expression::Minus(_, a, b) => {
936
180
                write!(f, "({} - {})", a.clone(), b.clone())
937
            }
938
108
            Expression::FlatAllDiff(_, es) => {
939
108
                write!(f, "__flat_alldiff({})", pretty_vec(es))
940
            }
941
288
            Expression::FlatAbsEq(_, a, b) => {
942
288
                write!(f, "AbsEq({},{})", a.clone(), b.clone())
943
            }
944
144
            Expression::FlatMinusEq(_, a, b) => {
945
144
                write!(f, "MinusEq({},{})", a.clone(), b.clone())
946
            }
947
162
            Expression::FlatProductEq(_, a, b, c) => {
948
162
                write!(
949
162
                    f,
950
162
                    "FlatProductEq({},{},{})",
951
162
                    a.clone(),
952
162
                    b.clone(),
953
162
                    c.clone()
954
162
                )
955
            }
956
252
            Expression::FlatWeightedSumLeq(_, cs, vs, total) => {
957
252
                write!(
958
252
                    f,
959
252
                    "FlatWeightedSumLeq({},{},{})",
960
252
                    pretty_vec(cs),
961
252
                    pretty_vec(vs),
962
252
                    total.clone()
963
252
                )
964
            }
965

            
966
180
            Expression::FlatWeightedSumGeq(_, cs, vs, total) => {
967
180
                write!(
968
180
                    f,
969
180
                    "FlatWeightedSumGeq({},{},{})",
970
180
                    pretty_vec(cs),
971
180
                    pretty_vec(vs),
972
180
                    total.clone()
973
180
                )
974
            }
975
180
            Expression::MinionPow(_, atom, atom1, atom2) => {
976
180
                write!(f, "MinionPow({},{},{})", atom, atom1, atom2)
977
            }
978
        }
979
280296
    }
980
}
981

            
982
#[cfg(test)]
983
mod tests {
984
    use std::rc::Rc;
985

            
986
    use crate::ast::declaration::Declaration;
987

            
988
    use super::*;
989

            
990
    #[test]
991
1
    fn test_domain_of_constant_sum() {
992
1
        let c1 = Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Int(1)));
993
1
        let c2 = Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Int(2)));
994
1
        let sum = Expression::Sum(Metadata::new(), vec![c1.clone(), c2.clone()]);
995
1
        assert_eq!(
996
1
            sum.domain_of(&SymbolTable::new()),
997
1
            Some(Domain::IntDomain(vec![Range::Single(3)]))
998
1
        );
999
1
    }
    #[test]
1
    fn test_domain_of_constant_invalid_type() {
1
        let c1 = Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Int(1)));
1
        let c2 = Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Bool(true)));
1
        let sum = Expression::Sum(Metadata::new(), vec![c1.clone(), c2.clone()]);
1
        assert_eq!(sum.domain_of(&SymbolTable::new()), None);
1
    }
    #[test]
1
    fn test_domain_of_empty_sum() {
1
        let sum = Expression::Sum(Metadata::new(), vec![]);
1
        assert_eq!(sum.domain_of(&SymbolTable::new()), None);
1
    }
    #[test]
1
    fn test_domain_of_reference() {
1
        let reference = Expression::Atomic(Metadata::new(), Atom::Reference(Name::MachineName(0)));
1
        let mut vars = SymbolTable::new();
1
        vars.insert(Rc::new(Declaration::new_var(
1
            Name::MachineName(0),
1
            Domain::IntDomain(vec![Range::Single(1)]),
1
        )))
1
        .unwrap();
1
        assert_eq!(
1
            reference.domain_of(&vars),
1
            Some(Domain::IntDomain(vec![Range::Single(1)]))
1
        );
1
    }
    #[test]
1
    fn test_domain_of_reference_not_found() {
1
        let reference = Expression::Atomic(Metadata::new(), Atom::Reference(Name::MachineName(0)));
1
        assert_eq!(reference.domain_of(&SymbolTable::new()), None);
1
    }
    #[test]
1
    fn test_domain_of_reference_sum_single() {
1
        let reference = Expression::Atomic(Metadata::new(), Atom::Reference(Name::MachineName(0)));
1
        let mut vars = SymbolTable::new();
1
        vars.insert(Rc::new(Declaration::new_var(
1
            Name::MachineName(0),
1
            Domain::IntDomain(vec![Range::Single(1)]),
1
        )))
1
        .unwrap();
1
        let sum = Expression::Sum(Metadata::new(), vec![reference.clone(), reference.clone()]);
1
        assert_eq!(
1
            sum.domain_of(&vars),
1
            Some(Domain::IntDomain(vec![Range::Single(2)]))
1
        );
1
    }
    #[test]
1
    fn test_domain_of_reference_sum_bounded() {
1
        let reference = Expression::Atomic(Metadata::new(), Atom::Reference(Name::MachineName(0)));
1
        let mut vars = SymbolTable::new();
1
        vars.insert(Rc::new(Declaration::new_var(
1
            Name::MachineName(0),
1
            Domain::IntDomain(vec![Range::Bounded(1, 2)]),
1
        )));
1
        let sum = Expression::Sum(Metadata::new(), vec![reference.clone(), reference.clone()]);
1
        assert_eq!(
1
            sum.domain_of(&vars),
1
            Some(Domain::IntDomain(vec![Range::Bounded(2, 4)]))
1
        );
1
    }
}