Skip to main content

conjure_cp_core/ast/
expressions.rs

1use std::collections::{HashSet, VecDeque};
2use std::fmt::{Display, Formatter};
3use tracing::trace;
4
5use conjure_cp_enum_compatibility_macro::document_compatibility;
6use itertools::Itertools;
7use serde::{Deserialize, Serialize};
8use ustr::Ustr;
9
10use polyquine::Quine;
11use uniplate::{Biplate, Uniplate};
12
13use crate::bug;
14
15use super::abstract_comprehension::AbstractComprehension;
16use super::ac_operators::ACOperatorKind;
17use super::categories::{Category, CategoryOf};
18use super::comprehension::Comprehension;
19use super::domains::HasDomain as _;
20use super::pretty::{pretty_expressions_as_top_level, pretty_vec};
21use super::records::RecordValue;
22use super::sat_encoding::SATIntEncoding;
23use super::{
24    AbstractLiteral, Atom, DeclarationPtr, Domain, DomainPtr, GroundDomain, IntVal, Literal,
25    Metadata, Model, Moo, Name, Range, Reference, ReturnType, SetAttr, SymbolTable, SymbolTablePtr,
26    Typeable, UnresolvedDomain, matrix,
27};
28
29// Ensure that this type doesn't get too big
30//
31// If you triggered this assertion, you either made a variant of this enum that is too big, or you
32// made Name,Literal,AbstractLiteral,Atom bigger, which made this bigger! To fix this, put some
33// stuff in boxes.
34//
35// Enums take the size of their largest variant, so an enum with mostly small variants and a few
36// large ones wastes memory... A larger Expression type also slows down Oxide.
37//
38// For more information, and more details on type sizes and how to measure them, see the commit
39// message for 6012de809 (perf: reduce size of AST types, 2025-06-18).
40//
41// You can also see type sizes in the rustdoc documentation, generated by ./tools/gen_docs.sh
42//
43// https://github.com/conjure-cp/conjure-oxide/commit/6012de8096ca491ded91ecec61352fdf4e994f2e
44
45// TODO: box all usages of Metadata to bring this down a bit more - I have added variants to
46// ReturnType, and Metadata contains ReturnType, so Metadata has got bigger. Metadata will get a
47// lot bigger still when we start using it for memoisation, so it should really be
48// boxed ~niklasdewally
49
50// expect size of Expression to be 112 bytes
51static_assertions::assert_eq_size!([u8; 112], Expression);
52
53/// Represents different types of expressions used to define rules and constraints in the model.
54///
55/// The `Expression` enum includes operations, constants, and variable references
56/// used to build rules and conditions for the model.
57#[document_compatibility]
58#[derive(Clone, Debug, Hash, PartialEq, Eq, Serialize, Deserialize, Uniplate, Quine)]
59#[biplate(to=AbstractComprehension)]
60#[biplate(to=AbstractLiteral<Expression>)]
61#[biplate(to=AbstractLiteral<Literal>)]
62#[biplate(to=Atom)]
63#[biplate(to=Comprehension)]
64#[biplate(to=DeclarationPtr)]
65#[biplate(to=DomainPtr)]
66#[biplate(to=Literal)]
67#[biplate(to=Metadata)]
68#[biplate(to=Name)]
69#[biplate(to=Option<Expression>)]
70#[biplate(to=RecordValue<Expression>)]
71#[biplate(to=RecordValue<Literal>)]
72#[biplate(to=Reference)]
73#[biplate(to=Model)]
74#[biplate(to=SymbolTable)]
75#[biplate(to=SymbolTablePtr)]
76#[biplate(to=Vec<Expression>)]
77#[path_prefix(conjure_cp::ast)]
78pub enum Expression {
79    AbstractLiteral(Metadata, AbstractLiteral<Expression>),
80    /// The top of the model
81    Root(Metadata, Vec<Expression>),
82
83    /// An expression representing "A is valid as long as B is true"
84    /// Turns into a conjunction when it reaches a boolean context
85    Bubble(Metadata, Moo<Expression>, Moo<Expression>),
86
87    /// A comprehension.
88    ///
89    /// The inside of the comprehension opens a new scope.
90    // todo (gskorokhod): Comprehension contains a symbol table which contains a bunch of pointers.
91    // This makes implementing Quine tricky (it doesnt support Rc, by design). Skip it for now.
92    #[polyquine_skip]
93    Comprehension(Metadata, Moo<Comprehension>),
94
95    /// Higher-level abstract comprehension
96    #[polyquine_skip] // no idea what this is lol but it stops rustc screaming at me
97    AbstractComprehension(Metadata, Moo<AbstractComprehension>),
98
99    /// Defines dominance ("Solution A is preferred over Solution B")
100    DominanceRelation(Metadata, Moo<Expression>),
101    /// `fromSolution(name)` - Used in dominance relation definitions
102    FromSolution(Metadata, Moo<Atom>),
103
104    #[polyquine_with(arm = (_, name) => {
105        let ident = proc_macro2::Ident::new(name.as_str(), proc_macro2::Span::call_site());
106        quote::quote! { #ident.clone().into() }
107    })]
108    Metavar(Metadata, Ustr),
109
110    Atomic(Metadata, Atom),
111
112    /// A matrix index.
113    ///
114    /// Defined iff the indices are within their respective index domains.
115    #[compatible(JsonInput)]
116    UnsafeIndex(Metadata, Moo<Expression>, Vec<Expression>),
117
118    /// A safe matrix index.
119    ///
120    /// See [`Expression::UnsafeIndex`]
121    #[compatible(SMT)]
122    SafeIndex(Metadata, Moo<Expression>, Vec<Expression>),
123
124    /// A matrix slice: `a[indices]`.
125    ///
126    /// One of the indicies may be `None`, representing the dimension of the matrix we want to take
127    /// a slice of. For example, for some 3d matrix a, `a[1,..,2]` has the indices
128    /// `Some(1),None,Some(2)`.
129    ///
130    /// It is assumed that the slice only has one "wild-card" dimension and thus is 1 dimensional.
131    ///
132    /// Defined iff the defined indices are within their respective index domains.
133    #[compatible(JsonInput)]
134    UnsafeSlice(Metadata, Moo<Expression>, Vec<Option<Expression>>),
135
136    /// A safe matrix slice: `a[indices]`.
137    ///
138    /// See [`Expression::UnsafeSlice`].
139    SafeSlice(Metadata, Moo<Expression>, Vec<Option<Expression>>),
140
141    /// `inDomain(x,domain)` iff `x` is in the domain `domain`.
142    ///
143    /// This cannot be constructed from Essence input, nor passed to a solver: this expression is
144    /// mainly used during the conversion of `UnsafeIndex` and `UnsafeSlice` to `SafeIndex` and
145    /// `SafeSlice` respectively.
146    InDomain(Metadata, Moo<Expression>, DomainPtr),
147
148    /// `toInt(b)` casts boolean expression b to an integer.
149    ///
150    /// - If b is false, then `toInt(b) == 0`
151    ///
152    /// - If b is true, then `toInt(b) == 1`
153    #[compatible(SMT)]
154    ToInt(Metadata, Moo<Expression>),
155
156    /// `|x|` - absolute value of `x`
157    #[compatible(JsonInput, SMT)]
158    Abs(Metadata, Moo<Expression>),
159
160    /// `sum(<vec_expr>)`
161    #[compatible(JsonInput, SMT)]
162    Sum(Metadata, Moo<Expression>),
163
164    /// `a * b * c * ...`
165    #[compatible(JsonInput, SMT)]
166    Product(Metadata, Moo<Expression>),
167
168    /// `min(<vec_expr>)`
169    #[compatible(JsonInput, SMT)]
170    Min(Metadata, Moo<Expression>),
171
172    /// `max(<vec_expr>)`
173    #[compatible(JsonInput, SMT)]
174    Max(Metadata, Moo<Expression>),
175
176    /// `not(a)`
177    #[compatible(JsonInput, SAT, SMT)]
178    Not(Metadata, Moo<Expression>),
179
180    /// `or(<vec_expr>)`
181    #[compatible(JsonInput, SAT, SMT)]
182    Or(Metadata, Moo<Expression>),
183
184    /// `and(<vec_expr>)`
185    #[compatible(JsonInput, SAT, SMT)]
186    And(Metadata, Moo<Expression>),
187
188    /// Ensures that `a->b` (material implication).
189    #[compatible(JsonInput, SMT)]
190    Imply(Metadata, Moo<Expression>, Moo<Expression>),
191
192    /// `iff(a, b)` a <-> b
193    #[compatible(JsonInput, SMT)]
194    Iff(Metadata, Moo<Expression>, Moo<Expression>),
195
196    #[compatible(JsonInput)]
197    Union(Metadata, Moo<Expression>, Moo<Expression>),
198
199    #[compatible(JsonInput)]
200    In(Metadata, Moo<Expression>, Moo<Expression>),
201
202    #[compatible(JsonInput)]
203    Intersect(Metadata, Moo<Expression>, Moo<Expression>),
204
205    #[compatible(JsonInput)]
206    Supset(Metadata, Moo<Expression>, Moo<Expression>),
207
208    #[compatible(JsonInput)]
209    SupsetEq(Metadata, Moo<Expression>, Moo<Expression>),
210
211    #[compatible(JsonInput)]
212    Subset(Metadata, Moo<Expression>, Moo<Expression>),
213
214    #[compatible(JsonInput)]
215    SubsetEq(Metadata, Moo<Expression>, Moo<Expression>),
216
217    #[compatible(JsonInput, SMT)]
218    Eq(Metadata, Moo<Expression>, Moo<Expression>),
219
220    #[compatible(JsonInput, SMT)]
221    Neq(Metadata, Moo<Expression>, Moo<Expression>),
222
223    #[compatible(JsonInput, SMT)]
224    Geq(Metadata, Moo<Expression>, Moo<Expression>),
225
226    #[compatible(JsonInput, SMT)]
227    Leq(Metadata, Moo<Expression>, Moo<Expression>),
228
229    #[compatible(JsonInput, SMT)]
230    Gt(Metadata, Moo<Expression>, Moo<Expression>),
231
232    #[compatible(JsonInput, SMT)]
233    Lt(Metadata, Moo<Expression>, Moo<Expression>),
234
235    /// Division after preventing division by zero, usually with a bubble
236    #[compatible(SMT)]
237    SafeDiv(Metadata, Moo<Expression>, Moo<Expression>),
238
239    /// Division with a possibly undefined value (division by 0)
240    #[compatible(JsonInput)]
241    UnsafeDiv(Metadata, Moo<Expression>, Moo<Expression>),
242
243    /// Modulo after preventing mod 0, usually with a bubble
244    #[compatible(SMT)]
245    SafeMod(Metadata, Moo<Expression>, Moo<Expression>),
246
247    /// Modulo with a possibly undefined value (mod 0)
248    #[compatible(JsonInput)]
249    UnsafeMod(Metadata, Moo<Expression>, Moo<Expression>),
250
251    /// Negation: `-x`
252    #[compatible(JsonInput, SMT)]
253    Neg(Metadata, Moo<Expression>),
254
255    /// Set of domain values function is defined for
256    #[compatible(JsonInput)]
257    Defined(Metadata, Moo<Expression>),
258
259    /// Set of codomain values function is defined for
260    #[compatible(JsonInput)]
261    Range(Metadata, Moo<Expression>),
262
263    /// Unsafe power`x**y` (possibly undefined)
264    ///
265    /// Defined when (X!=0 \\/ Y!=0) /\ Y>=0
266    #[compatible(JsonInput)]
267    UnsafePow(Metadata, Moo<Expression>, Moo<Expression>),
268
269    /// `UnsafePow` after preventing undefinedness
270    SafePow(Metadata, Moo<Expression>, Moo<Expression>),
271
272    /// Flatten matrix operator
273    /// `flatten(M)` or `flatten(n, M)`
274    /// where M is a matrix and n is an optional integer argument indicating depth of flattening
275    Flatten(Metadata, Option<Moo<Expression>>, Moo<Expression>),
276
277    /// `allDiff(<vec_expr>)`
278    #[compatible(JsonInput)]
279    AllDiff(Metadata, Moo<Expression>),
280
281    /// `table([x1, x2, ...], [[r11, r12, ...], [r21, r22, ...], ...])`
282    ///
283    /// Represents a positive table constraint: the tuple `[x1, x2, ...]` must match one of the
284    /// allowed rows.
285    #[compatible(JsonInput)]
286    Table(Metadata, Moo<Expression>, Moo<Expression>),
287
288    /// `negativeTable([x1, x2, ...], [[r11, r12, ...], [r21, r22, ...], ...])`
289    ///
290    /// Represents a negative table constraint: the tuple `[x1, x2, ...]` must NOT match any of the
291    /// forbidden rows.
292    #[compatible(JsonInput)]
293    NegativeTable(Metadata, Moo<Expression>, Moo<Expression>),
294    /// Binary subtraction operator
295    ///
296    /// This is a parser-level construct, and is immediately normalised to `Sum([a,-b])`.
297    /// TODO: make this compatible with Set Difference calculations - need to change return type and domain for this expression and write a set comprehension rule.
298    /// have already edited minus_to_sum to prevent this from applying to sets
299    #[compatible(JsonInput)]
300    Minus(Metadata, Moo<Expression>, Moo<Expression>),
301
302    /// Ensures that x=|y| i.e. x is the absolute value of y.
303    ///
304    /// Low-level Minion constraint.
305    ///
306    /// # See also
307    ///
308    /// + [Minion documentation](https://minion-solver.readthedocs.io/en/stable/usage/constraints.html#abs)
309    #[compatible(Minion)]
310    FlatAbsEq(Metadata, Moo<Atom>, Moo<Atom>),
311
312    /// Ensures that `alldiff([a,b,...])`.
313    ///
314    /// Low-level Minion constraint.
315    ///
316    /// # See also
317    ///
318    /// + [Minion documentation](https://minion-solver.readthedocs.io/en/stable/usage/constraints.html#alldiff)
319    #[compatible(Minion)]
320    FlatAllDiff(Metadata, Vec<Atom>),
321
322    /// Ensures that sum(vec) >= x.
323    ///
324    /// Low-level Minion constraint.
325    ///
326    /// # See also
327    ///
328    /// + [Minion documentation](https://minion-solver.readthedocs.io/en/stable/usage/constraints.html#sumgeq)
329    #[compatible(Minion)]
330    FlatSumGeq(Metadata, Vec<Atom>, Atom),
331
332    /// Ensures that sum(vec) <= x.
333    ///
334    /// Low-level Minion constraint.
335    ///
336    /// # See also
337    ///
338    /// + [Minion documentation](https://minion-solver.readthedocs.io/en/stable/usage/constraints.html#sumleq)
339    #[compatible(Minion)]
340    FlatSumLeq(Metadata, Vec<Atom>, Atom),
341
342    /// `ineq(x,y,k)` ensures that x <= y + k.
343    ///
344    /// Low-level Minion constraint.
345    ///
346    /// # See also
347    ///
348    /// + [Minion documentation](https://minion-solver.readthedocs.io/en/stable/usage/constraints.html#ineq)
349    #[compatible(Minion)]
350    FlatIneq(Metadata, Moo<Atom>, Moo<Atom>, Box<Literal>),
351
352    /// `w-literal(x,k)` ensures that x == k, where x is a variable and k a constant.
353    ///
354    /// Low-level Minion constraint.
355    ///
356    /// This is a low-level Minion constraint and you should probably use Eq instead. The main use
357    /// of w-literal is to convert boolean variables to constraints so that they can be used inside
358    /// watched-and and watched-or.
359    ///
360    /// # See also
361    ///
362    /// + [Minion documentation](https://minion-solver.readthedocs.io/en/stable/usage/constraints.html#minuseq)
363    /// + `rules::minion::boolean_literal_to_wliteral`.
364    #[compatible(Minion)]
365    #[polyquine_skip]
366    FlatWatchedLiteral(Metadata, Reference, Literal),
367
368    /// `weightedsumleq(cs,xs,total)` ensures that cs.xs <= total, where cs.xs is the scalar dot
369    /// product of cs and xs.
370    ///
371    /// Low-level Minion constraint.
372    ///
373    /// Represents a weighted sum of the form `ax + by + cz + ...`
374    ///
375    /// # See also
376    ///
377    /// + [Minion
378    /// documentation](https://minion-solver.readthedocs.io/en/stable/usage/constraints.html#weightedsumleq)
379    FlatWeightedSumLeq(Metadata, Vec<Literal>, Vec<Atom>, Moo<Atom>),
380
381    /// `weightedsumgeq(cs,xs,total)` ensures that cs.xs >= total, where cs.xs is the scalar dot
382    /// product of cs and xs.
383    ///
384    /// Low-level Minion constraint.
385    ///
386    /// Represents a weighted sum of the form `ax + by + cz + ...`
387    ///
388    /// # See also
389    ///
390    /// + [Minion
391    /// documentation](https://minion-solver.readthedocs.io/en/stable/usage/constraints.html#weightedsumleq)
392    FlatWeightedSumGeq(Metadata, Vec<Literal>, Vec<Atom>, Moo<Atom>),
393
394    /// Ensures that x =-y, where x and y are atoms.
395    ///
396    /// Low-level Minion constraint.
397    ///
398    /// # See also
399    ///
400    /// + [Minion documentation](https://minion-solver.readthedocs.io/en/stable/usage/constraints.html#minuseq)
401    #[compatible(Minion)]
402    FlatMinusEq(Metadata, Moo<Atom>, Moo<Atom>),
403
404    /// Ensures that x*y=z.
405    ///
406    /// Low-level Minion constraint.
407    ///
408    /// # See also
409    ///
410    /// + [Minion documentation](https://minion-solver.readthedocs.io/en/stable/usage/constraints.html#product)
411    #[compatible(Minion)]
412    FlatProductEq(Metadata, Moo<Atom>, Moo<Atom>, Moo<Atom>),
413
414    /// Ensures that floor(x/y)=z. Always true when y=0.
415    ///
416    /// Low-level Minion constraint.
417    ///
418    /// # See also
419    ///
420    /// + [Minion documentation](https://minion-solver.readthedocs.io/en/stable/usage/constraints.html#div_undefzero)
421    #[compatible(Minion)]
422    MinionDivEqUndefZero(Metadata, Moo<Atom>, Moo<Atom>, Moo<Atom>),
423
424    /// Ensures that x%y=z. Always true when y=0.
425    ///
426    /// Low-level Minion constraint.
427    ///
428    /// # See also
429    ///
430    /// + [Minion documentation](https://minion-solver.readthedocs.io/en/stable/usage/constraints.html#mod_undefzero)
431    #[compatible(Minion)]
432    MinionModuloEqUndefZero(Metadata, Moo<Atom>, Moo<Atom>, Moo<Atom>),
433
434    /// Ensures that `x**y = z`.
435    ///
436    /// Low-level Minion constraint.
437    ///
438    /// This constraint is false when `y<0` except for `1**y=1` and `(-1)**y=z` (where z is 1 if y
439    /// is odd and z is -1 if y is even).
440    ///
441    /// # See also
442    ///
443    /// + [Github comment about `pow` semantics](https://github.com/minion/minion/issues/40#issuecomment-2595914891)
444    /// + [Minion documentation](https://minion-solver.readthedocs.io/en/stable/usage/constraints.html#pow)
445    MinionPow(Metadata, Moo<Atom>, Moo<Atom>, Moo<Atom>),
446
447    /// `reify(constraint,r)` ensures that r=1 iff `constraint` is satisfied, where r is a 0/1
448    /// variable.
449    ///
450    /// Low-level Minion constraint.
451    ///
452    /// # See also
453    ///
454    ///  + [Minion documentation](https://minion-solver.readthedocs.io/en/stable/usage/constraints.html#reify)
455    #[compatible(Minion)]
456    MinionReify(Metadata, Moo<Expression>, Atom),
457
458    /// `reifyimply(constraint,r)` ensures that `r->constraint`, where r is a 0/1 variable.
459    /// variable.
460    ///
461    /// Low-level Minion constraint.
462    ///
463    /// # See also
464    ///
465    ///  + [Minion documentation](https://minion-solver.readthedocs.io/en/stable/usage/constraints.html#reifyimply)
466    #[compatible(Minion)]
467    MinionReifyImply(Metadata, Moo<Expression>, Atom),
468
469    /// `w-inintervalset(x, [a1,a2, b1,b2, … ])` ensures that the value of x belongs to one of the
470    /// intervals {a1,…,a2}, {b1,…,b2} etc.
471    ///
472    /// The list of intervals must be given in numerical order.
473    ///
474    /// Low-level Minion constraint.
475    ///
476    /// # See also
477    ///>
478    ///  + [Minion documentation](https://minion-solver.readthedocs.io/en/stable/usage/constraints.html#w-inintervalset)
479    #[compatible(Minion)]
480    MinionWInIntervalSet(Metadata, Atom, Vec<i32>),
481
482    /// `w-inset(x, [v1, v2, … ])` ensures that the value of `x` is one of the explicitly given values `v1`, `v2`, etc.
483    ///
484    /// This constraint enforces membership in a specific set of discrete values rather than intervals.
485    ///
486    /// The list of values must be given in numerical order.
487    ///
488    /// Low-level Minion constraint.
489    ///
490    /// # See also
491    ///
492    ///  + [Minion documentation](https://minion-solver.readthedocs.io/en/stable/usage/constraints.html#w-inset)
493    #[compatible(Minion)]
494    MinionWInSet(Metadata, Atom, Vec<i32>),
495
496    /// `element_one(vec, i, e)` specifies that `vec[i] = e`. This implies that i is
497    /// in the range `[1..len(vec)]`.
498    ///
499    /// Low-level Minion constraint.
500    ///
501    /// # See also
502    ///
503    ///  + [Minion documentation](https://minion-solver.readthedocs.io/en/stable/usage/constraints.html#element_one)
504    #[compatible(Minion)]
505    MinionElementOne(Metadata, Vec<Atom>, Moo<Atom>, Moo<Atom>),
506
507    /// Declaration of an auxiliary variable.
508    ///
509    /// As with Savile Row, we semantically distinguish this from `Eq`.
510    #[compatible(Minion)]
511    #[polyquine_skip]
512    AuxDeclaration(Metadata, Reference, Moo<Expression>),
513
514    /// This expression is for encoding ints for the SAT solver, it stores the encoding type, the vector of booleans and the min/max for the int.
515    #[compatible(SAT)]
516    SATInt(Metadata, SATIntEncoding, Moo<Expression>, (i32, i32)),
517
518    /// Addition over a pair of expressions (i.e. a + b) rather than a vec-expr like Expression::Sum.
519    /// This is for compatibility with backends that do not support addition over vectors.
520    #[compatible(SMT)]
521    PairwiseSum(Metadata, Moo<Expression>, Moo<Expression>),
522
523    /// Multiplication over a pair of expressions (i.e. a * b) rather than a vec-expr like Expression::Product.
524    /// This is for compatibility with backends that do not support multiplication over vectors.
525    #[compatible(SMT)]
526    PairwiseProduct(Metadata, Moo<Expression>, Moo<Expression>),
527
528    #[compatible(JsonInput)]
529    Image(Metadata, Moo<Expression>, Moo<Expression>),
530
531    #[compatible(JsonInput)]
532    ImageSet(Metadata, Moo<Expression>, Moo<Expression>),
533
534    #[compatible(JsonInput)]
535    PreImage(Metadata, Moo<Expression>, Moo<Expression>),
536
537    #[compatible(JsonInput)]
538    Inverse(Metadata, Moo<Expression>, Moo<Expression>),
539
540    #[compatible(JsonInput)]
541    Restrict(Metadata, Moo<Expression>, Moo<Expression>),
542
543    /// Lexicographical < between two matrices.
544    ///
545    /// A <lex B iff: A[i] < B[i] for some i /\ (A[j] > B[j] for some j -> i < j)
546    /// I.e. A must be less than B at some index i, and if it is greater than B at another index j,
547    /// then j comes after i.
548    /// I.e. A must be greater than B at the first index where they differ.
549    ///
550    /// E.g. [1, 1] <lex [2, 1] and [1, 1] <lex [1, 2]
551    LexLt(Metadata, Moo<Expression>, Moo<Expression>),
552
553    /// Lexicographical <= between two matrices
554    LexLeq(Metadata, Moo<Expression>, Moo<Expression>),
555
556    /// Lexicographical > between two matrices
557    /// This is a parser-level construct, and is immediately normalised to LexLt(b, a)
558    LexGt(Metadata, Moo<Expression>, Moo<Expression>),
559
560    /// Lexicographical >= between two matrices
561    /// This is a parser-level construct, and is immediately normalised to LexLeq(b, a)
562    LexGeq(Metadata, Moo<Expression>, Moo<Expression>),
563
564    /// Low-level minion constraint. See Expression::LexLt
565    FlatLexLt(Metadata, Vec<Atom>, Vec<Atom>),
566
567    /// Low-level minion constraint. See Expression::LexLeq
568    FlatLexLeq(Metadata, Vec<Atom>, Vec<Atom>),
569}
570
571// for the given matrix literal, return a bounded domain from the min to max of applying op to each
572// child expression.
573//
574// Op must be monotonic.
575//
576// Returns none if unbounded
577fn bounded_i32_domain_for_matrix_literal_monotonic(
578    e: &Expression,
579    op: fn(i32, i32) -> Option<i32>,
580) -> Option<DomainPtr> {
581    // only care about the elements, not the indices
582    let (mut exprs, _) = e.clone().unwrap_matrix_unchecked()?;
583
584    // fold each element's domain into one using op.
585    //
586    // here, I assume that op is monotone. This means that the bounds of op([a1,a2],[b1,b2])  for
587    // the ranges [a1,a2], [b1,b2] will be
588    // [min(op(a1,b1),op(a2,b1),op(a1,b2),op(a2,b2)),max(op(a1,b1),op(a2,b1),op(a1,b2),op(a2,b2))].
589    //
590    // We used to not assume this, and work out the bounds by applying op on the Cartesian product
591    // of A and B; however, this caused a combinatorial explosion and my computer to run out of
592    // memory (on the hakank_eprime_xkcd test)...
593    //Int
594    // For example, to find the bounds of the intervals [1,4], [1,5] combined using op, we used to do
595    //  [min(op(1,1), op(1,2),op(1,3),op(1,4),op(1,5),op(2,1)..
596    //
597    // +,-,/,* are all monotone, so this assumption should be fine for now...
598
599    let expr = exprs.pop()?;
600    let dom = expr.domain_of()?;
601    let resolved = dom.resolve()?;
602    let GroundDomain::Int(ranges) = resolved.as_ref() else {
603        return None;
604    };
605
606    let (mut current_min, mut current_max) = range_vec_bounds_i32(ranges)?;
607
608    for expr in exprs {
609        let dom = expr.domain_of()?;
610        let resolved = dom.resolve()?;
611        let GroundDomain::Int(ranges) = resolved.as_ref() else {
612            return None;
613        };
614
615        let (min, max) = range_vec_bounds_i32(ranges)?;
616
617        // all the possible new values for current_min / current_max
618        let minmax = op(min, current_max)?;
619        let minmin = op(min, current_min)?;
620        let maxmin = op(max, current_min)?;
621        let maxmax = op(max, current_max)?;
622        let vals = [minmax, minmin, maxmin, maxmax];
623
624        current_min = *vals
625            .iter()
626            .min()
627            .expect("vals iterator should not be empty, and should have a minimum.");
628        current_max = *vals
629            .iter()
630            .max()
631            .expect("vals iterator should not be empty, and should have a maximum.");
632    }
633
634    if current_min == current_max {
635        Some(Domain::int(vec![Range::Single(current_min)]))
636    } else {
637        Some(Domain::int(vec![Range::Bounded(current_min, current_max)]))
638    }
639}
640
641// Returns none if unbounded
642fn range_vec_bounds_i32(ranges: &Vec<Range<i32>>) -> Option<(i32, i32)> {
643    let mut min = i32::MAX;
644    let mut max = i32::MIN;
645    for r in ranges {
646        match r {
647            Range::Single(i) => {
648                if *i < min {
649                    min = *i;
650                }
651                if *i > max {
652                    max = *i;
653                }
654            }
655            Range::Bounded(i, j) => {
656                if *i < min {
657                    min = *i;
658                }
659                if *j > max {
660                    max = *j;
661                }
662            }
663            Range::UnboundedR(_) | Range::UnboundedL(_) | Range::Unbounded => return None,
664        }
665    }
666    Some((min, max))
667}
668
669impl Expression {
670    /// Returns the possible values of the expression, recursing to leaf expressions
671    pub fn domain_of(&self) -> Option<DomainPtr> {
672        match self {
673            Expression::Union(_, a, b) => Some(Domain::set(
674                SetAttr::<IntVal>::default(),
675                a.domain_of()?.union(&b.domain_of()?).ok()?,
676            )),
677            Expression::Intersect(_, a, b) => Some(Domain::set(
678                SetAttr::<IntVal>::default(),
679                a.domain_of()?.intersect(&b.domain_of()?).ok()?,
680            )),
681            Expression::In(_, _, _) => Some(Domain::bool()),
682            Expression::Supset(_, _, _) => Some(Domain::bool()),
683            Expression::SupsetEq(_, _, _) => Some(Domain::bool()),
684            Expression::Subset(_, _, _) => Some(Domain::bool()),
685            Expression::SubsetEq(_, _, _) => Some(Domain::bool()),
686            Expression::AbstractLiteral(_, abslit) => abslit.domain_of(),
687            Expression::DominanceRelation(_, _) => Some(Domain::bool()),
688            Expression::FromSolution(_, expr) => Some(expr.domain_of()),
689            Expression::Metavar(_, _) => None,
690            Expression::Comprehension(_, comprehension) => comprehension.domain_of(),
691            Expression::AbstractComprehension(_, comprehension) => comprehension.domain_of(),
692            Expression::UnsafeIndex(_, matrix, _) | Expression::SafeIndex(_, matrix, _) => {
693                let dom = matrix.domain_of()?;
694                if let Some((elem_domain, _)) = dom.as_matrix() {
695                    return Some(elem_domain);
696                }
697
698                // may actually use the value in the future
699                #[allow(clippy::redundant_pattern_matching)]
700                if let Some(_) = dom.as_tuple() {
701                    // TODO: We can implement proper indexing for tuples
702                    return None;
703                }
704
705                // may actually use the value in the future
706                #[allow(clippy::redundant_pattern_matching)]
707                if let Some(_) = dom.as_record() {
708                    // TODO: We can implement proper indexing for records
709                    return None;
710                }
711
712                bug!("subject of an index operation should support indexing")
713            }
714            Expression::UnsafeSlice(_, matrix, indices)
715            | Expression::SafeSlice(_, matrix, indices) => {
716                let sliced_dimension = indices.iter().position(Option::is_none);
717
718                let dom = matrix.domain_of()?;
719                let Some((elem_domain, index_domains)) = dom.as_matrix() else {
720                    bug!("subject of an index operation should be a matrix");
721                };
722
723                match sliced_dimension {
724                    Some(dimension) => Some(Domain::matrix(
725                        elem_domain,
726                        vec![index_domains[dimension].clone()],
727                    )),
728
729                    // same as index
730                    None => Some(elem_domain),
731                }
732            }
733            Expression::InDomain(_, _, _) => Some(Domain::bool()),
734            Expression::Atomic(_, atom) => Some(atom.domain_of()),
735            Expression::Sum(_, e) => {
736                bounded_i32_domain_for_matrix_literal_monotonic(e, |x, y| Some(x + y))
737            }
738            Expression::Product(_, e) => {
739                bounded_i32_domain_for_matrix_literal_monotonic(e, |x, y| Some(x * y))
740            }
741            Expression::Min(_, e) => bounded_i32_domain_for_matrix_literal_monotonic(e, |x, y| {
742                Some(if x < y { x } else { y })
743            }),
744            Expression::Max(_, e) => bounded_i32_domain_for_matrix_literal_monotonic(e, |x, y| {
745                Some(if x > y { x } else { y })
746            }),
747            Expression::UnsafeDiv(_, a, b) => a
748                .domain_of()?
749                .resolve()?
750                .apply_i32(
751                    // rust integer division is truncating; however, we want to always round down,
752                    // including for negative numbers.
753                    |x, y| {
754                        if y != 0 {
755                            Some((x as f32 / y as f32).floor() as i32)
756                        } else {
757                            None
758                        }
759                    },
760                    b.domain_of()?.resolve()?.as_ref(),
761                )
762                .map(DomainPtr::from)
763                .ok(),
764            Expression::SafeDiv(_, a, b) => {
765                // rust integer division is truncating; however, we want to always round down
766                // including for negative numbers.
767                let domain = a
768                    .domain_of()?
769                    .resolve()?
770                    .apply_i32(
771                        |x, y| {
772                            if y != 0 {
773                                Some((x as f32 / y as f32).floor() as i32)
774                            } else {
775                                None
776                            }
777                        },
778                        b.domain_of()?.resolve()?.as_ref(),
779                    )
780                    .unwrap_or_else(|err| bug!("Got {err} when computing domain of {self}"));
781
782                if let GroundDomain::Int(ranges) = domain {
783                    let mut ranges = ranges;
784                    ranges.push(Range::Single(0));
785                    Some(Domain::int(ranges))
786                } else {
787                    bug!("Domain of {self} was not integer")
788                }
789            }
790            Expression::UnsafeMod(_, a, b) => a
791                .domain_of()?
792                .resolve()?
793                .apply_i32(
794                    |x, y| if y != 0 { Some(x % y) } else { None },
795                    b.domain_of()?.resolve()?.as_ref(),
796                )
797                .map(DomainPtr::from)
798                .ok(),
799            Expression::SafeMod(_, a, b) => {
800                let domain = a
801                    .domain_of()?
802                    .resolve()?
803                    .apply_i32(
804                        |x, y| if y != 0 { Some(x % y) } else { None },
805                        b.domain_of()?.resolve()?.as_ref(),
806                    )
807                    .unwrap_or_else(|err| bug!("Got {err} when computing domain of {self}"));
808
809                if let GroundDomain::Int(ranges) = domain {
810                    let mut ranges = ranges;
811                    ranges.push(Range::Single(0));
812                    Some(Domain::int(ranges))
813                } else {
814                    bug!("Domain of {self} was not integer")
815                }
816            }
817            Expression::SafePow(_, a, b) | Expression::UnsafePow(_, a, b) => a
818                .domain_of()?
819                .resolve()?
820                .apply_i32(
821                    |x, y| {
822                        if (x != 0 || y != 0) && y >= 0 {
823                            Some(x.pow(y as u32))
824                        } else {
825                            None
826                        }
827                    },
828                    b.domain_of()?.resolve()?.as_ref(),
829                )
830                .map(DomainPtr::from)
831                .ok(),
832            Expression::Root(_, _) => None,
833            Expression::Bubble(_, inner, _) => inner.domain_of(),
834            Expression::AuxDeclaration(_, _, _) => Some(Domain::bool()),
835            Expression::And(_, _) => Some(Domain::bool()),
836            Expression::Not(_, _) => Some(Domain::bool()),
837            Expression::Or(_, _) => Some(Domain::bool()),
838            Expression::Imply(_, _, _) => Some(Domain::bool()),
839            Expression::Iff(_, _, _) => Some(Domain::bool()),
840            Expression::Eq(_, _, _) => Some(Domain::bool()),
841            Expression::Neq(_, _, _) => Some(Domain::bool()),
842            Expression::Geq(_, _, _) => Some(Domain::bool()),
843            Expression::Leq(_, _, _) => Some(Domain::bool()),
844            Expression::Gt(_, _, _) => Some(Domain::bool()),
845            Expression::Lt(_, _, _) => Some(Domain::bool()),
846            Expression::FlatAbsEq(_, _, _) => Some(Domain::bool()),
847            Expression::FlatSumGeq(_, _, _) => Some(Domain::bool()),
848            Expression::FlatSumLeq(_, _, _) => Some(Domain::bool()),
849            Expression::MinionDivEqUndefZero(_, _, _, _) => Some(Domain::bool()),
850            Expression::MinionModuloEqUndefZero(_, _, _, _) => Some(Domain::bool()),
851            Expression::FlatIneq(_, _, _, _) => Some(Domain::bool()),
852            Expression::Flatten(_, n, m) => {
853                if let Some(expr) = n {
854                    if expr.return_type() == ReturnType::Int {
855                        // TODO: handle flatten with depth argument
856                        return None;
857                    }
858                } else {
859                    // TODO: currently only works for matrices
860                    let dom = m.domain_of()?.resolve()?;
861                    let (val_dom, idx_doms) = match dom.as_ref() {
862                        GroundDomain::Matrix(val, idx) => (val, idx),
863                        _ => return None,
864                    };
865                    let num_elems = matrix::num_elements(idx_doms).ok()? as i32;
866
867                    let new_index_domain = Domain::int(vec![Range::Bounded(1, num_elems)]);
868                    return Some(Domain::matrix(
869                        val_dom.clone().into(),
870                        vec![new_index_domain],
871                    ));
872                }
873                None
874            }
875            Expression::AllDiff(_, _) => Some(Domain::bool()),
876            Expression::Table(_, _, _) => Some(Domain::bool()),
877            Expression::NegativeTable(_, _, _) => Some(Domain::bool()),
878            Expression::FlatWatchedLiteral(_, _, _) => Some(Domain::bool()),
879            Expression::MinionReify(_, _, _) => Some(Domain::bool()),
880            Expression::MinionReifyImply(_, _, _) => Some(Domain::bool()),
881            Expression::MinionWInIntervalSet(_, _, _) => Some(Domain::bool()),
882            Expression::MinionWInSet(_, _, _) => Some(Domain::bool()),
883            Expression::MinionElementOne(_, _, _, _) => Some(Domain::bool()),
884            Expression::Neg(_, x) => {
885                let dom = x.domain_of()?;
886                let mut ranges = dom.as_int()?;
887
888                ranges = ranges
889                    .into_iter()
890                    .map(|r| match r {
891                        Range::Single(x) => Range::Single(-x),
892                        Range::Bounded(x, y) => Range::Bounded(-y, -x),
893                        Range::UnboundedR(i) => Range::UnboundedL(-i),
894                        Range::UnboundedL(i) => Range::UnboundedR(-i),
895                        Range::Unbounded => Range::Unbounded,
896                    })
897                    .collect();
898
899                Some(Domain::int(ranges))
900            }
901            Expression::Minus(_, a, b) => a
902                .domain_of()?
903                .resolve()?
904                .apply_i32(|x, y| Some(x - y), b.domain_of()?.resolve()?.as_ref())
905                .map(DomainPtr::from)
906                .ok(),
907            Expression::FlatAllDiff(_, _) => Some(Domain::bool()),
908            Expression::FlatMinusEq(_, _, _) => Some(Domain::bool()),
909            Expression::FlatProductEq(_, _, _, _) => Some(Domain::bool()),
910            Expression::FlatWeightedSumLeq(_, _, _, _) => Some(Domain::bool()),
911            Expression::FlatWeightedSumGeq(_, _, _, _) => Some(Domain::bool()),
912            Expression::Abs(_, a) => a
913                .domain_of()?
914                .resolve()?
915                .apply_i32(|a, _| Some(a.abs()), a.domain_of()?.resolve()?.as_ref())
916                .map(DomainPtr::from)
917                .ok(),
918            Expression::MinionPow(_, _, _, _) => Some(Domain::bool()),
919            Expression::ToInt(_, _) => Some(Domain::int(vec![Range::Bounded(0, 1)])),
920            Expression::SATInt(_, _, _, (low, high)) => {
921                Some(Domain::int_ground(vec![Range::Bounded(*low, *high)]))
922            }
923            Expression::PairwiseSum(_, a, b) => a
924                .domain_of()?
925                .resolve()?
926                .apply_i32(|a, b| Some(a + b), b.domain_of()?.resolve()?.as_ref())
927                .map(DomainPtr::from)
928                .ok(),
929            Expression::PairwiseProduct(_, a, b) => a
930                .domain_of()?
931                .resolve()?
932                .apply_i32(|a, b| Some(a * b), b.domain_of()?.resolve()?.as_ref())
933                .map(DomainPtr::from)
934                .ok(),
935            Expression::Defined(_, function) => get_function_domain(function),
936            Expression::Range(_, function) => get_function_codomain(function),
937            Expression::Image(_, function, _) => get_function_codomain(function),
938            Expression::ImageSet(_, function, _) => get_function_codomain(function),
939            Expression::PreImage(_, function, _) => get_function_domain(function),
940            Expression::Restrict(_, function, new_domain) => {
941                let (attrs, _, codom) = function.domain_of()?.as_function()?;
942                let new_dom = new_domain.domain_of()?;
943                Some(Domain::function(attrs, new_dom, codom))
944            }
945            Expression::Inverse(..) => Some(Domain::bool()),
946            Expression::LexLt(..) => Some(Domain::bool()),
947            Expression::LexLeq(..) => Some(Domain::bool()),
948            Expression::LexGt(..) => Some(Domain::bool()),
949            Expression::LexGeq(..) => Some(Domain::bool()),
950            Expression::FlatLexLt(..) => Some(Domain::bool()),
951            Expression::FlatLexLeq(..) => Some(Domain::bool()),
952        }
953    }
954
955    pub fn get_meta(&self) -> Metadata {
956        let metas: VecDeque<Metadata> = self.children_bi();
957        metas[0].clone()
958    }
959
960    pub fn set_meta(&self, meta: Metadata) {
961        self.transform_bi(&|_| meta.clone());
962    }
963
964    /// Checks whether this expression is safe.
965    ///
966    /// An expression is unsafe if can be undefined, or if any of its children can be undefined.
967    ///
968    /// Unsafe expressions are (typically) prefixed with Unsafe in our AST, and can be made
969    /// safe through the use of bubble rules.
970    pub fn is_safe(&self) -> bool {
971        // TODO: memoise in Metadata
972        for expr in self.universe() {
973            match expr {
974                Expression::UnsafeDiv(_, _, _)
975                | Expression::UnsafeMod(_, _, _)
976                | Expression::UnsafePow(_, _, _)
977                | Expression::UnsafeIndex(_, _, _)
978                | Expression::Bubble(_, _, _)
979                | Expression::UnsafeSlice(_, _, _) => {
980                    return false;
981                }
982                _ => {}
983            }
984        }
985        true
986    }
987
988    pub fn is_clean(&self) -> bool {
989        let metadata = self.get_meta();
990        metadata.clean
991    }
992
993    pub fn set_clean(&mut self, bool_value: bool) {
994        let mut metadata = self.get_meta();
995        metadata.clean = bool_value;
996        self.set_meta(metadata);
997    }
998
999    /// True if the expression is an associative and commutative operator
1000    pub fn is_associative_commutative_operator(&self) -> bool {
1001        TryInto::<ACOperatorKind>::try_into(self).is_ok()
1002    }
1003
1004    /// True if the expression is a matrix literal.
1005    ///
1006    /// This is true for both forms of matrix literals: those with elements of type [`Literal`] and
1007    /// [`Expression`].
1008    pub fn is_matrix_literal(&self) -> bool {
1009        matches!(
1010            self,
1011            Expression::AbstractLiteral(_, AbstractLiteral::Matrix(_, _))
1012                | Expression::Atomic(
1013                    _,
1014                    Atom::Literal(Literal::AbstractLiteral(AbstractLiteral::Matrix(_, _))),
1015                )
1016        )
1017    }
1018
1019    /// True iff self and other are both atomic and identical.
1020    ///
1021    /// This method is useful to cheaply check equivalence. Assuming CSE is enabled, any unifiable
1022    /// expressions will be rewritten to a common variable. This is much cheaper than checking the
1023    /// entire subtrees of `self` and `other`.
1024    pub fn identical_atom_to(&self, other: &Expression) -> bool {
1025        let atom1: Result<&Atom, _> = self.try_into();
1026        let atom2: Result<&Atom, _> = other.try_into();
1027
1028        if let (Ok(atom1), Ok(atom2)) = (atom1, atom2) {
1029            atom2 == atom1
1030        } else {
1031            false
1032        }
1033    }
1034
1035    /// If the expression is a list, returns a *copied* vector of the inner expressions.
1036    ///
1037    /// A list is any a matrix with the domain `int(1..)`. This includes matrix literals without
1038    /// any explicitly specified domain.
1039    pub fn unwrap_list(&self) -> Option<Vec<Expression>> {
1040        match self {
1041            Expression::AbstractLiteral(_, matrix @ AbstractLiteral::Matrix(_, _)) => {
1042                matrix.unwrap_list().cloned()
1043            }
1044            Expression::Atomic(
1045                _,
1046                Atom::Literal(Literal::AbstractLiteral(matrix @ AbstractLiteral::Matrix(_, _))),
1047            ) => matrix.unwrap_list().map(|elems| {
1048                elems
1049                    .clone()
1050                    .into_iter()
1051                    .map(|x: Literal| Expression::Atomic(Metadata::new(), Atom::Literal(x)))
1052                    .collect_vec()
1053            }),
1054            _ => None,
1055        }
1056    }
1057
1058    /// If the expression is a matrix, gets it elements and index domain.
1059    ///
1060    /// **Consider using the safer [`Expression::unwrap_list`] instead.**
1061    ///
1062    /// It is generally undefined to edit the length of a matrix unless it is a list (as defined by
1063    /// [`Expression::unwrap_list`]). Users of this function should ensure that, if the matrix is
1064    /// reconstructed, the index domain and the number of elements in the matrix remain the same.
1065    pub fn unwrap_matrix_unchecked(self) -> Option<(Vec<Expression>, DomainPtr)> {
1066        match self {
1067            Expression::AbstractLiteral(_, AbstractLiteral::Matrix(elems, domain)) => {
1068                Some((elems, domain))
1069            }
1070            Expression::Atomic(
1071                _,
1072                Atom::Literal(Literal::AbstractLiteral(AbstractLiteral::Matrix(elems, domain))),
1073            ) => Some((
1074                elems
1075                    .into_iter()
1076                    .map(|x: Literal| Expression::Atomic(Metadata::new(), Atom::Literal(x)))
1077                    .collect_vec(),
1078                domain.into(),
1079            )),
1080
1081            _ => None,
1082        }
1083    }
1084
1085    /// For a Root expression, extends the inner vec with the given vec.
1086    ///
1087    /// # Panics
1088    /// Panics if the expression is not Root.
1089    pub fn extend_root(self, exprs: Vec<Expression>) -> Expression {
1090        match self {
1091            Expression::Root(meta, mut children) => {
1092                children.extend(exprs);
1093                Expression::Root(meta, children)
1094            }
1095            _ => panic!("extend_root called on a non-Root expression"),
1096        }
1097    }
1098
1099    /// Converts the expression to a literal, if possible.
1100    pub fn into_literal(self) -> Option<Literal> {
1101        match self {
1102            Expression::Atomic(_, Atom::Literal(lit)) => Some(lit),
1103            Expression::AbstractLiteral(_, abslit) => {
1104                Some(Literal::AbstractLiteral(abslit.into_literals()?))
1105            }
1106            Expression::Neg(_, e) => {
1107                let Literal::Int(i) = Moo::unwrap_or_clone(e).into_literal()? else {
1108                    bug!("negated literal should be an int");
1109                };
1110
1111                Some(Literal::Int(-i))
1112            }
1113
1114            _ => None,
1115        }
1116    }
1117
1118    /// If this expression is an associative-commutative operator, return its [ACOperatorKind].
1119    pub fn to_ac_operator_kind(&self) -> Option<ACOperatorKind> {
1120        TryFrom::try_from(self).ok()
1121    }
1122
1123    /// Returns the categories of all sub-expressions of self.
1124    pub fn universe_categories(&self) -> HashSet<Category> {
1125        self.universe()
1126            .into_iter()
1127            .map(|x| x.category_of())
1128            .collect()
1129    }
1130}
1131
1132pub fn get_function_domain(function: &Moo<Expression>) -> Option<DomainPtr> {
1133    let function_domain = function.domain_of()?;
1134    match function_domain.resolve().as_ref() {
1135        Some(d) => {
1136            match d.as_ref() {
1137                GroundDomain::Function(_, domain, _) => Some(domain.clone().into()),
1138                // Not defined for anything other than a function
1139                _ => None,
1140            }
1141        }
1142        None => {
1143            match function_domain.as_unresolved()? {
1144                UnresolvedDomain::Function(_, domain, _) => Some(domain.clone()),
1145                // Not defined for anything other than a function
1146                _ => None,
1147            }
1148        }
1149    }
1150}
1151
1152pub fn get_function_codomain(function: &Moo<Expression>) -> Option<DomainPtr> {
1153    let function_domain = function.domain_of()?;
1154    match function_domain.resolve().as_ref() {
1155        Some(d) => {
1156            match d.as_ref() {
1157                GroundDomain::Function(_, _, codomain) => Some(codomain.clone().into()),
1158                // Not defined for anything other than a function
1159                _ => None,
1160            }
1161        }
1162        None => {
1163            match function_domain.as_unresolved()? {
1164                UnresolvedDomain::Function(_, _, codomain) => Some(codomain.clone()),
1165                // Not defined for anything other than a function
1166                _ => None,
1167            }
1168        }
1169    }
1170}
1171
1172impl TryFrom<&Expression> for i32 {
1173    type Error = ();
1174
1175    fn try_from(value: &Expression) -> Result<Self, Self::Error> {
1176        let Expression::Atomic(_, atom) = value else {
1177            return Err(());
1178        };
1179
1180        let Atom::Literal(lit) = atom else {
1181            return Err(());
1182        };
1183
1184        let Literal::Int(i) = lit else {
1185            return Err(());
1186        };
1187
1188        Ok(*i)
1189    }
1190}
1191
1192impl TryFrom<Expression> for i32 {
1193    type Error = ();
1194
1195    fn try_from(value: Expression) -> Result<Self, Self::Error> {
1196        TryFrom::<&Expression>::try_from(&value)
1197    }
1198}
1199impl From<i32> for Expression {
1200    fn from(i: i32) -> Self {
1201        Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Int(i)))
1202    }
1203}
1204
1205impl From<bool> for Expression {
1206    fn from(b: bool) -> Self {
1207        Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Bool(b)))
1208    }
1209}
1210
1211impl From<Atom> for Expression {
1212    fn from(value: Atom) -> Self {
1213        Expression::Atomic(Metadata::new(), value)
1214    }
1215}
1216
1217impl From<Literal> for Expression {
1218    fn from(value: Literal) -> Self {
1219        Expression::Atomic(Metadata::new(), value.into())
1220    }
1221}
1222
1223impl From<Moo<Expression>> for Expression {
1224    fn from(val: Moo<Expression>) -> Self {
1225        val.as_ref().clone()
1226    }
1227}
1228
1229impl CategoryOf for Expression {
1230    fn category_of(&self) -> Category {
1231        // take highest category of all the expressions children
1232        let category = self.cata(&move |x,children| {
1233
1234            if let Some(max_category) = children.iter().max() {
1235                // if this expression contains subexpressions, return the maximum category of the
1236                // subexpressions
1237                *max_category
1238            } else {
1239                // this expression has no children
1240                let mut max_category = Category::Bottom;
1241
1242                // calculate the category by looking at all atoms, submodels, comprehensions, and
1243                // declarationptrs inside this expression
1244
1245                // this should generically cover all leaf types we currently have in oxide.
1246
1247                // if x contains submodels (including comprehensions)
1248                if !Biplate::<Model>::universe_bi(&x).is_empty() {
1249                    // assume that the category is decision
1250                    return Category::Decision;
1251                }
1252
1253                // if x contains atoms
1254                if let Some(max_atom_category) = Biplate::<Atom>::universe_bi(&x).iter().map(|x| x.category_of()).max()
1255                // and those atoms have a higher category than we already know about
1256                && max_atom_category > max_category{
1257                    // update category
1258                    max_category = max_atom_category;
1259                }
1260
1261                // if x contains declarationPtrs
1262                if let Some(max_declaration_category) = Biplate::<DeclarationPtr>::universe_bi(&x).iter().map(|x| x.category_of()).max()
1263                // and those pointers have a higher category than we already know about
1264                && max_declaration_category > max_category{
1265                    // update category
1266                    max_category = max_declaration_category;
1267                }
1268                max_category
1269
1270            }
1271        });
1272
1273        if cfg!(debug_assertions) {
1274            trace!(
1275                category= %category,
1276                expression= %self,
1277                "Called Expression::category_of()"
1278            );
1279        };
1280        category
1281    }
1282}
1283
1284impl Display for Expression {
1285    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
1286        match &self {
1287            Expression::Union(_, box1, box2) => {
1288                write!(f, "({} union {})", box1.clone(), box2.clone())
1289            }
1290            Expression::In(_, e1, e2) => {
1291                write!(f, "{e1} in {e2}")
1292            }
1293            Expression::Intersect(_, box1, box2) => {
1294                write!(f, "({} intersect {})", box1.clone(), box2.clone())
1295            }
1296            Expression::Supset(_, box1, box2) => {
1297                write!(f, "({} supset {})", box1.clone(), box2.clone())
1298            }
1299            Expression::SupsetEq(_, box1, box2) => {
1300                write!(f, "({} supsetEq {})", box1.clone(), box2.clone())
1301            }
1302            Expression::Subset(_, box1, box2) => {
1303                write!(f, "({} subset {})", box1.clone(), box2.clone())
1304            }
1305            Expression::SubsetEq(_, box1, box2) => {
1306                write!(f, "({} subsetEq {})", box1.clone(), box2.clone())
1307            }
1308
1309            Expression::AbstractLiteral(_, l) => l.fmt(f),
1310            Expression::Comprehension(_, c) => c.fmt(f),
1311            Expression::AbstractComprehension(_, c) => c.fmt(f),
1312            Expression::UnsafeIndex(_, e1, e2) | Expression::SafeIndex(_, e1, e2) => {
1313                write!(f, "{e1}{}", pretty_vec(e2))
1314            }
1315            Expression::UnsafeSlice(_, e1, es) | Expression::SafeSlice(_, e1, es) => {
1316                let args = es
1317                    .iter()
1318                    .map(|x| match x {
1319                        Some(x) => format!("{x}"),
1320                        None => "..".into(),
1321                    })
1322                    .join(",");
1323
1324                write!(f, "{e1}[{args}]")
1325            }
1326            Expression::InDomain(_, e, domain) => {
1327                write!(f, "__inDomain({e},{domain})")
1328            }
1329            Expression::Root(_, exprs) => {
1330                write!(f, "{}", pretty_expressions_as_top_level(exprs))
1331            }
1332            Expression::DominanceRelation(_, expr) => write!(f, "DominanceRelation({expr})"),
1333            Expression::FromSolution(_, expr) => write!(f, "FromSolution({expr})"),
1334            Expression::Metavar(_, name) => write!(f, "&{name}"),
1335            Expression::Atomic(_, atom) => atom.fmt(f),
1336            Expression::Abs(_, a) => write!(f, "|{a}|"),
1337            Expression::Sum(_, e) => {
1338                write!(f, "sum({e})")
1339            }
1340            Expression::Product(_, e) => {
1341                write!(f, "product({e})")
1342            }
1343            Expression::Min(_, e) => {
1344                write!(f, "min({e})")
1345            }
1346            Expression::Max(_, e) => {
1347                write!(f, "max({e})")
1348            }
1349            Expression::Not(_, expr_box) => {
1350                write!(f, "!({})", expr_box.clone())
1351            }
1352            Expression::Or(_, e) => {
1353                write!(f, "or({e})")
1354            }
1355            Expression::And(_, e) => {
1356                write!(f, "and({e})")
1357            }
1358            Expression::Imply(_, box1, box2) => {
1359                write!(f, "({box1}) -> ({box2})")
1360            }
1361            Expression::Iff(_, box1, box2) => {
1362                write!(f, "({box1}) <-> ({box2})")
1363            }
1364            Expression::Eq(_, box1, box2) => {
1365                write!(f, "({} = {})", box1.clone(), box2.clone())
1366            }
1367            Expression::Neq(_, box1, box2) => {
1368                write!(f, "({} != {})", box1.clone(), box2.clone())
1369            }
1370            Expression::Geq(_, box1, box2) => {
1371                write!(f, "({} >= {})", box1.clone(), box2.clone())
1372            }
1373            Expression::Leq(_, box1, box2) => {
1374                write!(f, "({} <= {})", box1.clone(), box2.clone())
1375            }
1376            Expression::Gt(_, box1, box2) => {
1377                write!(f, "({} > {})", box1.clone(), box2.clone())
1378            }
1379            Expression::Lt(_, box1, box2) => {
1380                write!(f, "({} < {})", box1.clone(), box2.clone())
1381            }
1382            Expression::FlatSumGeq(_, box1, box2) => {
1383                write!(f, "SumGeq({}, {})", pretty_vec(box1), box2.clone())
1384            }
1385            Expression::FlatSumLeq(_, box1, box2) => {
1386                write!(f, "SumLeq({}, {})", pretty_vec(box1), box2.clone())
1387            }
1388            Expression::FlatIneq(_, box1, box2, box3) => write!(
1389                f,
1390                "Ineq({}, {}, {})",
1391                box1.clone(),
1392                box2.clone(),
1393                box3.clone()
1394            ),
1395            Expression::Flatten(_, n, m) => {
1396                if let Some(n) = n {
1397                    write!(f, "flatten({n}, {m})")
1398                } else {
1399                    write!(f, "flatten({m})")
1400                }
1401            }
1402            Expression::AllDiff(_, e) => {
1403                write!(f, "allDiff({e})")
1404            }
1405            Expression::Table(_, tuple_expr, rows_expr) => {
1406                write!(f, "table({tuple_expr}, {rows_expr})")
1407            }
1408            Expression::NegativeTable(_, tuple_expr, rows_expr) => {
1409                write!(f, "negativeTable({tuple_expr}, {rows_expr})")
1410            }
1411            Expression::Bubble(_, box1, box2) => {
1412                write!(f, "{{{} @ {}}}", box1.clone(), box2.clone())
1413            }
1414            Expression::SafeDiv(_, box1, box2) => {
1415                write!(f, "SafeDiv({}, {})", box1.clone(), box2.clone())
1416            }
1417            Expression::UnsafeDiv(_, box1, box2) => {
1418                write!(f, "UnsafeDiv({}, {})", box1.clone(), box2.clone())
1419            }
1420            Expression::UnsafePow(_, box1, box2) => {
1421                write!(f, "UnsafePow({}, {})", box1.clone(), box2.clone())
1422            }
1423            Expression::SafePow(_, box1, box2) => {
1424                write!(f, "SafePow({}, {})", box1.clone(), box2.clone())
1425            }
1426            Expression::MinionDivEqUndefZero(_, box1, box2, box3) => {
1427                write!(
1428                    f,
1429                    "DivEq({}, {}, {})",
1430                    box1.clone(),
1431                    box2.clone(),
1432                    box3.clone()
1433                )
1434            }
1435            Expression::MinionModuloEqUndefZero(_, box1, box2, box3) => {
1436                write!(
1437                    f,
1438                    "ModEq({}, {}, {})",
1439                    box1.clone(),
1440                    box2.clone(),
1441                    box3.clone()
1442                )
1443            }
1444            Expression::FlatWatchedLiteral(_, x, l) => {
1445                write!(f, "WatchedLiteral({x},{l})")
1446            }
1447            Expression::MinionReify(_, box1, box2) => {
1448                write!(f, "Reify({}, {})", box1.clone(), box2.clone())
1449            }
1450            Expression::MinionReifyImply(_, box1, box2) => {
1451                write!(f, "ReifyImply({}, {})", box1.clone(), box2.clone())
1452            }
1453            Expression::MinionWInIntervalSet(_, atom, intervals) => {
1454                let intervals = intervals.iter().join(",");
1455                write!(f, "__minion_w_inintervalset({atom},[{intervals}])")
1456            }
1457            Expression::MinionWInSet(_, atom, values) => {
1458                let values = values.iter().join(",");
1459                write!(f, "__minion_w_inset({atom},[{values}])")
1460            }
1461            Expression::AuxDeclaration(_, reference, e) => {
1462                write!(f, "{} =aux {}", reference, e.clone())
1463            }
1464            Expression::UnsafeMod(_, a, b) => {
1465                write!(f, "{} % {}", a.clone(), b.clone())
1466            }
1467            Expression::SafeMod(_, a, b) => {
1468                write!(f, "SafeMod({},{})", a.clone(), b.clone())
1469            }
1470            Expression::Neg(_, a) => {
1471                write!(f, "-({})", a.clone())
1472            }
1473            Expression::Minus(_, a, b) => {
1474                write!(f, "({} - {})", a.clone(), b.clone())
1475            }
1476            Expression::FlatAllDiff(_, es) => {
1477                write!(f, "__flat_alldiff({})", pretty_vec(es))
1478            }
1479            Expression::FlatAbsEq(_, a, b) => {
1480                write!(f, "AbsEq({},{})", a.clone(), b.clone())
1481            }
1482            Expression::FlatMinusEq(_, a, b) => {
1483                write!(f, "MinusEq({},{})", a.clone(), b.clone())
1484            }
1485            Expression::FlatProductEq(_, a, b, c) => {
1486                write!(
1487                    f,
1488                    "FlatProductEq({},{},{})",
1489                    a.clone(),
1490                    b.clone(),
1491                    c.clone()
1492                )
1493            }
1494            Expression::FlatWeightedSumLeq(_, cs, vs, total) => {
1495                write!(
1496                    f,
1497                    "FlatWeightedSumLeq({},{},{})",
1498                    pretty_vec(cs),
1499                    pretty_vec(vs),
1500                    total.clone()
1501                )
1502            }
1503            Expression::FlatWeightedSumGeq(_, cs, vs, total) => {
1504                write!(
1505                    f,
1506                    "FlatWeightedSumGeq({},{},{})",
1507                    pretty_vec(cs),
1508                    pretty_vec(vs),
1509                    total.clone()
1510                )
1511            }
1512            Expression::MinionPow(_, atom, atom1, atom2) => {
1513                write!(f, "MinionPow({atom},{atom1},{atom2})")
1514            }
1515            Expression::MinionElementOne(_, atoms, atom, atom1) => {
1516                let atoms = atoms.iter().join(",");
1517                write!(f, "__minion_element_one([{atoms}],{atom},{atom1})")
1518            }
1519
1520            Expression::ToInt(_, expr) => {
1521                write!(f, "toInt({expr})")
1522            }
1523
1524            Expression::SATInt(_, encoding, bits, (min, max)) => {
1525                write!(f, "SATInt({encoding:?}, {bits} [{min}, {max}])")
1526            }
1527
1528            Expression::PairwiseSum(_, a, b) => write!(f, "PairwiseSum({a}, {b})"),
1529            Expression::PairwiseProduct(_, a, b) => write!(f, "PairwiseProduct({a}, {b})"),
1530
1531            Expression::Defined(_, function) => write!(f, "defined({function})"),
1532            Expression::Range(_, function) => write!(f, "range({function})"),
1533            Expression::Image(_, function, elems) => write!(f, "image({function},{elems})"),
1534            Expression::ImageSet(_, function, elems) => write!(f, "imageSet({function},{elems})"),
1535            Expression::PreImage(_, function, elems) => write!(f, "preImage({function},{elems})"),
1536            Expression::Inverse(_, a, b) => write!(f, "inverse({a},{b})"),
1537            Expression::Restrict(_, function, domain) => write!(f, "restrict({function},{domain})"),
1538
1539            Expression::LexLt(_, a, b) => write!(f, "({a} <lex {b})"),
1540            Expression::LexLeq(_, a, b) => write!(f, "({a} <=lex {b})"),
1541            Expression::LexGt(_, a, b) => write!(f, "({a} >lex {b})"),
1542            Expression::LexGeq(_, a, b) => write!(f, "({a} >=lex {b})"),
1543            Expression::FlatLexLt(_, a, b) => {
1544                write!(f, "FlatLexLt({}, {})", pretty_vec(a), pretty_vec(b))
1545            }
1546            Expression::FlatLexLeq(_, a, b) => {
1547                write!(f, "FlatLexLeq({}, {})", pretty_vec(a), pretty_vec(b))
1548            }
1549        }
1550    }
1551}
1552
1553impl Typeable for Expression {
1554    fn return_type(&self) -> ReturnType {
1555        match self {
1556            Expression::Union(_, subject, _) => ReturnType::Set(Box::new(subject.return_type())),
1557            Expression::Intersect(_, subject, _) => {
1558                ReturnType::Set(Box::new(subject.return_type()))
1559            }
1560            Expression::In(_, _, _) => ReturnType::Bool,
1561            Expression::Supset(_, _, _) => ReturnType::Bool,
1562            Expression::SupsetEq(_, _, _) => ReturnType::Bool,
1563            Expression::Subset(_, _, _) => ReturnType::Bool,
1564            Expression::SubsetEq(_, _, _) => ReturnType::Bool,
1565            Expression::AbstractLiteral(_, lit) => lit.return_type(),
1566            Expression::UnsafeIndex(_, subject, idx) | Expression::SafeIndex(_, subject, idx) => {
1567                let subject_ty = subject.return_type();
1568                match subject_ty {
1569                    ReturnType::Matrix(_) => {
1570                        // For n-dimensional matrices, unwrap the element type until
1571                        // we either get to the innermost element type or the last index
1572                        let mut elem_typ = subject_ty;
1573                        let mut idx_len = idx.len();
1574                        while idx_len > 0
1575                            && let ReturnType::Matrix(new_elem_typ) = &elem_typ
1576                        {
1577                            elem_typ = *new_elem_typ.clone();
1578                            idx_len -= 1;
1579                        }
1580                        elem_typ
1581                    }
1582                    // TODO: We can implement indexing for these eventually
1583                    ReturnType::Record(_) | ReturnType::Tuple(_) => ReturnType::Unknown,
1584                    _ => bug!(
1585                        "Invalid indexing operation: expected the operand to be a collection, got {self}: {subject_ty}"
1586                    ),
1587                }
1588            }
1589            Expression::UnsafeSlice(_, subject, _) | Expression::SafeSlice(_, subject, _) => {
1590                ReturnType::Matrix(Box::new(subject.return_type()))
1591            }
1592            Expression::InDomain(_, _, _) => ReturnType::Bool,
1593            Expression::Comprehension(_, comp) => comp.return_type(),
1594            Expression::AbstractComprehension(_, comp) => comp.return_type(),
1595            Expression::Root(_, _) => ReturnType::Bool,
1596            Expression::DominanceRelation(_, _) => ReturnType::Bool,
1597            Expression::FromSolution(_, expr) => expr.return_type(),
1598            Expression::Metavar(_, _) => ReturnType::Unknown,
1599            Expression::Atomic(_, atom) => atom.return_type(),
1600            Expression::Abs(_, _) => ReturnType::Int,
1601            Expression::Sum(_, _) => ReturnType::Int,
1602            Expression::Product(_, _) => ReturnType::Int,
1603            Expression::Min(_, _) => ReturnType::Int,
1604            Expression::Max(_, _) => ReturnType::Int,
1605            Expression::Not(_, _) => ReturnType::Bool,
1606            Expression::Or(_, _) => ReturnType::Bool,
1607            Expression::Imply(_, _, _) => ReturnType::Bool,
1608            Expression::Iff(_, _, _) => ReturnType::Bool,
1609            Expression::And(_, _) => ReturnType::Bool,
1610            Expression::Eq(_, _, _) => ReturnType::Bool,
1611            Expression::Neq(_, _, _) => ReturnType::Bool,
1612            Expression::Geq(_, _, _) => ReturnType::Bool,
1613            Expression::Leq(_, _, _) => ReturnType::Bool,
1614            Expression::Gt(_, _, _) => ReturnType::Bool,
1615            Expression::Lt(_, _, _) => ReturnType::Bool,
1616            Expression::SafeDiv(_, _, _) => ReturnType::Int,
1617            Expression::UnsafeDiv(_, _, _) => ReturnType::Int,
1618            Expression::FlatAllDiff(_, _) => ReturnType::Bool,
1619            Expression::FlatSumGeq(_, _, _) => ReturnType::Bool,
1620            Expression::FlatSumLeq(_, _, _) => ReturnType::Bool,
1621            Expression::MinionDivEqUndefZero(_, _, _, _) => ReturnType::Bool,
1622            Expression::FlatIneq(_, _, _, _) => ReturnType::Bool,
1623            Expression::Flatten(_, _, matrix) => {
1624                let matrix_type = matrix.return_type();
1625                match matrix_type {
1626                    ReturnType::Matrix(_) => {
1627                        // unwrap until we get to innermost element
1628                        let mut elem_type = matrix_type;
1629                        while let ReturnType::Matrix(new_elem_type) = &elem_type {
1630                            elem_type = *new_elem_type.clone();
1631                        }
1632                        ReturnType::Matrix(Box::new(elem_type))
1633                    }
1634                    _ => bug!(
1635                        "Invalid indexing operation: expected the operand to be a collection, got {self}: {matrix_type}"
1636                    ),
1637                }
1638            }
1639            Expression::AllDiff(_, _) => ReturnType::Bool,
1640            Expression::Table(_, _, _) => ReturnType::Bool,
1641            Expression::NegativeTable(_, _, _) => ReturnType::Bool,
1642            Expression::Bubble(_, inner, _) => inner.return_type(),
1643            Expression::FlatWatchedLiteral(_, _, _) => ReturnType::Bool,
1644            Expression::MinionReify(_, _, _) => ReturnType::Bool,
1645            Expression::MinionReifyImply(_, _, _) => ReturnType::Bool,
1646            Expression::MinionWInIntervalSet(_, _, _) => ReturnType::Bool,
1647            Expression::MinionWInSet(_, _, _) => ReturnType::Bool,
1648            Expression::MinionElementOne(_, _, _, _) => ReturnType::Bool,
1649            Expression::AuxDeclaration(_, _, _) => ReturnType::Bool,
1650            Expression::UnsafeMod(_, _, _) => ReturnType::Int,
1651            Expression::SafeMod(_, _, _) => ReturnType::Int,
1652            Expression::MinionModuloEqUndefZero(_, _, _, _) => ReturnType::Bool,
1653            Expression::Neg(_, _) => ReturnType::Int,
1654            Expression::UnsafePow(_, _, _) => ReturnType::Int,
1655            Expression::SafePow(_, _, _) => ReturnType::Int,
1656            Expression::Minus(_, _, _) => ReturnType::Int,
1657            Expression::FlatAbsEq(_, _, _) => ReturnType::Bool,
1658            Expression::FlatMinusEq(_, _, _) => ReturnType::Bool,
1659            Expression::FlatProductEq(_, _, _, _) => ReturnType::Bool,
1660            Expression::FlatWeightedSumLeq(_, _, _, _) => ReturnType::Bool,
1661            Expression::FlatWeightedSumGeq(_, _, _, _) => ReturnType::Bool,
1662            Expression::MinionPow(_, _, _, _) => ReturnType::Bool,
1663            Expression::ToInt(_, _) => ReturnType::Int,
1664            Expression::SATInt(..) => ReturnType::Int,
1665            Expression::PairwiseSum(_, _, _) => ReturnType::Int,
1666            Expression::PairwiseProduct(_, _, _) => ReturnType::Int,
1667            Expression::Defined(_, function) => {
1668                let subject = function.return_type();
1669                match subject {
1670                    ReturnType::Function(domain, _) => *domain,
1671                    _ => bug!(
1672                        "Invalid defined operation: expected the operand to be a function, got {self}: {subject}"
1673                    ),
1674                }
1675            }
1676            Expression::Range(_, function) => {
1677                let subject = function.return_type();
1678                match subject {
1679                    ReturnType::Function(_, codomain) => *codomain,
1680                    _ => bug!(
1681                        "Invalid range operation: expected the operand to be a function, got {self}: {subject}"
1682                    ),
1683                }
1684            }
1685            Expression::Image(_, function, _) => {
1686                let subject = function.return_type();
1687                match subject {
1688                    ReturnType::Function(_, codomain) => *codomain,
1689                    _ => bug!(
1690                        "Invalid image operation: expected the operand to be a function, got {self}: {subject}"
1691                    ),
1692                }
1693            }
1694            Expression::ImageSet(_, function, _) => {
1695                let subject = function.return_type();
1696                match subject {
1697                    ReturnType::Function(_, codomain) => *codomain,
1698                    _ => bug!(
1699                        "Invalid imageSet operation: expected the operand to be a function, got {self}: {subject}"
1700                    ),
1701                }
1702            }
1703            Expression::PreImage(_, function, _) => {
1704                let subject = function.return_type();
1705                match subject {
1706                    ReturnType::Function(domain, _) => *domain,
1707                    _ => bug!(
1708                        "Invalid preImage operation: expected the operand to be a function, got {self}: {subject}"
1709                    ),
1710                }
1711            }
1712            Expression::Restrict(_, function, new_domain) => {
1713                let subject = function.return_type();
1714                match subject {
1715                    ReturnType::Function(_, codomain) => {
1716                        ReturnType::Function(Box::new(new_domain.return_type()), codomain)
1717                    }
1718                    _ => bug!(
1719                        "Invalid preImage operation: expected the operand to be a function, got {self}: {subject}"
1720                    ),
1721                }
1722            }
1723            Expression::Inverse(..) => ReturnType::Bool,
1724            Expression::LexLt(..) => ReturnType::Bool,
1725            Expression::LexGt(..) => ReturnType::Bool,
1726            Expression::LexLeq(..) => ReturnType::Bool,
1727            Expression::LexGeq(..) => ReturnType::Bool,
1728            Expression::FlatLexLt(..) => ReturnType::Bool,
1729            Expression::FlatLexLeq(..) => ReturnType::Bool,
1730        }
1731    }
1732}
1733
1734#[cfg(test)]
1735mod tests {
1736    use crate::matrix_expr;
1737
1738    use super::*;
1739
1740    #[test]
1741    fn test_domain_of_constant_sum() {
1742        let c1 = Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Int(1)));
1743        let c2 = Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Int(2)));
1744        let sum = Expression::Sum(Metadata::new(), Moo::new(matrix_expr![c1, c2]));
1745        assert_eq!(sum.domain_of(), Some(Domain::int(vec![Range::Single(3)])));
1746    }
1747
1748    #[test]
1749    fn test_domain_of_constant_invalid_type() {
1750        let c1 = Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Int(1)));
1751        let c2 = Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Bool(true)));
1752        let sum = Expression::Sum(Metadata::new(), Moo::new(matrix_expr![c1, c2]));
1753        assert_eq!(sum.domain_of(), None);
1754    }
1755
1756    #[test]
1757    fn test_domain_of_empty_sum() {
1758        let sum = Expression::Sum(Metadata::new(), Moo::new(matrix_expr![]));
1759        assert_eq!(sum.domain_of(), None);
1760    }
1761}