Skip to main content

conjure_cp_core/ast/
expressions.rs

1use std::collections::{HashSet, VecDeque};
2use std::fmt::{Display, Formatter};
3use tracing::trace;
4
5use conjure_cp_enum_compatibility_macro::document_compatibility;
6use itertools::Itertools;
7use serde::{Deserialize, Serialize};
8use ustr::Ustr;
9
10use polyquine::Quine;
11use uniplate::{Biplate, Uniplate};
12
13use crate::bug;
14
15use super::abstract_comprehension::AbstractComprehension;
16use super::ac_operators::ACOperatorKind;
17use super::categories::{Category, CategoryOf};
18use super::comprehension::Comprehension;
19use super::domains::HasDomain as _;
20use super::pretty::{pretty_expressions_as_top_level, pretty_vec};
21use super::records::RecordValue;
22use super::sat_encoding::SATIntEncoding;
23use super::{
24    AbstractLiteral, Atom, DeclarationPtr, Domain, DomainPtr, GroundDomain, IntVal, Literal,
25    Metadata, Moo, Name, Range, Reference, ReturnType, SetAttr, SubModel, SymbolTable,
26    SymbolTablePtr, Typeable, UnresolvedDomain, matrix,
27};
28
29// Ensure that this type doesn't get too big
30//
31// If you triggered this assertion, you either made a variant of this enum that is too big, or you
32// made Name,Literal,AbstractLiteral,Atom bigger, which made this bigger! To fix this, put some
33// stuff in boxes.
34//
35// Enums take the size of their largest variant, so an enum with mostly small variants and a few
36// large ones wastes memory... A larger Expression type also slows down Oxide.
37//
38// For more information, and more details on type sizes and how to measure them, see the commit
39// message for 6012de809 (perf: reduce size of AST types, 2025-06-18).
40//
41// You can also see type sizes in the rustdoc documentation, generated by ./tools/gen_docs.sh
42//
43// https://github.com/conjure-cp/conjure-oxide/commit/6012de8096ca491ded91ecec61352fdf4e994f2e
44
45// TODO: box all usages of Metadata to bring this down a bit more - I have added variants to
46// ReturnType, and Metadata contains ReturnType, so Metadata has got bigger. Metadata will get a
47// lot bigger still when we start using it for memoisation, so it should really be
48// boxed ~niklasdewally
49
50// expect size of Expression to be 112 bytes
51static_assertions::assert_eq_size!([u8; 104], Expression);
52
53/// Represents different types of expressions used to define rules and constraints in the model.
54///
55/// The `Expression` enum includes operations, constants, and variable references
56/// used to build rules and conditions for the model.
57#[document_compatibility]
58#[derive(Clone, Debug, Hash, PartialEq, Eq, Serialize, Deserialize, Uniplate, Quine)]
59#[biplate(to=AbstractComprehension)]
60#[biplate(to=AbstractLiteral<Expression>)]
61#[biplate(to=AbstractLiteral<Literal>)]
62#[biplate(to=Atom)]
63#[biplate(to=Comprehension)]
64#[biplate(to=DeclarationPtr)]
65#[biplate(to=DomainPtr)]
66#[biplate(to=Literal)]
67#[biplate(to=Metadata)]
68#[biplate(to=Name)]
69#[biplate(to=Option<Expression>)]
70#[biplate(to=RecordValue<Expression>)]
71#[biplate(to=RecordValue<Literal>)]
72#[biplate(to=Reference)]
73#[biplate(to=SubModel)]
74#[biplate(to=SymbolTable)]
75#[biplate(to=SymbolTablePtr)]
76#[biplate(to=Vec<Expression>)]
77#[path_prefix(conjure_cp::ast)]
78pub enum Expression {
79    AbstractLiteral(Metadata, AbstractLiteral<Expression>),
80    /// The top of the model
81    Root(Metadata, Vec<Expression>),
82
83    /// An expression representing "A is valid as long as B is true"
84    /// Turns into a conjunction when it reaches a boolean context
85    Bubble(Metadata, Moo<Expression>, Moo<Expression>),
86
87    /// A comprehension.
88    ///
89    /// The inside of the comprehension opens a new scope.
90    // todo (gskorokhod): Comprehension contains a SubModel which contains a bunch of Rc pointers.
91    // This makes implementing Quine tricky (it doesnt support Rc, by design). Skip it for now.
92    #[polyquine_skip]
93    Comprehension(Metadata, Moo<Comprehension>),
94
95    /// Higher-level abstract comprehension
96    #[polyquine_skip] // no idea what this is lol but it stops rustc screaming at me
97    AbstractComprehension(Metadata, Moo<AbstractComprehension>),
98
99    /// Defines dominance ("Solution A is preferred over Solution B")
100    DominanceRelation(Metadata, Moo<Expression>),
101    /// `fromSolution(name)` - Used in dominance relation definitions
102    FromSolution(Metadata, Moo<Atom>),
103
104    #[polyquine_with(arm = (_, name) => {
105        let ident = proc_macro2::Ident::new(name.as_str(), proc_macro2::Span::call_site());
106        quote::quote! { #ident.clone().into() }
107    })]
108    Metavar(Metadata, Ustr),
109
110    Atomic(Metadata, Atom),
111
112    /// A matrix index.
113    ///
114    /// Defined iff the indices are within their respective index domains.
115    #[compatible(JsonInput)]
116    UnsafeIndex(Metadata, Moo<Expression>, Vec<Expression>),
117
118    /// A safe matrix index.
119    ///
120    /// See [`Expression::UnsafeIndex`]
121    #[compatible(SMT)]
122    SafeIndex(Metadata, Moo<Expression>, Vec<Expression>),
123
124    /// A matrix slice: `a[indices]`.
125    ///
126    /// One of the indicies may be `None`, representing the dimension of the matrix we want to take
127    /// a slice of. For example, for some 3d matrix a, `a[1,..,2]` has the indices
128    /// `Some(1),None,Some(2)`.
129    ///
130    /// It is assumed that the slice only has one "wild-card" dimension and thus is 1 dimensional.
131    ///
132    /// Defined iff the defined indices are within their respective index domains.
133    #[compatible(JsonInput)]
134    UnsafeSlice(Metadata, Moo<Expression>, Vec<Option<Expression>>),
135
136    /// A safe matrix slice: `a[indices]`.
137    ///
138    /// See [`Expression::UnsafeSlice`].
139    SafeSlice(Metadata, Moo<Expression>, Vec<Option<Expression>>),
140
141    /// `inDomain(x,domain)` iff `x` is in the domain `domain`.
142    ///
143    /// This cannot be constructed from Essence input, nor passed to a solver: this expression is
144    /// mainly used during the conversion of `UnsafeIndex` and `UnsafeSlice` to `SafeIndex` and
145    /// `SafeSlice` respectively.
146    InDomain(Metadata, Moo<Expression>, DomainPtr),
147
148    /// `toInt(b)` casts boolean expression b to an integer.
149    ///
150    /// - If b is false, then `toInt(b) == 0`
151    ///
152    /// - If b is true, then `toInt(b) == 1`
153    #[compatible(SMT)]
154    ToInt(Metadata, Moo<Expression>),
155
156    // todo (gskorokhod): Same reason as for Comprehension
157    #[polyquine_skip]
158    Scope(Metadata, Moo<SubModel>),
159
160    /// `|x|` - absolute value of `x`
161    #[compatible(JsonInput, SMT)]
162    Abs(Metadata, Moo<Expression>),
163
164    /// `sum(<vec_expr>)`
165    #[compatible(JsonInput, SMT)]
166    Sum(Metadata, Moo<Expression>),
167
168    /// `a * b * c * ...`
169    #[compatible(JsonInput, SMT)]
170    Product(Metadata, Moo<Expression>),
171
172    /// `min(<vec_expr>)`
173    #[compatible(JsonInput, SMT)]
174    Min(Metadata, Moo<Expression>),
175
176    /// `max(<vec_expr>)`
177    #[compatible(JsonInput, SMT)]
178    Max(Metadata, Moo<Expression>),
179
180    /// `not(a)`
181    #[compatible(JsonInput, SAT, SMT)]
182    Not(Metadata, Moo<Expression>),
183
184    /// `or(<vec_expr>)`
185    #[compatible(JsonInput, SAT, SMT)]
186    Or(Metadata, Moo<Expression>),
187
188    /// `and(<vec_expr>)`
189    #[compatible(JsonInput, SAT, SMT)]
190    And(Metadata, Moo<Expression>),
191
192    /// Ensures that `a->b` (material implication).
193    #[compatible(JsonInput, SMT)]
194    Imply(Metadata, Moo<Expression>, Moo<Expression>),
195
196    /// `iff(a, b)` a <-> b
197    #[compatible(JsonInput, SMT)]
198    Iff(Metadata, Moo<Expression>, Moo<Expression>),
199
200    #[compatible(JsonInput)]
201    Union(Metadata, Moo<Expression>, Moo<Expression>),
202
203    #[compatible(JsonInput)]
204    In(Metadata, Moo<Expression>, Moo<Expression>),
205
206    #[compatible(JsonInput)]
207    Intersect(Metadata, Moo<Expression>, Moo<Expression>),
208
209    #[compatible(JsonInput)]
210    Supset(Metadata, Moo<Expression>, Moo<Expression>),
211
212    #[compatible(JsonInput)]
213    SupsetEq(Metadata, Moo<Expression>, Moo<Expression>),
214
215    #[compatible(JsonInput)]
216    Subset(Metadata, Moo<Expression>, Moo<Expression>),
217
218    #[compatible(JsonInput)]
219    SubsetEq(Metadata, Moo<Expression>, Moo<Expression>),
220
221    #[compatible(JsonInput, SMT)]
222    Eq(Metadata, Moo<Expression>, Moo<Expression>),
223
224    #[compatible(JsonInput, SMT)]
225    Neq(Metadata, Moo<Expression>, Moo<Expression>),
226
227    #[compatible(JsonInput, SMT)]
228    Geq(Metadata, Moo<Expression>, Moo<Expression>),
229
230    #[compatible(JsonInput, SMT)]
231    Leq(Metadata, Moo<Expression>, Moo<Expression>),
232
233    #[compatible(JsonInput, SMT)]
234    Gt(Metadata, Moo<Expression>, Moo<Expression>),
235
236    #[compatible(JsonInput, SMT)]
237    Lt(Metadata, Moo<Expression>, Moo<Expression>),
238
239    /// Division after preventing division by zero, usually with a bubble
240    #[compatible(SMT)]
241    SafeDiv(Metadata, Moo<Expression>, Moo<Expression>),
242
243    /// Division with a possibly undefined value (division by 0)
244    #[compatible(JsonInput)]
245    UnsafeDiv(Metadata, Moo<Expression>, Moo<Expression>),
246
247    /// Modulo after preventing mod 0, usually with a bubble
248    #[compatible(SMT)]
249    SafeMod(Metadata, Moo<Expression>, Moo<Expression>),
250
251    /// Modulo with a possibly undefined value (mod 0)
252    #[compatible(JsonInput)]
253    UnsafeMod(Metadata, Moo<Expression>, Moo<Expression>),
254
255    /// Negation: `-x`
256    #[compatible(JsonInput, SMT)]
257    Neg(Metadata, Moo<Expression>),
258
259    /// Set of domain values function is defined for
260    #[compatible(JsonInput)]
261    Defined(Metadata, Moo<Expression>),
262
263    /// Set of codomain values function is defined for
264    #[compatible(JsonInput)]
265    Range(Metadata, Moo<Expression>),
266
267    /// Unsafe power`x**y` (possibly undefined)
268    ///
269    /// Defined when (X!=0 \\/ Y!=0) /\ Y>=0
270    #[compatible(JsonInput)]
271    UnsafePow(Metadata, Moo<Expression>, Moo<Expression>),
272
273    /// `UnsafePow` after preventing undefinedness
274    SafePow(Metadata, Moo<Expression>, Moo<Expression>),
275
276    /// Flatten matrix operator
277    /// `flatten(M)` or `flatten(n, M)`
278    /// where M is a matrix and n is an optional integer argument indicating depth of flattening
279    Flatten(Metadata, Option<Moo<Expression>>, Moo<Expression>),
280
281    /// `allDiff(<vec_expr>)`
282    #[compatible(JsonInput)]
283    AllDiff(Metadata, Moo<Expression>),
284
285    /// Binary subtraction operator
286    ///
287    /// This is a parser-level construct, and is immediately normalised to `Sum([a,-b])`.
288    /// TODO: make this compatible with Set Difference calculations - need to change return type and domain for this expression and write a set comprehension rule.
289    /// have already edited minus_to_sum to prevent this from applying to sets
290    #[compatible(JsonInput)]
291    Minus(Metadata, Moo<Expression>, Moo<Expression>),
292
293    /// Ensures that x=|y| i.e. x is the absolute value of y.
294    ///
295    /// Low-level Minion constraint.
296    ///
297    /// # See also
298    ///
299    /// + [Minion documentation](https://minion-solver.readthedocs.io/en/stable/usage/constraints.html#abs)
300    #[compatible(Minion)]
301    FlatAbsEq(Metadata, Moo<Atom>, Moo<Atom>),
302
303    /// Ensures that `alldiff([a,b,...])`.
304    ///
305    /// Low-level Minion constraint.
306    ///
307    /// # See also
308    ///
309    /// + [Minion documentation](https://minion-solver.readthedocs.io/en/stable/usage/constraints.html#alldiff)
310    #[compatible(Minion)]
311    FlatAllDiff(Metadata, Vec<Atom>),
312
313    /// Ensures that sum(vec) >= x.
314    ///
315    /// Low-level Minion constraint.
316    ///
317    /// # See also
318    ///
319    /// + [Minion documentation](https://minion-solver.readthedocs.io/en/stable/usage/constraints.html#sumgeq)
320    #[compatible(Minion)]
321    FlatSumGeq(Metadata, Vec<Atom>, Atom),
322
323    /// Ensures that sum(vec) <= x.
324    ///
325    /// Low-level Minion constraint.
326    ///
327    /// # See also
328    ///
329    /// + [Minion documentation](https://minion-solver.readthedocs.io/en/stable/usage/constraints.html#sumleq)
330    #[compatible(Minion)]
331    FlatSumLeq(Metadata, Vec<Atom>, Atom),
332
333    /// `ineq(x,y,k)` ensures that x <= y + k.
334    ///
335    /// Low-level Minion constraint.
336    ///
337    /// # See also
338    ///
339    /// + [Minion documentation](https://minion-solver.readthedocs.io/en/stable/usage/constraints.html#ineq)
340    #[compatible(Minion)]
341    FlatIneq(Metadata, Moo<Atom>, Moo<Atom>, Box<Literal>),
342
343    /// `w-literal(x,k)` ensures that x == k, where x is a variable and k a constant.
344    ///
345    /// Low-level Minion constraint.
346    ///
347    /// This is a low-level Minion constraint and you should probably use Eq instead. The main use
348    /// of w-literal is to convert boolean variables to constraints so that they can be used inside
349    /// watched-and and watched-or.
350    ///
351    /// # See also
352    ///
353    /// + [Minion documentation](https://minion-solver.readthedocs.io/en/stable/usage/constraints.html#minuseq)
354    /// + `rules::minion::boolean_literal_to_wliteral`.
355    #[compatible(Minion)]
356    #[polyquine_skip]
357    FlatWatchedLiteral(Metadata, Reference, Literal),
358
359    /// `weightedsumleq(cs,xs,total)` ensures that cs.xs <= total, where cs.xs is the scalar dot
360    /// product of cs and xs.
361    ///
362    /// Low-level Minion constraint.
363    ///
364    /// Represents a weighted sum of the form `ax + by + cz + ...`
365    ///
366    /// # See also
367    ///
368    /// + [Minion
369    /// documentation](https://minion-solver.readthedocs.io/en/stable/usage/constraints.html#weightedsumleq)
370    FlatWeightedSumLeq(Metadata, Vec<Literal>, Vec<Atom>, Moo<Atom>),
371
372    /// `weightedsumgeq(cs,xs,total)` ensures that cs.xs >= total, where cs.xs is the scalar dot
373    /// product of cs and xs.
374    ///
375    /// Low-level Minion constraint.
376    ///
377    /// Represents a weighted sum of the form `ax + by + cz + ...`
378    ///
379    /// # See also
380    ///
381    /// + [Minion
382    /// documentation](https://minion-solver.readthedocs.io/en/stable/usage/constraints.html#weightedsumleq)
383    FlatWeightedSumGeq(Metadata, Vec<Literal>, Vec<Atom>, Moo<Atom>),
384
385    /// Ensures that x =-y, where x and y are atoms.
386    ///
387    /// Low-level Minion constraint.
388    ///
389    /// # See also
390    ///
391    /// + [Minion documentation](https://minion-solver.readthedocs.io/en/stable/usage/constraints.html#minuseq)
392    #[compatible(Minion)]
393    FlatMinusEq(Metadata, Moo<Atom>, Moo<Atom>),
394
395    /// Ensures that x*y=z.
396    ///
397    /// Low-level Minion constraint.
398    ///
399    /// # See also
400    ///
401    /// + [Minion documentation](https://minion-solver.readthedocs.io/en/stable/usage/constraints.html#product)
402    #[compatible(Minion)]
403    FlatProductEq(Metadata, Moo<Atom>, Moo<Atom>, Moo<Atom>),
404
405    /// Ensures that floor(x/y)=z. Always true when y=0.
406    ///
407    /// Low-level Minion constraint.
408    ///
409    /// # See also
410    ///
411    /// + [Minion documentation](https://minion-solver.readthedocs.io/en/stable/usage/constraints.html#div_undefzero)
412    #[compatible(Minion)]
413    MinionDivEqUndefZero(Metadata, Moo<Atom>, Moo<Atom>, Moo<Atom>),
414
415    /// Ensures that x%y=z. Always true when y=0.
416    ///
417    /// Low-level Minion constraint.
418    ///
419    /// # See also
420    ///
421    /// + [Minion documentation](https://minion-solver.readthedocs.io/en/stable/usage/constraints.html#mod_undefzero)
422    #[compatible(Minion)]
423    MinionModuloEqUndefZero(Metadata, Moo<Atom>, Moo<Atom>, Moo<Atom>),
424
425    /// Ensures that `x**y = z`.
426    ///
427    /// Low-level Minion constraint.
428    ///
429    /// This constraint is false when `y<0` except for `1**y=1` and `(-1)**y=z` (where z is 1 if y
430    /// is odd and z is -1 if y is even).
431    ///
432    /// # See also
433    ///
434    /// + [Github comment about `pow` semantics](https://github.com/minion/minion/issues/40#issuecomment-2595914891)
435    /// + [Minion documentation](https://minion-solver.readthedocs.io/en/stable/usage/constraints.html#pow)
436    MinionPow(Metadata, Moo<Atom>, Moo<Atom>, Moo<Atom>),
437
438    /// `reify(constraint,r)` ensures that r=1 iff `constraint` is satisfied, where r is a 0/1
439    /// variable.
440    ///
441    /// Low-level Minion constraint.
442    ///
443    /// # See also
444    ///
445    ///  + [Minion documentation](https://minion-solver.readthedocs.io/en/stable/usage/constraints.html#reify)
446    #[compatible(Minion)]
447    MinionReify(Metadata, Moo<Expression>, Atom),
448
449    /// `reifyimply(constraint,r)` ensures that `r->constraint`, where r is a 0/1 variable.
450    /// variable.
451    ///
452    /// Low-level Minion constraint.
453    ///
454    /// # See also
455    ///
456    ///  + [Minion documentation](https://minion-solver.readthedocs.io/en/stable/usage/constraints.html#reifyimply)
457    #[compatible(Minion)]
458    MinionReifyImply(Metadata, Moo<Expression>, Atom),
459
460    /// `w-inintervalset(x, [a1,a2, b1,b2, … ])` ensures that the value of x belongs to one of the
461    /// intervals {a1,…,a2}, {b1,…,b2} etc.
462    ///
463    /// The list of intervals must be given in numerical order.
464    ///
465    /// Low-level Minion constraint.
466    ///
467    /// # See also
468    ///>
469    ///  + [Minion documentation](https://minion-solver.readthedocs.io/en/stable/usage/constraints.html#w-inintervalset)
470    #[compatible(Minion)]
471    MinionWInIntervalSet(Metadata, Atom, Vec<i32>),
472
473    /// `w-inset(x, [v1, v2, … ])` ensures that the value of `x` is one of the explicitly given values `v1`, `v2`, etc.
474    ///
475    /// This constraint enforces membership in a specific set of discrete values rather than intervals.
476    ///
477    /// The list of values must be given in numerical order.
478    ///
479    /// Low-level Minion constraint.
480    ///
481    /// # See also
482    ///
483    ///  + [Minion documentation](https://minion-solver.readthedocs.io/en/stable/usage/constraints.html#w-inset)
484    #[compatible(Minion)]
485    MinionWInSet(Metadata, Atom, Vec<i32>),
486
487    /// `element_one(vec, i, e)` specifies that `vec[i] = e`. This implies that i is
488    /// in the range `[1..len(vec)]`.
489    ///
490    /// Low-level Minion constraint.
491    ///
492    /// # See also
493    ///
494    ///  + [Minion documentation](https://minion-solver.readthedocs.io/en/stable/usage/constraints.html#element_one)
495    #[compatible(Minion)]
496    MinionElementOne(Metadata, Vec<Atom>, Moo<Atom>, Moo<Atom>),
497
498    /// Declaration of an auxiliary variable.
499    ///
500    /// As with Savile Row, we semantically distinguish this from `Eq`.
501    #[compatible(Minion)]
502    #[polyquine_skip]
503    AuxDeclaration(Metadata, Reference, Moo<Expression>),
504
505    /// This expression is for encoding ints for the SAT solver, it stores the encoding type, the vector of booleans and the min/max for the int.
506    #[compatible(SAT)]
507    SATInt(Metadata, SATIntEncoding, Moo<Expression>, (i32, i32)),
508
509    /// Addition over a pair of expressions (i.e. a + b) rather than a vec-expr like Expression::Sum.
510    /// This is for compatibility with backends that do not support addition over vectors.
511    #[compatible(SMT)]
512    PairwiseSum(Metadata, Moo<Expression>, Moo<Expression>),
513
514    /// Multiplication over a pair of expressions (i.e. a * b) rather than a vec-expr like Expression::Product.
515    /// This is for compatibility with backends that do not support multiplication over vectors.
516    #[compatible(SMT)]
517    PairwiseProduct(Metadata, Moo<Expression>, Moo<Expression>),
518
519    #[compatible(JsonInput)]
520    Image(Metadata, Moo<Expression>, Moo<Expression>),
521
522    #[compatible(JsonInput)]
523    ImageSet(Metadata, Moo<Expression>, Moo<Expression>),
524
525    #[compatible(JsonInput)]
526    PreImage(Metadata, Moo<Expression>, Moo<Expression>),
527
528    #[compatible(JsonInput)]
529    Inverse(Metadata, Moo<Expression>, Moo<Expression>),
530
531    #[compatible(JsonInput)]
532    Restrict(Metadata, Moo<Expression>, Moo<Expression>),
533
534    /// Lexicographical < between two matrices.
535    ///
536    /// A <lex B iff: A[i] < B[i] for some i /\ (A[j] > B[j] for some j -> i < j)
537    /// I.e. A must be less than B at some index i, and if it is greater than B at another index j,
538    /// then j comes after i.
539    /// I.e. A must be greater than B at the first index where they differ.
540    ///
541    /// E.g. [1, 1] <lex [2, 1] and [1, 1] <lex [1, 2]
542    LexLt(Metadata, Moo<Expression>, Moo<Expression>),
543
544    /// Lexicographical <= between two matrices
545    LexLeq(Metadata, Moo<Expression>, Moo<Expression>),
546
547    /// Lexicographical > between two matrices
548    /// This is a parser-level construct, and is immediately normalised to LexLt(b, a)
549    LexGt(Metadata, Moo<Expression>, Moo<Expression>),
550
551    /// Lexicographical >= between two matrices
552    /// This is a parser-level construct, and is immediately normalised to LexLeq(b, a)
553    LexGeq(Metadata, Moo<Expression>, Moo<Expression>),
554
555    /// Low-level minion constraint. See Expression::LexLt
556    FlatLexLt(Metadata, Vec<Atom>, Vec<Atom>),
557
558    /// Low-level minion constraint. See Expression::LexLeq
559    FlatLexLeq(Metadata, Vec<Atom>, Vec<Atom>),
560}
561
562// for the given matrix literal, return a bounded domain from the min to max of applying op to each
563// child expression.
564//
565// Op must be monotonic.
566//
567// Returns none if unbounded
568fn bounded_i32_domain_for_matrix_literal_monotonic(
569    e: &Expression,
570    op: fn(i32, i32) -> Option<i32>,
571) -> Option<DomainPtr> {
572    // only care about the elements, not the indices
573    let (mut exprs, _) = e.clone().unwrap_matrix_unchecked()?;
574
575    // fold each element's domain into one using op.
576    //
577    // here, I assume that op is monotone. This means that the bounds of op([a1,a2],[b1,b2])  for
578    // the ranges [a1,a2], [b1,b2] will be
579    // [min(op(a1,b1),op(a2,b1),op(a1,b2),op(a2,b2)),max(op(a1,b1),op(a2,b1),op(a1,b2),op(a2,b2))].
580    //
581    // We used to not assume this, and work out the bounds by applying op on the Cartesian product
582    // of A and B; however, this caused a combinatorial explosion and my computer to run out of
583    // memory (on the hakank_eprime_xkcd test)...
584    //Int
585    // For example, to find the bounds of the intervals [1,4], [1,5] combined using op, we used to do
586    //  [min(op(1,1), op(1,2),op(1,3),op(1,4),op(1,5),op(2,1)..
587    //
588    // +,-,/,* are all monotone, so this assumption should be fine for now...
589
590    let expr = exprs.pop()?;
591    let dom = expr.domain_of()?;
592    let Some(GroundDomain::Int(ranges)) = dom.as_ground() else {
593        return None;
594    };
595
596    let (mut current_min, mut current_max) = range_vec_bounds_i32(ranges)?;
597
598    for expr in exprs {
599        let dom = expr.domain_of()?;
600        let Some(GroundDomain::Int(ranges)) = dom.as_ground() else {
601            return None;
602        };
603
604        let (min, max) = range_vec_bounds_i32(ranges)?;
605
606        // all the possible new values for current_min / current_max
607        let minmax = op(min, current_max)?;
608        let minmin = op(min, current_min)?;
609        let maxmin = op(max, current_min)?;
610        let maxmax = op(max, current_max)?;
611        let vals = [minmax, minmin, maxmin, maxmax];
612
613        current_min = *vals
614            .iter()
615            .min()
616            .expect("vals iterator should not be empty, and should have a minimum.");
617        current_max = *vals
618            .iter()
619            .max()
620            .expect("vals iterator should not be empty, and should have a maximum.");
621    }
622
623    if current_min == current_max {
624        Some(Domain::int(vec![Range::Single(current_min)]))
625    } else {
626        Some(Domain::int(vec![Range::Bounded(current_min, current_max)]))
627    }
628}
629
630// Returns none if unbounded
631fn range_vec_bounds_i32(ranges: &Vec<Range<i32>>) -> Option<(i32, i32)> {
632    let mut min = i32::MAX;
633    let mut max = i32::MIN;
634    for r in ranges {
635        match r {
636            Range::Single(i) => {
637                if *i < min {
638                    min = *i;
639                }
640                if *i > max {
641                    max = *i;
642                }
643            }
644            Range::Bounded(i, j) => {
645                if *i < min {
646                    min = *i;
647                }
648                if *j > max {
649                    max = *j;
650                }
651            }
652            Range::UnboundedR(_) | Range::UnboundedL(_) | Range::Unbounded => return None,
653        }
654    }
655    Some((min, max))
656}
657
658impl Expression {
659    /// Returns the possible values of the expression, recursing to leaf expressions
660    pub fn domain_of(&self) -> Option<DomainPtr> {
661        match self {
662            Expression::Union(_, a, b) => Some(Domain::set(
663                SetAttr::<IntVal>::default(),
664                a.domain_of()?.union(&b.domain_of()?).ok()?,
665            )),
666            Expression::Intersect(_, a, b) => Some(Domain::set(
667                SetAttr::<IntVal>::default(),
668                a.domain_of()?.intersect(&b.domain_of()?).ok()?,
669            )),
670            Expression::In(_, _, _) => Some(Domain::bool()),
671            Expression::Supset(_, _, _) => Some(Domain::bool()),
672            Expression::SupsetEq(_, _, _) => Some(Domain::bool()),
673            Expression::Subset(_, _, _) => Some(Domain::bool()),
674            Expression::SubsetEq(_, _, _) => Some(Domain::bool()),
675            Expression::AbstractLiteral(_, abslit) => abslit.domain_of(),
676            Expression::DominanceRelation(_, _) => Some(Domain::bool()),
677            Expression::FromSolution(_, expr) => Some(expr.domain_of()),
678            Expression::Metavar(_, _) => None,
679            Expression::Comprehension(_, comprehension) => comprehension.domain_of(),
680            Expression::AbstractComprehension(_, comprehension) => comprehension.domain_of(),
681            Expression::UnsafeIndex(_, matrix, _) | Expression::SafeIndex(_, matrix, _) => {
682                let dom = matrix.domain_of()?;
683                if let Some((elem_domain, _)) = dom.as_matrix() {
684                    return Some(elem_domain);
685                }
686
687                // may actually use the value in the future
688                #[allow(clippy::redundant_pattern_matching)]
689                if let Some(_) = dom.as_tuple() {
690                    // TODO: We can implement proper indexing for tuples
691                    return None;
692                }
693
694                // may actually use the value in the future
695                #[allow(clippy::redundant_pattern_matching)]
696                if let Some(_) = dom.as_record() {
697                    // TODO: We can implement proper indexing for records
698                    return None;
699                }
700
701                bug!("subject of an index operation should support indexing")
702            }
703            Expression::UnsafeSlice(_, matrix, indices)
704            | Expression::SafeSlice(_, matrix, indices) => {
705                let sliced_dimension = indices.iter().position(Option::is_none);
706
707                let dom = matrix.domain_of()?;
708                let Some((elem_domain, index_domains)) = dom.as_matrix() else {
709                    bug!("subject of an index operation should be a matrix");
710                };
711
712                match sliced_dimension {
713                    Some(dimension) => Some(Domain::matrix(
714                        elem_domain,
715                        vec![index_domains[dimension].clone()],
716                    )),
717
718                    // same as index
719                    None => Some(elem_domain),
720                }
721            }
722            Expression::InDomain(_, _, _) => Some(Domain::bool()),
723            Expression::Atomic(_, atom) => Some(atom.domain_of()),
724            Expression::Scope(_, _) => Some(Domain::bool()),
725            Expression::Sum(_, e) => {
726                bounded_i32_domain_for_matrix_literal_monotonic(e, |x, y| Some(x + y))
727            }
728            Expression::Product(_, e) => {
729                bounded_i32_domain_for_matrix_literal_monotonic(e, |x, y| Some(x * y))
730            }
731            Expression::Min(_, e) => bounded_i32_domain_for_matrix_literal_monotonic(e, |x, y| {
732                Some(if x < y { x } else { y })
733            }),
734            Expression::Max(_, e) => bounded_i32_domain_for_matrix_literal_monotonic(e, |x, y| {
735                Some(if x > y { x } else { y })
736            }),
737            Expression::UnsafeDiv(_, a, b) => a
738                .domain_of()?
739                .resolve()?
740                .apply_i32(
741                    // rust integer division is truncating; however, we want to always round down,
742                    // including for negative numbers.
743                    |x, y| {
744                        if y != 0 {
745                            Some((x as f32 / y as f32).floor() as i32)
746                        } else {
747                            None
748                        }
749                    },
750                    b.domain_of()?.resolve()?.as_ref(),
751                )
752                .map(DomainPtr::from)
753                .ok(),
754            Expression::SafeDiv(_, a, b) => {
755                // rust integer division is truncating; however, we want to always round down
756                // including for negative numbers.
757                let domain = a
758                    .domain_of()?
759                    .resolve()?
760                    .apply_i32(
761                        |x, y| {
762                            if y != 0 {
763                                Some((x as f32 / y as f32).floor() as i32)
764                            } else {
765                                None
766                            }
767                        },
768                        b.domain_of()?.resolve()?.as_ref(),
769                    )
770                    .unwrap_or_else(|err| bug!("Got {err} when computing domain of {self}"));
771
772                if let GroundDomain::Int(ranges) = domain {
773                    let mut ranges = ranges;
774                    ranges.push(Range::Single(0));
775                    Some(Domain::int(ranges))
776                } else {
777                    bug!("Domain of {self} was not integer")
778                }
779            }
780            Expression::UnsafeMod(_, a, b) => a
781                .domain_of()?
782                .resolve()?
783                .apply_i32(
784                    |x, y| if y != 0 { Some(x % y) } else { None },
785                    b.domain_of()?.resolve()?.as_ref(),
786                )
787                .map(DomainPtr::from)
788                .ok(),
789            Expression::SafeMod(_, a, b) => {
790                let domain = a
791                    .domain_of()?
792                    .resolve()?
793                    .apply_i32(
794                        |x, y| if y != 0 { Some(x % y) } else { None },
795                        b.domain_of()?.resolve()?.as_ref(),
796                    )
797                    .unwrap_or_else(|err| bug!("Got {err} when computing domain of {self}"));
798
799                if let GroundDomain::Int(ranges) = domain {
800                    let mut ranges = ranges;
801                    ranges.push(Range::Single(0));
802                    Some(Domain::int(ranges))
803                } else {
804                    bug!("Domain of {self} was not integer")
805                }
806            }
807            Expression::SafePow(_, a, b) | Expression::UnsafePow(_, a, b) => a
808                .domain_of()?
809                .resolve()?
810                .apply_i32(
811                    |x, y| {
812                        if (x != 0 || y != 0) && y >= 0 {
813                            Some(x.pow(y as u32))
814                        } else {
815                            None
816                        }
817                    },
818                    b.domain_of()?.resolve()?.as_ref(),
819                )
820                .map(DomainPtr::from)
821                .ok(),
822            Expression::Root(_, _) => None,
823            Expression::Bubble(_, inner, _) => inner.domain_of(),
824            Expression::AuxDeclaration(_, _, _) => Some(Domain::bool()),
825            Expression::And(_, _) => Some(Domain::bool()),
826            Expression::Not(_, _) => Some(Domain::bool()),
827            Expression::Or(_, _) => Some(Domain::bool()),
828            Expression::Imply(_, _, _) => Some(Domain::bool()),
829            Expression::Iff(_, _, _) => Some(Domain::bool()),
830            Expression::Eq(_, _, _) => Some(Domain::bool()),
831            Expression::Neq(_, _, _) => Some(Domain::bool()),
832            Expression::Geq(_, _, _) => Some(Domain::bool()),
833            Expression::Leq(_, _, _) => Some(Domain::bool()),
834            Expression::Gt(_, _, _) => Some(Domain::bool()),
835            Expression::Lt(_, _, _) => Some(Domain::bool()),
836            Expression::FlatAbsEq(_, _, _) => Some(Domain::bool()),
837            Expression::FlatSumGeq(_, _, _) => Some(Domain::bool()),
838            Expression::FlatSumLeq(_, _, _) => Some(Domain::bool()),
839            Expression::MinionDivEqUndefZero(_, _, _, _) => Some(Domain::bool()),
840            Expression::MinionModuloEqUndefZero(_, _, _, _) => Some(Domain::bool()),
841            Expression::FlatIneq(_, _, _, _) => Some(Domain::bool()),
842            Expression::Flatten(_, n, m) => {
843                if let Some(expr) = n {
844                    if expr.return_type() == ReturnType::Int {
845                        // TODO: handle flatten with depth argument
846                        return None;
847                    }
848                } else {
849                    // TODO: currently only works for matrices
850                    let dom = m.domain_of()?.resolve()?;
851                    let (val_dom, idx_doms) = match dom.as_ref() {
852                        GroundDomain::Matrix(val, idx) => (val, idx),
853                        _ => return None,
854                    };
855                    let num_elems = matrix::num_elements(idx_doms).ok()? as i32;
856
857                    let new_index_domain = Domain::int(vec![Range::Bounded(1, num_elems)]);
858                    return Some(Domain::matrix(
859                        val_dom.clone().into(),
860                        vec![new_index_domain],
861                    ));
862                }
863                None
864            }
865            Expression::AllDiff(_, _) => Some(Domain::bool()),
866            Expression::FlatWatchedLiteral(_, _, _) => Some(Domain::bool()),
867            Expression::MinionReify(_, _, _) => Some(Domain::bool()),
868            Expression::MinionReifyImply(_, _, _) => Some(Domain::bool()),
869            Expression::MinionWInIntervalSet(_, _, _) => Some(Domain::bool()),
870            Expression::MinionWInSet(_, _, _) => Some(Domain::bool()),
871            Expression::MinionElementOne(_, _, _, _) => Some(Domain::bool()),
872            Expression::Neg(_, x) => {
873                let dom = x.domain_of()?;
874                let mut ranges = dom.as_int()?;
875
876                ranges = ranges
877                    .into_iter()
878                    .map(|r| match r {
879                        Range::Single(x) => Range::Single(-x),
880                        Range::Bounded(x, y) => Range::Bounded(-y, -x),
881                        Range::UnboundedR(i) => Range::UnboundedL(-i),
882                        Range::UnboundedL(i) => Range::UnboundedR(-i),
883                        Range::Unbounded => Range::Unbounded,
884                    })
885                    .collect();
886
887                Some(Domain::int(ranges))
888            }
889            Expression::Minus(_, a, b) => a
890                .domain_of()?
891                .resolve()?
892                .apply_i32(|x, y| Some(x - y), b.domain_of()?.resolve()?.as_ref())
893                .map(DomainPtr::from)
894                .ok(),
895            Expression::FlatAllDiff(_, _) => Some(Domain::bool()),
896            Expression::FlatMinusEq(_, _, _) => Some(Domain::bool()),
897            Expression::FlatProductEq(_, _, _, _) => Some(Domain::bool()),
898            Expression::FlatWeightedSumLeq(_, _, _, _) => Some(Domain::bool()),
899            Expression::FlatWeightedSumGeq(_, _, _, _) => Some(Domain::bool()),
900            Expression::Abs(_, a) => a
901                .domain_of()?
902                .resolve()?
903                .apply_i32(|a, _| Some(a.abs()), a.domain_of()?.resolve()?.as_ref())
904                .map(DomainPtr::from)
905                .ok(),
906            Expression::MinionPow(_, _, _, _) => Some(Domain::bool()),
907            Expression::ToInt(_, _) => Some(Domain::int(vec![Range::Bounded(0, 1)])),
908            Expression::SATInt(_, _, _, (low, high)) => {
909                Some(Domain::int_ground(vec![Range::Bounded(*low, *high)]))
910            }
911            Expression::PairwiseSum(_, a, b) => a
912                .domain_of()?
913                .resolve()?
914                .apply_i32(|a, b| Some(a + b), b.domain_of()?.resolve()?.as_ref())
915                .map(DomainPtr::from)
916                .ok(),
917            Expression::PairwiseProduct(_, a, b) => a
918                .domain_of()?
919                .resolve()?
920                .apply_i32(|a, b| Some(a * b), b.domain_of()?.resolve()?.as_ref())
921                .map(DomainPtr::from)
922                .ok(),
923            Expression::Defined(_, function) => get_function_domain(function),
924            Expression::Range(_, function) => get_function_codomain(function),
925            Expression::Image(_, function, _) => get_function_codomain(function),
926            Expression::ImageSet(_, function, _) => get_function_codomain(function),
927            Expression::PreImage(_, function, _) => get_function_domain(function),
928            Expression::Restrict(_, function, new_domain) => {
929                let (attrs, _, codom) = function.domain_of()?.as_function()?;
930                let new_dom = new_domain.domain_of()?;
931                Some(Domain::function(attrs, new_dom, codom))
932            }
933            Expression::Inverse(..) => Some(Domain::bool()),
934            Expression::LexLt(..) => Some(Domain::bool()),
935            Expression::LexLeq(..) => Some(Domain::bool()),
936            Expression::LexGt(..) => Some(Domain::bool()),
937            Expression::LexGeq(..) => Some(Domain::bool()),
938            Expression::FlatLexLt(..) => Some(Domain::bool()),
939            Expression::FlatLexLeq(..) => Some(Domain::bool()),
940        }
941    }
942
943    pub fn get_meta(&self) -> Metadata {
944        let metas: VecDeque<Metadata> = self.children_bi();
945        metas[0].clone()
946    }
947
948    pub fn set_meta(&self, meta: Metadata) {
949        self.transform_bi(&|_| meta.clone());
950    }
951
952    /// Checks whether this expression is safe.
953    ///
954    /// An expression is unsafe if can be undefined, or if any of its children can be undefined.
955    ///
956    /// Unsafe expressions are (typically) prefixed with Unsafe in our AST, and can be made
957    /// safe through the use of bubble rules.
958    pub fn is_safe(&self) -> bool {
959        // TODO: memoise in Metadata
960        for expr in self.universe() {
961            match expr {
962                Expression::UnsafeDiv(_, _, _)
963                | Expression::UnsafeMod(_, _, _)
964                | Expression::UnsafePow(_, _, _)
965                | Expression::UnsafeIndex(_, _, _)
966                | Expression::Bubble(_, _, _)
967                | Expression::UnsafeSlice(_, _, _) => {
968                    return false;
969                }
970                _ => {}
971            }
972        }
973        true
974    }
975
976    pub fn is_clean(&self) -> bool {
977        let metadata = self.get_meta();
978        metadata.clean
979    }
980
981    pub fn set_clean(&mut self, bool_value: bool) {
982        let mut metadata = self.get_meta();
983        metadata.clean = bool_value;
984        self.set_meta(metadata);
985    }
986
987    /// True if the expression is an associative and commutative operator
988    pub fn is_associative_commutative_operator(&self) -> bool {
989        TryInto::<ACOperatorKind>::try_into(self).is_ok()
990    }
991
992    /// True if the expression is a matrix literal.
993    ///
994    /// This is true for both forms of matrix literals: those with elements of type [`Literal`] and
995    /// [`Expression`].
996    pub fn is_matrix_literal(&self) -> bool {
997        matches!(
998            self,
999            Expression::AbstractLiteral(_, AbstractLiteral::Matrix(_, _))
1000                | Expression::Atomic(
1001                    _,
1002                    Atom::Literal(Literal::AbstractLiteral(AbstractLiteral::Matrix(_, _))),
1003                )
1004        )
1005    }
1006
1007    /// True iff self and other are both atomic and identical.
1008    ///
1009    /// This method is useful to cheaply check equivalence. Assuming CSE is enabled, any unifiable
1010    /// expressions will be rewritten to a common variable. This is much cheaper than checking the
1011    /// entire subtrees of `self` and `other`.
1012    pub fn identical_atom_to(&self, other: &Expression) -> bool {
1013        let atom1: Result<&Atom, _> = self.try_into();
1014        let atom2: Result<&Atom, _> = other.try_into();
1015
1016        if let (Ok(atom1), Ok(atom2)) = (atom1, atom2) {
1017            atom2 == atom1
1018        } else {
1019            false
1020        }
1021    }
1022
1023    /// If the expression is a list, returns a *copied* vector of the inner expressions.
1024    ///
1025    /// A list is any a matrix with the domain `int(1..)`. This includes matrix literals without
1026    /// any explicitly specified domain.
1027    pub fn unwrap_list(&self) -> Option<Vec<Expression>> {
1028        match self {
1029            Expression::AbstractLiteral(_, matrix @ AbstractLiteral::Matrix(_, _)) => {
1030                matrix.unwrap_list().cloned()
1031            }
1032            Expression::Atomic(
1033                _,
1034                Atom::Literal(Literal::AbstractLiteral(matrix @ AbstractLiteral::Matrix(_, _))),
1035            ) => matrix.unwrap_list().map(|elems| {
1036                elems
1037                    .clone()
1038                    .into_iter()
1039                    .map(|x: Literal| Expression::Atomic(Metadata::new(), Atom::Literal(x)))
1040                    .collect_vec()
1041            }),
1042            _ => None,
1043        }
1044    }
1045
1046    /// If the expression is a matrix, gets it elements and index domain.
1047    ///
1048    /// **Consider using the safer [`Expression::unwrap_list`] instead.**
1049    ///
1050    /// It is generally undefined to edit the length of a matrix unless it is a list (as defined by
1051    /// [`Expression::unwrap_list`]). Users of this function should ensure that, if the matrix is
1052    /// reconstructed, the index domain and the number of elements in the matrix remain the same.
1053    pub fn unwrap_matrix_unchecked(self) -> Option<(Vec<Expression>, DomainPtr)> {
1054        match self {
1055            Expression::AbstractLiteral(_, AbstractLiteral::Matrix(elems, domain)) => {
1056                Some((elems, domain))
1057            }
1058            Expression::Atomic(
1059                _,
1060                Atom::Literal(Literal::AbstractLiteral(AbstractLiteral::Matrix(elems, domain))),
1061            ) => Some((
1062                elems
1063                    .into_iter()
1064                    .map(|x: Literal| Expression::Atomic(Metadata::new(), Atom::Literal(x)))
1065                    .collect_vec(),
1066                domain.into(),
1067            )),
1068
1069            _ => None,
1070        }
1071    }
1072
1073    /// For a Root expression, extends the inner vec with the given vec.
1074    ///
1075    /// # Panics
1076    /// Panics if the expression is not Root.
1077    pub fn extend_root(self, exprs: Vec<Expression>) -> Expression {
1078        match self {
1079            Expression::Root(meta, mut children) => {
1080                children.extend(exprs);
1081                Expression::Root(meta, children)
1082            }
1083            _ => panic!("extend_root called on a non-Root expression"),
1084        }
1085    }
1086
1087    /// Converts the expression to a literal, if possible.
1088    pub fn into_literal(self) -> Option<Literal> {
1089        match self {
1090            Expression::Atomic(_, Atom::Literal(lit)) => Some(lit),
1091            Expression::AbstractLiteral(_, abslit) => {
1092                Some(Literal::AbstractLiteral(abslit.into_literals()?))
1093            }
1094            Expression::Neg(_, e) => {
1095                let Literal::Int(i) = Moo::unwrap_or_clone(e).into_literal()? else {
1096                    bug!("negated literal should be an int");
1097                };
1098
1099                Some(Literal::Int(-i))
1100            }
1101
1102            _ => None,
1103        }
1104    }
1105
1106    /// If this expression is an associative-commutative operator, return its [ACOperatorKind].
1107    pub fn to_ac_operator_kind(&self) -> Option<ACOperatorKind> {
1108        TryFrom::try_from(self).ok()
1109    }
1110
1111    /// Returns the categories of all sub-expressions of self.
1112    pub fn universe_categories(&self) -> HashSet<Category> {
1113        self.universe()
1114            .into_iter()
1115            .map(|x| x.category_of())
1116            .collect()
1117    }
1118}
1119
1120pub fn get_function_domain(function: &Moo<Expression>) -> Option<DomainPtr> {
1121    let function_domain = function.domain_of()?;
1122    match function_domain.resolve().as_ref() {
1123        Some(d) => {
1124            match d.as_ref() {
1125                GroundDomain::Function(_, domain, _) => Some(domain.clone().into()),
1126                // Not defined for anything other than a function
1127                _ => None,
1128            }
1129        }
1130        None => {
1131            match function_domain.as_unresolved()? {
1132                UnresolvedDomain::Function(_, domain, _) => Some(domain.clone()),
1133                // Not defined for anything other than a function
1134                _ => None,
1135            }
1136        }
1137    }
1138}
1139
1140pub fn get_function_codomain(function: &Moo<Expression>) -> Option<DomainPtr> {
1141    let function_domain = function.domain_of()?;
1142    match function_domain.resolve().as_ref() {
1143        Some(d) => {
1144            match d.as_ref() {
1145                GroundDomain::Function(_, _, codomain) => Some(codomain.clone().into()),
1146                // Not defined for anything other than a function
1147                _ => None,
1148            }
1149        }
1150        None => {
1151            match function_domain.as_unresolved()? {
1152                UnresolvedDomain::Function(_, _, codomain) => Some(codomain.clone()),
1153                // Not defined for anything other than a function
1154                _ => None,
1155            }
1156        }
1157    }
1158}
1159
1160impl TryFrom<&Expression> for i32 {
1161    type Error = ();
1162
1163    fn try_from(value: &Expression) -> Result<Self, Self::Error> {
1164        let Expression::Atomic(_, atom) = value else {
1165            return Err(());
1166        };
1167
1168        let Atom::Literal(lit) = atom else {
1169            return Err(());
1170        };
1171
1172        let Literal::Int(i) = lit else {
1173            return Err(());
1174        };
1175
1176        Ok(*i)
1177    }
1178}
1179
1180impl TryFrom<Expression> for i32 {
1181    type Error = ();
1182
1183    fn try_from(value: Expression) -> Result<Self, Self::Error> {
1184        TryFrom::<&Expression>::try_from(&value)
1185    }
1186}
1187impl From<i32> for Expression {
1188    fn from(i: i32) -> Self {
1189        Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Int(i)))
1190    }
1191}
1192
1193impl From<bool> for Expression {
1194    fn from(b: bool) -> Self {
1195        Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Bool(b)))
1196    }
1197}
1198
1199impl From<Atom> for Expression {
1200    fn from(value: Atom) -> Self {
1201        Expression::Atomic(Metadata::new(), value)
1202    }
1203}
1204
1205impl From<Literal> for Expression {
1206    fn from(value: Literal) -> Self {
1207        Expression::Atomic(Metadata::new(), value.into())
1208    }
1209}
1210
1211impl From<Moo<Expression>> for Expression {
1212    fn from(val: Moo<Expression>) -> Self {
1213        val.as_ref().clone()
1214    }
1215}
1216
1217impl CategoryOf for Expression {
1218    fn category_of(&self) -> Category {
1219        // take highest category of all the expressions children
1220        let category = self.cata(&move |x,children| {
1221
1222            if let Some(max_category) = children.iter().max() {
1223                // if this expression contains subexpressions, return the maximum category of the
1224                // subexpressions
1225                *max_category
1226            } else {
1227                // this expression has no children
1228                let mut max_category = Category::Bottom;
1229
1230                // calculate the category by looking at all atoms, submodels, comprehensions, and
1231                // declarationptrs inside this expression
1232
1233                // this should generically cover all leaf types we currently have in oxide.
1234
1235                // if x contains submodels (including comprehensions)
1236                if !Biplate::<SubModel>::universe_bi(&x).is_empty() {
1237                    // assume that the category is decision
1238                    return Category::Decision;
1239                }
1240
1241                // if x contains atoms
1242                if let Some(max_atom_category) = Biplate::<Atom>::universe_bi(&x).iter().map(|x| x.category_of()).max()
1243                // and those atoms have a higher category than we already know about
1244                && max_atom_category > max_category{
1245                    // update category 
1246                    max_category = max_atom_category;
1247                }
1248
1249                // if x contains declarationPtrs
1250                if let Some(max_declaration_category) = Biplate::<DeclarationPtr>::universe_bi(&x).iter().map(|x| x.category_of()).max()
1251                // and those pointers have a higher category than we already know about
1252                && max_declaration_category > max_category{
1253                    // update category 
1254                    max_category = max_declaration_category;
1255                }
1256                max_category
1257
1258            }
1259        });
1260
1261        if cfg!(debug_assertions) {
1262            trace!(
1263                category= %category,
1264                expression= %self,
1265                "Called Expression::category_of()"
1266            );
1267        };
1268        category
1269    }
1270}
1271
1272impl Display for Expression {
1273    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
1274        match &self {
1275            Expression::Union(_, box1, box2) => {
1276                write!(f, "({} union {})", box1.clone(), box2.clone())
1277            }
1278            Expression::In(_, e1, e2) => {
1279                write!(f, "{e1} in {e2}")
1280            }
1281            Expression::Intersect(_, box1, box2) => {
1282                write!(f, "({} intersect {})", box1.clone(), box2.clone())
1283            }
1284            Expression::Supset(_, box1, box2) => {
1285                write!(f, "({} supset {})", box1.clone(), box2.clone())
1286            }
1287            Expression::SupsetEq(_, box1, box2) => {
1288                write!(f, "({} supsetEq {})", box1.clone(), box2.clone())
1289            }
1290            Expression::Subset(_, box1, box2) => {
1291                write!(f, "({} subset {})", box1.clone(), box2.clone())
1292            }
1293            Expression::SubsetEq(_, box1, box2) => {
1294                write!(f, "({} subsetEq {})", box1.clone(), box2.clone())
1295            }
1296
1297            Expression::AbstractLiteral(_, l) => l.fmt(f),
1298            Expression::Comprehension(_, c) => c.fmt(f),
1299            Expression::AbstractComprehension(_, c) => c.fmt(f),
1300            Expression::UnsafeIndex(_, e1, e2) | Expression::SafeIndex(_, e1, e2) => {
1301                write!(f, "{e1}{}", pretty_vec(e2))
1302            }
1303            Expression::UnsafeSlice(_, e1, es) | Expression::SafeSlice(_, e1, es) => {
1304                let args = es
1305                    .iter()
1306                    .map(|x| match x {
1307                        Some(x) => format!("{x}"),
1308                        None => "..".into(),
1309                    })
1310                    .join(",");
1311
1312                write!(f, "{e1}[{args}]")
1313            }
1314            Expression::InDomain(_, e, domain) => {
1315                write!(f, "__inDomain({e},{domain})")
1316            }
1317            Expression::Root(_, exprs) => {
1318                write!(f, "{}", pretty_expressions_as_top_level(exprs))
1319            }
1320            Expression::DominanceRelation(_, expr) => write!(f, "DominanceRelation({expr})"),
1321            Expression::FromSolution(_, expr) => write!(f, "FromSolution({expr})"),
1322            Expression::Metavar(_, name) => write!(f, "&{name}"),
1323            Expression::Atomic(_, atom) => atom.fmt(f),
1324            Expression::Scope(_, submodel) => write!(f, "{{\n{submodel}\n}}"),
1325            Expression::Abs(_, a) => write!(f, "|{a}|"),
1326            Expression::Sum(_, e) => {
1327                write!(f, "sum({e})")
1328            }
1329            Expression::Product(_, e) => {
1330                write!(f, "product({e})")
1331            }
1332            Expression::Min(_, e) => {
1333                write!(f, "min({e})")
1334            }
1335            Expression::Max(_, e) => {
1336                write!(f, "max({e})")
1337            }
1338            Expression::Not(_, expr_box) => {
1339                write!(f, "!({})", expr_box.clone())
1340            }
1341            Expression::Or(_, e) => {
1342                write!(f, "or({e})")
1343            }
1344            Expression::And(_, e) => {
1345                write!(f, "and({e})")
1346            }
1347            Expression::Imply(_, box1, box2) => {
1348                write!(f, "({box1}) -> ({box2})")
1349            }
1350            Expression::Iff(_, box1, box2) => {
1351                write!(f, "({box1}) <-> ({box2})")
1352            }
1353            Expression::Eq(_, box1, box2) => {
1354                write!(f, "({} = {})", box1.clone(), box2.clone())
1355            }
1356            Expression::Neq(_, box1, box2) => {
1357                write!(f, "({} != {})", box1.clone(), box2.clone())
1358            }
1359            Expression::Geq(_, box1, box2) => {
1360                write!(f, "({} >= {})", box1.clone(), box2.clone())
1361            }
1362            Expression::Leq(_, box1, box2) => {
1363                write!(f, "({} <= {})", box1.clone(), box2.clone())
1364            }
1365            Expression::Gt(_, box1, box2) => {
1366                write!(f, "({} > {})", box1.clone(), box2.clone())
1367            }
1368            Expression::Lt(_, box1, box2) => {
1369                write!(f, "({} < {})", box1.clone(), box2.clone())
1370            }
1371            Expression::FlatSumGeq(_, box1, box2) => {
1372                write!(f, "SumGeq({}, {})", pretty_vec(box1), box2.clone())
1373            }
1374            Expression::FlatSumLeq(_, box1, box2) => {
1375                write!(f, "SumLeq({}, {})", pretty_vec(box1), box2.clone())
1376            }
1377            Expression::FlatIneq(_, box1, box2, box3) => write!(
1378                f,
1379                "Ineq({}, {}, {})",
1380                box1.clone(),
1381                box2.clone(),
1382                box3.clone()
1383            ),
1384            Expression::Flatten(_, n, m) => {
1385                if let Some(n) = n {
1386                    write!(f, "flatten({n}, {m})")
1387                } else {
1388                    write!(f, "flatten({m})")
1389                }
1390            }
1391            Expression::AllDiff(_, e) => {
1392                write!(f, "allDiff({e})")
1393            }
1394            Expression::Bubble(_, box1, box2) => {
1395                write!(f, "{{{} @ {}}}", box1.clone(), box2.clone())
1396            }
1397            Expression::SafeDiv(_, box1, box2) => {
1398                write!(f, "SafeDiv({}, {})", box1.clone(), box2.clone())
1399            }
1400            Expression::UnsafeDiv(_, box1, box2) => {
1401                write!(f, "UnsafeDiv({}, {})", box1.clone(), box2.clone())
1402            }
1403            Expression::UnsafePow(_, box1, box2) => {
1404                write!(f, "UnsafePow({}, {})", box1.clone(), box2.clone())
1405            }
1406            Expression::SafePow(_, box1, box2) => {
1407                write!(f, "SafePow({}, {})", box1.clone(), box2.clone())
1408            }
1409            Expression::MinionDivEqUndefZero(_, box1, box2, box3) => {
1410                write!(
1411                    f,
1412                    "DivEq({}, {}, {})",
1413                    box1.clone(),
1414                    box2.clone(),
1415                    box3.clone()
1416                )
1417            }
1418            Expression::MinionModuloEqUndefZero(_, box1, box2, box3) => {
1419                write!(
1420                    f,
1421                    "ModEq({}, {}, {})",
1422                    box1.clone(),
1423                    box2.clone(),
1424                    box3.clone()
1425                )
1426            }
1427            Expression::FlatWatchedLiteral(_, x, l) => {
1428                write!(f, "WatchedLiteral({x},{l})")
1429            }
1430            Expression::MinionReify(_, box1, box2) => {
1431                write!(f, "Reify({}, {})", box1.clone(), box2.clone())
1432            }
1433            Expression::MinionReifyImply(_, box1, box2) => {
1434                write!(f, "ReifyImply({}, {})", box1.clone(), box2.clone())
1435            }
1436            Expression::MinionWInIntervalSet(_, atom, intervals) => {
1437                let intervals = intervals.iter().join(",");
1438                write!(f, "__minion_w_inintervalset({atom},[{intervals}])")
1439            }
1440            Expression::MinionWInSet(_, atom, values) => {
1441                let values = values.iter().join(",");
1442                write!(f, "__minion_w_inset({atom},[{values}])")
1443            }
1444            Expression::AuxDeclaration(_, reference, e) => {
1445                write!(f, "{} =aux {}", reference, e.clone())
1446            }
1447            Expression::UnsafeMod(_, a, b) => {
1448                write!(f, "{} % {}", a.clone(), b.clone())
1449            }
1450            Expression::SafeMod(_, a, b) => {
1451                write!(f, "SafeMod({},{})", a.clone(), b.clone())
1452            }
1453            Expression::Neg(_, a) => {
1454                write!(f, "-({})", a.clone())
1455            }
1456            Expression::Minus(_, a, b) => {
1457                write!(f, "({} - {})", a.clone(), b.clone())
1458            }
1459            Expression::FlatAllDiff(_, es) => {
1460                write!(f, "__flat_alldiff({})", pretty_vec(es))
1461            }
1462            Expression::FlatAbsEq(_, a, b) => {
1463                write!(f, "AbsEq({},{})", a.clone(), b.clone())
1464            }
1465            Expression::FlatMinusEq(_, a, b) => {
1466                write!(f, "MinusEq({},{})", a.clone(), b.clone())
1467            }
1468            Expression::FlatProductEq(_, a, b, c) => {
1469                write!(
1470                    f,
1471                    "FlatProductEq({},{},{})",
1472                    a.clone(),
1473                    b.clone(),
1474                    c.clone()
1475                )
1476            }
1477            Expression::FlatWeightedSumLeq(_, cs, vs, total) => {
1478                write!(
1479                    f,
1480                    "FlatWeightedSumLeq({},{},{})",
1481                    pretty_vec(cs),
1482                    pretty_vec(vs),
1483                    total.clone()
1484                )
1485            }
1486            Expression::FlatWeightedSumGeq(_, cs, vs, total) => {
1487                write!(
1488                    f,
1489                    "FlatWeightedSumGeq({},{},{})",
1490                    pretty_vec(cs),
1491                    pretty_vec(vs),
1492                    total.clone()
1493                )
1494            }
1495            Expression::MinionPow(_, atom, atom1, atom2) => {
1496                write!(f, "MinionPow({atom},{atom1},{atom2})")
1497            }
1498            Expression::MinionElementOne(_, atoms, atom, atom1) => {
1499                let atoms = atoms.iter().join(",");
1500                write!(f, "__minion_element_one([{atoms}],{atom},{atom1})")
1501            }
1502
1503            Expression::ToInt(_, expr) => {
1504                write!(f, "toInt({expr})")
1505            }
1506
1507            Expression::SATInt(_, encoding, bits, (min, max)) => {
1508                write!(f, "SATInt({encoding:?}, {bits} [{min}, {max}])")
1509            }
1510
1511            Expression::PairwiseSum(_, a, b) => write!(f, "PairwiseSum({a}, {b})"),
1512            Expression::PairwiseProduct(_, a, b) => write!(f, "PairwiseProduct({a}, {b})"),
1513
1514            Expression::Defined(_, function) => write!(f, "defined({function})"),
1515            Expression::Range(_, function) => write!(f, "range({function})"),
1516            Expression::Image(_, function, elems) => write!(f, "image({function},{elems})"),
1517            Expression::ImageSet(_, function, elems) => write!(f, "imageSet({function},{elems})"),
1518            Expression::PreImage(_, function, elems) => write!(f, "preImage({function},{elems})"),
1519            Expression::Inverse(_, a, b) => write!(f, "inverse({a},{b})"),
1520            Expression::Restrict(_, function, domain) => write!(f, "restrict({function},{domain})"),
1521
1522            Expression::LexLt(_, a, b) => write!(f, "({a} <lex {b})"),
1523            Expression::LexLeq(_, a, b) => write!(f, "({a} <=lex {b})"),
1524            Expression::LexGt(_, a, b) => write!(f, "({a} >lex {b})"),
1525            Expression::LexGeq(_, a, b) => write!(f, "({a} >=lex {b})"),
1526            Expression::FlatLexLt(_, a, b) => {
1527                write!(f, "FlatLexLt({}, {})", pretty_vec(a), pretty_vec(b))
1528            }
1529            Expression::FlatLexLeq(_, a, b) => {
1530                write!(f, "FlatLexLeq({}, {})", pretty_vec(a), pretty_vec(b))
1531            }
1532        }
1533    }
1534}
1535
1536impl Typeable for Expression {
1537    fn return_type(&self) -> ReturnType {
1538        match self {
1539            Expression::Union(_, subject, _) => ReturnType::Set(Box::new(subject.return_type())),
1540            Expression::Intersect(_, subject, _) => {
1541                ReturnType::Set(Box::new(subject.return_type()))
1542            }
1543            Expression::In(_, _, _) => ReturnType::Bool,
1544            Expression::Supset(_, _, _) => ReturnType::Bool,
1545            Expression::SupsetEq(_, _, _) => ReturnType::Bool,
1546            Expression::Subset(_, _, _) => ReturnType::Bool,
1547            Expression::SubsetEq(_, _, _) => ReturnType::Bool,
1548            Expression::AbstractLiteral(_, lit) => lit.return_type(),
1549            Expression::UnsafeIndex(_, subject, idx) | Expression::SafeIndex(_, subject, idx) => {
1550                let subject_ty = subject.return_type();
1551                match subject_ty {
1552                    ReturnType::Matrix(_) => {
1553                        // For n-dimensional matrices, unwrap the element type until
1554                        // we either get to the innermost element type or the last index
1555                        let mut elem_typ = subject_ty;
1556                        let mut idx_len = idx.len();
1557                        while idx_len > 0
1558                            && let ReturnType::Matrix(new_elem_typ) = &elem_typ
1559                        {
1560                            elem_typ = *new_elem_typ.clone();
1561                            idx_len -= 1;
1562                        }
1563                        elem_typ
1564                    }
1565                    // TODO: We can implement indexing for these eventually
1566                    ReturnType::Record(_) | ReturnType::Tuple(_) => ReturnType::Unknown,
1567                    _ => bug!(
1568                        "Invalid indexing operation: expected the operand to be a collection, got {self}: {subject_ty}"
1569                    ),
1570                }
1571            }
1572            Expression::UnsafeSlice(_, subject, _) | Expression::SafeSlice(_, subject, _) => {
1573                ReturnType::Matrix(Box::new(subject.return_type()))
1574            }
1575            Expression::InDomain(_, _, _) => ReturnType::Bool,
1576            Expression::Comprehension(_, comp) => comp.return_type(),
1577            Expression::AbstractComprehension(_, comp) => comp.return_type(),
1578            Expression::Root(_, _) => ReturnType::Bool,
1579            Expression::DominanceRelation(_, _) => ReturnType::Bool,
1580            Expression::FromSolution(_, expr) => expr.return_type(),
1581            Expression::Metavar(_, _) => ReturnType::Unknown,
1582            Expression::Atomic(_, atom) => atom.return_type(),
1583            Expression::Scope(_, scope) => scope.return_type(),
1584            Expression::Abs(_, _) => ReturnType::Int,
1585            Expression::Sum(_, _) => ReturnType::Int,
1586            Expression::Product(_, _) => ReturnType::Int,
1587            Expression::Min(_, _) => ReturnType::Int,
1588            Expression::Max(_, _) => ReturnType::Int,
1589            Expression::Not(_, _) => ReturnType::Bool,
1590            Expression::Or(_, _) => ReturnType::Bool,
1591            Expression::Imply(_, _, _) => ReturnType::Bool,
1592            Expression::Iff(_, _, _) => ReturnType::Bool,
1593            Expression::And(_, _) => ReturnType::Bool,
1594            Expression::Eq(_, _, _) => ReturnType::Bool,
1595            Expression::Neq(_, _, _) => ReturnType::Bool,
1596            Expression::Geq(_, _, _) => ReturnType::Bool,
1597            Expression::Leq(_, _, _) => ReturnType::Bool,
1598            Expression::Gt(_, _, _) => ReturnType::Bool,
1599            Expression::Lt(_, _, _) => ReturnType::Bool,
1600            Expression::SafeDiv(_, _, _) => ReturnType::Int,
1601            Expression::UnsafeDiv(_, _, _) => ReturnType::Int,
1602            Expression::FlatAllDiff(_, _) => ReturnType::Bool,
1603            Expression::FlatSumGeq(_, _, _) => ReturnType::Bool,
1604            Expression::FlatSumLeq(_, _, _) => ReturnType::Bool,
1605            Expression::MinionDivEqUndefZero(_, _, _, _) => ReturnType::Bool,
1606            Expression::FlatIneq(_, _, _, _) => ReturnType::Bool,
1607            Expression::Flatten(_, _, matrix) => {
1608                let matrix_type = matrix.return_type();
1609                match matrix_type {
1610                    ReturnType::Matrix(_) => {
1611                        // unwrap until we get to innermost element
1612                        let mut elem_type = matrix_type;
1613                        while let ReturnType::Matrix(new_elem_type) = &elem_type {
1614                            elem_type = *new_elem_type.clone();
1615                        }
1616                        ReturnType::Matrix(Box::new(elem_type))
1617                    }
1618                    _ => bug!(
1619                        "Invalid indexing operation: expected the operand to be a collection, got {self}: {matrix_type}"
1620                    ),
1621                }
1622            }
1623            Expression::AllDiff(_, _) => ReturnType::Bool,
1624            Expression::Bubble(_, inner, _) => inner.return_type(),
1625            Expression::FlatWatchedLiteral(_, _, _) => ReturnType::Bool,
1626            Expression::MinionReify(_, _, _) => ReturnType::Bool,
1627            Expression::MinionReifyImply(_, _, _) => ReturnType::Bool,
1628            Expression::MinionWInIntervalSet(_, _, _) => ReturnType::Bool,
1629            Expression::MinionWInSet(_, _, _) => ReturnType::Bool,
1630            Expression::MinionElementOne(_, _, _, _) => ReturnType::Bool,
1631            Expression::AuxDeclaration(_, _, _) => ReturnType::Bool,
1632            Expression::UnsafeMod(_, _, _) => ReturnType::Int,
1633            Expression::SafeMod(_, _, _) => ReturnType::Int,
1634            Expression::MinionModuloEqUndefZero(_, _, _, _) => ReturnType::Bool,
1635            Expression::Neg(_, _) => ReturnType::Int,
1636            Expression::UnsafePow(_, _, _) => ReturnType::Int,
1637            Expression::SafePow(_, _, _) => ReturnType::Int,
1638            Expression::Minus(_, _, _) => ReturnType::Int,
1639            Expression::FlatAbsEq(_, _, _) => ReturnType::Bool,
1640            Expression::FlatMinusEq(_, _, _) => ReturnType::Bool,
1641            Expression::FlatProductEq(_, _, _, _) => ReturnType::Bool,
1642            Expression::FlatWeightedSumLeq(_, _, _, _) => ReturnType::Bool,
1643            Expression::FlatWeightedSumGeq(_, _, _, _) => ReturnType::Bool,
1644            Expression::MinionPow(_, _, _, _) => ReturnType::Bool,
1645            Expression::ToInt(_, _) => ReturnType::Int,
1646            Expression::SATInt(..) => ReturnType::Int,
1647            Expression::PairwiseSum(_, _, _) => ReturnType::Int,
1648            Expression::PairwiseProduct(_, _, _) => ReturnType::Int,
1649            Expression::Defined(_, function) => {
1650                let subject = function.return_type();
1651                match subject {
1652                    ReturnType::Function(domain, _) => *domain,
1653                    _ => bug!(
1654                        "Invalid defined operation: expected the operand to be a function, got {self}: {subject}"
1655                    ),
1656                }
1657            }
1658            Expression::Range(_, function) => {
1659                let subject = function.return_type();
1660                match subject {
1661                    ReturnType::Function(_, codomain) => *codomain,
1662                    _ => bug!(
1663                        "Invalid range operation: expected the operand to be a function, got {self}: {subject}"
1664                    ),
1665                }
1666            }
1667            Expression::Image(_, function, _) => {
1668                let subject = function.return_type();
1669                match subject {
1670                    ReturnType::Function(_, codomain) => *codomain,
1671                    _ => bug!(
1672                        "Invalid image operation: expected the operand to be a function, got {self}: {subject}"
1673                    ),
1674                }
1675            }
1676            Expression::ImageSet(_, function, _) => {
1677                let subject = function.return_type();
1678                match subject {
1679                    ReturnType::Function(_, codomain) => *codomain,
1680                    _ => bug!(
1681                        "Invalid imageSet operation: expected the operand to be a function, got {self}: {subject}"
1682                    ),
1683                }
1684            }
1685            Expression::PreImage(_, function, _) => {
1686                let subject = function.return_type();
1687                match subject {
1688                    ReturnType::Function(domain, _) => *domain,
1689                    _ => bug!(
1690                        "Invalid preImage operation: expected the operand to be a function, got {self}: {subject}"
1691                    ),
1692                }
1693            }
1694            Expression::Restrict(_, function, new_domain) => {
1695                let subject = function.return_type();
1696                match subject {
1697                    ReturnType::Function(_, codomain) => {
1698                        ReturnType::Function(Box::new(new_domain.return_type()), codomain)
1699                    }
1700                    _ => bug!(
1701                        "Invalid preImage operation: expected the operand to be a function, got {self}: {subject}"
1702                    ),
1703                }
1704            }
1705            Expression::Inverse(..) => ReturnType::Bool,
1706            Expression::LexLt(..) => ReturnType::Bool,
1707            Expression::LexGt(..) => ReturnType::Bool,
1708            Expression::LexLeq(..) => ReturnType::Bool,
1709            Expression::LexGeq(..) => ReturnType::Bool,
1710            Expression::FlatLexLt(..) => ReturnType::Bool,
1711            Expression::FlatLexLeq(..) => ReturnType::Bool,
1712        }
1713    }
1714}
1715
1716#[cfg(test)]
1717mod tests {
1718    use crate::matrix_expr;
1719
1720    use super::*;
1721
1722    #[test]
1723    fn test_domain_of_constant_sum() {
1724        let c1 = Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Int(1)));
1725        let c2 = Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Int(2)));
1726        let sum = Expression::Sum(Metadata::new(), Moo::new(matrix_expr![c1, c2]));
1727        assert_eq!(sum.domain_of(), Some(Domain::int(vec![Range::Single(3)])));
1728    }
1729
1730    #[test]
1731    fn test_domain_of_constant_invalid_type() {
1732        let c1 = Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Int(1)));
1733        let c2 = Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Bool(true)));
1734        let sum = Expression::Sum(Metadata::new(), Moo::new(matrix_expr![c1, c2]));
1735        assert_eq!(sum.domain_of(), None);
1736    }
1737
1738    #[test]
1739    fn test_domain_of_empty_sum() {
1740        let sum = Expression::Sum(Metadata::new(), Moo::new(matrix_expr![]));
1741        assert_eq!(sum.domain_of(), None);
1742    }
1743}