Skip to main content

conjure_cp_core/ast/
expressions.rs

1use std::collections::{HashSet, VecDeque};
2use std::fmt::{Display, Formatter};
3use std::hash::{DefaultHasher, Hash, Hasher};
4use std::sync::atomic::{AtomicU64, Ordering};
5
6static HASH_HITS: AtomicU64 = AtomicU64::new(0);
7static HASH_MISSES: AtomicU64 = AtomicU64::new(0);
8
9pub fn print_hash_stats() {
10    println!(
11        "Expression hash stats: hits={}, misses={}",
12        HASH_HITS.load(Ordering::Relaxed),
13        HASH_MISSES.load(Ordering::Relaxed)
14    );
15}
16use tracing::trace;
17
18use conjure_cp_enum_compatibility_macro::{document_compatibility, generate_discriminants};
19use itertools::Itertools;
20use serde::{Deserialize, Serialize};
21use tree_morph::cache::CacheHashable;
22use ustr::Ustr;
23
24use polyquine::Quine;
25use uniplate::{Biplate, Uniplate};
26
27use crate::ast::metadata::NO_HASH;
28use crate::bug;
29
30use super::abstract_comprehension::AbstractComprehension;
31use super::ac_operators::ACOperatorKind;
32use super::categories::{Category, CategoryOf};
33use super::comprehension::Comprehension;
34use super::domains::HasDomain as _;
35use super::pretty::{pretty_expressions_as_top_level, pretty_vec};
36use super::records::RecordValue;
37use super::sat_encoding::SATIntEncoding;
38use super::{
39    AbstractLiteral, Atom, DeclarationPtr, Domain, DomainPtr, GroundDomain, IntVal, Literal,
40    Metadata, Model, Moo, Name, Range, Reference, ReturnType, SetAttr, SymbolTable, SymbolTablePtr,
41    Typeable, UnresolvedDomain, matrix,
42};
43
44// Ensure that this type doesn't get too big
45//
46// If you triggered this assertion, you either made a variant of this enum that is too big, or you
47// made Name,Literal,AbstractLiteral,Atom bigger, which made this bigger! To fix this, put some
48// stuff in boxes.
49//
50// Enums take the size of their largest variant, so an enum with mostly small variants and a few
51// large ones wastes memory... A larger Expression type also slows down Oxide.
52//
53// For more information, and more details on type sizes and how to measure them, see the commit
54// message for 6012de809 (perf: reduce size of AST types, 2025-06-18).
55//
56// You can also see type sizes in the rustdoc documentation, generated by ./tools/gen_docs.sh
57//
58// https://github.com/conjure-cp/conjure-oxide/commit/6012de8096ca491ded91ecec61352fdf4e994f2e
59
60// TODO: box all usages of Metadata to bring this down a bit more - I have added variants to
61// ReturnType, and Metadata contains ReturnType, so Metadata has got bigger. Metadata will get a
62// lot bigger still when we start using it for memoisation, so it should really be
63// boxed ~niklasdewally
64
65// expect size of Expression to be 112 bytes
66static_assertions::assert_eq_size!([u8; 112], Expression);
67
68/// Represents different types of expressions used to define rules and constraints in the model.
69///
70/// The `Expression` enum includes operations, constants, and variable references
71/// used to build rules and conditions for the model.
72#[generate_discriminants]
73#[document_compatibility]
74#[derive(Clone, Debug, Hash, PartialEq, Eq, Serialize, Deserialize, Uniplate, Quine)]
75#[biplate(to=AbstractComprehension)]
76#[biplate(to=AbstractLiteral<Expression>)]
77#[biplate(to=AbstractLiteral<Literal>)]
78#[biplate(to=Atom)]
79#[biplate(to=Comprehension)]
80#[biplate(to=DeclarationPtr)]
81#[biplate(to=DomainPtr)]
82#[biplate(to=Literal)]
83#[biplate(to=Metadata)]
84#[biplate(to=Name)]
85#[biplate(to=Option<Expression>)]
86#[biplate(to=RecordValue<Expression>)]
87#[biplate(to=RecordValue<Literal>)]
88#[biplate(to=Reference)]
89#[biplate(to=Model)]
90#[biplate(to=SymbolTable)]
91#[biplate(to=SymbolTablePtr)]
92#[biplate(to=Vec<Expression>)]
93#[path_prefix(conjure_cp::ast)]
94pub enum Expression {
95    AbstractLiteral(Metadata, AbstractLiteral<Expression>),
96    /// The top of the model
97    Root(Metadata, Vec<Expression>),
98
99    /// An expression representing "A is valid as long as B is true"
100    /// Turns into a conjunction when it reaches a boolean context
101    Bubble(Metadata, Moo<Expression>, Moo<Expression>),
102
103    /// A comprehension.
104    ///
105    /// The inside of the comprehension opens a new scope.
106    // todo (gskorokhod): Comprehension contains a symbol table which contains a bunch of pointers.
107    // This makes implementing Quine tricky (it doesnt support Rc, by design). Skip it for now.
108    #[polyquine_skip]
109    Comprehension(Metadata, Moo<Comprehension>),
110
111    /// Higher-level abstract comprehension
112    #[polyquine_skip] // no idea what this is lol but it stops rustc screaming at me
113    AbstractComprehension(Metadata, Moo<AbstractComprehension>),
114
115    /// Defines dominance ("Solution A is preferred over Solution B")
116    DominanceRelation(Metadata, Moo<Expression>),
117    /// `fromSolution(name)` - Used in dominance relation definitions
118    FromSolution(Metadata, Moo<Atom>),
119
120    #[polyquine_with(arm = (_, name) => {
121        let ident = proc_macro2::Ident::new(name.as_str(), proc_macro2::Span::call_site());
122        quote::quote! { #ident.clone().into() }
123    })]
124    Metavar(Metadata, Ustr),
125
126    Atomic(Metadata, Atom),
127
128    /// A matrix index.
129    ///
130    /// Defined iff the indices are within their respective index domains.
131    #[compatible(JsonInput)]
132    UnsafeIndex(Metadata, Moo<Expression>, Vec<Expression>),
133
134    /// A safe matrix index.
135    ///
136    /// See [`Expression::UnsafeIndex`]
137    #[compatible(SMT)]
138    SafeIndex(Metadata, Moo<Expression>, Vec<Expression>),
139
140    /// A matrix slice: `a[indices]`.
141    ///
142    /// One of the indicies may be `None`, representing the dimension of the matrix we want to take
143    /// a slice of. For example, for some 3d matrix a, `a[1,..,2]` has the indices
144    /// `Some(1),None,Some(2)`.
145    ///
146    /// It is assumed that the slice only has one "wild-card" dimension and thus is 1 dimensional.
147    ///
148    /// Defined iff the defined indices are within their respective index domains.
149    #[compatible(JsonInput)]
150    UnsafeSlice(Metadata, Moo<Expression>, Vec<Option<Expression>>),
151
152    /// A safe matrix slice: `a[indices]`.
153    ///
154    /// See [`Expression::UnsafeSlice`].
155    SafeSlice(Metadata, Moo<Expression>, Vec<Option<Expression>>),
156
157    /// `inDomain(x,domain)` iff `x` is in the domain `domain`.
158    ///
159    /// This cannot be constructed from Essence input, nor passed to a solver: this expression is
160    /// mainly used during the conversion of `UnsafeIndex` and `UnsafeSlice` to `SafeIndex` and
161    /// `SafeSlice` respectively.
162    InDomain(Metadata, Moo<Expression>, DomainPtr),
163
164    /// `toInt(b)` casts boolean expression b to an integer.
165    ///
166    /// - If b is false, then `toInt(b) == 0`
167    ///
168    /// - If b is true, then `toInt(b) == 1`
169    #[compatible(SMT)]
170    ToInt(Metadata, Moo<Expression>),
171
172    /// `|x|` - absolute value of `x`
173    #[compatible(JsonInput, SMT)]
174    Abs(Metadata, Moo<Expression>),
175
176    /// `sum(<vec_expr>)`
177    #[compatible(JsonInput, SMT)]
178    Sum(Metadata, Moo<Expression>),
179
180    /// `a * b * c * ...`
181    #[compatible(JsonInput, SMT)]
182    Product(Metadata, Moo<Expression>),
183
184    /// `min(<vec_expr>)`
185    #[compatible(JsonInput, SMT)]
186    Min(Metadata, Moo<Expression>),
187
188    /// `max(<vec_expr>)`
189    #[compatible(JsonInput, SMT)]
190    Max(Metadata, Moo<Expression>),
191
192    /// `not(a)`
193    #[compatible(JsonInput, SAT, SMT)]
194    Not(Metadata, Moo<Expression>),
195
196    /// `or(<vec_expr>)`
197    #[compatible(JsonInput, SAT, SMT)]
198    Or(Metadata, Moo<Expression>),
199
200    /// `and(<vec_expr>)`
201    #[compatible(JsonInput, SAT, SMT)]
202    And(Metadata, Moo<Expression>),
203
204    /// Ensures that `a->b` (material implication).
205    #[compatible(JsonInput, SMT)]
206    Imply(Metadata, Moo<Expression>, Moo<Expression>),
207
208    /// `iff(a, b)` a <-> b
209    #[compatible(JsonInput, SMT)]
210    Iff(Metadata, Moo<Expression>, Moo<Expression>),
211
212    #[compatible(JsonInput)]
213    Union(Metadata, Moo<Expression>, Moo<Expression>),
214
215    #[compatible(JsonInput)]
216    In(Metadata, Moo<Expression>, Moo<Expression>),
217
218    #[compatible(JsonInput)]
219    Intersect(Metadata, Moo<Expression>, Moo<Expression>),
220
221    #[compatible(JsonInput)]
222    Supset(Metadata, Moo<Expression>, Moo<Expression>),
223
224    #[compatible(JsonInput)]
225    SupsetEq(Metadata, Moo<Expression>, Moo<Expression>),
226
227    #[compatible(JsonInput)]
228    Subset(Metadata, Moo<Expression>, Moo<Expression>),
229
230    #[compatible(JsonInput)]
231    SubsetEq(Metadata, Moo<Expression>, Moo<Expression>),
232
233    #[compatible(JsonInput, SMT)]
234    Eq(Metadata, Moo<Expression>, Moo<Expression>),
235
236    #[compatible(JsonInput, SMT)]
237    Neq(Metadata, Moo<Expression>, Moo<Expression>),
238
239    #[compatible(JsonInput, SMT)]
240    Geq(Metadata, Moo<Expression>, Moo<Expression>),
241
242    #[compatible(JsonInput, SMT)]
243    Leq(Metadata, Moo<Expression>, Moo<Expression>),
244
245    #[compatible(JsonInput, SMT)]
246    Gt(Metadata, Moo<Expression>, Moo<Expression>),
247
248    #[compatible(JsonInput, SMT)]
249    Lt(Metadata, Moo<Expression>, Moo<Expression>),
250
251    /// Division after preventing division by zero, usually with a bubble
252    #[compatible(SMT)]
253    SafeDiv(Metadata, Moo<Expression>, Moo<Expression>),
254
255    /// Division with a possibly undefined value (division by 0)
256    #[compatible(JsonInput)]
257    UnsafeDiv(Metadata, Moo<Expression>, Moo<Expression>),
258
259    /// Modulo after preventing mod 0, usually with a bubble
260    #[compatible(SMT)]
261    SafeMod(Metadata, Moo<Expression>, Moo<Expression>),
262
263    /// Modulo with a possibly undefined value (mod 0)
264    #[compatible(JsonInput)]
265    UnsafeMod(Metadata, Moo<Expression>, Moo<Expression>),
266
267    /// Negation: `-x`
268    #[compatible(JsonInput, SMT)]
269    Neg(Metadata, Moo<Expression>),
270
271    /// Factorial: `x!` or 'factorial(x)`
272    #[compatible(JsonInput)]
273    Factorial(Metadata, Moo<Expression>),
274
275    /// Set of domain values function is defined for
276    #[compatible(JsonInput)]
277    Defined(Metadata, Moo<Expression>),
278
279    /// Set of codomain values function is defined for
280    #[compatible(JsonInput)]
281    Range(Metadata, Moo<Expression>),
282
283    /// Unsafe power`x**y` (possibly undefined)
284    ///
285    /// Defined when (X!=0 \\/ Y!=0) /\ Y>=0
286    #[compatible(JsonInput)]
287    UnsafePow(Metadata, Moo<Expression>, Moo<Expression>),
288
289    /// `UnsafePow` after preventing undefinedness
290    SafePow(Metadata, Moo<Expression>, Moo<Expression>),
291
292    /// Flatten matrix operator
293    /// `flatten(M)` or `flatten(n, M)`
294    /// where M is a matrix and n is an optional integer argument indicating depth of flattening
295    Flatten(Metadata, Option<Moo<Expression>>, Moo<Expression>),
296
297    /// `allDiff(<vec_expr>)`
298    #[compatible(JsonInput)]
299    AllDiff(Metadata, Moo<Expression>),
300
301    /// `table([x1, x2, ...], [[r11, r12, ...], [r21, r22, ...], ...])`
302    ///
303    /// Represents a positive table constraint: the tuple `[x1, x2, ...]` must match one of the
304    /// allowed rows.
305    #[compatible(JsonInput)]
306    Table(Metadata, Moo<Expression>, Moo<Expression>),
307
308    /// `negativeTable([x1, x2, ...], [[r11, r12, ...], [r21, r22, ...], ...])`
309    ///
310    /// Represents a negative table constraint: the tuple `[x1, x2, ...]` must NOT match any of the
311    /// forbidden rows.
312    #[compatible(JsonInput)]
313    NegativeTable(Metadata, Moo<Expression>, Moo<Expression>),
314    /// Binary subtraction operator
315    ///
316    /// This is a parser-level construct, and is immediately normalised to `Sum([a,-b])`.
317    /// TODO: make this compatible with Set Difference calculations - need to change return type and domain for this expression and write a set comprehension rule.
318    /// have already edited minus_to_sum to prevent this from applying to sets
319    #[compatible(JsonInput)]
320    Minus(Metadata, Moo<Expression>, Moo<Expression>),
321
322    /// Ensures that x=|y| i.e. x is the absolute value of y.
323    ///
324    /// Low-level Minion constraint.
325    ///
326    /// # See also
327    ///
328    /// + [Minion documentation](https://minion-solver.readthedocs.io/en/stable/usage/constraints.html#abs)
329    #[compatible(Minion)]
330    FlatAbsEq(Metadata, Moo<Atom>, Moo<Atom>),
331
332    /// Ensures that `alldiff([a,b,...])`.
333    ///
334    /// Low-level Minion constraint.
335    ///
336    /// # See also
337    ///
338    /// + [Minion documentation](https://minion-solver.readthedocs.io/en/stable/usage/constraints.html#alldiff)
339    #[compatible(Minion)]
340    FlatAllDiff(Metadata, Vec<Atom>),
341
342    /// Ensures that sum(vec) >= x.
343    ///
344    /// Low-level Minion constraint.
345    ///
346    /// # See also
347    ///
348    /// + [Minion documentation](https://minion-solver.readthedocs.io/en/stable/usage/constraints.html#sumgeq)
349    #[compatible(Minion)]
350    FlatSumGeq(Metadata, Vec<Atom>, Atom),
351
352    /// Ensures that sum(vec) <= x.
353    ///
354    /// Low-level Minion constraint.
355    ///
356    /// # See also
357    ///
358    /// + [Minion documentation](https://minion-solver.readthedocs.io/en/stable/usage/constraints.html#sumleq)
359    #[compatible(Minion)]
360    FlatSumLeq(Metadata, Vec<Atom>, Atom),
361
362    /// `ineq(x,y,k)` ensures that x <= y + k.
363    ///
364    /// Low-level Minion constraint.
365    ///
366    /// # See also
367    ///
368    /// + [Minion documentation](https://minion-solver.readthedocs.io/en/stable/usage/constraints.html#ineq)
369    #[compatible(Minion)]
370    FlatIneq(Metadata, Moo<Atom>, Moo<Atom>, Box<Literal>),
371
372    /// `w-literal(x,k)` ensures that x == k, where x is a variable and k a constant.
373    ///
374    /// Low-level Minion constraint.
375    ///
376    /// This is a low-level Minion constraint and you should probably use Eq instead. The main use
377    /// of w-literal is to convert boolean variables to constraints so that they can be used inside
378    /// watched-and and watched-or.
379    ///
380    /// # See also
381    ///
382    /// + [Minion documentation](https://minion-solver.readthedocs.io/en/stable/usage/constraints.html#minuseq)
383    /// + `rules::minion::boolean_literal_to_wliteral`.
384    #[compatible(Minion)]
385    #[polyquine_skip]
386    FlatWatchedLiteral(Metadata, Reference, Literal),
387
388    /// `weightedsumleq(cs,xs,total)` ensures that cs.xs <= total, where cs.xs is the scalar dot
389    /// product of cs and xs.
390    ///
391    /// Low-level Minion constraint.
392    ///
393    /// Represents a weighted sum of the form `ax + by + cz + ...`
394    ///
395    /// # See also
396    ///
397    /// + [Minion
398    /// documentation](https://minion-solver.readthedocs.io/en/stable/usage/constraints.html#weightedsumleq)
399    FlatWeightedSumLeq(Metadata, Vec<Literal>, Vec<Atom>, Moo<Atom>),
400
401    /// `weightedsumgeq(cs,xs,total)` ensures that cs.xs >= total, where cs.xs is the scalar dot
402    /// product of cs and xs.
403    ///
404    /// Low-level Minion constraint.
405    ///
406    /// Represents a weighted sum of the form `ax + by + cz + ...`
407    ///
408    /// # See also
409    ///
410    /// + [Minion
411    /// documentation](https://minion-solver.readthedocs.io/en/stable/usage/constraints.html#weightedsumleq)
412    FlatWeightedSumGeq(Metadata, Vec<Literal>, Vec<Atom>, Moo<Atom>),
413
414    /// Ensures that x =-y, where x and y are atoms.
415    ///
416    /// Low-level Minion constraint.
417    ///
418    /// # See also
419    ///
420    /// + [Minion documentation](https://minion-solver.readthedocs.io/en/stable/usage/constraints.html#minuseq)
421    #[compatible(Minion)]
422    FlatMinusEq(Metadata, Moo<Atom>, Moo<Atom>),
423
424    /// Ensures that x*y=z.
425    ///
426    /// Low-level Minion constraint.
427    ///
428    /// # See also
429    ///
430    /// + [Minion documentation](https://minion-solver.readthedocs.io/en/stable/usage/constraints.html#product)
431    #[compatible(Minion)]
432    FlatProductEq(Metadata, Moo<Atom>, Moo<Atom>, Moo<Atom>),
433
434    /// Ensures that floor(x/y)=z. Always true when y=0.
435    ///
436    /// Low-level Minion constraint.
437    ///
438    /// # See also
439    ///
440    /// + [Minion documentation](https://minion-solver.readthedocs.io/en/stable/usage/constraints.html#div_undefzero)
441    #[compatible(Minion)]
442    MinionDivEqUndefZero(Metadata, Moo<Atom>, Moo<Atom>, Moo<Atom>),
443
444    /// Ensures that x%y=z. Always true when y=0.
445    ///
446    /// Low-level Minion constraint.
447    ///
448    /// # See also
449    ///
450    /// + [Minion documentation](https://minion-solver.readthedocs.io/en/stable/usage/constraints.html#mod_undefzero)
451    #[compatible(Minion)]
452    MinionModuloEqUndefZero(Metadata, Moo<Atom>, Moo<Atom>, Moo<Atom>),
453
454    /// Ensures that `x**y = z`.
455    ///
456    /// Low-level Minion constraint.
457    ///
458    /// This constraint is false when `y<0` except for `1**y=1` and `(-1)**y=z` (where z is 1 if y
459    /// is odd and z is -1 if y is even).
460    ///
461    /// # See also
462    ///
463    /// + [Github comment about `pow` semantics](https://github.com/minion/minion/issues/40#issuecomment-2595914891)
464    /// + [Minion documentation](https://minion-solver.readthedocs.io/en/stable/usage/constraints.html#pow)
465    MinionPow(Metadata, Moo<Atom>, Moo<Atom>, Moo<Atom>),
466
467    /// `reify(constraint,r)` ensures that r=1 iff `constraint` is satisfied, where r is a 0/1
468    /// variable.
469    ///
470    /// Low-level Minion constraint.
471    ///
472    /// # See also
473    ///
474    ///  + [Minion documentation](https://minion-solver.readthedocs.io/en/stable/usage/constraints.html#reify)
475    #[compatible(Minion)]
476    MinionReify(Metadata, Moo<Expression>, Atom),
477
478    /// `reifyimply(constraint,r)` ensures that `r->constraint`, where r is a 0/1 variable.
479    /// variable.
480    ///
481    /// Low-level Minion constraint.
482    ///
483    /// # See also
484    ///
485    ///  + [Minion documentation](https://minion-solver.readthedocs.io/en/stable/usage/constraints.html#reifyimply)
486    #[compatible(Minion)]
487    MinionReifyImply(Metadata, Moo<Expression>, Atom),
488
489    /// `w-inintervalset(x, [a1,a2, b1,b2, … ])` ensures that the value of x belongs to one of the
490    /// intervals {a1,…,a2}, {b1,…,b2} etc.
491    ///
492    /// The list of intervals must be given in numerical order.
493    ///
494    /// Low-level Minion constraint.
495    ///
496    /// # See also
497    ///>
498    ///  + [Minion documentation](https://minion-solver.readthedocs.io/en/stable/usage/constraints.html#w-inintervalset)
499    #[compatible(Minion)]
500    MinionWInIntervalSet(Metadata, Atom, Vec<i32>),
501
502    /// `w-inset(x, [v1, v2, … ])` ensures that the value of `x` is one of the explicitly given values `v1`, `v2`, etc.
503    ///
504    /// This constraint enforces membership in a specific set of discrete values rather than intervals.
505    ///
506    /// The list of values must be given in numerical order.
507    ///
508    /// Low-level Minion constraint.
509    ///
510    /// # See also
511    ///
512    ///  + [Minion documentation](https://minion-solver.readthedocs.io/en/stable/usage/constraints.html#w-inset)
513    #[compatible(Minion)]
514    MinionWInSet(Metadata, Atom, Vec<i32>),
515
516    /// `element_one(vec, i, e)` specifies that `vec[i] = e`. This implies that i is
517    /// in the range `[1..len(vec)]`.
518    ///
519    /// Low-level Minion constraint.
520    ///
521    /// # See also
522    ///
523    ///  + [Minion documentation](https://minion-solver.readthedocs.io/en/stable/usage/constraints.html#element_one)
524    #[compatible(Minion)]
525    MinionElementOne(Metadata, Vec<Atom>, Moo<Atom>, Moo<Atom>),
526
527    /// Declaration of an auxiliary variable.
528    ///
529    /// As with Savile Row, we semantically distinguish this from `Eq`.
530    #[compatible(Minion)]
531    #[polyquine_skip]
532    AuxDeclaration(Metadata, Reference, Moo<Expression>),
533
534    /// This expression is for encoding ints for the SAT solver, it stores the encoding type, the vector of booleans and the min/max for the int.
535    #[compatible(SAT)]
536    SATInt(Metadata, SATIntEncoding, Moo<Expression>, (i32, i32)),
537
538    /// Addition over a pair of expressions (i.e. a + b) rather than a vec-expr like Expression::Sum.
539    /// This is for compatibility with backends that do not support addition over vectors.
540    #[compatible(SMT)]
541    PairwiseSum(Metadata, Moo<Expression>, Moo<Expression>),
542
543    /// Multiplication over a pair of expressions (i.e. a * b) rather than a vec-expr like Expression::Product.
544    /// This is for compatibility with backends that do not support multiplication over vectors.
545    #[compatible(SMT)]
546    PairwiseProduct(Metadata, Moo<Expression>, Moo<Expression>),
547
548    #[compatible(JsonInput)]
549    Image(Metadata, Moo<Expression>, Moo<Expression>),
550
551    #[compatible(JsonInput)]
552    ImageSet(Metadata, Moo<Expression>, Moo<Expression>),
553
554    #[compatible(JsonInput)]
555    PreImage(Metadata, Moo<Expression>, Moo<Expression>),
556
557    #[compatible(JsonInput)]
558    Inverse(Metadata, Moo<Expression>, Moo<Expression>),
559
560    #[compatible(JsonInput)]
561    Restrict(Metadata, Moo<Expression>, Moo<Expression>),
562
563    /// Lexicographical < between two matrices.
564    ///
565    /// A <lex B iff: A[i] < B[i] for some i /\ (A[j] > B[j] for some j -> i < j)
566    /// I.e. A must be less than B at some index i, and if it is greater than B at another index j,
567    /// then j comes after i.
568    /// I.e. A must be greater than B at the first index where they differ.
569    ///
570    /// E.g. [1, 1] <lex [2, 1] and [1, 1] <lex [1, 2]
571    LexLt(Metadata, Moo<Expression>, Moo<Expression>),
572
573    /// Lexicographical <= between two matrices
574    LexLeq(Metadata, Moo<Expression>, Moo<Expression>),
575
576    /// Lexicographical > between two matrices
577    /// This is a parser-level construct, and is immediately normalised to LexLt(b, a)
578    LexGt(Metadata, Moo<Expression>, Moo<Expression>),
579
580    /// Lexicographical >= between two matrices
581    /// This is a parser-level construct, and is immediately normalised to LexLeq(b, a)
582    LexGeq(Metadata, Moo<Expression>, Moo<Expression>),
583
584    /// Low-level minion constraint. See Expression::LexLt
585    FlatLexLt(Metadata, Vec<Atom>, Vec<Atom>),
586
587    /// Low-level minion constraint. See Expression::LexLeq
588    FlatLexLeq(Metadata, Vec<Atom>, Vec<Atom>),
589}
590
591// for the given matrix literal, return a bounded domain from the min to max of applying op to each
592// child expression.
593//
594// Op must be monotonic.
595//
596// Returns none if unbounded
597fn bounded_i32_domain_for_matrix_literal_monotonic(
598    e: &Expression,
599    op: fn(i32, i32) -> Option<i32>,
600) -> Option<DomainPtr> {
601    // only care about the elements, not the indices
602    let (mut exprs, _) = e.clone().unwrap_matrix_unchecked()?;
603
604    // fold each element's domain into one using op.
605    //
606    // here, I assume that op is monotone. This means that the bounds of op([a1,a2],[b1,b2])  for
607    // the ranges [a1,a2], [b1,b2] will be
608    // [min(op(a1,b1),op(a2,b1),op(a1,b2),op(a2,b2)),max(op(a1,b1),op(a2,b1),op(a1,b2),op(a2,b2))].
609    //
610    // We used to not assume this, and work out the bounds by applying op on the Cartesian product
611    // of A and B; however, this caused a combinatorial explosion and my computer to run out of
612    // memory (on the hakank_eprime_xkcd test)...
613    //Int
614    // For example, to find the bounds of the intervals [1,4], [1,5] combined using op, we used to do
615    //  [min(op(1,1), op(1,2),op(1,3),op(1,4),op(1,5),op(2,1)..
616    //
617    // +,-,/,* are all monotone, so this assumption should be fine for now...
618
619    let expr = exprs.pop()?;
620    let dom = expr.domain_of()?;
621    let resolved = dom.resolve()?;
622    let GroundDomain::Int(ranges) = resolved.as_ref() else {
623        return None;
624    };
625
626    let (mut current_min, mut current_max) = range_vec_bounds_i32(ranges)?;
627
628    for expr in exprs {
629        let dom = expr.domain_of()?;
630        let resolved = dom.resolve()?;
631        let GroundDomain::Int(ranges) = resolved.as_ref() else {
632            return None;
633        };
634
635        let (min, max) = range_vec_bounds_i32(ranges)?;
636
637        // all the possible new values for current_min / current_max
638        let minmax = op(min, current_max)?;
639        let minmin = op(min, current_min)?;
640        let maxmin = op(max, current_min)?;
641        let maxmax = op(max, current_max)?;
642        let vals = [minmax, minmin, maxmin, maxmax];
643
644        current_min = *vals
645            .iter()
646            .min()
647            .expect("vals iterator should not be empty, and should have a minimum.");
648        current_max = *vals
649            .iter()
650            .max()
651            .expect("vals iterator should not be empty, and should have a maximum.");
652    }
653
654    if current_min == current_max {
655        Some(Domain::int(vec![Range::Single(current_min)]))
656    } else {
657        Some(Domain::int(vec![Range::Bounded(current_min, current_max)]))
658    }
659}
660
661fn matrix_element_domain(e: &Expression) -> Option<DomainPtr> {
662    let (elem_domain, _) = e.domain_of()?.as_matrix()?;
663    elem_domain.as_ref().as_int()?;
664    Some(elem_domain)
665}
666
667// Returns none if unbounded
668fn range_vec_bounds_i32(ranges: &Vec<Range<i32>>) -> Option<(i32, i32)> {
669    let mut min = i32::MAX;
670    let mut max = i32::MIN;
671    for r in ranges {
672        match r {
673            Range::Single(i) => {
674                if *i < min {
675                    min = *i;
676                }
677                if *i > max {
678                    max = *i;
679                }
680            }
681            Range::Bounded(i, j) => {
682                if *i < min {
683                    min = *i;
684                }
685                if *j > max {
686                    max = *j;
687                }
688            }
689            Range::UnboundedR(_) | Range::UnboundedL(_) | Range::Unbounded => return None,
690        }
691    }
692    Some((min, max))
693}
694
695impl Expression {
696    /// Returns the possible values of the expression, recursing to leaf expressions
697    pub fn domain_of(&self) -> Option<DomainPtr> {
698        match self {
699            Expression::Union(_, a, b) => Some(Domain::set(
700                SetAttr::<IntVal>::default(),
701                a.domain_of()?.union(&b.domain_of()?).ok()?,
702            )),
703            Expression::Intersect(_, a, b) => Some(Domain::set(
704                SetAttr::<IntVal>::default(),
705                a.domain_of()?.intersect(&b.domain_of()?).ok()?,
706            )),
707            Expression::In(_, _, _) => Some(Domain::bool()),
708            Expression::Supset(_, _, _) => Some(Domain::bool()),
709            Expression::SupsetEq(_, _, _) => Some(Domain::bool()),
710            Expression::Subset(_, _, _) => Some(Domain::bool()),
711            Expression::SubsetEq(_, _, _) => Some(Domain::bool()),
712            Expression::AbstractLiteral(_, abslit) => abslit.domain_of(),
713            Expression::DominanceRelation(_, _) => Some(Domain::bool()),
714            Expression::FromSolution(_, expr) => Some(expr.domain_of()),
715            Expression::Metavar(_, _) => None,
716            Expression::Comprehension(_, comprehension) => comprehension.domain_of(),
717            Expression::AbstractComprehension(_, comprehension) => comprehension.domain_of(),
718            Expression::UnsafeIndex(_, matrix, _) | Expression::SafeIndex(_, matrix, _) => {
719                let dom = matrix.domain_of()?;
720                if let Some((elem_domain, _)) = dom.as_matrix() {
721                    return Some(elem_domain);
722                }
723
724                // may actually use the value in the future
725                #[allow(clippy::redundant_pattern_matching)]
726                if let Some(_) = dom.as_tuple() {
727                    // TODO: We can implement proper indexing for tuples
728                    return None;
729                }
730
731                // may actually use the value in the future
732                #[allow(clippy::redundant_pattern_matching)]
733                if let Some(_) = dom.as_record() {
734                    // TODO: We can implement proper indexing for records
735                    return None;
736                }
737
738                bug!("subject of an index operation should support indexing")
739            }
740            Expression::UnsafeSlice(_, matrix, indices)
741            | Expression::SafeSlice(_, matrix, indices) => {
742                let sliced_dimension = indices.iter().position(Option::is_none);
743
744                let dom = matrix.domain_of()?;
745                let Some((elem_domain, index_domains)) = dom.as_matrix() else {
746                    bug!("subject of an index operation should be a matrix");
747                };
748
749                match sliced_dimension {
750                    Some(dimension) => Some(Domain::matrix(
751                        elem_domain,
752                        vec![index_domains[dimension].clone()],
753                    )),
754
755                    // same as index
756                    None => Some(elem_domain),
757                }
758            }
759            Expression::InDomain(_, _, _) => Some(Domain::bool()),
760            Expression::Atomic(_, atom) => Some(atom.domain_of()),
761            Expression::Sum(_, e) => {
762                bounded_i32_domain_for_matrix_literal_monotonic(e, |x, y| Some(x + y))
763            }
764            Expression::Product(_, e) => {
765                bounded_i32_domain_for_matrix_literal_monotonic(e, |x, y| Some(x * y))
766            }
767            Expression::Min(_, e) => bounded_i32_domain_for_matrix_literal_monotonic(e, |x, y| {
768                Some(if x < y { x } else { y })
769            })
770            .or_else(|| matrix_element_domain(e)),
771            Expression::Max(_, e) => bounded_i32_domain_for_matrix_literal_monotonic(e, |x, y| {
772                Some(if x > y { x } else { y })
773            })
774            .or_else(|| matrix_element_domain(e)),
775            Expression::UnsafeDiv(_, a, b) => a
776                .domain_of()?
777                .resolve()?
778                .apply_i32(
779                    // rust integer division is truncating; however, we want to always round down,
780                    // including for negative numbers.
781                    |x, y| {
782                        if y != 0 {
783                            Some((x as f32 / y as f32).floor() as i32)
784                        } else {
785                            None
786                        }
787                    },
788                    b.domain_of()?.resolve()?.as_ref(),
789                )
790                .map(DomainPtr::from)
791                .ok(),
792            Expression::SafeDiv(_, a, b) => {
793                // rust integer division is truncating; however, we want to always round down
794                // including for negative numbers.
795                let domain = a
796                    .domain_of()?
797                    .resolve()?
798                    .apply_i32(
799                        |x, y| {
800                            if y != 0 {
801                                Some((x as f32 / y as f32).floor() as i32)
802                            } else {
803                                None
804                            }
805                        },
806                        b.domain_of()?.resolve()?.as_ref(),
807                    )
808                    .unwrap_or_else(|err| bug!("Got {err} when computing domain of {self}"));
809
810                if let GroundDomain::Int(ranges) = domain {
811                    let mut ranges = ranges;
812                    ranges.push(Range::Single(0));
813                    Some(Domain::int(ranges))
814                } else {
815                    bug!("Domain of {self} was not integer")
816                }
817            }
818            Expression::UnsafeMod(_, a, b) => a
819                .domain_of()?
820                .resolve()?
821                .apply_i32(
822                    |x, y| if y != 0 { Some(x % y) } else { None },
823                    b.domain_of()?.resolve()?.as_ref(),
824                )
825                .map(DomainPtr::from)
826                .ok(),
827            Expression::SafeMod(_, a, b) => {
828                let domain = a
829                    .domain_of()?
830                    .resolve()?
831                    .apply_i32(
832                        |x, y| if y != 0 { Some(x % y) } else { None },
833                        b.domain_of()?.resolve()?.as_ref(),
834                    )
835                    .unwrap_or_else(|err| bug!("Got {err} when computing domain of {self}"));
836
837                if let GroundDomain::Int(ranges) = domain {
838                    let mut ranges = ranges;
839                    ranges.push(Range::Single(0));
840                    Some(Domain::int(ranges))
841                } else {
842                    bug!("Domain of {self} was not integer")
843                }
844            }
845            Expression::SafePow(_, a, b) | Expression::UnsafePow(_, a, b) => a
846                .domain_of()?
847                .resolve()?
848                .apply_i32(
849                    |x, y| {
850                        if (x != 0 || y != 0) && y >= 0 {
851                            Some(x.pow(y as u32))
852                        } else {
853                            None
854                        }
855                    },
856                    b.domain_of()?.resolve()?.as_ref(),
857                )
858                .map(DomainPtr::from)
859                .ok(),
860            Expression::Root(_, _) => None,
861            Expression::Bubble(_, inner, _) => inner.domain_of(),
862            Expression::AuxDeclaration(_, _, _) => Some(Domain::bool()),
863            Expression::And(_, _) => Some(Domain::bool()),
864            Expression::Not(_, _) => Some(Domain::bool()),
865            Expression::Or(_, _) => Some(Domain::bool()),
866            Expression::Imply(_, _, _) => Some(Domain::bool()),
867            Expression::Iff(_, _, _) => Some(Domain::bool()),
868            Expression::Eq(_, _, _) => Some(Domain::bool()),
869            Expression::Neq(_, _, _) => Some(Domain::bool()),
870            Expression::Geq(_, _, _) => Some(Domain::bool()),
871            Expression::Leq(_, _, _) => Some(Domain::bool()),
872            Expression::Gt(_, _, _) => Some(Domain::bool()),
873            Expression::Lt(_, _, _) => Some(Domain::bool()),
874            Expression::Factorial(_, _) => None, // not implemented
875            Expression::FlatAbsEq(_, _, _) => Some(Domain::bool()),
876            Expression::FlatSumGeq(_, _, _) => Some(Domain::bool()),
877            Expression::FlatSumLeq(_, _, _) => Some(Domain::bool()),
878            Expression::MinionDivEqUndefZero(_, _, _, _) => Some(Domain::bool()),
879            Expression::MinionModuloEqUndefZero(_, _, _, _) => Some(Domain::bool()),
880            Expression::FlatIneq(_, _, _, _) => Some(Domain::bool()),
881            Expression::Flatten(_, n, m) => {
882                if let Some(expr) = n {
883                    if expr.return_type() == ReturnType::Int {
884                        // TODO: handle flatten with depth argument
885                        return None;
886                    }
887                } else {
888                    // TODO: currently only works for matrices
889                    let dom = m.domain_of()?.resolve()?;
890                    let (val_dom, idx_doms) = match dom.as_ref() {
891                        GroundDomain::Matrix(val, idx) => (val, idx),
892                        _ => return None,
893                    };
894                    let num_elems = matrix::num_elements(idx_doms).ok()? as i32;
895
896                    let new_index_domain = Domain::int(vec![Range::Bounded(1, num_elems)]);
897                    return Some(Domain::matrix(
898                        val_dom.clone().into(),
899                        vec![new_index_domain],
900                    ));
901                }
902                None
903            }
904            Expression::AllDiff(_, _) => Some(Domain::bool()),
905            Expression::Table(_, _, _) => Some(Domain::bool()),
906            Expression::NegativeTable(_, _, _) => Some(Domain::bool()),
907            Expression::FlatWatchedLiteral(_, _, _) => Some(Domain::bool()),
908            Expression::MinionReify(_, _, _) => Some(Domain::bool()),
909            Expression::MinionReifyImply(_, _, _) => Some(Domain::bool()),
910            Expression::MinionWInIntervalSet(_, _, _) => Some(Domain::bool()),
911            Expression::MinionWInSet(_, _, _) => Some(Domain::bool()),
912            Expression::MinionElementOne(_, _, _, _) => Some(Domain::bool()),
913            Expression::Neg(_, x) => {
914                let dom = x.domain_of()?;
915                let mut ranges = dom.as_int()?;
916
917                ranges = ranges
918                    .into_iter()
919                    .map(|r| match r {
920                        Range::Single(x) => Range::Single(-x),
921                        Range::Bounded(x, y) => Range::Bounded(-y, -x),
922                        Range::UnboundedR(i) => Range::UnboundedL(-i),
923                        Range::UnboundedL(i) => Range::UnboundedR(-i),
924                        Range::Unbounded => Range::Unbounded,
925                    })
926                    .collect();
927
928                Some(Domain::int(ranges))
929            }
930            Expression::Minus(_, a, b) => a
931                .domain_of()?
932                .resolve()?
933                .apply_i32(|x, y| Some(x - y), b.domain_of()?.resolve()?.as_ref())
934                .map(DomainPtr::from)
935                .ok(),
936            Expression::FlatAllDiff(_, _) => Some(Domain::bool()),
937            Expression::FlatMinusEq(_, _, _) => Some(Domain::bool()),
938            Expression::FlatProductEq(_, _, _, _) => Some(Domain::bool()),
939            Expression::FlatWeightedSumLeq(_, _, _, _) => Some(Domain::bool()),
940            Expression::FlatWeightedSumGeq(_, _, _, _) => Some(Domain::bool()),
941            Expression::Abs(_, a) => a
942                .domain_of()?
943                .resolve()?
944                .apply_i32(|a, _| Some(a.abs()), a.domain_of()?.resolve()?.as_ref())
945                .map(DomainPtr::from)
946                .ok(),
947            Expression::MinionPow(_, _, _, _) => Some(Domain::bool()),
948            Expression::ToInt(_, _) => Some(Domain::int(vec![Range::Bounded(0, 1)])),
949            Expression::SATInt(_, _, _, (low, high)) => {
950                Some(Domain::int_ground(vec![Range::Bounded(*low, *high)]))
951            }
952            Expression::PairwiseSum(_, a, b) => a
953                .domain_of()?
954                .resolve()?
955                .apply_i32(|a, b| Some(a + b), b.domain_of()?.resolve()?.as_ref())
956                .map(DomainPtr::from)
957                .ok(),
958            Expression::PairwiseProduct(_, a, b) => a
959                .domain_of()?
960                .resolve()?
961                .apply_i32(|a, b| Some(a * b), b.domain_of()?.resolve()?.as_ref())
962                .map(DomainPtr::from)
963                .ok(),
964            Expression::Defined(_, function) => get_function_domain(function),
965            Expression::Range(_, function) => get_function_codomain(function),
966            Expression::Image(_, function, _) => get_function_codomain(function),
967            Expression::ImageSet(_, function, _) => get_function_codomain(function),
968            Expression::PreImage(_, function, _) => get_function_domain(function),
969            Expression::Restrict(_, function, new_domain) => {
970                let (attrs, _, codom) = function.domain_of()?.as_function()?;
971                let new_dom = new_domain.domain_of()?;
972                Some(Domain::function(attrs, new_dom, codom))
973            }
974            Expression::Inverse(..) => Some(Domain::bool()),
975            Expression::LexLt(..) => Some(Domain::bool()),
976            Expression::LexLeq(..) => Some(Domain::bool()),
977            Expression::LexGt(..) => Some(Domain::bool()),
978            Expression::LexGeq(..) => Some(Domain::bool()),
979            Expression::FlatLexLt(..) => Some(Domain::bool()),
980            Expression::FlatLexLeq(..) => Some(Domain::bool()),
981        }
982    }
983
984    /// Returns a reference to this expression's metadata without cloning.
985    pub fn meta_ref(&self) -> &Metadata {
986        macro_rules! match_meta_ref {
987            ($($variant:ident),* $(,)?) => {
988                match self {
989                    $(Expression::$variant(meta, ..) => meta,)*
990                }
991            };
992        }
993        match_meta_ref!(
994            AbstractLiteral,
995            Root,
996            Bubble,
997            Comprehension,
998            AbstractComprehension,
999            DominanceRelation,
1000            FromSolution,
1001            Metavar,
1002            Atomic,
1003            UnsafeIndex,
1004            SafeIndex,
1005            UnsafeSlice,
1006            SafeSlice,
1007            InDomain,
1008            ToInt,
1009            Abs,
1010            Sum,
1011            Product,
1012            Min,
1013            Max,
1014            Not,
1015            Or,
1016            And,
1017            Imply,
1018            Iff,
1019            Union,
1020            In,
1021            Intersect,
1022            Supset,
1023            SupsetEq,
1024            Subset,
1025            SubsetEq,
1026            Eq,
1027            Neq,
1028            Geq,
1029            Leq,
1030            Gt,
1031            Lt,
1032            SafeDiv,
1033            UnsafeDiv,
1034            SafeMod,
1035            UnsafeMod,
1036            Neg,
1037            Defined,
1038            Range,
1039            UnsafePow,
1040            SafePow,
1041            Flatten,
1042            AllDiff,
1043            Minus,
1044            Factorial,
1045            FlatAbsEq,
1046            FlatAllDiff,
1047            FlatSumGeq,
1048            FlatSumLeq,
1049            FlatIneq,
1050            FlatWatchedLiteral,
1051            FlatWeightedSumLeq,
1052            FlatWeightedSumGeq,
1053            FlatMinusEq,
1054            FlatProductEq,
1055            MinionDivEqUndefZero,
1056            MinionModuloEqUndefZero,
1057            MinionPow,
1058            MinionReify,
1059            MinionReifyImply,
1060            MinionWInIntervalSet,
1061            MinionWInSet,
1062            MinionElementOne,
1063            AuxDeclaration,
1064            SATInt,
1065            PairwiseSum,
1066            PairwiseProduct,
1067            Image,
1068            ImageSet,
1069            PreImage,
1070            Inverse,
1071            Restrict,
1072            LexLt,
1073            LexLeq,
1074            LexGt,
1075            LexGeq,
1076            FlatLexLt,
1077            FlatLexLeq,
1078            NegativeTable,
1079            Table
1080        )
1081    }
1082
1083    pub fn get_meta(&self) -> Metadata {
1084        let metas: VecDeque<Metadata> = self.children_bi();
1085        metas[0].clone()
1086    }
1087
1088    pub fn set_meta(&self, meta: Metadata) {
1089        self.transform_bi(&|_| meta.clone());
1090    }
1091
1092    /// Checks whether this expression is safe.
1093    ///
1094    /// An expression is unsafe if can be undefined, or if any of its children can be undefined.
1095    ///
1096    /// Unsafe expressions are (typically) prefixed with Unsafe in our AST, and can be made
1097    /// safe through the use of bubble rules.
1098    pub fn is_safe(&self) -> bool {
1099        // TODO: memoise in Metadata
1100        for expr in self.universe() {
1101            match expr {
1102                Expression::UnsafeDiv(_, _, _)
1103                | Expression::UnsafeMod(_, _, _)
1104                | Expression::UnsafePow(_, _, _)
1105                | Expression::UnsafeIndex(_, _, _)
1106                | Expression::Bubble(_, _, _)
1107                | Expression::UnsafeSlice(_, _, _) => {
1108                    return false;
1109                }
1110                _ => {}
1111            }
1112        }
1113        true
1114    }
1115
1116    /// True if the expression is an associative and commutative operator
1117    pub fn is_associative_commutative_operator(&self) -> bool {
1118        TryInto::<ACOperatorKind>::try_into(self).is_ok()
1119    }
1120
1121    /// True if the expression is a matrix literal.
1122    ///
1123    /// This is true for both forms of matrix literals: those with elements of type [`Literal`] and
1124    /// [`Expression`].
1125    pub fn is_matrix_literal(&self) -> bool {
1126        matches!(
1127            self,
1128            Expression::AbstractLiteral(_, AbstractLiteral::Matrix(_, _))
1129                | Expression::Atomic(
1130                    _,
1131                    Atom::Literal(Literal::AbstractLiteral(AbstractLiteral::Matrix(_, _))),
1132                )
1133        )
1134    }
1135
1136    /// True iff self and other are both atomic and identical.
1137    ///
1138    /// This method is useful to cheaply check equivalence. Assuming CSE is enabled, any unifiable
1139    /// expressions will be rewritten to a common variable. This is much cheaper than checking the
1140    /// entire subtrees of `self` and `other`.
1141    pub fn identical_atom_to(&self, other: &Expression) -> bool {
1142        let atom1: Result<&Atom, _> = self.try_into();
1143        let atom2: Result<&Atom, _> = other.try_into();
1144
1145        if let (Ok(atom1), Ok(atom2)) = (atom1, atom2) {
1146            atom2 == atom1
1147        } else {
1148            false
1149        }
1150    }
1151
1152    /// If the expression is a list, returns a *copied* vector of the inner expressions.
1153    ///
1154    /// A list is any a matrix with the domain `int(1..)`. This includes matrix literals without
1155    /// any explicitly specified domain.
1156    pub fn unwrap_list(&self) -> Option<Vec<Expression>> {
1157        match self {
1158            Expression::AbstractLiteral(_, matrix @ AbstractLiteral::Matrix(_, _)) => {
1159                matrix.unwrap_list().cloned()
1160            }
1161            Expression::Atomic(
1162                _,
1163                Atom::Literal(Literal::AbstractLiteral(matrix @ AbstractLiteral::Matrix(_, _))),
1164            ) => matrix.unwrap_list().map(|elems| {
1165                elems
1166                    .clone()
1167                    .into_iter()
1168                    .map(|x: Literal| Expression::Atomic(Metadata::new(), Atom::Literal(x)))
1169                    .collect_vec()
1170            }),
1171            _ => None,
1172        }
1173    }
1174
1175    /// If the expression is a matrix, gets it elements and index domain.
1176    ///
1177    /// **Consider using the safer [`Expression::unwrap_list`] instead.**
1178    ///
1179    /// It is generally undefined to edit the length of a matrix unless it is a list (as defined by
1180    /// [`Expression::unwrap_list`]). Users of this function should ensure that, if the matrix is
1181    /// reconstructed, the index domain and the number of elements in the matrix remain the same.
1182    pub fn unwrap_matrix_unchecked(self) -> Option<(Vec<Expression>, DomainPtr)> {
1183        match self {
1184            Expression::AbstractLiteral(_, AbstractLiteral::Matrix(elems, domain)) => {
1185                Some((elems, domain))
1186            }
1187            Expression::Atomic(
1188                _,
1189                Atom::Literal(Literal::AbstractLiteral(AbstractLiteral::Matrix(elems, domain))),
1190            ) => Some((
1191                elems
1192                    .into_iter()
1193                    .map(|x: Literal| Expression::Atomic(Metadata::new(), Atom::Literal(x)))
1194                    .collect_vec(),
1195                domain.into(),
1196            )),
1197
1198            _ => None,
1199        }
1200    }
1201
1202    /// For a Root expression, extends the inner vec with the given vec.
1203    ///
1204    /// # Panics
1205    /// Panics if the expression is not Root.
1206    pub fn extend_root(self, exprs: Vec<Expression>) -> Expression {
1207        match self {
1208            Expression::Root(meta, mut children) => {
1209                children.extend(exprs);
1210                Expression::Root(meta, children)
1211            }
1212            _ => panic!("extend_root called on a non-Root expression"),
1213        }
1214    }
1215
1216    /// Converts the expression to a literal, if possible.
1217    pub fn into_literal(self) -> Option<Literal> {
1218        match self {
1219            Expression::Atomic(_, Atom::Literal(lit)) => Some(lit),
1220            Expression::AbstractLiteral(_, abslit) => {
1221                Some(Literal::AbstractLiteral(abslit.into_literals()?))
1222            }
1223            Expression::Neg(_, e) => {
1224                let Literal::Int(i) = Moo::unwrap_or_clone(e).into_literal()? else {
1225                    bug!("negated literal should be an int");
1226                };
1227
1228                Some(Literal::Int(-i))
1229            }
1230
1231            _ => None,
1232        }
1233    }
1234
1235    /// If this expression is an associative-commutative operator, return its [ACOperatorKind].
1236    pub fn to_ac_operator_kind(&self) -> Option<ACOperatorKind> {
1237        TryFrom::try_from(self).ok()
1238    }
1239
1240    /// Returns the categories of all sub-expressions of self.
1241    pub fn universe_categories(&self) -> HashSet<Category> {
1242        self.universe()
1243            .into_iter()
1244            .map(|x| x.category_of())
1245            .collect()
1246    }
1247}
1248
1249pub fn get_function_domain(function: &Moo<Expression>) -> Option<DomainPtr> {
1250    let function_domain = function.domain_of()?;
1251    match function_domain.resolve().as_ref() {
1252        Some(d) => {
1253            match d.as_ref() {
1254                GroundDomain::Function(_, domain, _) => Some(domain.clone().into()),
1255                // Not defined for anything other than a function
1256                _ => None,
1257            }
1258        }
1259        None => {
1260            match function_domain.as_unresolved()? {
1261                UnresolvedDomain::Function(_, domain, _) => Some(domain.clone()),
1262                // Not defined for anything other than a function
1263                _ => None,
1264            }
1265        }
1266    }
1267}
1268
1269pub fn get_function_codomain(function: &Moo<Expression>) -> Option<DomainPtr> {
1270    let function_domain = function.domain_of()?;
1271    match function_domain.resolve().as_ref() {
1272        Some(d) => {
1273            match d.as_ref() {
1274                GroundDomain::Function(_, _, codomain) => Some(codomain.clone().into()),
1275                // Not defined for anything other than a function
1276                _ => None,
1277            }
1278        }
1279        None => {
1280            match function_domain.as_unresolved()? {
1281                UnresolvedDomain::Function(_, _, codomain) => Some(codomain.clone()),
1282                // Not defined for anything other than a function
1283                _ => None,
1284            }
1285        }
1286    }
1287}
1288
1289impl TryFrom<&Expression> for i32 {
1290    type Error = ();
1291
1292    fn try_from(value: &Expression) -> Result<Self, Self::Error> {
1293        let Expression::Atomic(_, atom) = value else {
1294            return Err(());
1295        };
1296
1297        let Atom::Literal(lit) = atom else {
1298            return Err(());
1299        };
1300
1301        let Literal::Int(i) = lit else {
1302            return Err(());
1303        };
1304
1305        Ok(*i)
1306    }
1307}
1308
1309impl TryFrom<Expression> for i32 {
1310    type Error = ();
1311
1312    fn try_from(value: Expression) -> Result<Self, Self::Error> {
1313        TryFrom::<&Expression>::try_from(&value)
1314    }
1315}
1316impl From<i32> for Expression {
1317    fn from(i: i32) -> Self {
1318        Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Int(i)))
1319    }
1320}
1321
1322impl From<bool> for Expression {
1323    fn from(b: bool) -> Self {
1324        Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Bool(b)))
1325    }
1326}
1327
1328impl From<Atom> for Expression {
1329    fn from(value: Atom) -> Self {
1330        Expression::Atomic(Metadata::new(), value)
1331    }
1332}
1333
1334impl From<Literal> for Expression {
1335    fn from(value: Literal) -> Self {
1336        Expression::Atomic(Metadata::new(), value.into())
1337    }
1338}
1339
1340impl From<Moo<Expression>> for Expression {
1341    fn from(val: Moo<Expression>) -> Self {
1342        val.as_ref().clone()
1343    }
1344}
1345
1346impl CategoryOf for Expression {
1347    fn category_of(&self) -> Category {
1348        // take highest category of all the expressions children
1349        let category = self.cata(&move |x,children| {
1350
1351            if let Some(max_category) = children.iter().max() {
1352                // if this expression contains subexpressions, return the maximum category of the
1353                // subexpressions
1354                *max_category
1355            } else {
1356                // this expression has no children
1357                let mut max_category = Category::Bottom;
1358
1359                // calculate the category by looking at all atoms, submodels, comprehensions, and
1360                // declarationptrs inside this expression
1361
1362                // this should generically cover all leaf types we currently have in oxide.
1363
1364                // if x contains submodels (including comprehensions)
1365                if !Biplate::<Model>::universe_bi(&x).is_empty() {
1366                    // assume that the category is decision
1367                    return Category::Decision;
1368                }
1369
1370                // if x contains atoms
1371                if let Some(max_atom_category) = Biplate::<Atom>::universe_bi(&x).iter().map(|x| x.category_of()).max()
1372                // and those atoms have a higher category than we already know about
1373                && max_atom_category > max_category{
1374                    // update category
1375                    max_category = max_atom_category;
1376                }
1377
1378                // if x contains declarationPtrs
1379                if let Some(max_declaration_category) = Biplate::<DeclarationPtr>::universe_bi(&x).iter().map(|x| x.category_of()).max()
1380                // and those pointers have a higher category than we already know about
1381                && max_declaration_category > max_category{
1382                    // update category
1383                    max_category = max_declaration_category;
1384                }
1385                max_category
1386
1387            }
1388        });
1389
1390        if cfg!(debug_assertions) {
1391            trace!(
1392                category= %category,
1393                expression= %self,
1394                "Called Expression::category_of()"
1395            );
1396        };
1397        category
1398    }
1399}
1400
1401impl Display for Expression {
1402    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
1403        match &self {
1404            Expression::Union(_, box1, box2) => {
1405                write!(f, "({} union {})", box1.clone(), box2.clone())
1406            }
1407            Expression::In(_, e1, e2) => {
1408                write!(f, "{e1} in {e2}")
1409            }
1410            Expression::Intersect(_, box1, box2) => {
1411                write!(f, "({} intersect {})", box1.clone(), box2.clone())
1412            }
1413            Expression::Supset(_, box1, box2) => {
1414                write!(f, "({} supset {})", box1.clone(), box2.clone())
1415            }
1416            Expression::SupsetEq(_, box1, box2) => {
1417                write!(f, "({} supsetEq {})", box1.clone(), box2.clone())
1418            }
1419            Expression::Subset(_, box1, box2) => {
1420                write!(f, "({} subset {})", box1.clone(), box2.clone())
1421            }
1422            Expression::SubsetEq(_, box1, box2) => {
1423                write!(f, "({} subsetEq {})", box1.clone(), box2.clone())
1424            }
1425
1426            Expression::AbstractLiteral(_, l) => l.fmt(f),
1427            Expression::Comprehension(_, c) => c.fmt(f),
1428            Expression::AbstractComprehension(_, c) => c.fmt(f),
1429            Expression::UnsafeIndex(_, e1, e2) => write!(f, "{e1}{}", pretty_vec(e2)),
1430            Expression::SafeIndex(_, e1, e2) => write!(f, "SafeIndex({e1},{})", pretty_vec(e2)),
1431            Expression::UnsafeSlice(_, e1, es) => {
1432                let args = es
1433                    .iter()
1434                    .map(|x| match x {
1435                        Some(x) => format!("{x}"),
1436                        None => "..".into(),
1437                    })
1438                    .join(",");
1439
1440                write!(f, "{e1}[{args}]")
1441            }
1442            Expression::SafeSlice(_, e1, es) => {
1443                let args = es
1444                    .iter()
1445                    .map(|x| match x {
1446                        Some(x) => format!("{x}"),
1447                        None => "..".into(),
1448                    })
1449                    .join(",");
1450
1451                write!(f, "SafeSlice({e1},[{args}])")
1452            }
1453            Expression::InDomain(_, e, domain) => {
1454                write!(f, "__inDomain({e},{domain})")
1455            }
1456            Expression::Root(_, exprs) => {
1457                write!(f, "{}", pretty_expressions_as_top_level(exprs))
1458            }
1459            Expression::DominanceRelation(_, expr) => write!(f, "DominanceRelation({expr})"),
1460            Expression::FromSolution(_, expr) => write!(f, "FromSolution({expr})"),
1461            Expression::Metavar(_, name) => write!(f, "&{name}"),
1462            Expression::Atomic(_, atom) => atom.fmt(f),
1463            Expression::Abs(_, a) => write!(f, "|{a}|"),
1464            Expression::Sum(_, e) => {
1465                write!(f, "sum({e})")
1466            }
1467            Expression::Product(_, e) => {
1468                write!(f, "product({e})")
1469            }
1470            Expression::Min(_, e) => {
1471                write!(f, "min({e})")
1472            }
1473            Expression::Max(_, e) => {
1474                write!(f, "max({e})")
1475            }
1476            Expression::Not(_, expr_box) => {
1477                write!(f, "!({})", expr_box.clone())
1478            }
1479            Expression::Or(_, e) => {
1480                write!(f, "or({e})")
1481            }
1482            Expression::And(_, e) => {
1483                write!(f, "and({e})")
1484            }
1485            Expression::Imply(_, box1, box2) => {
1486                write!(f, "({box1}) -> ({box2})")
1487            }
1488            Expression::Iff(_, box1, box2) => {
1489                write!(f, "({box1}) <-> ({box2})")
1490            }
1491            Expression::Eq(_, box1, box2) => {
1492                write!(f, "({} = {})", box1.clone(), box2.clone())
1493            }
1494            Expression::Neq(_, box1, box2) => {
1495                write!(f, "({} != {})", box1.clone(), box2.clone())
1496            }
1497            Expression::Geq(_, box1, box2) => {
1498                write!(f, "({} >= {})", box1.clone(), box2.clone())
1499            }
1500            Expression::Leq(_, box1, box2) => {
1501                write!(f, "({} <= {})", box1.clone(), box2.clone())
1502            }
1503            Expression::Gt(_, box1, box2) => {
1504                write!(f, "({} > {})", box1.clone(), box2.clone())
1505            }
1506            Expression::Lt(_, box1, box2) => {
1507                write!(f, "({} < {})", box1.clone(), box2.clone())
1508            }
1509            Expression::FlatSumGeq(_, box1, box2) => {
1510                write!(f, "SumGeq({}, {})", pretty_vec(box1), box2.clone())
1511            }
1512            Expression::FlatSumLeq(_, box1, box2) => {
1513                write!(f, "SumLeq({}, {})", pretty_vec(box1), box2.clone())
1514            }
1515            Expression::FlatIneq(_, box1, box2, box3) => write!(
1516                f,
1517                "Ineq({}, {}, {})",
1518                box1.clone(),
1519                box2.clone(),
1520                box3.clone()
1521            ),
1522            Expression::Flatten(_, n, m) => {
1523                if let Some(n) = n {
1524                    write!(f, "flatten({n}, {m})")
1525                } else {
1526                    write!(f, "flatten({m})")
1527                }
1528            }
1529            Expression::AllDiff(_, e) => {
1530                write!(f, "allDiff({e})")
1531            }
1532            Expression::Table(_, tuple_expr, rows_expr) => {
1533                write!(f, "table({tuple_expr}, {rows_expr})")
1534            }
1535            Expression::NegativeTable(_, tuple_expr, rows_expr) => {
1536                write!(f, "negativeTable({tuple_expr}, {rows_expr})")
1537            }
1538            Expression::Bubble(_, box1, box2) => {
1539                write!(f, "{{{} @ {}}}", box1.clone(), box2.clone())
1540            }
1541            Expression::SafeDiv(_, box1, box2) => {
1542                write!(f, "SafeDiv({}, {})", box1.clone(), box2.clone())
1543            }
1544            Expression::UnsafeDiv(_, box1, box2) => {
1545                write!(f, "({} / {})", box1.clone(), box2.clone())
1546            }
1547            Expression::UnsafePow(_, box1, box2) => {
1548                write!(f, "({} ** {})", box1.clone(), box2.clone())
1549            }
1550            Expression::SafePow(_, box1, box2) => {
1551                write!(f, "SafePow({}, {})", box1.clone(), box2.clone())
1552            }
1553            Expression::MinionDivEqUndefZero(_, box1, box2, box3) => {
1554                write!(
1555                    f,
1556                    "DivEq({}, {}, {})",
1557                    box1.clone(),
1558                    box2.clone(),
1559                    box3.clone()
1560                )
1561            }
1562            Expression::MinionModuloEqUndefZero(_, box1, box2, box3) => {
1563                write!(
1564                    f,
1565                    "ModEq({}, {}, {})",
1566                    box1.clone(),
1567                    box2.clone(),
1568                    box3.clone()
1569                )
1570            }
1571            Expression::FlatWatchedLiteral(_, x, l) => {
1572                write!(f, "WatchedLiteral({x},{l})")
1573            }
1574            Expression::MinionReify(_, box1, box2) => {
1575                write!(f, "Reify({}, {})", box1.clone(), box2.clone())
1576            }
1577            Expression::MinionReifyImply(_, box1, box2) => {
1578                write!(f, "ReifyImply({}, {})", box1.clone(), box2.clone())
1579            }
1580            Expression::MinionWInIntervalSet(_, atom, intervals) => {
1581                let intervals = intervals.iter().join(",");
1582                write!(f, "__minion_w_inintervalset({atom},[{intervals}])")
1583            }
1584            Expression::MinionWInSet(_, atom, values) => {
1585                let values = values.iter().join(",");
1586                write!(f, "__minion_w_inset({atom},[{values}])")
1587            }
1588            Expression::AuxDeclaration(_, reference, e) => {
1589                write!(f, "{} =aux {}", reference, e.clone())
1590            }
1591            Expression::UnsafeMod(_, a, b) => {
1592                write!(f, "{} % {}", a.clone(), b.clone())
1593            }
1594            Expression::SafeMod(_, a, b) => {
1595                write!(f, "SafeMod({},{})", a.clone(), b.clone())
1596            }
1597            Expression::Neg(_, a) => {
1598                write!(f, "-({})", a.clone())
1599            }
1600            Expression::Factorial(_, a) => {
1601                write!(f, "({})!", a.clone())
1602            }
1603            Expression::Minus(_, a, b) => {
1604                write!(f, "({} - {})", a.clone(), b.clone())
1605            }
1606            Expression::FlatAllDiff(_, es) => {
1607                write!(f, "__flat_alldiff({})", pretty_vec(es))
1608            }
1609            Expression::FlatAbsEq(_, a, b) => {
1610                write!(f, "AbsEq({},{})", a.clone(), b.clone())
1611            }
1612            Expression::FlatMinusEq(_, a, b) => {
1613                write!(f, "MinusEq({},{})", a.clone(), b.clone())
1614            }
1615            Expression::FlatProductEq(_, a, b, c) => {
1616                write!(
1617                    f,
1618                    "FlatProductEq({},{},{})",
1619                    a.clone(),
1620                    b.clone(),
1621                    c.clone()
1622                )
1623            }
1624            Expression::FlatWeightedSumLeq(_, cs, vs, total) => {
1625                write!(
1626                    f,
1627                    "FlatWeightedSumLeq({},{},{})",
1628                    pretty_vec(cs),
1629                    pretty_vec(vs),
1630                    total.clone()
1631                )
1632            }
1633            Expression::FlatWeightedSumGeq(_, cs, vs, total) => {
1634                write!(
1635                    f,
1636                    "FlatWeightedSumGeq({},{},{})",
1637                    pretty_vec(cs),
1638                    pretty_vec(vs),
1639                    total.clone()
1640                )
1641            }
1642            Expression::MinionPow(_, atom, atom1, atom2) => {
1643                write!(f, "MinionPow({atom},{atom1},{atom2})")
1644            }
1645            Expression::MinionElementOne(_, atoms, atom, atom1) => {
1646                let atoms = atoms.iter().join(",");
1647                write!(f, "__minion_element_one([{atoms}],{atom},{atom1})")
1648            }
1649
1650            Expression::ToInt(_, expr) => {
1651                write!(f, "toInt({expr})")
1652            }
1653
1654            Expression::SATInt(_, encoding, bits, (min, max)) => {
1655                write!(f, "SATInt({encoding:?}, {bits} [{min}, {max}])")
1656            }
1657
1658            Expression::PairwiseSum(_, a, b) => write!(f, "PairwiseSum({a}, {b})"),
1659            Expression::PairwiseProduct(_, a, b) => write!(f, "PairwiseProduct({a}, {b})"),
1660
1661            Expression::Defined(_, function) => write!(f, "defined({function})"),
1662            Expression::Range(_, function) => write!(f, "range({function})"),
1663            Expression::Image(_, function, elems) => write!(f, "image({function},{elems})"),
1664            Expression::ImageSet(_, function, elems) => write!(f, "imageSet({function},{elems})"),
1665            Expression::PreImage(_, function, elems) => write!(f, "preImage({function},{elems})"),
1666            Expression::Inverse(_, a, b) => write!(f, "inverse({a},{b})"),
1667            Expression::Restrict(_, function, domain) => write!(f, "restrict({function},{domain})"),
1668
1669            Expression::LexLt(_, a, b) => write!(f, "({a} <lex {b})"),
1670            Expression::LexLeq(_, a, b) => write!(f, "({a} <=lex {b})"),
1671            Expression::LexGt(_, a, b) => write!(f, "({a} >lex {b})"),
1672            Expression::LexGeq(_, a, b) => write!(f, "({a} >=lex {b})"),
1673            Expression::FlatLexLt(_, a, b) => {
1674                write!(f, "FlatLexLt({}, {})", pretty_vec(a), pretty_vec(b))
1675            }
1676            Expression::FlatLexLeq(_, a, b) => {
1677                write!(f, "FlatLexLeq({}, {})", pretty_vec(a), pretty_vec(b))
1678            }
1679        }
1680    }
1681}
1682
1683impl Typeable for Expression {
1684    fn return_type(&self) -> ReturnType {
1685        match self {
1686            Expression::Union(_, subject, _) => ReturnType::Set(Box::new(subject.return_type())),
1687            Expression::Intersect(_, subject, _) => {
1688                ReturnType::Set(Box::new(subject.return_type()))
1689            }
1690            Expression::In(_, _, _) => ReturnType::Bool,
1691            Expression::Supset(_, _, _) => ReturnType::Bool,
1692            Expression::SupsetEq(_, _, _) => ReturnType::Bool,
1693            Expression::Subset(_, _, _) => ReturnType::Bool,
1694            Expression::SubsetEq(_, _, _) => ReturnType::Bool,
1695            Expression::AbstractLiteral(_, lit) => lit.return_type(),
1696            Expression::UnsafeIndex(_, subject, idx) | Expression::SafeIndex(_, subject, idx) => {
1697                let subject_ty = subject.return_type();
1698                match subject_ty {
1699                    ReturnType::Matrix(_) => {
1700                        // For n-dimensional matrices, unwrap the element type until
1701                        // we either get to the innermost element type or the last index
1702                        let mut elem_typ = subject_ty;
1703                        let mut idx_len = idx.len();
1704                        while idx_len > 0
1705                            && let ReturnType::Matrix(new_elem_typ) = &elem_typ
1706                        {
1707                            elem_typ = *new_elem_typ.clone();
1708                            idx_len -= 1;
1709                        }
1710                        elem_typ
1711                    }
1712                    // TODO: We can implement indexing for these eventually
1713                    ReturnType::Record(_) | ReturnType::Tuple(_) => ReturnType::Unknown,
1714                    _ => bug!(
1715                        "Invalid indexing operation: expected the operand to be a collection, got {self}: {subject_ty}"
1716                    ),
1717                }
1718            }
1719            Expression::UnsafeSlice(_, subject, _) | Expression::SafeSlice(_, subject, _) => {
1720                ReturnType::Matrix(Box::new(subject.return_type()))
1721            }
1722            Expression::InDomain(_, _, _) => ReturnType::Bool,
1723            Expression::Comprehension(_, comp) => comp.return_type(),
1724            Expression::AbstractComprehension(_, comp) => comp.return_type(),
1725            Expression::Root(_, _) => ReturnType::Bool,
1726            Expression::DominanceRelation(_, _) => ReturnType::Bool,
1727            Expression::FromSolution(_, expr) => expr.return_type(),
1728            Expression::Metavar(_, _) => ReturnType::Unknown,
1729            Expression::Atomic(_, atom) => atom.return_type(),
1730            Expression::Abs(_, _) => ReturnType::Int,
1731            Expression::Sum(_, _) => ReturnType::Int,
1732            Expression::Product(_, _) => ReturnType::Int,
1733            Expression::Min(_, _) => ReturnType::Int,
1734            Expression::Max(_, _) => ReturnType::Int,
1735            Expression::Not(_, _) => ReturnType::Bool,
1736            Expression::Or(_, _) => ReturnType::Bool,
1737            Expression::Imply(_, _, _) => ReturnType::Bool,
1738            Expression::Iff(_, _, _) => ReturnType::Bool,
1739            Expression::And(_, _) => ReturnType::Bool,
1740            Expression::Eq(_, _, _) => ReturnType::Bool,
1741            Expression::Neq(_, _, _) => ReturnType::Bool,
1742            Expression::Geq(_, _, _) => ReturnType::Bool,
1743            Expression::Leq(_, _, _) => ReturnType::Bool,
1744            Expression::Gt(_, _, _) => ReturnType::Bool,
1745            Expression::Lt(_, _, _) => ReturnType::Bool,
1746            Expression::SafeDiv(_, _, _) => ReturnType::Int,
1747            Expression::UnsafeDiv(_, _, _) => ReturnType::Int,
1748            Expression::FlatAllDiff(_, _) => ReturnType::Bool,
1749            Expression::FlatSumGeq(_, _, _) => ReturnType::Bool,
1750            Expression::FlatSumLeq(_, _, _) => ReturnType::Bool,
1751            Expression::MinionDivEqUndefZero(_, _, _, _) => ReturnType::Bool,
1752            Expression::FlatIneq(_, _, _, _) => ReturnType::Bool,
1753            Expression::Flatten(_, _, matrix) => {
1754                let matrix_type = matrix.return_type();
1755                match matrix_type {
1756                    ReturnType::Matrix(_) => {
1757                        // unwrap until we get to innermost element
1758                        let mut elem_type = matrix_type;
1759                        while let ReturnType::Matrix(new_elem_type) = &elem_type {
1760                            elem_type = *new_elem_type.clone();
1761                        }
1762                        ReturnType::Matrix(Box::new(elem_type))
1763                    }
1764                    _ => bug!(
1765                        "Invalid indexing operation: expected the operand to be a collection, got {self}: {matrix_type}"
1766                    ),
1767                }
1768            }
1769            Expression::AllDiff(_, _) => ReturnType::Bool,
1770            Expression::Table(_, _, _) => ReturnType::Bool,
1771            Expression::NegativeTable(_, _, _) => ReturnType::Bool,
1772            Expression::Bubble(_, inner, _) => inner.return_type(),
1773            Expression::FlatWatchedLiteral(_, _, _) => ReturnType::Bool,
1774            Expression::MinionReify(_, _, _) => ReturnType::Bool,
1775            Expression::MinionReifyImply(_, _, _) => ReturnType::Bool,
1776            Expression::MinionWInIntervalSet(_, _, _) => ReturnType::Bool,
1777            Expression::MinionWInSet(_, _, _) => ReturnType::Bool,
1778            Expression::MinionElementOne(_, _, _, _) => ReturnType::Bool,
1779            Expression::AuxDeclaration(_, _, _) => ReturnType::Bool,
1780            Expression::UnsafeMod(_, _, _) => ReturnType::Int,
1781            Expression::SafeMod(_, _, _) => ReturnType::Int,
1782            Expression::MinionModuloEqUndefZero(_, _, _, _) => ReturnType::Bool,
1783            Expression::Neg(_, _) => ReturnType::Int,
1784            Expression::Factorial(_, _) => ReturnType::Int,
1785            Expression::UnsafePow(_, _, _) => ReturnType::Int,
1786            Expression::SafePow(_, _, _) => ReturnType::Int,
1787            Expression::Minus(_, _, _) => ReturnType::Int,
1788            Expression::FlatAbsEq(_, _, _) => ReturnType::Bool,
1789            Expression::FlatMinusEq(_, _, _) => ReturnType::Bool,
1790            Expression::FlatProductEq(_, _, _, _) => ReturnType::Bool,
1791            Expression::FlatWeightedSumLeq(_, _, _, _) => ReturnType::Bool,
1792            Expression::FlatWeightedSumGeq(_, _, _, _) => ReturnType::Bool,
1793            Expression::MinionPow(_, _, _, _) => ReturnType::Bool,
1794            Expression::ToInt(_, _) => ReturnType::Int,
1795            Expression::SATInt(..) => ReturnType::Int,
1796            Expression::PairwiseSum(_, _, _) => ReturnType::Int,
1797            Expression::PairwiseProduct(_, _, _) => ReturnType::Int,
1798            Expression::Defined(_, function) => {
1799                let subject = function.return_type();
1800                match subject {
1801                    ReturnType::Function(domain, _) => *domain,
1802                    _ => bug!(
1803                        "Invalid defined operation: expected the operand to be a function, got {self}: {subject}"
1804                    ),
1805                }
1806            }
1807            Expression::Range(_, function) => {
1808                let subject = function.return_type();
1809                match subject {
1810                    ReturnType::Function(_, codomain) => *codomain,
1811                    _ => bug!(
1812                        "Invalid range operation: expected the operand to be a function, got {self}: {subject}"
1813                    ),
1814                }
1815            }
1816            Expression::Image(_, function, _) => {
1817                let subject = function.return_type();
1818                match subject {
1819                    ReturnType::Function(_, codomain) => *codomain,
1820                    _ => bug!(
1821                        "Invalid image operation: expected the operand to be a function, got {self}: {subject}"
1822                    ),
1823                }
1824            }
1825            Expression::ImageSet(_, function, _) => {
1826                let subject = function.return_type();
1827                match subject {
1828                    ReturnType::Function(_, codomain) => *codomain,
1829                    _ => bug!(
1830                        "Invalid imageSet operation: expected the operand to be a function, got {self}: {subject}"
1831                    ),
1832                }
1833            }
1834            Expression::PreImage(_, function, _) => {
1835                let subject = function.return_type();
1836                match subject {
1837                    ReturnType::Function(domain, _) => *domain,
1838                    _ => bug!(
1839                        "Invalid preImage operation: expected the operand to be a function, got {self}: {subject}"
1840                    ),
1841                }
1842            }
1843            Expression::Restrict(_, function, new_domain) => {
1844                let subject = function.return_type();
1845                match subject {
1846                    ReturnType::Function(_, codomain) => {
1847                        ReturnType::Function(Box::new(new_domain.return_type()), codomain)
1848                    }
1849                    _ => bug!(
1850                        "Invalid preImage operation: expected the operand to be a function, got {self}: {subject}"
1851                    ),
1852                }
1853            }
1854            Expression::Inverse(..) => ReturnType::Bool,
1855            Expression::LexLt(..) => ReturnType::Bool,
1856            Expression::LexGt(..) => ReturnType::Bool,
1857            Expression::LexLeq(..) => ReturnType::Bool,
1858            Expression::LexGeq(..) => ReturnType::Bool,
1859            Expression::FlatLexLt(..) => ReturnType::Bool,
1860            Expression::FlatLexLeq(..) => ReturnType::Bool,
1861        }
1862    }
1863}
1864
1865impl Expression {
1866    /// Visit each direct `Expression` child by reference, without cloning.
1867    fn for_each_expr_child(&self, f: &mut impl FnMut(&Expression)) {
1868        match self {
1869            // Special Case
1870            Expression::AbstractLiteral(_, alit) => match alit {
1871                AbstractLiteral::Set(v) | AbstractLiteral::MSet(v) | AbstractLiteral::Tuple(v) => {
1872                    for expr in v {
1873                        f(expr);
1874                    }
1875                }
1876                AbstractLiteral::Matrix(v, _domain) => {
1877                    for expr in v {
1878                        f(expr);
1879                    }
1880                }
1881                AbstractLiteral::Record(rs) => {
1882                    for r in rs {
1883                        f(&r.value);
1884                    }
1885                }
1886                AbstractLiteral::Function(vs) => {
1887                    for (a, b) in vs {
1888                        f(a);
1889                        f(b);
1890                    }
1891                }
1892            },
1893            Expression::Root(_, vs) => {
1894                for expr in vs {
1895                    f(expr);
1896                }
1897            }
1898
1899            // Moo<Expression>
1900            Expression::DominanceRelation(_, m1)
1901            | Expression::ToInt(_, m1)
1902            | Expression::Abs(_, m1)
1903            | Expression::Sum(_, m1)
1904            | Expression::Product(_, m1)
1905            | Expression::Min(_, m1)
1906            | Expression::Max(_, m1)
1907            | Expression::Not(_, m1)
1908            | Expression::Or(_, m1)
1909            | Expression::And(_, m1)
1910            | Expression::Neg(_, m1)
1911            | Expression::Defined(_, m1)
1912            | Expression::AllDiff(_, m1)
1913            | Expression::Factorial(_, m1)
1914            | Expression::Range(_, m1) => {
1915                f(m1);
1916            }
1917
1918            // Moo<Expression> + Moo<Expression>
1919            Expression::Table(_, m1, m2)
1920            | Expression::NegativeTable(_, m1, m2)
1921            | Expression::Bubble(_, m1, m2)
1922            | Expression::Imply(_, m1, m2)
1923            | Expression::Iff(_, m1, m2)
1924            | Expression::Union(_, m1, m2)
1925            | Expression::In(_, m1, m2)
1926            | Expression::Intersect(_, m1, m2)
1927            | Expression::Supset(_, m1, m2)
1928            | Expression::SupsetEq(_, m1, m2)
1929            | Expression::Subset(_, m1, m2)
1930            | Expression::SubsetEq(_, m1, m2)
1931            | Expression::Eq(_, m1, m2)
1932            | Expression::Neq(_, m1, m2)
1933            | Expression::Geq(_, m1, m2)
1934            | Expression::Leq(_, m1, m2)
1935            | Expression::Gt(_, m1, m2)
1936            | Expression::Lt(_, m1, m2)
1937            | Expression::SafeDiv(_, m1, m2)
1938            | Expression::UnsafeDiv(_, m1, m2)
1939            | Expression::SafeMod(_, m1, m2)
1940            | Expression::UnsafeMod(_, m1, m2)
1941            | Expression::UnsafePow(_, m1, m2)
1942            | Expression::SafePow(_, m1, m2)
1943            | Expression::Minus(_, m1, m2)
1944            | Expression::PairwiseSum(_, m1, m2)
1945            | Expression::PairwiseProduct(_, m1, m2)
1946            | Expression::Image(_, m1, m2)
1947            | Expression::ImageSet(_, m1, m2)
1948            | Expression::PreImage(_, m1, m2)
1949            | Expression::Inverse(_, m1, m2)
1950            | Expression::Restrict(_, m1, m2)
1951            | Expression::LexLt(_, m1, m2)
1952            | Expression::LexLeq(_, m1, m2)
1953            | Expression::LexGt(_, m1, m2)
1954            | Expression::LexGeq(_, m1, m2) => {
1955                f(m1);
1956                f(m2);
1957            }
1958
1959            // Moo<Expression> + Vec<Expression>
1960            Expression::UnsafeIndex(_, m, vs) | Expression::SafeIndex(_, m, vs) => {
1961                f(m);
1962                for v in vs {
1963                    f(v);
1964                }
1965            }
1966
1967            // Moo<Expression> + Vec<Option<Expression>>
1968            Expression::UnsafeSlice(_, m, vs) | Expression::SafeSlice(_, m, vs) => {
1969                f(m);
1970                for e in vs.iter().flatten() {
1971                    f(e);
1972                }
1973            }
1974
1975            // Moo<Expression> + DomainPtr
1976            Expression::InDomain(_, m, _) => {
1977                f(m);
1978            }
1979
1980            // Option<Moo<Expression>> + Moo<Expression>
1981            Expression::Flatten(_, opt, m) => {
1982                if let Some(e) = opt {
1983                    f(e);
1984                }
1985                f(m);
1986            }
1987
1988            // Moo<Expression> + Atom
1989            Expression::MinionReify(_, m, _) | Expression::MinionReifyImply(_, m, _) => {
1990                f(m);
1991            }
1992
1993            // Reference + Moo<Expression>
1994            Expression::AuxDeclaration(_, _, m) => {
1995                f(m);
1996            }
1997
1998            // SATIntEncoding + Moo<Expression> + (i32, i32)
1999            Expression::SATInt(_, _, m, _) => {
2000                f(m);
2001            }
2002
2003            // No Expression children
2004            Expression::Comprehension(_, _)
2005            | Expression::AbstractComprehension(_, _)
2006            | Expression::Atomic(_, _)
2007            | Expression::FromSolution(_, _)
2008            | Expression::Metavar(_, _)
2009            | Expression::FlatAbsEq(_, _, _)
2010            | Expression::FlatMinusEq(_, _, _)
2011            | Expression::FlatProductEq(_, _, _, _)
2012            | Expression::MinionDivEqUndefZero(_, _, _, _)
2013            | Expression::MinionModuloEqUndefZero(_, _, _, _)
2014            | Expression::MinionPow(_, _, _, _)
2015            | Expression::FlatAllDiff(_, _)
2016            | Expression::FlatSumGeq(_, _, _)
2017            | Expression::FlatSumLeq(_, _, _)
2018            | Expression::FlatIneq(_, _, _, _)
2019            | Expression::FlatWatchedLiteral(_, _, _)
2020            | Expression::FlatWeightedSumLeq(_, _, _, _)
2021            | Expression::FlatWeightedSumGeq(_, _, _, _)
2022            | Expression::MinionWInIntervalSet(_, _, _)
2023            | Expression::MinionWInSet(_, _, _)
2024            | Expression::MinionElementOne(_, _, _, _)
2025            | Expression::FlatLexLt(_, _, _)
2026            | Expression::FlatLexLeq(_, _, _) => {}
2027        }
2028    }
2029}
2030
2031impl CacheHashable for Expression {
2032    fn invalidate_cache(&self) {
2033        self.meta_ref()
2034            .stored_hash
2035            .store(NO_HASH, Ordering::Relaxed);
2036    }
2037
2038    fn invalidate_cache_recursive(&self) {
2039        self.invalidate_cache();
2040        self.for_each_expr_child(&mut |child| {
2041            child.invalidate_cache_recursive();
2042        });
2043    }
2044
2045    fn get_cached_hash(&self) -> u64 {
2046        let stored = self.meta_ref().stored_hash.load(Ordering::Relaxed);
2047        if stored != NO_HASH {
2048            HASH_HITS.fetch_add(1, Ordering::Relaxed);
2049            return stored;
2050        }
2051        HASH_MISSES.fetch_add(1, Ordering::Relaxed);
2052        self.calculate_hash()
2053    }
2054
2055    fn calculate_hash(&self) -> u64 {
2056        let mut hasher = DefaultHasher::new();
2057        std::mem::discriminant(self).hash(&mut hasher);
2058        match self {
2059            // Special Case
2060            Expression::AbstractLiteral(_, alit) => match alit {
2061                AbstractLiteral::Set(v) | AbstractLiteral::MSet(v) | AbstractLiteral::Tuple(v) => {
2062                    for expr in v {
2063                        expr.get_cached_hash().hash(&mut hasher);
2064                    }
2065                }
2066                AbstractLiteral::Matrix(v, domain) => {
2067                    domain.hash(&mut hasher);
2068                    for expr in v {
2069                        expr.get_cached_hash().hash(&mut hasher);
2070                    }
2071                }
2072                AbstractLiteral::Record(rs) => {
2073                    for r in rs {
2074                        r.name.hash(&mut hasher);
2075                        r.value.get_cached_hash().hash(&mut hasher);
2076                    }
2077                }
2078                AbstractLiteral::Function(vs) => {
2079                    for (a, b) in vs {
2080                        a.get_cached_hash().hash(&mut hasher);
2081                        b.get_cached_hash().hash(&mut hasher);
2082                    }
2083                }
2084            },
2085            Expression::Root(_, vs) => {
2086                for expr in vs {
2087                    expr.get_cached_hash().hash(&mut hasher);
2088                }
2089            }
2090
2091            // Moo<Expression>
2092            Expression::DominanceRelation(_, m1)
2093            | Expression::ToInt(_, m1)
2094            | Expression::Abs(_, m1)
2095            | Expression::Sum(_, m1)
2096            | Expression::Product(_, m1)
2097            | Expression::Min(_, m1)
2098            | Expression::Max(_, m1)
2099            | Expression::Not(_, m1)
2100            | Expression::Or(_, m1)
2101            | Expression::And(_, m1)
2102            | Expression::Neg(_, m1)
2103            | Expression::Defined(_, m1)
2104            | Expression::AllDiff(_, m1)
2105            | Expression::Factorial(_, m1)
2106            | Expression::Range(_, m1) => {
2107                m1.get_cached_hash().hash(&mut hasher);
2108            }
2109
2110            // Moo<Expression> + Moo<Expression>
2111            Expression::Table(_, m1, m2)
2112            | Expression::NegativeTable(_, m1, m2)
2113            | Expression::Bubble(_, m1, m2)
2114            | Expression::Imply(_, m1, m2)
2115            | Expression::Iff(_, m1, m2)
2116            | Expression::Union(_, m1, m2)
2117            | Expression::In(_, m1, m2)
2118            | Expression::Intersect(_, m1, m2)
2119            | Expression::Supset(_, m1, m2)
2120            | Expression::SupsetEq(_, m1, m2)
2121            | Expression::Subset(_, m1, m2)
2122            | Expression::SubsetEq(_, m1, m2)
2123            | Expression::Eq(_, m1, m2)
2124            | Expression::Neq(_, m1, m2)
2125            | Expression::Geq(_, m1, m2)
2126            | Expression::Leq(_, m1, m2)
2127            | Expression::Gt(_, m1, m2)
2128            | Expression::Lt(_, m1, m2)
2129            | Expression::SafeDiv(_, m1, m2)
2130            | Expression::UnsafeDiv(_, m1, m2)
2131            | Expression::SafeMod(_, m1, m2)
2132            | Expression::UnsafeMod(_, m1, m2)
2133            | Expression::UnsafePow(_, m1, m2)
2134            | Expression::SafePow(_, m1, m2)
2135            | Expression::Minus(_, m1, m2)
2136            | Expression::PairwiseSum(_, m1, m2)
2137            | Expression::PairwiseProduct(_, m1, m2)
2138            | Expression::Image(_, m1, m2)
2139            | Expression::ImageSet(_, m1, m2)
2140            | Expression::PreImage(_, m1, m2)
2141            | Expression::Inverse(_, m1, m2)
2142            | Expression::Restrict(_, m1, m2)
2143            | Expression::LexLt(_, m1, m2)
2144            | Expression::LexLeq(_, m1, m2)
2145            | Expression::LexGt(_, m1, m2)
2146            | Expression::LexGeq(_, m1, m2) => {
2147                m1.get_cached_hash().hash(&mut hasher);
2148                m2.get_cached_hash().hash(&mut hasher);
2149            }
2150            // Moo<Expression> + Vec<Expression>
2151            Expression::UnsafeIndex(_, m, vs) | Expression::SafeIndex(_, m, vs) => {
2152                m.get_cached_hash().hash(&mut hasher);
2153                for v in vs {
2154                    v.get_cached_hash().hash(&mut hasher);
2155                }
2156            }
2157
2158            // Moo<Expression> + Vec<Option<Expression>>
2159            Expression::UnsafeSlice(_, m, vs) | Expression::SafeSlice(_, m, vs) => {
2160                m.get_cached_hash().hash(&mut hasher);
2161                for v in vs {
2162                    match v {
2163                        Some(e) => e.get_cached_hash().hash(&mut hasher),
2164                        None => 0u64.hash(&mut hasher),
2165                    }
2166                }
2167            }
2168
2169            // Moo<Expression> + DomainPtr
2170            Expression::InDomain(_, m, d) => {
2171                m.get_cached_hash().hash(&mut hasher);
2172                d.hash(&mut hasher);
2173            }
2174
2175            // Option<Moo<Expression>> + Moo<Expression>
2176            Expression::Flatten(_, opt, m) => {
2177                if let Some(e) = opt {
2178                    e.get_cached_hash().hash(&mut hasher);
2179                }
2180                m.get_cached_hash().hash(&mut hasher);
2181            }
2182
2183            // Moo<Expression> + Atom
2184            Expression::MinionReify(_, m, a) | Expression::MinionReifyImply(_, m, a) => {
2185                m.get_cached_hash().hash(&mut hasher);
2186                a.hash(&mut hasher);
2187            }
2188
2189            // Reference + Moo<Expression>
2190            Expression::AuxDeclaration(_, r, m) => {
2191                r.hash(&mut hasher);
2192                m.get_cached_hash().hash(&mut hasher);
2193            }
2194
2195            // SATIntEncoding + Moo<Expression> + (i32, i32)
2196            Expression::SATInt(_, enc, m, bounds) => {
2197                enc.hash(&mut hasher);
2198                m.get_cached_hash().hash(&mut hasher);
2199                bounds.hash(&mut hasher);
2200            }
2201
2202            // Non-Expression Moo types - hash normally
2203            Expression::Comprehension(_, c) => c.hash(&mut hasher),
2204            Expression::AbstractComprehension(_, c) => c.hash(&mut hasher),
2205
2206            // Leaf types - no Expression children
2207            Expression::Atomic(_, a) => a.hash(&mut hasher),
2208            Expression::FromSolution(_, a) => a.hash(&mut hasher),
2209            Expression::Metavar(_, u) => u.hash(&mut hasher),
2210
2211            // Two Moo<Atom>
2212            Expression::FlatAbsEq(_, a1, a2) | Expression::FlatMinusEq(_, a1, a2) => {
2213                a1.hash(&mut hasher);
2214                a2.hash(&mut hasher);
2215            }
2216
2217            // Three Moo<Atom>
2218            Expression::FlatProductEq(_, a1, a2, a3)
2219            | Expression::MinionDivEqUndefZero(_, a1, a2, a3)
2220            | Expression::MinionModuloEqUndefZero(_, a1, a2, a3)
2221            | Expression::MinionPow(_, a1, a2, a3) => {
2222                a1.hash(&mut hasher);
2223                a2.hash(&mut hasher);
2224                a3.hash(&mut hasher);
2225            }
2226
2227            // Vec<Atom>
2228            Expression::FlatAllDiff(_, vs) => {
2229                for v in vs {
2230                    v.hash(&mut hasher);
2231                }
2232            }
2233
2234            // Vec<Atom> + Atom
2235            Expression::FlatSumGeq(_, vs, a) | Expression::FlatSumLeq(_, vs, a) => {
2236                for v in vs {
2237                    v.hash(&mut hasher);
2238                }
2239                a.hash(&mut hasher);
2240            }
2241
2242            // Moo<Atom> + Moo<Atom> + Box<Literal>
2243            Expression::FlatIneq(_, a1, a2, lit) => {
2244                a1.hash(&mut hasher);
2245                a2.hash(&mut hasher);
2246                lit.hash(&mut hasher);
2247            }
2248
2249            // Reference + Literal
2250            Expression::FlatWatchedLiteral(_, r, l) => {
2251                r.hash(&mut hasher);
2252                l.hash(&mut hasher);
2253            }
2254
2255            // Vec<Literal> + Vec<Atom> + Moo<Atom>
2256            Expression::FlatWeightedSumLeq(_, lits, atoms, a)
2257            | Expression::FlatWeightedSumGeq(_, lits, atoms, a) => {
2258                for l in lits {
2259                    l.hash(&mut hasher);
2260                }
2261                for at in atoms {
2262                    at.hash(&mut hasher);
2263                }
2264                a.hash(&mut hasher);
2265            }
2266
2267            // Atom + Vec<i32>
2268            Expression::MinionWInIntervalSet(_, a, vs) | Expression::MinionWInSet(_, a, vs) => {
2269                a.hash(&mut hasher);
2270                for v in vs {
2271                    v.hash(&mut hasher);
2272                }
2273            }
2274
2275            // Vec<Atom> + Moo<Atom> + Moo<Atom>
2276            Expression::MinionElementOne(_, vs, a1, a2) => {
2277                for v in vs {
2278                    v.hash(&mut hasher);
2279                }
2280                a1.hash(&mut hasher);
2281                a2.hash(&mut hasher);
2282            }
2283
2284            // Vec<Atom> + Vec<Atom>
2285            Expression::FlatLexLt(_, v1, v2) | Expression::FlatLexLeq(_, v1, v2) => {
2286                for v in v1 {
2287                    v.hash(&mut hasher);
2288                }
2289                for v in v2 {
2290                    v.hash(&mut hasher);
2291                }
2292            }
2293        };
2294
2295        let result = hasher.finish();
2296        self.meta_ref().stored_hash.swap(result, Ordering::Relaxed);
2297        result
2298    }
2299}
2300
2301#[cfg(test)]
2302mod tests {
2303    use crate::matrix_expr;
2304
2305    use super::*;
2306
2307    #[test]
2308    fn test_domain_of_constant_sum() {
2309        let c1 = Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Int(1)));
2310        let c2 = Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Int(2)));
2311        let sum = Expression::Sum(Metadata::new(), Moo::new(matrix_expr![c1, c2]));
2312        assert_eq!(sum.domain_of(), Some(Domain::int(vec![Range::Single(3)])));
2313    }
2314
2315    #[test]
2316    fn test_domain_of_constant_invalid_type() {
2317        let c1 = Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Int(1)));
2318        let c2 = Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Bool(true)));
2319        let sum = Expression::Sum(Metadata::new(), Moo::new(matrix_expr![c1, c2]));
2320        assert_eq!(sum.domain_of(), None);
2321    }
2322
2323    #[test]
2324    fn test_domain_of_empty_sum() {
2325        let sum = Expression::Sum(Metadata::new(), Moo::new(matrix_expr![]));
2326        assert_eq!(sum.domain_of(), None);
2327    }
2328}