conjure_core/ast/
expressions.rs

1use std::collections::VecDeque;
2use std::fmt::{Display, Formatter};
3use std::sync::Arc;
4
5use itertools::Itertools;
6use serde::{Deserialize, Serialize};
7
8use crate::ast::literals::AbstractLiteral;
9use crate::ast::literals::Literal;
10use crate::ast::pretty::{pretty_expressions_as_top_level, pretty_vec};
11use crate::ast::symbol_table::SymbolTable;
12use crate::ast::Atom;
13use crate::ast::Name;
14use crate::ast::ReturnType;
15use crate::ast::SetAttr;
16use crate::bug;
17use crate::metadata::Metadata;
18use enum_compatability_macro::document_compatibility;
19use uniplate::derive::Uniplate;
20use uniplate::{Biplate, Uniplate as _};
21
22use super::comprehension::Comprehension;
23use super::records::RecordValue;
24use super::{Domain, Range, SubModel, Typeable};
25
26/// Represents different types of expressions used to define rules and constraints in the model.
27///
28/// The `Expression` enum includes operations, constants, and variable references
29/// used to build rules and conditions for the model.
30#[document_compatibility]
31#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize, Uniplate)]
32#[uniplate(walk_into=[Atom,SubModel,AbstractLiteral<Expression>,Comprehension])]
33#[biplate(to=Metadata)]
34#[biplate(to=Atom)]
35#[biplate(to=Name,walk_into=[Atom])]
36#[biplate(to=Vec<Expression>)]
37#[biplate(to=Option<Expression>)]
38#[biplate(to=SubModel,walk_into=[Comprehension])]
39#[biplate(to=Comprehension)]
40#[biplate(to=AbstractLiteral<Expression>)]
41#[biplate(to=AbstractLiteral<Literal>,walk_into=[Atom])]
42#[biplate(to=RecordValue<Expression>,walk_into=[AbstractLiteral<Expression>])]
43#[biplate(to=RecordValue<Literal>,walk_into=[Atom,Literal,AbstractLiteral<Literal>,AbstractLiteral<Expression>])]
44#[biplate(to=Literal,walk_into=[Atom])]
45pub enum Expression {
46    AbstractLiteral(Metadata, AbstractLiteral<Expression>),
47    /// The top of the model
48    Root(Metadata, Vec<Expression>),
49
50    /// An expression representing "A is valid as long as B is true"
51    /// Turns into a conjunction when it reaches a boolean context
52    Bubble(Metadata, Box<Expression>, Box<Expression>),
53
54    /// A comprehension.
55    ///
56    /// The inside of the comprehension opens a new scope.
57    Comprehension(Metadata, Box<Comprehension>),
58
59    /// Defines dominance ("Solution A is preferred over Solution B")
60    DominanceRelation(Metadata, Box<Expression>),
61    /// `fromSolution(name)` - Used in dominance relation definitions
62    FromSolution(Metadata, Box<Expression>),
63
64    Atomic(Metadata, Atom),
65
66    /// A matrix index.
67    ///
68    /// Defined iff the indices are within their respective index domains.
69    #[compatible(JsonInput)]
70    UnsafeIndex(Metadata, Box<Expression>, Vec<Expression>),
71
72    /// A safe matrix index.
73    ///
74    /// See [`Expression::UnsafeIndex`]
75    SafeIndex(Metadata, Box<Expression>, Vec<Expression>),
76
77    /// A matrix slice: `a[indices]`.
78    ///
79    /// One of the indicies may be `None`, representing the dimension of the matrix we want to take
80    /// a slice of. For example, for some 3d matrix a, `a[1,..,2]` has the indices
81    /// `Some(1),None,Some(2)`.
82    ///
83    /// It is assumed that the slice only has one "wild-card" dimension and thus is 1 dimensional.
84    ///
85    /// Defined iff the defined indices are within their respective index domains.
86    #[compatible(JsonInput)]
87    UnsafeSlice(Metadata, Box<Expression>, Vec<Option<Expression>>),
88
89    /// A safe matrix slice: `a[indices]`.
90    ///
91    /// See [`Expression::UnsafeSlice`].
92    SafeSlice(Metadata, Box<Expression>, Vec<Option<Expression>>),
93
94    /// `inDomain(x,domain)` iff `x` is in the domain `domain`.
95    ///
96    /// This cannot be constructed from Essence input, nor passed to a solver: this expression is
97    /// mainly used during the conversion of `UnsafeIndex` and `UnsafeSlice` to `SafeIndex` and
98    /// `SafeSlice` respectively.
99    InDomain(Metadata, Box<Expression>, Domain),
100
101    Scope(Metadata, Box<SubModel>),
102
103    /// `|x|` - absolute value of `x`
104    #[compatible(JsonInput)]
105    Abs(Metadata, Box<Expression>),
106
107    /// `sum(<vec_expr>)`
108    #[compatible(JsonInput)]
109    Sum(Metadata, Box<Expression>),
110
111    /// `a * b * c * ...`
112    #[compatible(JsonInput)]
113    Product(Metadata, Vec<Expression>),
114
115    /// `min(<vec_expr>)`
116    #[compatible(JsonInput)]
117    Min(Metadata, Box<Expression>),
118
119    /// `max(<vec_expr>)`
120    #[compatible(JsonInput)]
121    Max(Metadata, Box<Expression>),
122
123    /// `not(a)`
124    #[compatible(JsonInput, SAT)]
125    Not(Metadata, Box<Expression>),
126
127    /// `or(<vec_expr>)`
128    #[compatible(JsonInput, SAT)]
129    Or(Metadata, Box<Expression>),
130
131    /// `and(<vec_expr>)`
132    #[compatible(JsonInput, SAT)]
133    And(Metadata, Box<Expression>),
134
135    /// Ensures that `a->b` (material implication).
136    #[compatible(JsonInput)]
137    Imply(Metadata, Box<Expression>, Box<Expression>),
138
139    /// `iff(a, b)` a <-> b
140    #[compatible(JsonInput)]
141    Iff(Metadata, Box<Expression>, Box<Expression>),
142
143    #[compatible(JsonInput)]
144    Union(Metadata, Box<Expression>, Box<Expression>),
145
146    #[compatible(JsonInput)]
147    Intersect(Metadata, Box<Expression>, Box<Expression>),
148
149    #[compatible(JsonInput)]
150    Supset(Metadata, Box<Expression>, Box<Expression>),
151
152    #[compatible(JsonInput)]
153    SupsetEq(Metadata, Box<Expression>, Box<Expression>),
154
155    #[compatible(JsonInput)]
156    Subset(Metadata, Box<Expression>, Box<Expression>),
157
158    #[compatible(JsonInput)]
159    SubsetEq(Metadata, Box<Expression>, Box<Expression>),
160
161    #[compatible(JsonInput)]
162    Eq(Metadata, Box<Expression>, Box<Expression>),
163
164    #[compatible(JsonInput)]
165    Neq(Metadata, Box<Expression>, Box<Expression>),
166
167    #[compatible(JsonInput)]
168    Geq(Metadata, Box<Expression>, Box<Expression>),
169
170    #[compatible(JsonInput)]
171    Leq(Metadata, Box<Expression>, Box<Expression>),
172
173    #[compatible(JsonInput)]
174    Gt(Metadata, Box<Expression>, Box<Expression>),
175
176    #[compatible(JsonInput)]
177    Lt(Metadata, Box<Expression>, Box<Expression>),
178
179    /// Division after preventing division by zero, usually with a bubble
180    SafeDiv(Metadata, Box<Expression>, Box<Expression>),
181
182    /// Division with a possibly undefined value (division by 0)
183    #[compatible(JsonInput)]
184    UnsafeDiv(Metadata, Box<Expression>, Box<Expression>),
185
186    /// Modulo after preventing mod 0, usually with a bubble
187    SafeMod(Metadata, Box<Expression>, Box<Expression>),
188
189    /// Modulo with a possibly undefined value (mod 0)
190    #[compatible(JsonInput)]
191    UnsafeMod(Metadata, Box<Expression>, Box<Expression>),
192
193    /// Negation: `-x`
194    #[compatible(JsonInput)]
195    Neg(Metadata, Box<Expression>),
196
197    /// Unsafe power`x**y` (possibly undefined)
198    ///
199    /// Defined when (X!=0 \\/ Y!=0) /\ Y>=0
200    #[compatible(JsonInput)]
201    UnsafePow(Metadata, Box<Expression>, Box<Expression>),
202
203    /// `UnsafePow` after preventing undefinedness
204    SafePow(Metadata, Box<Expression>, Box<Expression>),
205
206    /// `allDiff(<vec_expr>)`
207    #[compatible(JsonInput)]
208    AllDiff(Metadata, Box<Expression>),
209
210    /// Binary subtraction operator
211    ///
212    /// This is a parser-level construct, and is immediately normalised to `Sum([a,-b])`.
213    /// TODO: make this compatible with Set Difference calculations - need to change return type and domain for this expression and write a set comprehension rule.
214    /// have already edited minus_to_sum to prevent this from applying to sets
215    #[compatible(JsonInput)]
216    Minus(Metadata, Box<Expression>, Box<Expression>),
217
218    /// Ensures that x=|y| i.e. x is the absolute value of y.
219    ///
220    /// Low-level Minion constraint.
221    ///
222    /// # See also
223    ///
224    /// + [Minion documentation](https://minion-solver.readthedocs.io/en/stable/usage/constraints.html#abs)
225    #[compatible(Minion)]
226    FlatAbsEq(Metadata, Atom, Atom),
227
228    /// Ensures that `alldiff([a,b,...])`.
229    ///
230    /// Low-level Minion constraint.
231    ///
232    /// # See also
233    ///
234    /// + [Minion documentation](https://minion-solver.readthedocs.io/en/stable/usage/constraints.html#alldiff)
235    #[compatible(Minion)]
236    FlatAllDiff(Metadata, Vec<Atom>),
237
238    /// Ensures that sum(vec) >= x.
239    ///
240    /// Low-level Minion constraint.
241    ///
242    /// # See also
243    ///
244    /// + [Minion documentation](https://minion-solver.readthedocs.io/en/stable/usage/constraints.html#sumgeq)
245    #[compatible(Minion)]
246    FlatSumGeq(Metadata, Vec<Atom>, Atom),
247
248    /// Ensures that sum(vec) <= x.
249    ///
250    /// Low-level Minion constraint.
251    ///
252    /// # See also
253    ///
254    /// + [Minion documentation](https://minion-solver.readthedocs.io/en/stable/usage/constraints.html#sumleq)
255    #[compatible(Minion)]
256    FlatSumLeq(Metadata, Vec<Atom>, Atom),
257
258    /// `ineq(x,y,k)` ensures that x <= y + k.
259    ///
260    /// Low-level Minion constraint.
261    ///
262    /// # See also
263    ///
264    /// + [Minion documentation](https://minion-solver.readthedocs.io/en/stable/usage/constraints.html#ineq)
265    #[compatible(Minion)]
266    FlatIneq(Metadata, Atom, Atom, Literal),
267
268    /// `w-literal(x,k)` ensures that x == k, where x is a variable and k a constant.
269    ///
270    /// Low-level Minion constraint.
271    ///
272    /// This is a low-level Minion constraint and you should probably use Eq instead. The main use
273    /// of w-literal is to convert boolean variables to constraints so that they can be used inside
274    /// watched-and and watched-or.
275    ///
276    /// # See also
277    ///
278    /// + [Minion documentation](https://minion-solver.readthedocs.io/en/stable/usage/constraints.html#minuseq)
279    /// + `rules::minion::boolean_literal_to_wliteral`.
280    #[compatible(Minion)]
281    FlatWatchedLiteral(Metadata, Name, Literal),
282
283    /// `weightedsumleq(cs,xs,total)` ensures that cs.xs <= total, where cs.xs is the scalar dot
284    /// product of cs and xs.
285    ///
286    /// Low-level Minion constraint.
287    ///
288    /// Represents a weighted sum of the form `ax + by + cz + ...`
289    ///
290    /// # See also
291    ///
292    /// + [Minion
293    /// documentation](https://minion-solver.readthedocs.io/en/stable/usage/constraints.html#weightedsumleq)
294    FlatWeightedSumLeq(Metadata, Vec<Literal>, Vec<Atom>, Atom),
295
296    /// `weightedsumgeq(cs,xs,total)` ensures that cs.xs >= total, where cs.xs is the scalar dot
297    /// product of cs and xs.
298    ///
299    /// Low-level Minion constraint.
300    ///
301    /// Represents a weighted sum of the form `ax + by + cz + ...`
302    ///
303    /// # See also
304    ///
305    /// + [Minion
306    /// documentation](https://minion-solver.readthedocs.io/en/stable/usage/constraints.html#weightedsumleq)
307    FlatWeightedSumGeq(Metadata, Vec<Literal>, Vec<Atom>, Atom),
308
309    /// Ensures that x =-y, where x and y are atoms.
310    ///
311    /// Low-level Minion constraint.
312    ///
313    /// # See also
314    ///
315    /// + [Minion documentation](https://minion-solver.readthedocs.io/en/stable/usage/constraints.html#minuseq)
316    #[compatible(Minion)]
317    FlatMinusEq(Metadata, Atom, Atom),
318
319    /// Ensures that x*y=z.
320    ///
321    /// Low-level Minion constraint.
322    ///
323    /// # See also
324    ///
325    /// + [Minion documentation](https://minion-solver.readthedocs.io/en/stable/usage/constraints.html#product)
326    #[compatible(Minion)]
327    FlatProductEq(Metadata, Atom, Atom, Atom),
328
329    /// Ensures that floor(x/y)=z. Always true when y=0.
330    ///
331    /// Low-level Minion constraint.
332    ///
333    /// # See also
334    ///
335    /// + [Minion documentation](https://minion-solver.readthedocs.io/en/stable/usage/constraints.html#div_undefzero)
336    #[compatible(Minion)]
337    MinionDivEqUndefZero(Metadata, Atom, Atom, Atom),
338
339    /// Ensures that x%y=z. Always true when y=0.
340    ///
341    /// Low-level Minion constraint.
342    ///
343    /// # See also
344    ///
345    /// + [Minion documentation](https://minion-solver.readthedocs.io/en/stable/usage/constraints.html#mod_undefzero)
346    #[compatible(Minion)]
347    MinionModuloEqUndefZero(Metadata, Atom, Atom, Atom),
348
349    /// Ensures that `x**y = z`.
350    ///
351    /// Low-level Minion constraint.
352    ///
353    /// This constraint is false when `y<0` except for `1**y=1` and `(-1)**y=z` (where z is 1 if y
354    /// is odd and z is -1 if y is even).
355    ///
356    /// # See also
357    ///
358    /// + [Github comment about `pow` semantics](https://github.com/minion/minion/issues/40#issuecomment-2595914891)
359    /// + [Minion documentation](https://minion-solver.readthedocs.io/en/stable/usage/constraints.html#pow)
360    MinionPow(Metadata, Atom, Atom, Atom),
361
362    /// `reify(constraint,r)` ensures that r=1 iff `constraint` is satisfied, where r is a 0/1
363    /// variable.
364    ///
365    /// Low-level Minion constraint.
366    ///
367    /// # See also
368    ///
369    ///  + [Minion documentation](https://minion-solver.readthedocs.io/en/stable/usage/constraints.html#reify)
370    #[compatible(Minion)]
371    MinionReify(Metadata, Box<Expression>, Atom),
372
373    /// `reifyimply(constraint,r)` ensures that `r->constraint`, where r is a 0/1 variable.
374    /// variable.
375    ///
376    /// Low-level Minion constraint.
377    ///
378    /// # See also
379    ///
380    ///  + [Minion documentation](https://minion-solver.readthedocs.io/en/stable/usage/constraints.html#reifyimply)
381    #[compatible(Minion)]
382    MinionReifyImply(Metadata, Box<Expression>, Atom),
383
384    /// `w-inintervalset(x, [a1,a2, b1,b2, … ])` ensures that the value of x belongs to one of the
385    /// intervals {a1,…,a2}, {b1,…,b2} etc.
386    ///
387    /// The list of intervals must be given in numerical order.
388    ///
389    /// Low-level Minion constraint.
390    ///
391    /// # See also
392    ///
393    ///  + [Minion documentation](https://minion-solver.readthedocs.io/en/stable/usage/constraints.html#w-inintervalset)
394    #[compatible(Minion)]
395    MinionWInIntervalSet(Metadata, Atom, Vec<i32>),
396
397    /// `element_one(vec, i, e)` specifies that `vec[i] = e`. This implies that i is
398    /// in the range `[1..len(vec)]`.
399    ///
400    /// Low-level Minion constraint.
401    ///
402    /// # See also
403    ///
404    ///  + [Minion documentation](https://minion-solver.readthedocs.io/en/stable/usage/constraints.html#element_one)
405    #[compatible(Minion)]
406    MinionElementOne(Metadata, Vec<Atom>, Atom, Atom),
407
408    /// Declaration of an auxiliary variable.
409    ///
410    /// As with Savile Row, we semantically distinguish this from `Eq`.
411    #[compatible(Minion)]
412    AuxDeclaration(Metadata, Name, Box<Expression>),
413}
414
415fn expr_vec_to_domain_i32(
416    exprs: &[Expression],
417    op: fn(i32, i32) -> Option<i32>,
418    vars: &SymbolTable,
419) -> Option<Domain> {
420    let domains: Vec<Option<_>> = exprs.iter().map(|e| e.domain_of(vars)).collect();
421    domains
422        .into_iter()
423        .reduce(|a, b| a.and_then(|x| b.and_then(|y| x.apply_i32(op, &y))))
424        .flatten()
425}
426fn expr_vec_lit_to_domain_i32(
427    e: &Expression,
428    op: fn(i32, i32) -> Option<i32>,
429    vars: &SymbolTable,
430) -> Option<Domain> {
431    let exprs = e.clone().unwrap_list()?;
432    expr_vec_to_domain_i32(&exprs, op, vars)
433}
434
435// Returns none if unbounded
436fn range_vec_bounds_i32(ranges: &Vec<Range<i32>>) -> Option<(i32, i32)> {
437    let mut min = i32::MAX;
438    let mut max = i32::MIN;
439    for r in ranges {
440        match r {
441            Range::Single(i) => {
442                if *i < min {
443                    min = *i;
444                }
445                if *i > max {
446                    max = *i;
447                }
448            }
449            Range::Bounded(i, j) => {
450                if *i < min {
451                    min = *i;
452                }
453                if *j > max {
454                    max = *j;
455                }
456            }
457            Range::UnboundedR(_) | Range::UnboundedL(_) => return None,
458        }
459    }
460    Some((min, max))
461}
462
463impl Expression {
464    /// Returns the possible values of the expression, recursing to leaf expressions
465    pub fn domain_of(&self, syms: &SymbolTable) -> Option<Domain> {
466        let ret = match self {
467            Expression::Union(_, a, b) => Some(Domain::DomainSet(
468                SetAttr::None,
469                Box::new(a.domain_of(syms)?.union(&b.domain_of(syms)?)?),
470            )),
471            Expression::Intersect(_, a, b) => Some(Domain::DomainSet(
472                SetAttr::None,
473                Box::new(a.domain_of(syms)?.intersect(&b.domain_of(syms)?)?),
474            )),
475            Expression::Supset(_, _, _) => Some(Domain::BoolDomain),
476            Expression::SupsetEq(_, _, _) => Some(Domain::BoolDomain),
477            Expression::Subset(_, _, _) => Some(Domain::BoolDomain),
478            Expression::SubsetEq(_, _, _) => Some(Domain::BoolDomain),
479
480            //todo
481            Expression::AbstractLiteral(_, _) => None,
482            Expression::DominanceRelation(_, _) => Some(Domain::BoolDomain),
483            Expression::FromSolution(_, expr) => expr.domain_of(syms),
484            Expression::Comprehension(_, comprehension) => comprehension.domain_of(),
485            Expression::UnsafeIndex(_, matrix, _) | Expression::SafeIndex(_, matrix, _) => {
486                match matrix.domain_of(syms)? {
487                    Domain::DomainMatrix(elem_domain, _) => Some(*elem_domain),
488                    Domain::DomainTuple(_) => None,
489                    Domain::DomainRecord(_) => None,
490                    _ => {
491                        bug!("subject of an index operation should support indexing")
492                    }
493                }
494            }
495            Expression::UnsafeSlice(_, matrix, indices)
496            | Expression::SafeSlice(_, matrix, indices) => {
497                let sliced_dimension = indices.iter().position(Option::is_none);
498
499                let Domain::DomainMatrix(elem_domain, index_domains) = matrix.domain_of(syms)?
500                else {
501                    bug!("subject of an index operation should be a matrix");
502                };
503
504                match sliced_dimension {
505                    Some(dimension) => Some(Domain::DomainMatrix(
506                        elem_domain,
507                        vec![index_domains[dimension].clone()],
508                    )),
509
510                    // same as index
511                    None => Some(*elem_domain),
512                }
513            }
514            Expression::InDomain(_, _, _) => Some(Domain::BoolDomain),
515            Expression::Atomic(_, Atom::Reference(name)) => Some(syms.resolve_domain(name)?),
516            Expression::Atomic(_, Atom::Literal(Literal::Int(n))) => {
517                Some(Domain::IntDomain(vec![Range::Single(*n)]))
518            }
519            Expression::Atomic(_, Atom::Literal(Literal::Bool(_))) => Some(Domain::BoolDomain),
520            Expression::Atomic(_, Atom::Literal(Literal::AbstractLiteral(_))) => None,
521            Expression::Scope(_, _) => Some(Domain::BoolDomain),
522            Expression::Sum(_, e) => expr_vec_lit_to_domain_i32(e, |x, y| Some(x + y), syms),
523            Expression::Product(_, exprs) => {
524                expr_vec_to_domain_i32(exprs, |x, y| Some(x * y), syms)
525            }
526            Expression::Min(_, e) => {
527                expr_vec_lit_to_domain_i32(e, |x, y| Some(if x < y { x } else { y }), syms)
528            }
529            Expression::Max(_, e) => {
530                expr_vec_lit_to_domain_i32(e, |x, y| Some(if x > y { x } else { y }), syms)
531            }
532            Expression::UnsafeDiv(_, a, b) => a.domain_of(syms)?.apply_i32(
533                // rust integer division is truncating; however, we want to always round down,
534                // including for negative numbers.
535                |x, y| {
536                    if y != 0 {
537                        Some((x as f32 / y as f32).floor() as i32)
538                    } else {
539                        None
540                    }
541                },
542                &b.domain_of(syms)?,
543            ),
544            Expression::SafeDiv(_, a, b) => {
545                // rust integer division is truncating; however, we want to always round down
546                // including for negative numbers.
547                let domain = a.domain_of(syms)?.apply_i32(
548                    |x, y| {
549                        if y != 0 {
550                            Some((x as f32 / y as f32).floor() as i32)
551                        } else {
552                            None
553                        }
554                    },
555                    &b.domain_of(syms)?,
556                );
557
558                match domain {
559                    Some(Domain::IntDomain(ranges)) => {
560                        let mut ranges = ranges;
561                        ranges.push(Range::Single(0));
562                        Some(Domain::IntDomain(ranges))
563                    }
564                    None => Some(Domain::IntDomain(vec![Range::Single(0)])),
565                    _ => None,
566                }
567            }
568            Expression::UnsafeMod(_, a, b) => a.domain_of(syms)?.apply_i32(
569                |x, y| if y != 0 { Some(x % y) } else { None },
570                &b.domain_of(syms)?,
571            ),
572
573            Expression::SafeMod(_, a, b) => {
574                let domain = a.domain_of(syms)?.apply_i32(
575                    |x, y| if y != 0 { Some(x % y) } else { None },
576                    &b.domain_of(syms)?,
577                );
578
579                match domain {
580                    Some(Domain::IntDomain(ranges)) => {
581                        let mut ranges = ranges;
582                        ranges.push(Range::Single(0));
583                        Some(Domain::IntDomain(ranges))
584                    }
585                    None => Some(Domain::IntDomain(vec![Range::Single(0)])),
586                    _ => None,
587                }
588            }
589
590            Expression::SafePow(_, a, b) | Expression::UnsafePow(_, a, b) => {
591                a.domain_of(syms)?.apply_i32(
592                    |x, y| {
593                        if (x != 0 || y != 0) && y >= 0 {
594                            Some(x.pow(y as u32))
595                        } else {
596                            None
597                        }
598                    },
599                    &b.domain_of(syms)?,
600                )
601            }
602
603            Expression::Root(_, _) => None,
604            Expression::Bubble(_, _, _) => None,
605            Expression::AuxDeclaration(_, _, _) => Some(Domain::BoolDomain),
606            Expression::And(_, _) => Some(Domain::BoolDomain),
607            Expression::Not(_, _) => Some(Domain::BoolDomain),
608            Expression::Or(_, _) => Some(Domain::BoolDomain),
609            Expression::Imply(_, _, _) => Some(Domain::BoolDomain),
610            Expression::Iff(_, _, _) => Some(Domain::BoolDomain),
611            Expression::Eq(_, _, _) => Some(Domain::BoolDomain),
612            Expression::Neq(_, _, _) => Some(Domain::BoolDomain),
613            Expression::Geq(_, _, _) => Some(Domain::BoolDomain),
614            Expression::Leq(_, _, _) => Some(Domain::BoolDomain),
615            Expression::Gt(_, _, _) => Some(Domain::BoolDomain),
616            Expression::Lt(_, _, _) => Some(Domain::BoolDomain),
617            Expression::FlatAbsEq(_, _, _) => Some(Domain::BoolDomain),
618            Expression::FlatSumGeq(_, _, _) => Some(Domain::BoolDomain),
619            Expression::FlatSumLeq(_, _, _) => Some(Domain::BoolDomain),
620            Expression::MinionDivEqUndefZero(_, _, _, _) => Some(Domain::BoolDomain),
621            Expression::MinionModuloEqUndefZero(_, _, _, _) => Some(Domain::BoolDomain),
622            Expression::FlatIneq(_, _, _, _) => Some(Domain::BoolDomain),
623            Expression::AllDiff(_, _) => Some(Domain::BoolDomain),
624            Expression::FlatWatchedLiteral(_, _, _) => Some(Domain::BoolDomain),
625            Expression::MinionReify(_, _, _) => Some(Domain::BoolDomain),
626            Expression::MinionReifyImply(_, _, _) => Some(Domain::BoolDomain),
627            Expression::MinionWInIntervalSet(_, _, _) => Some(Domain::BoolDomain),
628            Expression::MinionElementOne(_, _, _, _) => Some(Domain::BoolDomain),
629            Expression::Neg(_, x) => {
630                let Some(Domain::IntDomain(mut ranges)) = x.domain_of(syms) else {
631                    return None;
632                };
633
634                for range in ranges.iter_mut() {
635                    *range = match range {
636                        Range::Single(x) => Range::Single(-*x),
637                        Range::Bounded(x, y) => Range::Bounded(-*y, -*x),
638                        Range::UnboundedR(i) => Range::UnboundedL(-*i),
639                        Range::UnboundedL(i) => Range::UnboundedR(-*i),
640                    };
641                }
642
643                Some(Domain::IntDomain(ranges))
644            }
645            Expression::Minus(_, a, b) => a
646                .domain_of(syms)?
647                .apply_i32(|x, y| Some(x - y), &b.domain_of(syms)?),
648
649            Expression::FlatAllDiff(_, _) => Some(Domain::BoolDomain),
650            Expression::FlatMinusEq(_, _, _) => Some(Domain::BoolDomain),
651            Expression::FlatProductEq(_, _, _, _) => Some(Domain::BoolDomain),
652            Expression::FlatWeightedSumLeq(_, _, _, _) => Some(Domain::BoolDomain),
653            Expression::FlatWeightedSumGeq(_, _, _, _) => Some(Domain::BoolDomain),
654            Expression::Abs(_, a) => a
655                .domain_of(syms)?
656                .apply_i32(|a, _| Some(a.abs()), &a.domain_of(syms)?),
657            Expression::MinionPow(_, _, _, _) => Some(Domain::BoolDomain),
658        };
659        match ret {
660            // TODO: (flm8) the Minion bindings currently only support single ranges for domains, so we use the min/max bounds
661            // Once they support a full domain as we define it, we can remove this conversion
662            Some(Domain::IntDomain(ranges)) if ranges.len() > 1 => {
663                let (min, max) = range_vec_bounds_i32(&ranges)?;
664                Some(Domain::IntDomain(vec![Range::Bounded(min, max)]))
665            }
666            _ => ret,
667        }
668    }
669
670    pub fn get_meta(&self) -> Metadata {
671        let metas: VecDeque<Metadata> = self.children_bi();
672        metas[0].clone()
673    }
674
675    pub fn set_meta(&self, meta: Metadata) {
676        self.transform_bi(Arc::new(move |_| meta.clone()));
677    }
678
679    /// Checks whether this expression is safe.
680    ///
681    /// An expression is unsafe if can be undefined, or if any of its children can be undefined.
682    ///
683    /// Unsafe expressions are (typically) prefixed with Unsafe in our AST, and can be made
684    /// safe through the use of bubble rules.
685    pub fn is_safe(&self) -> bool {
686        // TODO: memoise in Metadata
687        for expr in self.universe() {
688            match expr {
689                Expression::UnsafeDiv(_, _, _)
690                | Expression::UnsafeMod(_, _, _)
691                | Expression::UnsafePow(_, _, _)
692                | Expression::UnsafeIndex(_, _, _)
693                | Expression::UnsafeSlice(_, _, _) => {
694                    return false;
695                }
696                _ => {}
697            }
698        }
699        true
700    }
701
702    pub fn is_clean(&self) -> bool {
703        let metadata = self.get_meta();
704        metadata.clean
705    }
706
707    pub fn set_clean(&mut self, bool_value: bool) {
708        let mut metadata = self.get_meta();
709        metadata.clean = bool_value;
710        self.set_meta(metadata);
711    }
712
713    /// True if the expression is an associative and commutative operator
714    pub fn is_associative_commutative_operator(&self) -> bool {
715        matches!(
716            self,
717            Expression::Sum(_, _)
718                | Expression::Or(_, _)
719                | Expression::And(_, _)
720                | Expression::Product(_, _)
721        )
722    }
723
724    /// True iff self and other are both atomic and identical.
725    ///
726    /// This method is useful to cheaply check equivalence. Assuming CSE is enabled, any unifiable
727    /// expressions will be rewritten to a common variable. This is much cheaper than checking the
728    /// entire subtrees of `self` and `other`.
729    pub fn identical_atom_to(&self, other: &Expression) -> bool {
730        let atom1: Result<&Atom, _> = self.try_into();
731        let atom2: Result<&Atom, _> = other.try_into();
732
733        if let (Ok(atom1), Ok(atom2)) = (atom1, atom2) {
734            atom2 == atom1
735        } else {
736            false
737        }
738    }
739
740    /// If the expression is a list, returns the inner expressions.
741    ///
742    /// A list is any a matrix with the domain `int(1..)`. This includes matrix literals without
743    /// any explicitly specified domain.
744    pub fn unwrap_list(self) -> Option<Vec<Expression>> {
745        match self {
746            Expression::AbstractLiteral(_, matrix @ AbstractLiteral::Matrix(_, _)) => {
747                matrix.unwrap_list().cloned()
748            }
749            Expression::Atomic(
750                _,
751                Atom::Literal(Literal::AbstractLiteral(matrix @ AbstractLiteral::Matrix(_, _))),
752            ) => matrix.unwrap_list().map(|elems| {
753                elems
754                    .clone()
755                    .into_iter()
756                    .map(|x: Literal| Expression::Atomic(Metadata::new(), Atom::Literal(x)))
757                    .collect_vec()
758            }),
759            _ => None,
760        }
761    }
762
763    /// If the expression is a matrix, gets it elements and index domain.
764    ///
765    /// **Consider using the safer [`Expression::unwrap_list`] instead.**
766    ///
767    /// It is generally undefined to edit the length of a matrix unless it is a list (as defined by
768    /// [`Expression::unwrap_list`]). Users of this function should ensure that, if the matrix is
769    /// reconstructed, the index domain and the number of elements in the matrix remain the same.
770    pub fn unwrap_matrix_unchecked(self) -> Option<(Vec<Expression>, Domain)> {
771        match self {
772            Expression::AbstractLiteral(_, AbstractLiteral::Matrix(elems, domain)) => {
773                Some((elems.clone(), domain))
774            }
775            Expression::Atomic(
776                _,
777                Atom::Literal(Literal::AbstractLiteral(AbstractLiteral::Matrix(elems, domain))),
778            ) => Some((
779                elems
780                    .clone()
781                    .into_iter()
782                    .map(|x: Literal| Expression::Atomic(Metadata::new(), Atom::Literal(x)))
783                    .collect_vec(),
784                domain,
785            )),
786
787            _ => None,
788        }
789    }
790
791    /// For a Root expression, extends the inner vec with the given vec.
792    ///
793    /// # Panics
794    /// Panics if the expression is not Root.
795    pub fn extend_root(self, exprs: Vec<Expression>) -> Expression {
796        match self {
797            Expression::Root(meta, mut children) => {
798                children.extend(exprs);
799                Expression::Root(meta, children)
800            }
801            _ => panic!("extend_root called on a non-Root expression"),
802        }
803    }
804
805    /// Converts the expression to a literal, if possible.
806    pub fn to_literal(self) -> Option<Literal> {
807        match self {
808            Expression::Atomic(_, Atom::Literal(lit)) => Some(lit),
809            Expression::AbstractLiteral(_, abslit) => {
810                Some(Literal::AbstractLiteral(abslit.clone().as_literals()?))
811            }
812            Expression::Neg(_, e) => {
813                let Literal::Int(i) = e.to_literal()? else {
814                    bug!("negated literal should be an int");
815                };
816
817                Some(Literal::Int(-i))
818            }
819
820            _ => None,
821        }
822    }
823}
824
825impl From<i32> for Expression {
826    fn from(i: i32) -> Self {
827        Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Int(i)))
828    }
829}
830
831impl From<bool> for Expression {
832    fn from(b: bool) -> Self {
833        Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Bool(b)))
834    }
835}
836
837impl From<Atom> for Expression {
838    fn from(value: Atom) -> Self {
839        Expression::Atomic(Metadata::new(), value)
840    }
841}
842
843impl From<Name> for Expression {
844    fn from(name: Name) -> Self {
845        Expression::Atomic(Metadata::new(), Atom::Reference(name))
846    }
847}
848
849impl From<Box<Expression>> for Expression {
850    fn from(val: Box<Expression>) -> Self {
851        val.as_ref().clone()
852    }
853}
854
855impl Display for Expression {
856    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
857        match &self {
858            Expression::Union(_, box1, box2) => {
859                write!(f, "({} union {})", box1.clone(), box2.clone())
860            }
861            Expression::Intersect(_, box1, box2) => {
862                write!(f, "({} intersect {})", box1.clone(), box2.clone())
863            }
864            Expression::Supset(_, box1, box2) => {
865                write!(f, "({} supset {})", box1.clone(), box2.clone())
866            }
867            Expression::SupsetEq(_, box1, box2) => {
868                write!(f, "({} supsetEq {})", box1.clone(), box2.clone())
869            }
870            Expression::Subset(_, box1, box2) => {
871                write!(f, "({} subset {})", box1.clone(), box2.clone())
872            }
873            Expression::SubsetEq(_, box1, box2) => {
874                write!(f, "({} subsetEq {})", box1.clone(), box2.clone())
875            }
876
877            Expression::AbstractLiteral(_, l) => l.fmt(f),
878            Expression::Comprehension(_, c) => c.fmt(f),
879            Expression::UnsafeIndex(_, e1, e2) | Expression::SafeIndex(_, e1, e2) => {
880                write!(f, "{e1}{}", pretty_vec(e2))
881            }
882            Expression::UnsafeSlice(_, e1, es) | Expression::SafeSlice(_, e1, es) => {
883                let args = es
884                    .iter()
885                    .map(|x| match x {
886                        Some(x) => format!("{}", x),
887                        None => "..".into(),
888                    })
889                    .join(",");
890
891                write!(f, "{e1}[{args}]")
892            }
893            Expression::InDomain(_, e, domain) => {
894                write!(f, "__inDomain({e},{domain})")
895            }
896            Expression::Root(_, exprs) => {
897                write!(f, "{}", pretty_expressions_as_top_level(exprs))
898            }
899            Expression::DominanceRelation(_, expr) => write!(f, "DominanceRelation({})", expr),
900            Expression::FromSolution(_, expr) => write!(f, "FromSolution({})", expr),
901            Expression::Atomic(_, atom) => atom.fmt(f),
902            Expression::Scope(_, submodel) => write!(f, "{{\n{submodel}\n}}"),
903            Expression::Abs(_, a) => write!(f, "|{}|", a),
904            Expression::Sum(_, e) => {
905                write!(f, "Sum({e})")
906            }
907            Expression::Product(_, expressions) => {
908                write!(f, "Product({})", pretty_vec(expressions))
909            }
910            Expression::Min(_, e) => {
911                write!(f, "min({e})")
912            }
913            Expression::Max(_, e) => {
914                write!(f, "max({e})")
915            }
916            Expression::Not(_, expr_box) => {
917                write!(f, "Not({})", expr_box.clone())
918            }
919            Expression::Or(_, e) => {
920                write!(f, "or({e})")
921            }
922            Expression::And(_, e) => {
923                write!(f, "and({e})")
924            }
925            Expression::Imply(_, box1, box2) => {
926                write!(f, "({}) -> ({})", box1, box2)
927            }
928            Expression::Iff(_, box1, box2) => {
929                write!(f, "({}) <-> ({})", box1, box2)
930            }
931            Expression::Eq(_, box1, box2) => {
932                write!(f, "({} = {})", box1.clone(), box2.clone())
933            }
934            Expression::Neq(_, box1, box2) => {
935                write!(f, "({} != {})", box1.clone(), box2.clone())
936            }
937            Expression::Geq(_, box1, box2) => {
938                write!(f, "({} >= {})", box1.clone(), box2.clone())
939            }
940            Expression::Leq(_, box1, box2) => {
941                write!(f, "({} <= {})", box1.clone(), box2.clone())
942            }
943            Expression::Gt(_, box1, box2) => {
944                write!(f, "({} > {})", box1.clone(), box2.clone())
945            }
946            Expression::Lt(_, box1, box2) => {
947                write!(f, "({} < {})", box1.clone(), box2.clone())
948            }
949            Expression::FlatSumGeq(_, box1, box2) => {
950                write!(f, "SumGeq({}, {})", pretty_vec(box1), box2.clone())
951            }
952            Expression::FlatSumLeq(_, box1, box2) => {
953                write!(f, "SumLeq({}, {})", pretty_vec(box1), box2.clone())
954            }
955            Expression::FlatIneq(_, box1, box2, box3) => write!(
956                f,
957                "Ineq({}, {}, {})",
958                box1.clone(),
959                box2.clone(),
960                box3.clone()
961            ),
962            Expression::AllDiff(_, e) => {
963                write!(f, "allDiff({e})")
964            }
965            Expression::Bubble(_, box1, box2) => {
966                write!(f, "{{{} @ {}}}", box1.clone(), box2.clone())
967            }
968            Expression::SafeDiv(_, box1, box2) => {
969                write!(f, "SafeDiv({}, {})", box1.clone(), box2.clone())
970            }
971            Expression::UnsafeDiv(_, box1, box2) => {
972                write!(f, "UnsafeDiv({}, {})", box1.clone(), box2.clone())
973            }
974            Expression::UnsafePow(_, box1, box2) => {
975                write!(f, "UnsafePow({}, {})", box1.clone(), box2.clone())
976            }
977            Expression::SafePow(_, box1, box2) => {
978                write!(f, "SafePow({}, {})", box1.clone(), box2.clone())
979            }
980            Expression::MinionDivEqUndefZero(_, box1, box2, box3) => {
981                write!(
982                    f,
983                    "DivEq({}, {}, {})",
984                    box1.clone(),
985                    box2.clone(),
986                    box3.clone()
987                )
988            }
989            Expression::MinionModuloEqUndefZero(_, box1, box2, box3) => {
990                write!(
991                    f,
992                    "ModEq({}, {}, {})",
993                    box1.clone(),
994                    box2.clone(),
995                    box3.clone()
996                )
997            }
998            Expression::FlatWatchedLiteral(_, x, l) => {
999                write!(f, "WatchedLiteral({},{})", x, l)
1000            }
1001            Expression::MinionReify(_, box1, box2) => {
1002                write!(f, "Reify({}, {})", box1.clone(), box2.clone())
1003            }
1004            Expression::MinionReifyImply(_, box1, box2) => {
1005                write!(f, "ReifyImply({}, {})", box1.clone(), box2.clone())
1006            }
1007            Expression::MinionWInIntervalSet(_, atom, intervals) => {
1008                let intervals = intervals.iter().join(",");
1009                write!(f, "__minion_w_inintervalset({atom},{intervals})")
1010            }
1011            Expression::AuxDeclaration(_, n, e) => {
1012                write!(f, "{} =aux {}", n, e.clone())
1013            }
1014            Expression::UnsafeMod(_, a, b) => {
1015                write!(f, "{} % {}", a.clone(), b.clone())
1016            }
1017            Expression::SafeMod(_, a, b) => {
1018                write!(f, "SafeMod({},{})", a.clone(), b.clone())
1019            }
1020            Expression::Neg(_, a) => {
1021                write!(f, "-({})", a.clone())
1022            }
1023            Expression::Minus(_, a, b) => {
1024                write!(f, "({} - {})", a.clone(), b.clone())
1025            }
1026            Expression::FlatAllDiff(_, es) => {
1027                write!(f, "__flat_alldiff({})", pretty_vec(es))
1028            }
1029            Expression::FlatAbsEq(_, a, b) => {
1030                write!(f, "AbsEq({},{})", a.clone(), b.clone())
1031            }
1032            Expression::FlatMinusEq(_, a, b) => {
1033                write!(f, "MinusEq({},{})", a.clone(), b.clone())
1034            }
1035            Expression::FlatProductEq(_, a, b, c) => {
1036                write!(
1037                    f,
1038                    "FlatProductEq({},{},{})",
1039                    a.clone(),
1040                    b.clone(),
1041                    c.clone()
1042                )
1043            }
1044            Expression::FlatWeightedSumLeq(_, cs, vs, total) => {
1045                write!(
1046                    f,
1047                    "FlatWeightedSumLeq({},{},{})",
1048                    pretty_vec(cs),
1049                    pretty_vec(vs),
1050                    total.clone()
1051                )
1052            }
1053            Expression::FlatWeightedSumGeq(_, cs, vs, total) => {
1054                write!(
1055                    f,
1056                    "FlatWeightedSumGeq({},{},{})",
1057                    pretty_vec(cs),
1058                    pretty_vec(vs),
1059                    total.clone()
1060                )
1061            }
1062            Expression::MinionPow(_, atom, atom1, atom2) => {
1063                write!(f, "MinionPow({},{},{})", atom, atom1, atom2)
1064            }
1065            Expression::MinionElementOne(_, atoms, atom, atom1) => {
1066                let atoms = atoms.iter().join(",");
1067                write!(f, "__minion_element_one([{atoms}],{atom},{atom1})")
1068            }
1069        }
1070    }
1071}
1072
1073impl Typeable for Expression {
1074    fn return_type(&self) -> Option<ReturnType> {
1075        match self {
1076            Expression::Union(_, subject, _) => {
1077                Some(ReturnType::Set(Box::new(subject.return_type()?)))
1078            }
1079            Expression::Intersect(_, subject, _) => {
1080                Some(ReturnType::Set(Box::new(subject.return_type()?)))
1081            }
1082            Expression::Supset(_, _, _) => Some(ReturnType::Bool),
1083            Expression::SupsetEq(_, _, _) => Some(ReturnType::Bool),
1084            Expression::Subset(_, _, _) => Some(ReturnType::Bool),
1085            Expression::SubsetEq(_, _, _) => Some(ReturnType::Bool),
1086
1087            // handles sets and matrices, since typeable defined for abstract literals
1088            Expression::AbstractLiteral(_, lit) => lit.return_type(),
1089            Expression::UnsafeIndex(_, subject, _) | Expression::SafeIndex(_, subject, _) => {
1090                Some(subject.return_type()?)
1091            }
1092            Expression::UnsafeSlice(_, subject, _) | Expression::SafeSlice(_, subject, _) => {
1093                Some(ReturnType::Matrix(Box::new(subject.return_type()?)))
1094            }
1095            Expression::InDomain(_, _, _) => Some(ReturnType::Bool),
1096            Expression::Comprehension(_, _) => None,
1097            Expression::Root(_, _) => Some(ReturnType::Bool),
1098            Expression::DominanceRelation(_, _) => Some(ReturnType::Bool),
1099            Expression::FromSolution(_, expr) => expr.return_type(),
1100            // handles integers, booleans and abstract literals all at once, since typeable defined for literal
1101            Expression::Atomic(_, Atom::Literal(lit)) => lit.return_type(),
1102            // TODO - access symbol table to get return type of references - define inside Atom instead
1103            Expression::Atomic(_, Atom::Reference(_)) => None,
1104            Expression::Scope(_, scope) => scope.return_type(),
1105            Expression::Abs(_, _) => Some(ReturnType::Int),
1106            Expression::Sum(_, _) => Some(ReturnType::Int),
1107            Expression::Product(_, _) => Some(ReturnType::Int),
1108            Expression::Min(_, _) => Some(ReturnType::Int),
1109            Expression::Max(_, _) => Some(ReturnType::Int),
1110            Expression::Not(_, _) => Some(ReturnType::Bool),
1111            Expression::Or(_, _) => Some(ReturnType::Bool),
1112            Expression::Imply(_, _, _) => Some(ReturnType::Bool),
1113            Expression::Iff(_, _, _) => Some(ReturnType::Bool),
1114            Expression::And(_, _) => Some(ReturnType::Bool),
1115            Expression::Eq(_, _, _) => Some(ReturnType::Bool),
1116            Expression::Neq(_, _, _) => Some(ReturnType::Bool),
1117            Expression::Geq(_, _, _) => Some(ReturnType::Bool),
1118            Expression::Leq(_, _, _) => Some(ReturnType::Bool),
1119            Expression::Gt(_, _, _) => Some(ReturnType::Bool),
1120            Expression::Lt(_, _, _) => Some(ReturnType::Bool),
1121            Expression::SafeDiv(_, _, _) => Some(ReturnType::Int),
1122            Expression::UnsafeDiv(_, _, _) => Some(ReturnType::Int),
1123            Expression::FlatAllDiff(_, _) => Some(ReturnType::Bool),
1124            Expression::FlatSumGeq(_, _, _) => Some(ReturnType::Bool),
1125            Expression::FlatSumLeq(_, _, _) => Some(ReturnType::Bool),
1126            Expression::MinionDivEqUndefZero(_, _, _, _) => Some(ReturnType::Bool),
1127            Expression::FlatIneq(_, _, _, _) => Some(ReturnType::Bool),
1128            Expression::AllDiff(_, _) => Some(ReturnType::Bool),
1129            Expression::Bubble(_, _, _) => None, // TODO: (flm8) should this be a bool?
1130            Expression::FlatWatchedLiteral(_, _, _) => Some(ReturnType::Bool),
1131            Expression::MinionReify(_, _, _) => Some(ReturnType::Bool),
1132            Expression::MinionReifyImply(_, _, _) => Some(ReturnType::Bool),
1133            Expression::MinionWInIntervalSet(_, _, _) => Some(ReturnType::Bool),
1134            Expression::MinionElementOne(_, _, _, _) => Some(ReturnType::Bool),
1135            Expression::AuxDeclaration(_, _, _) => Some(ReturnType::Bool),
1136            Expression::UnsafeMod(_, _, _) => Some(ReturnType::Int),
1137            Expression::SafeMod(_, _, _) => Some(ReturnType::Int),
1138            Expression::MinionModuloEqUndefZero(_, _, _, _) => Some(ReturnType::Bool),
1139            Expression::Neg(_, _) => Some(ReturnType::Int),
1140            Expression::UnsafePow(_, _, _) => Some(ReturnType::Int),
1141            Expression::SafePow(_, _, _) => Some(ReturnType::Int),
1142            Expression::Minus(_, _, _) => Some(ReturnType::Int),
1143            Expression::FlatAbsEq(_, _, _) => Some(ReturnType::Bool),
1144            Expression::FlatMinusEq(_, _, _) => Some(ReturnType::Bool),
1145            Expression::FlatProductEq(_, _, _, _) => Some(ReturnType::Bool),
1146            Expression::FlatWeightedSumLeq(_, _, _, _) => Some(ReturnType::Bool),
1147            Expression::FlatWeightedSumGeq(_, _, _, _) => Some(ReturnType::Bool),
1148            Expression::MinionPow(_, _, _, _) => Some(ReturnType::Bool),
1149        }
1150    }
1151}
1152
1153#[cfg(test)]
1154mod tests {
1155    use std::rc::Rc;
1156
1157    use crate::{ast::declaration::Declaration, matrix_expr};
1158
1159    use super::*;
1160
1161    #[test]
1162    fn test_domain_of_constant_sum() {
1163        let c1 = Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Int(1)));
1164        let c2 = Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Int(2)));
1165        let sum = Expression::Sum(
1166            Metadata::new(),
1167            Box::new(matrix_expr![c1.clone(), c2.clone()]),
1168        );
1169        assert_eq!(
1170            sum.domain_of(&SymbolTable::new()),
1171            Some(Domain::IntDomain(vec![Range::Single(3)]))
1172        );
1173    }
1174
1175    #[test]
1176    fn test_domain_of_constant_invalid_type() {
1177        let c1 = Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Int(1)));
1178        let c2 = Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Bool(true)));
1179        let sum = Expression::Sum(
1180            Metadata::new(),
1181            Box::new(matrix_expr![c1.clone(), c2.clone()]),
1182        );
1183        assert_eq!(sum.domain_of(&SymbolTable::new()), None);
1184    }
1185
1186    #[test]
1187    fn test_domain_of_empty_sum() {
1188        let sum = Expression::Sum(Metadata::new(), Box::new(matrix_expr![]));
1189        assert_eq!(sum.domain_of(&SymbolTable::new()), None);
1190    }
1191
1192    #[test]
1193    fn test_domain_of_reference() {
1194        let reference = Expression::Atomic(Metadata::new(), Atom::Reference(Name::MachineName(0)));
1195        let mut vars = SymbolTable::new();
1196        vars.insert(Rc::new(Declaration::new_var(
1197            Name::MachineName(0),
1198            Domain::IntDomain(vec![Range::Single(1)]),
1199        )))
1200        .unwrap();
1201        assert_eq!(
1202            reference.domain_of(&vars),
1203            Some(Domain::IntDomain(vec![Range::Single(1)]))
1204        );
1205    }
1206
1207    #[test]
1208    fn test_domain_of_reference_not_found() {
1209        let reference = Expression::Atomic(Metadata::new(), Atom::Reference(Name::MachineName(0)));
1210        assert_eq!(reference.domain_of(&SymbolTable::new()), None);
1211    }
1212
1213    #[test]
1214    fn test_domain_of_reference_sum_single() {
1215        let reference = Expression::Atomic(Metadata::new(), Atom::Reference(Name::MachineName(0)));
1216        let mut vars = SymbolTable::new();
1217        vars.insert(Rc::new(Declaration::new_var(
1218            Name::MachineName(0),
1219            Domain::IntDomain(vec![Range::Single(1)]),
1220        )))
1221        .unwrap();
1222        let sum = Expression::Sum(
1223            Metadata::new(),
1224            Box::new(matrix_expr![reference.clone(), reference.clone()]),
1225        );
1226        assert_eq!(
1227            sum.domain_of(&vars),
1228            Some(Domain::IntDomain(vec![Range::Single(2)]))
1229        );
1230    }
1231
1232    #[test]
1233    fn test_domain_of_reference_sum_bounded() {
1234        let reference = Expression::Atomic(Metadata::new(), Atom::Reference(Name::MachineName(0)));
1235        let mut vars = SymbolTable::new();
1236        vars.insert(Rc::new(Declaration::new_var(
1237            Name::MachineName(0),
1238            Domain::IntDomain(vec![Range::Bounded(1, 2)]),
1239        )));
1240        let sum = Expression::Sum(
1241            Metadata::new(),
1242            Box::new(matrix_expr![reference.clone(), reference.clone()]),
1243        );
1244        assert_eq!(
1245            sum.domain_of(&vars),
1246            Some(Domain::IntDomain(vec![Range::Bounded(2, 4)]))
1247        );
1248    }
1249}