conjure_cp_core/ast/
expressions.rs

1use std::collections::{HashSet, VecDeque};
2use std::fmt::{Display, Formatter};
3use tracing::trace;
4
5use crate::ast::ReturnType;
6use crate::ast::SetAttr;
7use crate::ast::literals::AbstractLiteral;
8use crate::ast::literals::Literal;
9use crate::ast::pretty::{pretty_expressions_as_top_level, pretty_vec};
10use crate::ast::{Atom, DomainPtr};
11use crate::ast::{GroundDomain, Metadata, UnresolvedDomain};
12use crate::ast::{IntVal, Moo};
13use crate::ast::{Name, matrix};
14use crate::bug;
15use conjure_cp_enum_compatibility_macro::document_compatibility;
16use itertools::Itertools;
17use serde::{Deserialize, Serialize};
18use ustr::Ustr;
19
20use polyquine::Quine;
21use uniplate::{Biplate, Uniplate};
22
23use super::ac_operators::ACOperatorKind;
24use super::categories::{Category, CategoryOf};
25use super::comprehension::Comprehension;
26use super::domains::HasDomain as _;
27use super::records::RecordValue;
28use super::{DeclarationPtr, Domain, Range, Reference, SubModel, Typeable};
29
30// Ensure that this type doesn't get too big
31//
32// If you triggered this assertion, you either made a variant of this enum that is too big, or you
33// made Name,Literal,AbstractLiteral,Atom bigger, which made this bigger! To fix this, put some
34// stuff in boxes.
35//
36// Enums take the size of their largest variant, so an enum with mostly small variants and a few
37// large ones wastes memory... A larger Expression type also slows down Oxide.
38//
39// For more information, and more details on type sizes and how to measure them, see the commit
40// message for 6012de809 (perf: reduce size of AST types, 2025-06-18).
41//
42// You can also see type sizes in the rustdoc documentation, generated by ./tools/gen_docs.sh
43//
44// https://github.com/conjure-cp/conjure-oxide/commit/6012de8096ca491ded91ecec61352fdf4e994f2e
45
46// TODO: box all usages of Metadata to bring this down a bit more - I have added variants to
47// ReturnType, and Metadata contains ReturnType, so Metadata has got bigger. Metadata will get a
48// lot bigger still when we start using it for memoisation, so it should really be
49// boxed ~niklasdewally
50
51// expect size of Expression to be 112 bytes
52static_assertions::assert_eq_size!([u8; 104], Expression);
53
54/// Represents different types of expressions used to define rules and constraints in the model.
55///
56/// The `Expression` enum includes operations, constants, and variable references
57/// used to build rules and conditions for the model.
58#[document_compatibility]
59#[derive(Clone, Debug, Hash, PartialEq, Eq, Serialize, Deserialize, Uniplate, Quine)]
60#[biplate(to=Metadata)]
61#[biplate(to=Atom)]
62#[biplate(to=DeclarationPtr)]
63#[biplate(to=Name)]
64#[biplate(to=Reference)]
65#[biplate(to=Vec<Expression>)]
66#[biplate(to=Option<Expression>)]
67#[biplate(to=SubModel)]
68#[biplate(to=Comprehension)]
69#[biplate(to=AbstractLiteral<Expression>)]
70#[biplate(to=AbstractLiteral<Literal>)]
71#[biplate(to=RecordValue<Expression>)]
72#[biplate(to=RecordValue<Literal>)]
73#[biplate(to=Literal)]
74#[biplate(to=DomainPtr)]
75#[path_prefix(conjure_cp::ast)]
76pub enum Expression {
77    AbstractLiteral(Metadata, AbstractLiteral<Expression>),
78    /// The top of the model
79    Root(Metadata, Vec<Expression>),
80
81    /// An expression representing "A is valid as long as B is true"
82    /// Turns into a conjunction when it reaches a boolean context
83    Bubble(Metadata, Moo<Expression>, Moo<Expression>),
84
85    /// A comprehension.
86    ///
87    /// The inside of the comprehension opens a new scope.
88    // todo (gskorokhod): Comprehension contains a SubModel which contains a bunch of Rc pointers.
89    // This makes implementing Quine tricky (it doesnt support Rc, by design). Skip it for now.
90    #[polyquine_skip]
91    Comprehension(Metadata, Moo<Comprehension>),
92
93    /// Defines dominance ("Solution A is preferred over Solution B")
94    DominanceRelation(Metadata, Moo<Expression>),
95    /// `fromSolution(name)` - Used in dominance relation definitions
96    FromSolution(Metadata, Moo<Atom>),
97
98    #[polyquine_with(arm = (_, name) => {
99        let ident = proc_macro2::Ident::new(name.as_str(), proc_macro2::Span::call_site());
100        quote::quote! { #ident.clone().into() }
101    })]
102    Metavar(Metadata, Ustr),
103
104    Atomic(Metadata, Atom),
105
106    /// A matrix index.
107    ///
108    /// Defined iff the indices are within their respective index domains.
109    #[compatible(JsonInput)]
110    UnsafeIndex(Metadata, Moo<Expression>, Vec<Expression>),
111
112    /// A safe matrix index.
113    ///
114    /// See [`Expression::UnsafeIndex`]
115    #[compatible(SMT)]
116    SafeIndex(Metadata, Moo<Expression>, Vec<Expression>),
117
118    /// A matrix slice: `a[indices]`.
119    ///
120    /// One of the indicies may be `None`, representing the dimension of the matrix we want to take
121    /// a slice of. For example, for some 3d matrix a, `a[1,..,2]` has the indices
122    /// `Some(1),None,Some(2)`.
123    ///
124    /// It is assumed that the slice only has one "wild-card" dimension and thus is 1 dimensional.
125    ///
126    /// Defined iff the defined indices are within their respective index domains.
127    #[compatible(JsonInput)]
128    UnsafeSlice(Metadata, Moo<Expression>, Vec<Option<Expression>>),
129
130    /// A safe matrix slice: `a[indices]`.
131    ///
132    /// See [`Expression::UnsafeSlice`].
133    SafeSlice(Metadata, Moo<Expression>, Vec<Option<Expression>>),
134
135    /// `inDomain(x,domain)` iff `x` is in the domain `domain`.
136    ///
137    /// This cannot be constructed from Essence input, nor passed to a solver: this expression is
138    /// mainly used during the conversion of `UnsafeIndex` and `UnsafeSlice` to `SafeIndex` and
139    /// `SafeSlice` respectively.
140    InDomain(Metadata, Moo<Expression>, DomainPtr),
141
142    /// `toInt(b)` casts boolean expression b to an integer.
143    ///
144    /// - If b is false, then `toInt(b) == 0`
145    ///
146    /// - If b is true, then `toInt(b) == 1`
147    #[compatible(SMT)]
148    ToInt(Metadata, Moo<Expression>),
149
150    // todo (gskorokhod): Same reason as for Comprehension
151    #[polyquine_skip]
152    Scope(Metadata, Moo<SubModel>),
153
154    /// `|x|` - absolute value of `x`
155    #[compatible(JsonInput, SMT)]
156    Abs(Metadata, Moo<Expression>),
157
158    /// `sum(<vec_expr>)`
159    #[compatible(JsonInput, SMT)]
160    Sum(Metadata, Moo<Expression>),
161
162    /// `a * b * c * ...`
163    #[compatible(JsonInput, SMT)]
164    Product(Metadata, Moo<Expression>),
165
166    /// `min(<vec_expr>)`
167    #[compatible(JsonInput, SMT)]
168    Min(Metadata, Moo<Expression>),
169
170    /// `max(<vec_expr>)`
171    #[compatible(JsonInput, SMT)]
172    Max(Metadata, Moo<Expression>),
173
174    /// `not(a)`
175    #[compatible(JsonInput, SAT, SMT)]
176    Not(Metadata, Moo<Expression>),
177
178    /// `or(<vec_expr>)`
179    #[compatible(JsonInput, SAT, SMT)]
180    Or(Metadata, Moo<Expression>),
181
182    /// `and(<vec_expr>)`
183    #[compatible(JsonInput, SAT, SMT)]
184    And(Metadata, Moo<Expression>),
185
186    /// Ensures that `a->b` (material implication).
187    #[compatible(JsonInput, SMT)]
188    Imply(Metadata, Moo<Expression>, Moo<Expression>),
189
190    /// `iff(a, b)` a <-> b
191    #[compatible(JsonInput, SMT)]
192    Iff(Metadata, Moo<Expression>, Moo<Expression>),
193
194    #[compatible(JsonInput)]
195    Union(Metadata, Moo<Expression>, Moo<Expression>),
196
197    #[compatible(JsonInput)]
198    In(Metadata, Moo<Expression>, Moo<Expression>),
199
200    #[compatible(JsonInput)]
201    Intersect(Metadata, Moo<Expression>, Moo<Expression>),
202
203    #[compatible(JsonInput)]
204    Supset(Metadata, Moo<Expression>, Moo<Expression>),
205
206    #[compatible(JsonInput)]
207    SupsetEq(Metadata, Moo<Expression>, Moo<Expression>),
208
209    #[compatible(JsonInput)]
210    Subset(Metadata, Moo<Expression>, Moo<Expression>),
211
212    #[compatible(JsonInput)]
213    SubsetEq(Metadata, Moo<Expression>, Moo<Expression>),
214
215    #[compatible(JsonInput, SMT)]
216    Eq(Metadata, Moo<Expression>, Moo<Expression>),
217
218    #[compatible(JsonInput, SMT)]
219    Neq(Metadata, Moo<Expression>, Moo<Expression>),
220
221    #[compatible(JsonInput, SMT)]
222    Geq(Metadata, Moo<Expression>, Moo<Expression>),
223
224    #[compatible(JsonInput, SMT)]
225    Leq(Metadata, Moo<Expression>, Moo<Expression>),
226
227    #[compatible(JsonInput, SMT)]
228    Gt(Metadata, Moo<Expression>, Moo<Expression>),
229
230    #[compatible(JsonInput, SMT)]
231    Lt(Metadata, Moo<Expression>, Moo<Expression>),
232
233    /// Division after preventing division by zero, usually with a bubble
234    #[compatible(SMT)]
235    SafeDiv(Metadata, Moo<Expression>, Moo<Expression>),
236
237    /// Division with a possibly undefined value (division by 0)
238    #[compatible(JsonInput)]
239    UnsafeDiv(Metadata, Moo<Expression>, Moo<Expression>),
240
241    /// Modulo after preventing mod 0, usually with a bubble
242    #[compatible(SMT)]
243    SafeMod(Metadata, Moo<Expression>, Moo<Expression>),
244
245    /// Modulo with a possibly undefined value (mod 0)
246    #[compatible(JsonInput)]
247    UnsafeMod(Metadata, Moo<Expression>, Moo<Expression>),
248
249    /// Negation: `-x`
250    #[compatible(JsonInput, SMT)]
251    Neg(Metadata, Moo<Expression>),
252
253    /// Set of domain values function is defined for
254    #[compatible(JsonInput)]
255    Defined(Metadata, Moo<Expression>),
256
257    /// Set of codomain values function is defined for
258    #[compatible(JsonInput)]
259    Range(Metadata, Moo<Expression>),
260
261    /// Unsafe power`x**y` (possibly undefined)
262    ///
263    /// Defined when (X!=0 \\/ Y!=0) /\ Y>=0
264    #[compatible(JsonInput)]
265    UnsafePow(Metadata, Moo<Expression>, Moo<Expression>),
266
267    /// `UnsafePow` after preventing undefinedness
268    SafePow(Metadata, Moo<Expression>, Moo<Expression>),
269
270    /// Flatten matrix operator
271    /// `flatten(M)` or `flatten(n, M)`
272    /// where M is a matrix and n is an optional integer argument indicating depth of flattening
273    Flatten(Metadata, Option<Moo<Expression>>, Moo<Expression>),
274
275    /// `allDiff(<vec_expr>)`
276    #[compatible(JsonInput)]
277    AllDiff(Metadata, Moo<Expression>),
278
279    /// Binary subtraction operator
280    ///
281    /// This is a parser-level construct, and is immediately normalised to `Sum([a,-b])`.
282    /// TODO: make this compatible with Set Difference calculations - need to change return type and domain for this expression and write a set comprehension rule.
283    /// have already edited minus_to_sum to prevent this from applying to sets
284    #[compatible(JsonInput)]
285    Minus(Metadata, Moo<Expression>, Moo<Expression>),
286
287    /// Ensures that x=|y| i.e. x is the absolute value of y.
288    ///
289    /// Low-level Minion constraint.
290    ///
291    /// # See also
292    ///
293    /// + [Minion documentation](https://minion-solver.readthedocs.io/en/stable/usage/constraints.html#abs)
294    #[compatible(Minion)]
295    FlatAbsEq(Metadata, Moo<Atom>, Moo<Atom>),
296
297    /// Ensures that `alldiff([a,b,...])`.
298    ///
299    /// Low-level Minion constraint.
300    ///
301    /// # See also
302    ///
303    /// + [Minion documentation](https://minion-solver.readthedocs.io/en/stable/usage/constraints.html#alldiff)
304    #[compatible(Minion)]
305    FlatAllDiff(Metadata, Vec<Atom>),
306
307    /// Ensures that sum(vec) >= x.
308    ///
309    /// Low-level Minion constraint.
310    ///
311    /// # See also
312    ///
313    /// + [Minion documentation](https://minion-solver.readthedocs.io/en/stable/usage/constraints.html#sumgeq)
314    #[compatible(Minion)]
315    FlatSumGeq(Metadata, Vec<Atom>, Atom),
316
317    /// Ensures that sum(vec) <= x.
318    ///
319    /// Low-level Minion constraint.
320    ///
321    /// # See also
322    ///
323    /// + [Minion documentation](https://minion-solver.readthedocs.io/en/stable/usage/constraints.html#sumleq)
324    #[compatible(Minion)]
325    FlatSumLeq(Metadata, Vec<Atom>, Atom),
326
327    /// `ineq(x,y,k)` ensures that x <= y + k.
328    ///
329    /// Low-level Minion constraint.
330    ///
331    /// # See also
332    ///
333    /// + [Minion documentation](https://minion-solver.readthedocs.io/en/stable/usage/constraints.html#ineq)
334    #[compatible(Minion)]
335    FlatIneq(Metadata, Moo<Atom>, Moo<Atom>, Box<Literal>),
336
337    /// `w-literal(x,k)` ensures that x == k, where x is a variable and k a constant.
338    ///
339    /// Low-level Minion constraint.
340    ///
341    /// This is a low-level Minion constraint and you should probably use Eq instead. The main use
342    /// of w-literal is to convert boolean variables to constraints so that they can be used inside
343    /// watched-and and watched-or.
344    ///
345    /// # See also
346    ///
347    /// + [Minion documentation](https://minion-solver.readthedocs.io/en/stable/usage/constraints.html#minuseq)
348    /// + `rules::minion::boolean_literal_to_wliteral`.
349    #[compatible(Minion)]
350    #[polyquine_skip]
351    FlatWatchedLiteral(Metadata, Reference, Literal),
352
353    /// `weightedsumleq(cs,xs,total)` ensures that cs.xs <= total, where cs.xs is the scalar dot
354    /// product of cs and xs.
355    ///
356    /// Low-level Minion constraint.
357    ///
358    /// Represents a weighted sum of the form `ax + by + cz + ...`
359    ///
360    /// # See also
361    ///
362    /// + [Minion
363    /// documentation](https://minion-solver.readthedocs.io/en/stable/usage/constraints.html#weightedsumleq)
364    FlatWeightedSumLeq(Metadata, Vec<Literal>, Vec<Atom>, Moo<Atom>),
365
366    /// `weightedsumgeq(cs,xs,total)` ensures that cs.xs >= total, where cs.xs is the scalar dot
367    /// product of cs and xs.
368    ///
369    /// Low-level Minion constraint.
370    ///
371    /// Represents a weighted sum of the form `ax + by + cz + ...`
372    ///
373    /// # See also
374    ///
375    /// + [Minion
376    /// documentation](https://minion-solver.readthedocs.io/en/stable/usage/constraints.html#weightedsumleq)
377    FlatWeightedSumGeq(Metadata, Vec<Literal>, Vec<Atom>, Moo<Atom>),
378
379    /// Ensures that x =-y, where x and y are atoms.
380    ///
381    /// Low-level Minion constraint.
382    ///
383    /// # See also
384    ///
385    /// + [Minion documentation](https://minion-solver.readthedocs.io/en/stable/usage/constraints.html#minuseq)
386    #[compatible(Minion)]
387    FlatMinusEq(Metadata, Moo<Atom>, Moo<Atom>),
388
389    /// Ensures that x*y=z.
390    ///
391    /// Low-level Minion constraint.
392    ///
393    /// # See also
394    ///
395    /// + [Minion documentation](https://minion-solver.readthedocs.io/en/stable/usage/constraints.html#product)
396    #[compatible(Minion)]
397    FlatProductEq(Metadata, Moo<Atom>, Moo<Atom>, Moo<Atom>),
398
399    /// Ensures that floor(x/y)=z. Always true when y=0.
400    ///
401    /// Low-level Minion constraint.
402    ///
403    /// # See also
404    ///
405    /// + [Minion documentation](https://minion-solver.readthedocs.io/en/stable/usage/constraints.html#div_undefzero)
406    #[compatible(Minion)]
407    MinionDivEqUndefZero(Metadata, Moo<Atom>, Moo<Atom>, Moo<Atom>),
408
409    /// Ensures that x%y=z. Always true when y=0.
410    ///
411    /// Low-level Minion constraint.
412    ///
413    /// # See also
414    ///
415    /// + [Minion documentation](https://minion-solver.readthedocs.io/en/stable/usage/constraints.html#mod_undefzero)
416    #[compatible(Minion)]
417    MinionModuloEqUndefZero(Metadata, Moo<Atom>, Moo<Atom>, Moo<Atom>),
418
419    /// Ensures that `x**y = z`.
420    ///
421    /// Low-level Minion constraint.
422    ///
423    /// This constraint is false when `y<0` except for `1**y=1` and `(-1)**y=z` (where z is 1 if y
424    /// is odd and z is -1 if y is even).
425    ///
426    /// # See also
427    ///
428    /// + [Github comment about `pow` semantics](https://github.com/minion/minion/issues/40#issuecomment-2595914891)
429    /// + [Minion documentation](https://minion-solver.readthedocs.io/en/stable/usage/constraints.html#pow)
430    MinionPow(Metadata, Moo<Atom>, Moo<Atom>, Moo<Atom>),
431
432    /// `reify(constraint,r)` ensures that r=1 iff `constraint` is satisfied, where r is a 0/1
433    /// variable.
434    ///
435    /// Low-level Minion constraint.
436    ///
437    /// # See also
438    ///
439    ///  + [Minion documentation](https://minion-solver.readthedocs.io/en/stable/usage/constraints.html#reify)
440    #[compatible(Minion)]
441    MinionReify(Metadata, Moo<Expression>, Atom),
442
443    /// `reifyimply(constraint,r)` ensures that `r->constraint`, where r is a 0/1 variable.
444    /// variable.
445    ///
446    /// Low-level Minion constraint.
447    ///
448    /// # See also
449    ///
450    ///  + [Minion documentation](https://minion-solver.readthedocs.io/en/stable/usage/constraints.html#reifyimply)
451    #[compatible(Minion)]
452    MinionReifyImply(Metadata, Moo<Expression>, Atom),
453
454    /// `w-inintervalset(x, [a1,a2, b1,b2, … ])` ensures that the value of x belongs to one of the
455    /// intervals {a1,…,a2}, {b1,…,b2} etc.
456    ///
457    /// The list of intervals must be given in numerical order.
458    ///
459    /// Low-level Minion constraint.
460    ///
461    /// # See also
462    ///>
463    ///  + [Minion documentation](https://minion-solver.readthedocs.io/en/stable/usage/constraints.html#w-inintervalset)
464    #[compatible(Minion)]
465    MinionWInIntervalSet(Metadata, Atom, Vec<i32>),
466
467    /// `w-inset(x, [v1, v2, … ])` ensures that the value of `x` is one of the explicitly given values `v1`, `v2`, etc.
468    ///
469    /// This constraint enforces membership in a specific set of discrete values rather than intervals.
470    ///
471    /// The list of values must be given in numerical order.
472    ///
473    /// Low-level Minion constraint.
474    ///
475    /// # See also
476    ///
477    ///  + [Minion documentation](https://minion-solver.readthedocs.io/en/stable/usage/constraints.html#w-inset)
478    #[compatible(Minion)]
479    MinionWInSet(Metadata, Atom, Vec<i32>),
480
481    /// `element_one(vec, i, e)` specifies that `vec[i] = e`. This implies that i is
482    /// in the range `[1..len(vec)]`.
483    ///
484    /// Low-level Minion constraint.
485    ///
486    /// # See also
487    ///
488    ///  + [Minion documentation](https://minion-solver.readthedocs.io/en/stable/usage/constraints.html#element_one)
489    #[compatible(Minion)]
490    MinionElementOne(Metadata, Vec<Atom>, Moo<Atom>, Moo<Atom>),
491
492    /// Declaration of an auxiliary variable.
493    ///
494    /// As with Savile Row, we semantically distinguish this from `Eq`.
495    #[compatible(Minion)]
496    #[polyquine_skip]
497    AuxDeclaration(Metadata, Reference, Moo<Expression>),
498
499    /// This expression is for encoding i32 ints as a vector of boolean expressions for cnf - using 2s complement
500    #[compatible(SAT)]
501    SATInt(Metadata, Moo<Expression>),
502
503    /// Addition over a pair of expressions (i.e. a + b) rather than a vec-expr like Expression::Sum.
504    /// This is for compatibility with backends that do not support addition over vectors.
505    #[compatible(SMT)]
506    PairwiseSum(Metadata, Moo<Expression>, Moo<Expression>),
507
508    /// Multiplication over a pair of expressions (i.e. a * b) rather than a vec-expr like Expression::Product.
509    /// This is for compatibility with backends that do not support multiplication over vectors.
510    #[compatible(SMT)]
511    PairwiseProduct(Metadata, Moo<Expression>, Moo<Expression>),
512
513    #[compatible(JsonInput)]
514    Image(Metadata, Moo<Expression>, Moo<Expression>),
515
516    #[compatible(JsonInput)]
517    ImageSet(Metadata, Moo<Expression>, Moo<Expression>),
518
519    #[compatible(JsonInput)]
520    PreImage(Metadata, Moo<Expression>, Moo<Expression>),
521
522    #[compatible(JsonInput)]
523    Inverse(Metadata, Moo<Expression>, Moo<Expression>),
524
525    #[compatible(JsonInput)]
526    Restrict(Metadata, Moo<Expression>, Moo<Expression>),
527
528    /// Lexicographical < between two matrices.
529    ///
530    /// A <lex B iff: A[i] < B[i] for some i /\ (A[j] > B[j] for some j -> i < j)
531    /// I.e. A must be less than B at some index i, and if it is greater than B at another index j,
532    /// then j comes after i.
533    /// I.e. A must be greater than B at the first index where they differ.
534    ///
535    /// E.g. [1, 1] <lex [2, 1] and [1, 1] <lex [1, 2]
536    LexLt(Metadata, Moo<Expression>, Moo<Expression>),
537
538    /// Lexicographical <= between two matrices
539    LexLeq(Metadata, Moo<Expression>, Moo<Expression>),
540
541    /// Lexicographical > between two matrices
542    /// This is a parser-level construct, and is immediately normalised to LexLt(b, a)
543    LexGt(Metadata, Moo<Expression>, Moo<Expression>),
544
545    /// Lexicographical >= between two matrices
546    /// This is a parser-level construct, and is immediately normalised to LexLeq(b, a)
547    LexGeq(Metadata, Moo<Expression>, Moo<Expression>),
548
549    /// Low-level minion constraint. See Expression::LexLt
550    FlatLexLt(Metadata, Vec<Atom>, Vec<Atom>),
551
552    /// Low-level minion constraint. See Expression::LexLeq
553    FlatLexLeq(Metadata, Vec<Atom>, Vec<Atom>),
554}
555
556// for the given matrix literal, return a bounded domain from the min to max of applying op to each
557// child expression.
558//
559// Op must be monotonic.
560//
561// Returns none if unbounded
562fn bounded_i32_domain_for_matrix_literal_monotonic(
563    e: &Expression,
564    op: fn(i32, i32) -> Option<i32>,
565) -> Option<DomainPtr> {
566    // only care about the elements, not the indices
567    let (mut exprs, _) = e.clone().unwrap_matrix_unchecked()?;
568
569    // fold each element's domain into one using op.
570    //
571    // here, I assume that op is monotone. This means that the bounds of op([a1,a2],[b1,b2])  for
572    // the ranges [a1,a2], [b1,b2] will be
573    // [min(op(a1,b1),op(a2,b1),op(a1,b2),op(a2,b2)),max(op(a1,b1),op(a2,b1),op(a1,b2),op(a2,b2))].
574    //
575    // We used to not assume this, and work out the bounds by applying op on the Cartesian product
576    // of A and B; however, this caused a combinatorial explosion and my computer to run out of
577    // memory (on the hakank_eprime_xkcd test)...
578    //Int
579    // For example, to find the bounds of the intervals [1,4], [1,5] combined using op, we used to do
580    //  [min(op(1,1), op(1,2),op(1,3),op(1,4),op(1,5),op(2,1)..
581    //
582    // +,-,/,* are all monotone, so this assumption should be fine for now...
583
584    let expr = exprs.pop()?;
585    let dom = expr.domain_of()?;
586    let Some(GroundDomain::Int(ranges)) = dom.as_ground() else {
587        return None;
588    };
589
590    let (mut current_min, mut current_max) = range_vec_bounds_i32(ranges)?;
591
592    for expr in exprs {
593        let dom = expr.domain_of()?;
594        let Some(GroundDomain::Int(ranges)) = dom.as_ground() else {
595            return None;
596        };
597
598        let (min, max) = range_vec_bounds_i32(ranges)?;
599
600        // all the possible new values for current_min / current_max
601        let minmax = op(min, current_max)?;
602        let minmin = op(min, current_min)?;
603        let maxmin = op(max, current_min)?;
604        let maxmax = op(max, current_max)?;
605        let vals = [minmax, minmin, maxmin, maxmax];
606
607        current_min = *vals
608            .iter()
609            .min()
610            .expect("vals iterator should not be empty, and should have a minimum.");
611        current_max = *vals
612            .iter()
613            .max()
614            .expect("vals iterator should not be empty, and should have a maximum.");
615    }
616
617    if current_min == current_max {
618        Some(Domain::int(vec![Range::Single(current_min)]))
619    } else {
620        Some(Domain::int(vec![Range::Bounded(current_min, current_max)]))
621    }
622}
623
624// Returns none if unbounded
625fn range_vec_bounds_i32(ranges: &Vec<Range<i32>>) -> Option<(i32, i32)> {
626    let mut min = i32::MAX;
627    let mut max = i32::MIN;
628    for r in ranges {
629        match r {
630            Range::Single(i) => {
631                if *i < min {
632                    min = *i;
633                }
634                if *i > max {
635                    max = *i;
636                }
637            }
638            Range::Bounded(i, j) => {
639                if *i < min {
640                    min = *i;
641                }
642                if *j > max {
643                    max = *j;
644                }
645            }
646            Range::UnboundedR(_) | Range::UnboundedL(_) | Range::Unbounded => return None,
647        }
648    }
649    Some((min, max))
650}
651
652impl Expression {
653    /// Returns the possible values of the expression, recursing to leaf expressions
654    pub fn domain_of(&self) -> Option<DomainPtr> {
655        //println!("domain_of {self}");
656        let ret = match self {
657            Expression::Union(_, a, b) => Some(Domain::set(
658                SetAttr::<IntVal>::default(),
659                a.domain_of()?.union(&b.domain_of()?).ok()?,
660            )),
661            Expression::Intersect(_, a, b) => Some(Domain::set(
662                SetAttr::<IntVal>::default(),
663                a.domain_of()?.intersect(&b.domain_of()?).ok()?,
664            )),
665            Expression::In(_, _, _) => Some(Domain::bool()),
666            Expression::Supset(_, _, _) => Some(Domain::bool()),
667            Expression::SupsetEq(_, _, _) => Some(Domain::bool()),
668            Expression::Subset(_, _, _) => Some(Domain::bool()),
669            Expression::SubsetEq(_, _, _) => Some(Domain::bool()),
670            Expression::AbstractLiteral(_, abslit) => abslit.domain_of(),
671            Expression::DominanceRelation(_, _) => Some(Domain::bool()),
672            Expression::FromSolution(_, expr) => Some(expr.domain_of()),
673            Expression::Metavar(_, _) => None,
674            Expression::Comprehension(_, comprehension) => comprehension.domain_of(),
675            Expression::UnsafeIndex(_, matrix, _) | Expression::SafeIndex(_, matrix, _) => {
676                let dom = matrix.domain_of()?;
677                if let Some((elem_domain, _)) = dom.as_matrix() {
678                    return Some(elem_domain);
679                }
680
681                // may actually use the value in the future
682                #[allow(clippy::redundant_pattern_matching)]
683                if let Some(_) = dom.as_tuple() {
684                    // TODO: We can implement proper indexing for tuples
685                    return None;
686                }
687
688                // may actually use the value in the future
689                #[allow(clippy::redundant_pattern_matching)]
690                if let Some(_) = dom.as_record() {
691                    // TODO: We can implement proper indexing for records
692                    return None;
693                }
694
695                bug!("subject of an index operation should support indexing")
696            }
697            Expression::UnsafeSlice(_, matrix, indices)
698            | Expression::SafeSlice(_, matrix, indices) => {
699                let sliced_dimension = indices.iter().position(Option::is_none);
700
701                let dom = matrix.domain_of()?;
702                let Some((elem_domain, index_domains)) = dom.as_matrix() else {
703                    bug!("subject of an index operation should be a matrix");
704                };
705
706                match sliced_dimension {
707                    Some(dimension) => Some(Domain::matrix(
708                        elem_domain,
709                        vec![index_domains[dimension].clone()],
710                    )),
711
712                    // same as index
713                    None => Some(elem_domain),
714                }
715            }
716            Expression::InDomain(_, _, _) => Some(Domain::bool()),
717            Expression::Atomic(_, atom) => Some(atom.domain_of()),
718            Expression::Scope(_, _) => Some(Domain::bool()),
719            Expression::Sum(_, e) => {
720                bounded_i32_domain_for_matrix_literal_monotonic(e, |x, y| Some(x + y))
721            }
722            Expression::Product(_, e) => {
723                bounded_i32_domain_for_matrix_literal_monotonic(e, |x, y| Some(x * y))
724            }
725            Expression::Min(_, e) => bounded_i32_domain_for_matrix_literal_monotonic(e, |x, y| {
726                Some(if x < y { x } else { y })
727            }),
728            Expression::Max(_, e) => bounded_i32_domain_for_matrix_literal_monotonic(e, |x, y| {
729                Some(if x > y { x } else { y })
730            }),
731            Expression::UnsafeDiv(_, a, b) => a
732                .domain_of()?
733                .resolve()?
734                .apply_i32(
735                    // rust integer division is truncating; however, we want to always round down,
736                    // including for negative numbers.
737                    |x, y| {
738                        if y != 0 {
739                            Some((x as f32 / y as f32).floor() as i32)
740                        } else {
741                            None
742                        }
743                    },
744                    b.domain_of()?.resolve()?.as_ref(),
745                )
746                .map(DomainPtr::from)
747                .ok(),
748            Expression::SafeDiv(_, a, b) => {
749                // rust integer division is truncating; however, we want to always round down
750                // including for negative numbers.
751                let domain = a
752                    .domain_of()?
753                    .resolve()?
754                    .apply_i32(
755                        |x, y| {
756                            if y != 0 {
757                                Some((x as f32 / y as f32).floor() as i32)
758                            } else {
759                                None
760                            }
761                        },
762                        b.domain_of()?.resolve()?.as_ref(),
763                    )
764                    .unwrap_or_else(|err| bug!("Got {err} when computing domain of {self}"));
765
766                if let GroundDomain::Int(ranges) = domain {
767                    let mut ranges = ranges;
768                    ranges.push(Range::Single(0));
769                    return Some(Domain::int(ranges));
770                } else {
771                    bug!("Domain of {self} was not integer")
772                }
773            }
774            Expression::UnsafeMod(_, a, b) => a
775                .domain_of()?
776                .resolve()?
777                .apply_i32(
778                    |x, y| if y != 0 { Some(x % y) } else { None },
779                    b.domain_of()?.resolve()?.as_ref(),
780                )
781                .map(DomainPtr::from)
782                .ok(),
783            Expression::SafeMod(_, a, b) => {
784                let domain = a
785                    .domain_of()?
786                    .resolve()?
787                    .apply_i32(
788                        |x, y| if y != 0 { Some(x % y) } else { None },
789                        b.domain_of()?.resolve()?.as_ref(),
790                    )
791                    .unwrap_or_else(|err| bug!("Got {err} when computing domain of {self}"));
792
793                if let GroundDomain::Int(ranges) = domain {
794                    let mut ranges = ranges;
795                    ranges.push(Range::Single(0));
796                    return Some(Domain::int(ranges));
797                } else {
798                    bug!("Domain of {self} was not integer")
799                }
800            }
801            Expression::SafePow(_, a, b) | Expression::UnsafePow(_, a, b) => a
802                .domain_of()?
803                .resolve()?
804                .apply_i32(
805                    |x, y| {
806                        if (x != 0 || y != 0) && y >= 0 {
807                            Some(x.pow(y as u32))
808                        } else {
809                            None
810                        }
811                    },
812                    b.domain_of()?.resolve()?.as_ref(),
813                )
814                .map(DomainPtr::from)
815                .ok(),
816            Expression::Root(_, _) => None,
817            Expression::Bubble(_, inner, _) => inner.domain_of(),
818            Expression::AuxDeclaration(_, _, _) => Some(Domain::bool()),
819            Expression::And(_, _) => Some(Domain::bool()),
820            Expression::Not(_, _) => Some(Domain::bool()),
821            Expression::Or(_, _) => Some(Domain::bool()),
822            Expression::Imply(_, _, _) => Some(Domain::bool()),
823            Expression::Iff(_, _, _) => Some(Domain::bool()),
824            Expression::Eq(_, _, _) => Some(Domain::bool()),
825            Expression::Neq(_, _, _) => Some(Domain::bool()),
826            Expression::Geq(_, _, _) => Some(Domain::bool()),
827            Expression::Leq(_, _, _) => Some(Domain::bool()),
828            Expression::Gt(_, _, _) => Some(Domain::bool()),
829            Expression::Lt(_, _, _) => Some(Domain::bool()),
830            Expression::FlatAbsEq(_, _, _) => Some(Domain::bool()),
831            Expression::FlatSumGeq(_, _, _) => Some(Domain::bool()),
832            Expression::FlatSumLeq(_, _, _) => Some(Domain::bool()),
833            Expression::MinionDivEqUndefZero(_, _, _, _) => Some(Domain::bool()),
834            Expression::MinionModuloEqUndefZero(_, _, _, _) => Some(Domain::bool()),
835            Expression::FlatIneq(_, _, _, _) => Some(Domain::bool()),
836            Expression::Flatten(_, n, m) => {
837                if let Some(expr) = n {
838                    if expr.return_type() == ReturnType::Int {
839                        // TODO: handle flatten with depth argument
840                        return None;
841                    }
842                } else {
843                    // TODO: currently only works for matrices
844                    let dom = m.domain_of()?.resolve()?;
845                    let (val_dom, idx_doms) = match dom.as_ref() {
846                        GroundDomain::Matrix(val, idx) => (val, idx),
847                        _ => return None,
848                    };
849                    let num_elems = matrix::num_elements(idx_doms).ok()? as i32;
850
851                    let new_index_domain = Domain::int(vec![Range::Bounded(1, num_elems)]);
852                    return Some(Domain::matrix(
853                        val_dom.clone().into(),
854                        vec![new_index_domain],
855                    ));
856                }
857                None
858            }
859            Expression::AllDiff(_, _) => Some(Domain::bool()),
860            Expression::FlatWatchedLiteral(_, _, _) => Some(Domain::bool()),
861            Expression::MinionReify(_, _, _) => Some(Domain::bool()),
862            Expression::MinionReifyImply(_, _, _) => Some(Domain::bool()),
863            Expression::MinionWInIntervalSet(_, _, _) => Some(Domain::bool()),
864            Expression::MinionWInSet(_, _, _) => Some(Domain::bool()),
865            Expression::MinionElementOne(_, _, _, _) => Some(Domain::bool()),
866            Expression::Neg(_, x) => {
867                let dom = x.domain_of()?;
868                let mut ranges = dom.as_int()?;
869
870                ranges = ranges
871                    .into_iter()
872                    .map(|r| match r {
873                        Range::Single(x) => Range::Single(-x),
874                        Range::Bounded(x, y) => Range::Bounded(-y, -x),
875                        Range::UnboundedR(i) => Range::UnboundedL(-i),
876                        Range::UnboundedL(i) => Range::UnboundedR(-i),
877                        Range::Unbounded => Range::Unbounded,
878                    })
879                    .collect();
880
881                Some(Domain::int(ranges))
882            }
883            Expression::Minus(_, a, b) => a
884                .domain_of()?
885                .resolve()?
886                .apply_i32(|x, y| Some(x - y), b.domain_of()?.resolve()?.as_ref())
887                .map(DomainPtr::from)
888                .ok(),
889            Expression::FlatAllDiff(_, _) => Some(Domain::bool()),
890            Expression::FlatMinusEq(_, _, _) => Some(Domain::bool()),
891            Expression::FlatProductEq(_, _, _, _) => Some(Domain::bool()),
892            Expression::FlatWeightedSumLeq(_, _, _, _) => Some(Domain::bool()),
893            Expression::FlatWeightedSumGeq(_, _, _, _) => Some(Domain::bool()),
894            Expression::Abs(_, a) => a
895                .domain_of()?
896                .resolve()?
897                .apply_i32(|a, _| Some(a.abs()), a.domain_of()?.resolve()?.as_ref())
898                .map(DomainPtr::from)
899                .ok(),
900            Expression::MinionPow(_, _, _, _) => Some(Domain::bool()),
901            Expression::ToInt(_, _) => Some(Domain::int(vec![Range::Bounded(0, 1)])),
902            Expression::SATInt(_, _) => {
903                Some(Domain::int_ground(vec![Range::Bounded(
904                    i8::MIN.into(),
905                    i8::MAX.into(),
906                )])) // BITS
907            } // A CnfInt can represent any i8 integer at the moment
908            // A CnfInt contains multiple boolean expressions and represents the integer
909            // formed when these booleans are treated as the bits in an integer encoding.
910            // So the 'domain of' should be an integer
911            Expression::PairwiseSum(_, a, b) => a
912                .domain_of()?
913                .resolve()?
914                .apply_i32(|a, b| Some(a + b), b.domain_of()?.resolve()?.as_ref())
915                .map(DomainPtr::from)
916                .ok(),
917            Expression::PairwiseProduct(_, a, b) => a
918                .domain_of()?
919                .resolve()?
920                .apply_i32(|a, b| Some(a * b), b.domain_of()?.resolve()?.as_ref())
921                .map(DomainPtr::from)
922                .ok(),
923            Expression::Defined(_, function) => get_function_domain(function),
924            Expression::Range(_, function) => get_function_codomain(function),
925            Expression::Image(_, function, _) => get_function_codomain(function),
926            Expression::ImageSet(_, function, _) => get_function_codomain(function),
927            Expression::PreImage(_, function, _) => get_function_domain(function),
928            Expression::Restrict(_, function, new_domain) => {
929                let function_domain = function.domain_of()?;
930                match function_domain.resolve().as_ref() {
931                    Some(d) => {
932                        match d.as_ref() {
933                            GroundDomain::Function(attrs, _, codomain) => Some(Domain::function(
934                                attrs.clone(),
935                                new_domain.domain_of()?,
936                                codomain.clone().into(),
937                            )),
938                            // Not defined for anything other than a function
939                            _ => None,
940                        }
941                    }
942                    None => {
943                        match function_domain.as_unresolved()? {
944                            UnresolvedDomain::Function(attrs, _, codomain) => {
945                                Some(Domain::function(
946                                    attrs.clone(),
947                                    new_domain.domain_of()?,
948                                    codomain.clone(),
949                                ))
950                            }
951                            // Not defined for anything other than a function
952                            _ => None,
953                        }
954                    }
955                }
956            }
957            Expression::Inverse(..) => Some(Domain::bool()),
958            Expression::LexLt(..) => Some(Domain::bool()),
959            Expression::LexLeq(..) => Some(Domain::bool()),
960            Expression::LexGt(..) => Some(Domain::bool()),
961            Expression::LexGeq(..) => Some(Domain::bool()),
962            Expression::FlatLexLt(..) => Some(Domain::bool()),
963            Expression::FlatLexLeq(..) => Some(Domain::bool()),
964        };
965        if let Some(dom) = &ret
966            && let Some(ranges) = dom.as_int_ground()
967            && ranges.len() > 1
968        {
969            // TODO: (flm8) the Minion bindings currently only support single ranges for domains, so we use the min/max bounds
970            // Once they support a full domain as we define it, we can remove this conversion
971            let (min, max) = range_vec_bounds_i32(ranges)?;
972            return Some(Domain::int(vec![Range::Bounded(min, max)]));
973        }
974        ret
975    }
976
977    pub fn get_meta(&self) -> Metadata {
978        let metas: VecDeque<Metadata> = self.children_bi();
979        metas[0].clone()
980    }
981
982    pub fn set_meta(&self, meta: Metadata) {
983        self.transform_bi(&|_| meta.clone());
984    }
985
986    /// Checks whether this expression is safe.
987    ///
988    /// An expression is unsafe if can be undefined, or if any of its children can be undefined.
989    ///
990    /// Unsafe expressions are (typically) prefixed with Unsafe in our AST, and can be made
991    /// safe through the use of bubble rules.
992    pub fn is_safe(&self) -> bool {
993        // TODO: memoise in Metadata
994        for expr in self.universe() {
995            match expr {
996                Expression::UnsafeDiv(_, _, _)
997                | Expression::UnsafeMod(_, _, _)
998                | Expression::UnsafePow(_, _, _)
999                | Expression::UnsafeIndex(_, _, _)
1000                | Expression::Bubble(_, _, _)
1001                | Expression::UnsafeSlice(_, _, _) => {
1002                    return false;
1003                }
1004                _ => {}
1005            }
1006        }
1007        true
1008    }
1009
1010    pub fn is_clean(&self) -> bool {
1011        let metadata = self.get_meta();
1012        metadata.clean
1013    }
1014
1015    pub fn set_clean(&mut self, bool_value: bool) {
1016        let mut metadata = self.get_meta();
1017        metadata.clean = bool_value;
1018        self.set_meta(metadata);
1019    }
1020
1021    /// True if the expression is an associative and commutative operator
1022    pub fn is_associative_commutative_operator(&self) -> bool {
1023        TryInto::<ACOperatorKind>::try_into(self).is_ok()
1024    }
1025
1026    /// True if the expression is a matrix literal.
1027    ///
1028    /// This is true for both forms of matrix literals: those with elements of type [`Literal`] and
1029    /// [`Expression`].
1030    pub fn is_matrix_literal(&self) -> bool {
1031        matches!(
1032            self,
1033            Expression::AbstractLiteral(_, AbstractLiteral::Matrix(_, _))
1034                | Expression::Atomic(
1035                    _,
1036                    Atom::Literal(Literal::AbstractLiteral(AbstractLiteral::Matrix(_, _))),
1037                )
1038        )
1039    }
1040
1041    /// True iff self and other are both atomic and identical.
1042    ///
1043    /// This method is useful to cheaply check equivalence. Assuming CSE is enabled, any unifiable
1044    /// expressions will be rewritten to a common variable. This is much cheaper than checking the
1045    /// entire subtrees of `self` and `other`.
1046    pub fn identical_atom_to(&self, other: &Expression) -> bool {
1047        let atom1: Result<&Atom, _> = self.try_into();
1048        let atom2: Result<&Atom, _> = other.try_into();
1049
1050        if let (Ok(atom1), Ok(atom2)) = (atom1, atom2) {
1051            atom2 == atom1
1052        } else {
1053            false
1054        }
1055    }
1056
1057    /// If the expression is a list, returns the inner expressions.
1058    ///
1059    /// A list is any a matrix with the domain `int(1..)`. This includes matrix literals without
1060    /// any explicitly specified domain.
1061    pub fn unwrap_list(self) -> Option<Vec<Expression>> {
1062        match self {
1063            Expression::AbstractLiteral(_, matrix @ AbstractLiteral::Matrix(_, _)) => {
1064                matrix.unwrap_list().cloned()
1065            }
1066            Expression::Atomic(
1067                _,
1068                Atom::Literal(Literal::AbstractLiteral(matrix @ AbstractLiteral::Matrix(_, _))),
1069            ) => matrix.unwrap_list().map(|elems| {
1070                elems
1071                    .clone()
1072                    .into_iter()
1073                    .map(|x: Literal| Expression::Atomic(Metadata::new(), Atom::Literal(x)))
1074                    .collect_vec()
1075            }),
1076            _ => None,
1077        }
1078    }
1079
1080    /// If the expression is a matrix, gets it elements and index domain.
1081    ///
1082    /// **Consider using the safer [`Expression::unwrap_list`] instead.**
1083    ///
1084    /// It is generally undefined to edit the length of a matrix unless it is a list (as defined by
1085    /// [`Expression::unwrap_list`]). Users of this function should ensure that, if the matrix is
1086    /// reconstructed, the index domain and the number of elements in the matrix remain the same.
1087    pub fn unwrap_matrix_unchecked(self) -> Option<(Vec<Expression>, DomainPtr)> {
1088        match self {
1089            Expression::AbstractLiteral(_, AbstractLiteral::Matrix(elems, domain)) => {
1090                Some((elems, domain))
1091            }
1092            Expression::Atomic(
1093                _,
1094                Atom::Literal(Literal::AbstractLiteral(AbstractLiteral::Matrix(elems, domain))),
1095            ) => Some((
1096                elems
1097                    .into_iter()
1098                    .map(|x: Literal| Expression::Atomic(Metadata::new(), Atom::Literal(x)))
1099                    .collect_vec(),
1100                domain.into(),
1101            )),
1102
1103            _ => None,
1104        }
1105    }
1106
1107    /// For a Root expression, extends the inner vec with the given vec.
1108    ///
1109    /// # Panics
1110    /// Panics if the expression is not Root.
1111    pub fn extend_root(self, exprs: Vec<Expression>) -> Expression {
1112        match self {
1113            Expression::Root(meta, mut children) => {
1114                children.extend(exprs);
1115                Expression::Root(meta, children)
1116            }
1117            _ => panic!("extend_root called on a non-Root expression"),
1118        }
1119    }
1120
1121    /// Converts the expression to a literal, if possible.
1122    pub fn into_literal(self) -> Option<Literal> {
1123        match self {
1124            Expression::Atomic(_, Atom::Literal(lit)) => Some(lit),
1125            Expression::AbstractLiteral(_, abslit) => {
1126                Some(Literal::AbstractLiteral(abslit.into_literals()?))
1127            }
1128            Expression::Neg(_, e) => {
1129                let Literal::Int(i) = Moo::unwrap_or_clone(e).into_literal()? else {
1130                    bug!("negated literal should be an int");
1131                };
1132
1133                Some(Literal::Int(-i))
1134            }
1135
1136            _ => None,
1137        }
1138    }
1139
1140    /// If this expression is an associative-commutative operator, return its [ACOperatorKind].
1141    pub fn to_ac_operator_kind(&self) -> Option<ACOperatorKind> {
1142        TryFrom::try_from(self).ok()
1143    }
1144
1145    /// Returns the categories of all sub-expressions of self.
1146    pub fn universe_categories(&self) -> HashSet<Category> {
1147        self.universe()
1148            .into_iter()
1149            .map(|x| x.category_of())
1150            .collect()
1151    }
1152}
1153
1154pub fn get_function_domain(function: &Moo<Expression>) -> Option<DomainPtr> {
1155    let function_domain = function.domain_of()?;
1156    match function_domain.resolve().as_ref() {
1157        Some(d) => {
1158            match d.as_ref() {
1159                GroundDomain::Function(_, domain, _) => Some(domain.clone().into()),
1160                // Not defined for anything other than a function
1161                _ => None,
1162            }
1163        }
1164        None => {
1165            match function_domain.as_unresolved()? {
1166                UnresolvedDomain::Function(_, domain, _) => Some(domain.clone()),
1167                // Not defined for anything other than a function
1168                _ => None,
1169            }
1170        }
1171    }
1172}
1173
1174pub fn get_function_codomain(function: &Moo<Expression>) -> Option<DomainPtr> {
1175    let function_domain = function.domain_of()?;
1176    match function_domain.resolve().as_ref() {
1177        Some(d) => {
1178            match d.as_ref() {
1179                GroundDomain::Function(_, _, codomain) => Some(codomain.clone().into()),
1180                // Not defined for anything other than a function
1181                _ => None,
1182            }
1183        }
1184        None => {
1185            match function_domain.as_unresolved()? {
1186                UnresolvedDomain::Function(_, _, codomain) => Some(codomain.clone()),
1187                // Not defined for anything other than a function
1188                _ => None,
1189            }
1190        }
1191    }
1192}
1193
1194impl TryFrom<&Expression> for i32 {
1195    type Error = ();
1196
1197    fn try_from(value: &Expression) -> Result<Self, Self::Error> {
1198        let Expression::Atomic(_, atom) = value else {
1199            return Err(());
1200        };
1201
1202        let Atom::Literal(lit) = atom else {
1203            return Err(());
1204        };
1205
1206        let Literal::Int(i) = lit else {
1207            return Err(());
1208        };
1209
1210        Ok(*i)
1211    }
1212}
1213
1214impl TryFrom<Expression> for i32 {
1215    type Error = ();
1216
1217    fn try_from(value: Expression) -> Result<Self, Self::Error> {
1218        TryFrom::<&Expression>::try_from(&value)
1219    }
1220}
1221impl From<i32> for Expression {
1222    fn from(i: i32) -> Self {
1223        Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Int(i)))
1224    }
1225}
1226
1227impl From<bool> for Expression {
1228    fn from(b: bool) -> Self {
1229        Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Bool(b)))
1230    }
1231}
1232
1233impl From<Atom> for Expression {
1234    fn from(value: Atom) -> Self {
1235        Expression::Atomic(Metadata::new(), value)
1236    }
1237}
1238
1239impl From<Literal> for Expression {
1240    fn from(value: Literal) -> Self {
1241        Expression::Atomic(Metadata::new(), value.into())
1242    }
1243}
1244
1245impl From<Moo<Expression>> for Expression {
1246    fn from(val: Moo<Expression>) -> Self {
1247        val.as_ref().clone()
1248    }
1249}
1250
1251impl CategoryOf for Expression {
1252    fn category_of(&self) -> Category {
1253        // take highest category of all the expressions children
1254        let category = self.cata(&move |x,children| {
1255
1256            if let Some(max_category) = children.iter().max() {
1257                // if this expression contains subexpressions, return the maximum category of the
1258                // subexpressions
1259                *max_category
1260            } else {
1261                // this expression has no children
1262                let mut max_category = Category::Bottom;
1263
1264                // calculate the category by looking at all atoms, submodels, comprehensions, and
1265                // declarationptrs inside this expression
1266
1267                // this should generically cover all leaf types we currently have in oxide.
1268
1269                // if x contains submodels (including comprehensions)
1270                if !Biplate::<SubModel>::universe_bi(&x).is_empty() {
1271                    // assume that the category is decision
1272                    return Category::Decision;
1273                }
1274
1275                // if x contains atoms
1276                if let Some(max_atom_category) = Biplate::<Atom>::universe_bi(&x).iter().map(|x| x.category_of()).max()
1277                // and those atoms have a higher category than we already know about
1278                && max_atom_category > max_category{
1279                    // update category 
1280                    max_category = max_atom_category;
1281                }
1282
1283                // if x contains declarationPtrs
1284                if let Some(max_declaration_category) = Biplate::<DeclarationPtr>::universe_bi(&x).iter().map(|x| x.category_of()).max()
1285                // and those pointers have a higher category than we already know about
1286                && max_declaration_category > max_category{
1287                    // update category 
1288                    max_category = max_declaration_category;
1289                }
1290                max_category
1291
1292            }
1293        });
1294
1295        if cfg!(debug_assertions) {
1296            trace!(
1297                category= %category,
1298                expression= %self,
1299                "Called Expression::category_of()"
1300            );
1301        };
1302        category
1303    }
1304}
1305
1306impl Display for Expression {
1307    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
1308        match &self {
1309            Expression::Union(_, box1, box2) => {
1310                write!(f, "({} union {})", box1.clone(), box2.clone())
1311            }
1312            Expression::In(_, e1, e2) => {
1313                write!(f, "{e1} in {e2}")
1314            }
1315            Expression::Intersect(_, box1, box2) => {
1316                write!(f, "({} intersect {})", box1.clone(), box2.clone())
1317            }
1318            Expression::Supset(_, box1, box2) => {
1319                write!(f, "({} supset {})", box1.clone(), box2.clone())
1320            }
1321            Expression::SupsetEq(_, box1, box2) => {
1322                write!(f, "({} supsetEq {})", box1.clone(), box2.clone())
1323            }
1324            Expression::Subset(_, box1, box2) => {
1325                write!(f, "({} subset {})", box1.clone(), box2.clone())
1326            }
1327            Expression::SubsetEq(_, box1, box2) => {
1328                write!(f, "({} subsetEq {})", box1.clone(), box2.clone())
1329            }
1330
1331            Expression::AbstractLiteral(_, l) => l.fmt(f),
1332            Expression::Comprehension(_, c) => c.fmt(f),
1333            Expression::UnsafeIndex(_, e1, e2) | Expression::SafeIndex(_, e1, e2) => {
1334                write!(f, "{e1}{}", pretty_vec(e2))
1335            }
1336            Expression::UnsafeSlice(_, e1, es) | Expression::SafeSlice(_, e1, es) => {
1337                let args = es
1338                    .iter()
1339                    .map(|x| match x {
1340                        Some(x) => format!("{x}"),
1341                        None => "..".into(),
1342                    })
1343                    .join(",");
1344
1345                write!(f, "{e1}[{args}]")
1346            }
1347            Expression::InDomain(_, e, domain) => {
1348                write!(f, "__inDomain({e},{domain})")
1349            }
1350            Expression::Root(_, exprs) => {
1351                write!(f, "{}", pretty_expressions_as_top_level(exprs))
1352            }
1353            Expression::DominanceRelation(_, expr) => write!(f, "DominanceRelation({expr})"),
1354            Expression::FromSolution(_, expr) => write!(f, "FromSolution({expr})"),
1355            Expression::Metavar(_, name) => write!(f, "&{name}"),
1356            Expression::Atomic(_, atom) => atom.fmt(f),
1357            Expression::Scope(_, submodel) => write!(f, "{{\n{submodel}\n}}"),
1358            Expression::Abs(_, a) => write!(f, "|{a}|"),
1359            Expression::Sum(_, e) => {
1360                write!(f, "sum({e})")
1361            }
1362            Expression::Product(_, e) => {
1363                write!(f, "product({e})")
1364            }
1365            Expression::Min(_, e) => {
1366                write!(f, "min({e})")
1367            }
1368            Expression::Max(_, e) => {
1369                write!(f, "max({e})")
1370            }
1371            Expression::Not(_, expr_box) => {
1372                write!(f, "!({})", expr_box.clone())
1373            }
1374            Expression::Or(_, e) => {
1375                write!(f, "or({e})")
1376            }
1377            Expression::And(_, e) => {
1378                write!(f, "and({e})")
1379            }
1380            Expression::Imply(_, box1, box2) => {
1381                write!(f, "({box1}) -> ({box2})")
1382            }
1383            Expression::Iff(_, box1, box2) => {
1384                write!(f, "({box1}) <-> ({box2})")
1385            }
1386            Expression::Eq(_, box1, box2) => {
1387                write!(f, "({} = {})", box1.clone(), box2.clone())
1388            }
1389            Expression::Neq(_, box1, box2) => {
1390                write!(f, "({} != {})", box1.clone(), box2.clone())
1391            }
1392            Expression::Geq(_, box1, box2) => {
1393                write!(f, "({} >= {})", box1.clone(), box2.clone())
1394            }
1395            Expression::Leq(_, box1, box2) => {
1396                write!(f, "({} <= {})", box1.clone(), box2.clone())
1397            }
1398            Expression::Gt(_, box1, box2) => {
1399                write!(f, "({} > {})", box1.clone(), box2.clone())
1400            }
1401            Expression::Lt(_, box1, box2) => {
1402                write!(f, "({} < {})", box1.clone(), box2.clone())
1403            }
1404            Expression::FlatSumGeq(_, box1, box2) => {
1405                write!(f, "SumGeq({}, {})", pretty_vec(box1), box2.clone())
1406            }
1407            Expression::FlatSumLeq(_, box1, box2) => {
1408                write!(f, "SumLeq({}, {})", pretty_vec(box1), box2.clone())
1409            }
1410            Expression::FlatIneq(_, box1, box2, box3) => write!(
1411                f,
1412                "Ineq({}, {}, {})",
1413                box1.clone(),
1414                box2.clone(),
1415                box3.clone()
1416            ),
1417            Expression::Flatten(_, n, m) => {
1418                if let Some(n) = n {
1419                    write!(f, "flatten({n}, {m})")
1420                } else {
1421                    write!(f, "flatten({m})")
1422                }
1423            }
1424            Expression::AllDiff(_, e) => {
1425                write!(f, "allDiff({e})")
1426            }
1427            Expression::Bubble(_, box1, box2) => {
1428                write!(f, "{{{} @ {}}}", box1.clone(), box2.clone())
1429            }
1430            Expression::SafeDiv(_, box1, box2) => {
1431                write!(f, "SafeDiv({}, {})", box1.clone(), box2.clone())
1432            }
1433            Expression::UnsafeDiv(_, box1, box2) => {
1434                write!(f, "UnsafeDiv({}, {})", box1.clone(), box2.clone())
1435            }
1436            Expression::UnsafePow(_, box1, box2) => {
1437                write!(f, "UnsafePow({}, {})", box1.clone(), box2.clone())
1438            }
1439            Expression::SafePow(_, box1, box2) => {
1440                write!(f, "SafePow({}, {})", box1.clone(), box2.clone())
1441            }
1442            Expression::MinionDivEqUndefZero(_, box1, box2, box3) => {
1443                write!(
1444                    f,
1445                    "DivEq({}, {}, {})",
1446                    box1.clone(),
1447                    box2.clone(),
1448                    box3.clone()
1449                )
1450            }
1451            Expression::MinionModuloEqUndefZero(_, box1, box2, box3) => {
1452                write!(
1453                    f,
1454                    "ModEq({}, {}, {})",
1455                    box1.clone(),
1456                    box2.clone(),
1457                    box3.clone()
1458                )
1459            }
1460            Expression::FlatWatchedLiteral(_, x, l) => {
1461                write!(f, "WatchedLiteral({x},{l})")
1462            }
1463            Expression::MinionReify(_, box1, box2) => {
1464                write!(f, "Reify({}, {})", box1.clone(), box2.clone())
1465            }
1466            Expression::MinionReifyImply(_, box1, box2) => {
1467                write!(f, "ReifyImply({}, {})", box1.clone(), box2.clone())
1468            }
1469            Expression::MinionWInIntervalSet(_, atom, intervals) => {
1470                let intervals = intervals.iter().join(",");
1471                write!(f, "__minion_w_inintervalset({atom},[{intervals}])")
1472            }
1473            Expression::MinionWInSet(_, atom, values) => {
1474                let values = values.iter().join(",");
1475                write!(f, "__minion_w_inset({atom},{values})")
1476            }
1477            Expression::AuxDeclaration(_, reference, e) => {
1478                write!(f, "{} =aux {}", reference, e.clone())
1479            }
1480            Expression::UnsafeMod(_, a, b) => {
1481                write!(f, "{} % {}", a.clone(), b.clone())
1482            }
1483            Expression::SafeMod(_, a, b) => {
1484                write!(f, "SafeMod({},{})", a.clone(), b.clone())
1485            }
1486            Expression::Neg(_, a) => {
1487                write!(f, "-({})", a.clone())
1488            }
1489            Expression::Minus(_, a, b) => {
1490                write!(f, "({} - {})", a.clone(), b.clone())
1491            }
1492            Expression::FlatAllDiff(_, es) => {
1493                write!(f, "__flat_alldiff({})", pretty_vec(es))
1494            }
1495            Expression::FlatAbsEq(_, a, b) => {
1496                write!(f, "AbsEq({},{})", a.clone(), b.clone())
1497            }
1498            Expression::FlatMinusEq(_, a, b) => {
1499                write!(f, "MinusEq({},{})", a.clone(), b.clone())
1500            }
1501            Expression::FlatProductEq(_, a, b, c) => {
1502                write!(
1503                    f,
1504                    "FlatProductEq({},{},{})",
1505                    a.clone(),
1506                    b.clone(),
1507                    c.clone()
1508                )
1509            }
1510            Expression::FlatWeightedSumLeq(_, cs, vs, total) => {
1511                write!(
1512                    f,
1513                    "FlatWeightedSumLeq({},{},{})",
1514                    pretty_vec(cs),
1515                    pretty_vec(vs),
1516                    total.clone()
1517                )
1518            }
1519            Expression::FlatWeightedSumGeq(_, cs, vs, total) => {
1520                write!(
1521                    f,
1522                    "FlatWeightedSumGeq({},{},{})",
1523                    pretty_vec(cs),
1524                    pretty_vec(vs),
1525                    total.clone()
1526                )
1527            }
1528            Expression::MinionPow(_, atom, atom1, atom2) => {
1529                write!(f, "MinionPow({atom},{atom1},{atom2})")
1530            }
1531            Expression::MinionElementOne(_, atoms, atom, atom1) => {
1532                let atoms = atoms.iter().join(",");
1533                write!(f, "__minion_element_one([{atoms}],{atom},{atom1})")
1534            }
1535
1536            Expression::ToInt(_, expr) => {
1537                write!(f, "toInt({expr})")
1538            }
1539
1540            Expression::SATInt(_, e) => {
1541                write!(f, "SATInt({e})")
1542            }
1543
1544            Expression::PairwiseSum(_, a, b) => write!(f, "PairwiseSum({a}, {b})"),
1545            Expression::PairwiseProduct(_, a, b) => write!(f, "PairwiseProduct({a}, {b})"),
1546
1547            Expression::Defined(_, function) => write!(f, "defined({function})"),
1548            Expression::Range(_, function) => write!(f, "range({function})"),
1549            Expression::Image(_, function, elems) => write!(f, "image({function},{elems})"),
1550            Expression::ImageSet(_, function, elems) => write!(f, "imageSet({function},{elems})"),
1551            Expression::PreImage(_, function, elems) => write!(f, "preImage({function},{elems})"),
1552            Expression::Inverse(_, a, b) => write!(f, "inverse({a},{b})"),
1553            Expression::Restrict(_, function, domain) => write!(f, "restrict({function},{domain})"),
1554
1555            Expression::LexLt(_, a, b) => write!(f, "({a} <lex {b})"),
1556            Expression::LexLeq(_, a, b) => write!(f, "({a} <=lex {b})"),
1557            Expression::LexGt(_, a, b) => write!(f, "({a} >lex {b})"),
1558            Expression::LexGeq(_, a, b) => write!(f, "({a} >=lex {b})"),
1559            Expression::FlatLexLt(_, a, b) => {
1560                write!(f, "FlatLexLt({}, {})", pretty_vec(a), pretty_vec(b))
1561            }
1562            Expression::FlatLexLeq(_, a, b) => {
1563                write!(f, "FlatLexLeq({}, {})", pretty_vec(a), pretty_vec(b))
1564            }
1565        }
1566    }
1567}
1568
1569impl Typeable for Expression {
1570    fn return_type(&self) -> ReturnType {
1571        match self {
1572            Expression::Union(_, subject, _) => ReturnType::Set(Box::new(subject.return_type())),
1573            Expression::Intersect(_, subject, _) => {
1574                ReturnType::Set(Box::new(subject.return_type()))
1575            }
1576            Expression::In(_, _, _) => ReturnType::Bool,
1577            Expression::Supset(_, _, _) => ReturnType::Bool,
1578            Expression::SupsetEq(_, _, _) => ReturnType::Bool,
1579            Expression::Subset(_, _, _) => ReturnType::Bool,
1580            Expression::SubsetEq(_, _, _) => ReturnType::Bool,
1581            Expression::AbstractLiteral(_, lit) => lit.return_type(),
1582            Expression::UnsafeIndex(_, subject, idx) | Expression::SafeIndex(_, subject, idx) => {
1583                let subject_ty = subject.return_type();
1584                match subject_ty {
1585                    ReturnType::Matrix(_) => {
1586                        // For n-dimensional matrices, unwrap the element type until
1587                        // we either get to the innermost element type or the last index
1588                        let mut elem_typ = subject_ty;
1589                        let mut idx_len = idx.len();
1590                        while idx_len > 0
1591                            && let ReturnType::Matrix(new_elem_typ) = &elem_typ
1592                        {
1593                            elem_typ = *new_elem_typ.clone();
1594                            idx_len -= 1;
1595                        }
1596                        elem_typ
1597                    }
1598                    // TODO: We can implement indexing for these eventually
1599                    ReturnType::Record(_) | ReturnType::Tuple(_) => ReturnType::Unknown,
1600                    _ => bug!(
1601                        "Invalid indexing operation: expected the operand to be a collection, got {self}: {subject_ty}"
1602                    ),
1603                }
1604            }
1605            Expression::UnsafeSlice(_, subject, _) | Expression::SafeSlice(_, subject, _) => {
1606                ReturnType::Matrix(Box::new(subject.return_type()))
1607            }
1608            Expression::InDomain(_, _, _) => ReturnType::Bool,
1609            Expression::Comprehension(_, comp) => comp.return_type(),
1610            Expression::Root(_, _) => ReturnType::Bool,
1611            Expression::DominanceRelation(_, _) => ReturnType::Bool,
1612            Expression::FromSolution(_, expr) => expr.return_type(),
1613            Expression::Metavar(_, _) => ReturnType::Unknown,
1614            Expression::Atomic(_, atom) => atom.return_type(),
1615            Expression::Scope(_, scope) => scope.return_type(),
1616            Expression::Abs(_, _) => ReturnType::Int,
1617            Expression::Sum(_, _) => ReturnType::Int,
1618            Expression::Product(_, _) => ReturnType::Int,
1619            Expression::Min(_, _) => ReturnType::Int,
1620            Expression::Max(_, _) => ReturnType::Int,
1621            Expression::Not(_, _) => ReturnType::Bool,
1622            Expression::Or(_, _) => ReturnType::Bool,
1623            Expression::Imply(_, _, _) => ReturnType::Bool,
1624            Expression::Iff(_, _, _) => ReturnType::Bool,
1625            Expression::And(_, _) => ReturnType::Bool,
1626            Expression::Eq(_, _, _) => ReturnType::Bool,
1627            Expression::Neq(_, _, _) => ReturnType::Bool,
1628            Expression::Geq(_, _, _) => ReturnType::Bool,
1629            Expression::Leq(_, _, _) => ReturnType::Bool,
1630            Expression::Gt(_, _, _) => ReturnType::Bool,
1631            Expression::Lt(_, _, _) => ReturnType::Bool,
1632            Expression::SafeDiv(_, _, _) => ReturnType::Int,
1633            Expression::UnsafeDiv(_, _, _) => ReturnType::Int,
1634            Expression::FlatAllDiff(_, _) => ReturnType::Bool,
1635            Expression::FlatSumGeq(_, _, _) => ReturnType::Bool,
1636            Expression::FlatSumLeq(_, _, _) => ReturnType::Bool,
1637            Expression::MinionDivEqUndefZero(_, _, _, _) => ReturnType::Bool,
1638            Expression::FlatIneq(_, _, _, _) => ReturnType::Bool,
1639            Expression::Flatten(_, _, matrix) => {
1640                let matrix_type = matrix.return_type();
1641                match matrix_type {
1642                    ReturnType::Matrix(_) => {
1643                        // unwrap until we get to innermost element
1644                        let mut elem_type = matrix_type;
1645                        while let ReturnType::Matrix(new_elem_type) = &elem_type {
1646                            elem_type = *new_elem_type.clone();
1647                        }
1648                        ReturnType::Matrix(Box::new(elem_type))
1649                    }
1650                    _ => bug!(
1651                        "Invalid indexing operation: expected the operand to be a collection, got {self}: {matrix_type}"
1652                    ),
1653                }
1654            }
1655            Expression::AllDiff(_, _) => ReturnType::Bool,
1656            Expression::Bubble(_, inner, _) => inner.return_type(),
1657            Expression::FlatWatchedLiteral(_, _, _) => ReturnType::Bool,
1658            Expression::MinionReify(_, _, _) => ReturnType::Bool,
1659            Expression::MinionReifyImply(_, _, _) => ReturnType::Bool,
1660            Expression::MinionWInIntervalSet(_, _, _) => ReturnType::Bool,
1661            Expression::MinionWInSet(_, _, _) => ReturnType::Bool,
1662            Expression::MinionElementOne(_, _, _, _) => ReturnType::Bool,
1663            Expression::AuxDeclaration(_, _, _) => ReturnType::Bool,
1664            Expression::UnsafeMod(_, _, _) => ReturnType::Int,
1665            Expression::SafeMod(_, _, _) => ReturnType::Int,
1666            Expression::MinionModuloEqUndefZero(_, _, _, _) => ReturnType::Bool,
1667            Expression::Neg(_, _) => ReturnType::Int,
1668            Expression::UnsafePow(_, _, _) => ReturnType::Int,
1669            Expression::SafePow(_, _, _) => ReturnType::Int,
1670            Expression::Minus(_, _, _) => ReturnType::Int,
1671            Expression::FlatAbsEq(_, _, _) => ReturnType::Bool,
1672            Expression::FlatMinusEq(_, _, _) => ReturnType::Bool,
1673            Expression::FlatProductEq(_, _, _, _) => ReturnType::Bool,
1674            Expression::FlatWeightedSumLeq(_, _, _, _) => ReturnType::Bool,
1675            Expression::FlatWeightedSumGeq(_, _, _, _) => ReturnType::Bool,
1676            Expression::MinionPow(_, _, _, _) => ReturnType::Bool,
1677            Expression::ToInt(_, _) => ReturnType::Int,
1678            Expression::SATInt(_, _) => ReturnType::Int,
1679            Expression::PairwiseSum(_, _, _) => ReturnType::Int,
1680            Expression::PairwiseProduct(_, _, _) => ReturnType::Int,
1681            Expression::Defined(_, function) => {
1682                let subject = function.return_type();
1683                match subject {
1684                    ReturnType::Function(domain, _) => *domain,
1685                    _ => bug!(
1686                        "Invalid defined operation: expected the operand to be a function, got {self}: {subject}"
1687                    ),
1688                }
1689            }
1690            Expression::Range(_, function) => {
1691                let subject = function.return_type();
1692                match subject {
1693                    ReturnType::Function(_, codomain) => *codomain,
1694                    _ => bug!(
1695                        "Invalid range operation: expected the operand to be a function, got {self}: {subject}"
1696                    ),
1697                }
1698            }
1699            Expression::Image(_, function, _) => {
1700                let subject = function.return_type();
1701                match subject {
1702                    ReturnType::Function(_, codomain) => *codomain,
1703                    _ => bug!(
1704                        "Invalid image operation: expected the operand to be a function, got {self}: {subject}"
1705                    ),
1706                }
1707            }
1708            Expression::ImageSet(_, function, _) => {
1709                let subject = function.return_type();
1710                match subject {
1711                    ReturnType::Function(_, codomain) => *codomain,
1712                    _ => bug!(
1713                        "Invalid imageSet operation: expected the operand to be a function, got {self}: {subject}"
1714                    ),
1715                }
1716            }
1717            Expression::PreImage(_, function, _) => {
1718                let subject = function.return_type();
1719                match subject {
1720                    ReturnType::Function(domain, _) => *domain,
1721                    _ => bug!(
1722                        "Invalid preImage operation: expected the operand to be a function, got {self}: {subject}"
1723                    ),
1724                }
1725            }
1726            Expression::Restrict(_, function, new_domain) => {
1727                let subject = function.return_type();
1728                match subject {
1729                    ReturnType::Function(_, codomain) => {
1730                        ReturnType::Function(Box::new(new_domain.return_type()), codomain)
1731                    }
1732                    _ => bug!(
1733                        "Invalid preImage operation: expected the operand to be a function, got {self}: {subject}"
1734                    ),
1735                }
1736            }
1737            Expression::Inverse(..) => ReturnType::Bool,
1738            Expression::LexLt(..) => ReturnType::Bool,
1739            Expression::LexGt(..) => ReturnType::Bool,
1740            Expression::LexLeq(..) => ReturnType::Bool,
1741            Expression::LexGeq(..) => ReturnType::Bool,
1742            Expression::FlatLexLt(..) => ReturnType::Bool,
1743            Expression::FlatLexLeq(..) => ReturnType::Bool,
1744        }
1745    }
1746}
1747
1748#[cfg(test)]
1749mod tests {
1750
1751    use crate::matrix_expr;
1752
1753    use super::*;
1754
1755    #[test]
1756    fn test_domain_of_constant_sum() {
1757        let c1 = Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Int(1)));
1758        let c2 = Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Int(2)));
1759        let sum = Expression::Sum(Metadata::new(), Moo::new(matrix_expr![c1, c2]));
1760        assert_eq!(sum.domain_of(), Some(Domain::int(vec![Range::Single(3)])));
1761    }
1762
1763    #[test]
1764    fn test_domain_of_constant_invalid_type() {
1765        let c1 = Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Int(1)));
1766        let c2 = Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Bool(true)));
1767        let sum = Expression::Sum(Metadata::new(), Moo::new(matrix_expr![c1, c2]));
1768        assert_eq!(sum.domain_of(), None);
1769    }
1770
1771    #[test]
1772    fn test_domain_of_empty_sum() {
1773        let sum = Expression::Sum(Metadata::new(), Moo::new(matrix_expr![]));
1774        assert_eq!(sum.domain_of(), None);
1775    }
1776}