1
use std::collections::{HashSet, VecDeque};
2
use std::fmt::{Display, Formatter};
3
use tracing::trace;
4

            
5
use conjure_cp_enum_compatibility_macro::document_compatibility;
6
use itertools::Itertools;
7
use serde::{Deserialize, Serialize};
8
use ustr::Ustr;
9

            
10
use polyquine::Quine;
11
use uniplate::{Biplate, Uniplate};
12

            
13
use crate::bug;
14

            
15
use super::abstract_comprehension::AbstractComprehension;
16
use super::ac_operators::ACOperatorKind;
17
use super::categories::{Category, CategoryOf};
18
use super::comprehension::Comprehension;
19
use super::domains::HasDomain as _;
20
use super::pretty::{pretty_expressions_as_top_level, pretty_vec};
21
use super::records::RecordValue;
22
use super::sat_encoding::SATIntEncoding;
23
use super::{
24
    AbstractLiteral, Atom, DeclarationPtr, Domain, DomainPtr, GroundDomain, IntVal, Literal,
25
    Metadata, Model, Moo, Name, Range, Reference, ReturnType, SetAttr, SymbolTable, SymbolTablePtr,
26
    Typeable, UnresolvedDomain, matrix,
27
};
28

            
29
// Ensure that this type doesn't get too big
30
//
31
// If you triggered this assertion, you either made a variant of this enum that is too big, or you
32
// made Name,Literal,AbstractLiteral,Atom bigger, which made this bigger! To fix this, put some
33
// stuff in boxes.
34
//
35
// Enums take the size of their largest variant, so an enum with mostly small variants and a few
36
// large ones wastes memory... A larger Expression type also slows down Oxide.
37
//
38
// For more information, and more details on type sizes and how to measure them, see the commit
39
// message for 6012de809 (perf: reduce size of AST types, 2025-06-18).
40
//
41
// You can also see type sizes in the rustdoc documentation, generated by ./tools/gen_docs.sh
42
//
43
// https://github.com/conjure-cp/conjure-oxide/commit/6012de8096ca491ded91ecec61352fdf4e994f2e
44

            
45
// TODO: box all usages of Metadata to bring this down a bit more - I have added variants to
46
// ReturnType, and Metadata contains ReturnType, so Metadata has got bigger. Metadata will get a
47
// lot bigger still when we start using it for memoisation, so it should really be
48
// boxed ~niklasdewally
49

            
50
// expect size of Expression to be 112 bytes
51
static_assertions::assert_eq_size!([u8; 112], Expression);
52

            
53
/// Represents different types of expressions used to define rules and constraints in the model.
54
///
55
/// The `Expression` enum includes operations, constants, and variable references
56
/// used to build rules and conditions for the model.
57
#[document_compatibility]
58
#[derive(Clone, Debug, Hash, PartialEq, Eq, Serialize, Deserialize, Uniplate, Quine)]
59
#[biplate(to=AbstractComprehension)]
60
#[biplate(to=AbstractLiteral<Expression>)]
61
#[biplate(to=AbstractLiteral<Literal>)]
62
#[biplate(to=Atom)]
63
#[biplate(to=Comprehension)]
64
#[biplate(to=DeclarationPtr)]
65
#[biplate(to=DomainPtr)]
66
#[biplate(to=Literal)]
67
#[biplate(to=Metadata)]
68
#[biplate(to=Name)]
69
#[biplate(to=Option<Expression>)]
70
#[biplate(to=RecordValue<Expression>)]
71
#[biplate(to=RecordValue<Literal>)]
72
#[biplate(to=Reference)]
73
#[biplate(to=Model)]
74
#[biplate(to=SymbolTable)]
75
#[biplate(to=SymbolTablePtr)]
76
#[biplate(to=Vec<Expression>)]
77
#[path_prefix(conjure_cp::ast)]
78
pub enum Expression {
79
    AbstractLiteral(Metadata, AbstractLiteral<Expression>),
80
    /// The top of the model
81
    Root(Metadata, Vec<Expression>),
82

            
83
    /// An expression representing "A is valid as long as B is true"
84
    /// Turns into a conjunction when it reaches a boolean context
85
    Bubble(Metadata, Moo<Expression>, Moo<Expression>),
86

            
87
    /// A comprehension.
88
    ///
89
    /// The inside of the comprehension opens a new scope.
90
    // todo (gskorokhod): Comprehension contains a symbol table which contains a bunch of pointers.
91
    // This makes implementing Quine tricky (it doesnt support Rc, by design). Skip it for now.
92
    #[polyquine_skip]
93
    Comprehension(Metadata, Moo<Comprehension>),
94

            
95
    /// Higher-level abstract comprehension
96
    #[polyquine_skip] // no idea what this is lol but it stops rustc screaming at me
97
    AbstractComprehension(Metadata, Moo<AbstractComprehension>),
98

            
99
    /// Defines dominance ("Solution A is preferred over Solution B")
100
    DominanceRelation(Metadata, Moo<Expression>),
101
    /// `fromSolution(name)` - Used in dominance relation definitions
102
    FromSolution(Metadata, Moo<Atom>),
103

            
104
    #[polyquine_with(arm = (_, name) => {
105
        let ident = proc_macro2::Ident::new(name.as_str(), proc_macro2::Span::call_site());
106
        quote::quote! { #ident.clone().into() }
107
    })]
108
    Metavar(Metadata, Ustr),
109

            
110
    Atomic(Metadata, Atom),
111

            
112
    /// A matrix index.
113
    ///
114
    /// Defined iff the indices are within their respective index domains.
115
    #[compatible(JsonInput)]
116
    UnsafeIndex(Metadata, Moo<Expression>, Vec<Expression>),
117

            
118
    /// A safe matrix index.
119
    ///
120
    /// See [`Expression::UnsafeIndex`]
121
    #[compatible(SMT)]
122
    SafeIndex(Metadata, Moo<Expression>, Vec<Expression>),
123

            
124
    /// A matrix slice: `a[indices]`.
125
    ///
126
    /// One of the indicies may be `None`, representing the dimension of the matrix we want to take
127
    /// a slice of. For example, for some 3d matrix a, `a[1,..,2]` has the indices
128
    /// `Some(1),None,Some(2)`.
129
    ///
130
    /// It is assumed that the slice only has one "wild-card" dimension and thus is 1 dimensional.
131
    ///
132
    /// Defined iff the defined indices are within their respective index domains.
133
    #[compatible(JsonInput)]
134
    UnsafeSlice(Metadata, Moo<Expression>, Vec<Option<Expression>>),
135

            
136
    /// A safe matrix slice: `a[indices]`.
137
    ///
138
    /// See [`Expression::UnsafeSlice`].
139
    SafeSlice(Metadata, Moo<Expression>, Vec<Option<Expression>>),
140

            
141
    /// `inDomain(x,domain)` iff `x` is in the domain `domain`.
142
    ///
143
    /// This cannot be constructed from Essence input, nor passed to a solver: this expression is
144
    /// mainly used during the conversion of `UnsafeIndex` and `UnsafeSlice` to `SafeIndex` and
145
    /// `SafeSlice` respectively.
146
    InDomain(Metadata, Moo<Expression>, DomainPtr),
147

            
148
    /// `toInt(b)` casts boolean expression b to an integer.
149
    ///
150
    /// - If b is false, then `toInt(b) == 0`
151
    ///
152
    /// - If b is true, then `toInt(b) == 1`
153
    #[compatible(SMT)]
154
    ToInt(Metadata, Moo<Expression>),
155

            
156
    /// `|x|` - absolute value of `x`
157
    #[compatible(JsonInput, SMT)]
158
    Abs(Metadata, Moo<Expression>),
159

            
160
    /// `sum(<vec_expr>)`
161
    #[compatible(JsonInput, SMT)]
162
    Sum(Metadata, Moo<Expression>),
163

            
164
    /// `a * b * c * ...`
165
    #[compatible(JsonInput, SMT)]
166
    Product(Metadata, Moo<Expression>),
167

            
168
    /// `min(<vec_expr>)`
169
    #[compatible(JsonInput, SMT)]
170
    Min(Metadata, Moo<Expression>),
171

            
172
    /// `max(<vec_expr>)`
173
    #[compatible(JsonInput, SMT)]
174
    Max(Metadata, Moo<Expression>),
175

            
176
    /// `not(a)`
177
    #[compatible(JsonInput, SAT, SMT)]
178
    Not(Metadata, Moo<Expression>),
179

            
180
    /// `or(<vec_expr>)`
181
    #[compatible(JsonInput, SAT, SMT)]
182
    Or(Metadata, Moo<Expression>),
183

            
184
    /// `and(<vec_expr>)`
185
    #[compatible(JsonInput, SAT, SMT)]
186
    And(Metadata, Moo<Expression>),
187

            
188
    /// Ensures that `a->b` (material implication).
189
    #[compatible(JsonInput, SMT)]
190
    Imply(Metadata, Moo<Expression>, Moo<Expression>),
191

            
192
    /// `iff(a, b)` a <-> b
193
    #[compatible(JsonInput, SMT)]
194
    Iff(Metadata, Moo<Expression>, Moo<Expression>),
195

            
196
    #[compatible(JsonInput)]
197
    Union(Metadata, Moo<Expression>, Moo<Expression>),
198

            
199
    #[compatible(JsonInput)]
200
    In(Metadata, Moo<Expression>, Moo<Expression>),
201

            
202
    #[compatible(JsonInput)]
203
    Intersect(Metadata, Moo<Expression>, Moo<Expression>),
204

            
205
    #[compatible(JsonInput)]
206
    Supset(Metadata, Moo<Expression>, Moo<Expression>),
207

            
208
    #[compatible(JsonInput)]
209
    SupsetEq(Metadata, Moo<Expression>, Moo<Expression>),
210

            
211
    #[compatible(JsonInput)]
212
    Subset(Metadata, Moo<Expression>, Moo<Expression>),
213

            
214
    #[compatible(JsonInput)]
215
    SubsetEq(Metadata, Moo<Expression>, Moo<Expression>),
216

            
217
    #[compatible(JsonInput, SMT)]
218
    Eq(Metadata, Moo<Expression>, Moo<Expression>),
219

            
220
    #[compatible(JsonInput, SMT)]
221
    Neq(Metadata, Moo<Expression>, Moo<Expression>),
222

            
223
    #[compatible(JsonInput, SMT)]
224
    Geq(Metadata, Moo<Expression>, Moo<Expression>),
225

            
226
    #[compatible(JsonInput, SMT)]
227
    Leq(Metadata, Moo<Expression>, Moo<Expression>),
228

            
229
    #[compatible(JsonInput, SMT)]
230
    Gt(Metadata, Moo<Expression>, Moo<Expression>),
231

            
232
    #[compatible(JsonInput, SMT)]
233
    Lt(Metadata, Moo<Expression>, Moo<Expression>),
234

            
235
    /// Division after preventing division by zero, usually with a bubble
236
    #[compatible(SMT)]
237
    SafeDiv(Metadata, Moo<Expression>, Moo<Expression>),
238

            
239
    /// Division with a possibly undefined value (division by 0)
240
    #[compatible(JsonInput)]
241
    UnsafeDiv(Metadata, Moo<Expression>, Moo<Expression>),
242

            
243
    /// Modulo after preventing mod 0, usually with a bubble
244
    #[compatible(SMT)]
245
    SafeMod(Metadata, Moo<Expression>, Moo<Expression>),
246

            
247
    /// Modulo with a possibly undefined value (mod 0)
248
    #[compatible(JsonInput)]
249
    UnsafeMod(Metadata, Moo<Expression>, Moo<Expression>),
250

            
251
    /// Negation: `-x`
252
    #[compatible(JsonInput, SMT)]
253
    Neg(Metadata, Moo<Expression>),
254

            
255
    /// Set of domain values function is defined for
256
    #[compatible(JsonInput)]
257
    Defined(Metadata, Moo<Expression>),
258

            
259
    /// Set of codomain values function is defined for
260
    #[compatible(JsonInput)]
261
    Range(Metadata, Moo<Expression>),
262

            
263
    /// Unsafe power`x**y` (possibly undefined)
264
    ///
265
    /// Defined when (X!=0 \\/ Y!=0) /\ Y>=0
266
    #[compatible(JsonInput)]
267
    UnsafePow(Metadata, Moo<Expression>, Moo<Expression>),
268

            
269
    /// `UnsafePow` after preventing undefinedness
270
    SafePow(Metadata, Moo<Expression>, Moo<Expression>),
271

            
272
    /// Flatten matrix operator
273
    /// `flatten(M)` or `flatten(n, M)`
274
    /// where M is a matrix and n is an optional integer argument indicating depth of flattening
275
    Flatten(Metadata, Option<Moo<Expression>>, Moo<Expression>),
276

            
277
    /// `allDiff(<vec_expr>)`
278
    #[compatible(JsonInput)]
279
    AllDiff(Metadata, Moo<Expression>),
280

            
281
    /// `table([x1, x2, ...], [[r11, r12, ...], [r21, r22, ...], ...])`
282
    ///
283
    /// Represents a positive table constraint: the tuple `[x1, x2, ...]` must match one of the
284
    /// allowed rows.
285
    #[compatible(JsonInput)]
286
    Table(Metadata, Moo<Expression>, Moo<Expression>),
287

            
288
    /// `negativeTable([x1, x2, ...], [[r11, r12, ...], [r21, r22, ...], ...])`
289
    ///
290
    /// Represents a negative table constraint: the tuple `[x1, x2, ...]` must NOT match any of the
291
    /// forbidden rows.
292
    #[compatible(JsonInput)]
293
    NegativeTable(Metadata, Moo<Expression>, Moo<Expression>),
294
    /// Binary subtraction operator
295
    ///
296
    /// This is a parser-level construct, and is immediately normalised to `Sum([a,-b])`.
297
    /// TODO: make this compatible with Set Difference calculations - need to change return type and domain for this expression and write a set comprehension rule.
298
    /// have already edited minus_to_sum to prevent this from applying to sets
299
    #[compatible(JsonInput)]
300
    Minus(Metadata, Moo<Expression>, Moo<Expression>),
301

            
302
    /// Ensures that x=|y| i.e. x is the absolute value of y.
303
    ///
304
    /// Low-level Minion constraint.
305
    ///
306
    /// # See also
307
    ///
308
    /// + [Minion documentation](https://minion-solver.readthedocs.io/en/stable/usage/constraints.html#abs)
309
    #[compatible(Minion)]
310
    FlatAbsEq(Metadata, Moo<Atom>, Moo<Atom>),
311

            
312
    /// Ensures that `alldiff([a,b,...])`.
313
    ///
314
    /// Low-level Minion constraint.
315
    ///
316
    /// # See also
317
    ///
318
    /// + [Minion documentation](https://minion-solver.readthedocs.io/en/stable/usage/constraints.html#alldiff)
319
    #[compatible(Minion)]
320
    FlatAllDiff(Metadata, Vec<Atom>),
321

            
322
    /// Ensures that sum(vec) >= x.
323
    ///
324
    /// Low-level Minion constraint.
325
    ///
326
    /// # See also
327
    ///
328
    /// + [Minion documentation](https://minion-solver.readthedocs.io/en/stable/usage/constraints.html#sumgeq)
329
    #[compatible(Minion)]
330
    FlatSumGeq(Metadata, Vec<Atom>, Atom),
331

            
332
    /// Ensures that sum(vec) <= x.
333
    ///
334
    /// Low-level Minion constraint.
335
    ///
336
    /// # See also
337
    ///
338
    /// + [Minion documentation](https://minion-solver.readthedocs.io/en/stable/usage/constraints.html#sumleq)
339
    #[compatible(Minion)]
340
    FlatSumLeq(Metadata, Vec<Atom>, Atom),
341

            
342
    /// `ineq(x,y,k)` ensures that x <= y + k.
343
    ///
344
    /// Low-level Minion constraint.
345
    ///
346
    /// # See also
347
    ///
348
    /// + [Minion documentation](https://minion-solver.readthedocs.io/en/stable/usage/constraints.html#ineq)
349
    #[compatible(Minion)]
350
    FlatIneq(Metadata, Moo<Atom>, Moo<Atom>, Box<Literal>),
351

            
352
    /// `w-literal(x,k)` ensures that x == k, where x is a variable and k a constant.
353
    ///
354
    /// Low-level Minion constraint.
355
    ///
356
    /// This is a low-level Minion constraint and you should probably use Eq instead. The main use
357
    /// of w-literal is to convert boolean variables to constraints so that they can be used inside
358
    /// watched-and and watched-or.
359
    ///
360
    /// # See also
361
    ///
362
    /// + [Minion documentation](https://minion-solver.readthedocs.io/en/stable/usage/constraints.html#minuseq)
363
    /// + `rules::minion::boolean_literal_to_wliteral`.
364
    #[compatible(Minion)]
365
    #[polyquine_skip]
366
    FlatWatchedLiteral(Metadata, Reference, Literal),
367

            
368
    /// `weightedsumleq(cs,xs,total)` ensures that cs.xs <= total, where cs.xs is the scalar dot
369
    /// product of cs and xs.
370
    ///
371
    /// Low-level Minion constraint.
372
    ///
373
    /// Represents a weighted sum of the form `ax + by + cz + ...`
374
    ///
375
    /// # See also
376
    ///
377
    /// + [Minion
378
    /// documentation](https://minion-solver.readthedocs.io/en/stable/usage/constraints.html#weightedsumleq)
379
    FlatWeightedSumLeq(Metadata, Vec<Literal>, Vec<Atom>, Moo<Atom>),
380

            
381
    /// `weightedsumgeq(cs,xs,total)` ensures that cs.xs >= total, where cs.xs is the scalar dot
382
    /// product of cs and xs.
383
    ///
384
    /// Low-level Minion constraint.
385
    ///
386
    /// Represents a weighted sum of the form `ax + by + cz + ...`
387
    ///
388
    /// # See also
389
    ///
390
    /// + [Minion
391
    /// documentation](https://minion-solver.readthedocs.io/en/stable/usage/constraints.html#weightedsumleq)
392
    FlatWeightedSumGeq(Metadata, Vec<Literal>, Vec<Atom>, Moo<Atom>),
393

            
394
    /// Ensures that x =-y, where x and y are atoms.
395
    ///
396
    /// Low-level Minion constraint.
397
    ///
398
    /// # See also
399
    ///
400
    /// + [Minion documentation](https://minion-solver.readthedocs.io/en/stable/usage/constraints.html#minuseq)
401
    #[compatible(Minion)]
402
    FlatMinusEq(Metadata, Moo<Atom>, Moo<Atom>),
403

            
404
    /// Ensures that x*y=z.
405
    ///
406
    /// Low-level Minion constraint.
407
    ///
408
    /// # See also
409
    ///
410
    /// + [Minion documentation](https://minion-solver.readthedocs.io/en/stable/usage/constraints.html#product)
411
    #[compatible(Minion)]
412
    FlatProductEq(Metadata, Moo<Atom>, Moo<Atom>, Moo<Atom>),
413

            
414
    /// Ensures that floor(x/y)=z. Always true when y=0.
415
    ///
416
    /// Low-level Minion constraint.
417
    ///
418
    /// # See also
419
    ///
420
    /// + [Minion documentation](https://minion-solver.readthedocs.io/en/stable/usage/constraints.html#div_undefzero)
421
    #[compatible(Minion)]
422
    MinionDivEqUndefZero(Metadata, Moo<Atom>, Moo<Atom>, Moo<Atom>),
423

            
424
    /// Ensures that x%y=z. Always true when y=0.
425
    ///
426
    /// Low-level Minion constraint.
427
    ///
428
    /// # See also
429
    ///
430
    /// + [Minion documentation](https://minion-solver.readthedocs.io/en/stable/usage/constraints.html#mod_undefzero)
431
    #[compatible(Minion)]
432
    MinionModuloEqUndefZero(Metadata, Moo<Atom>, Moo<Atom>, Moo<Atom>),
433

            
434
    /// Ensures that `x**y = z`.
435
    ///
436
    /// Low-level Minion constraint.
437
    ///
438
    /// This constraint is false when `y<0` except for `1**y=1` and `(-1)**y=z` (where z is 1 if y
439
    /// is odd and z is -1 if y is even).
440
    ///
441
    /// # See also
442
    ///
443
    /// + [Github comment about `pow` semantics](https://github.com/minion/minion/issues/40#issuecomment-2595914891)
444
    /// + [Minion documentation](https://minion-solver.readthedocs.io/en/stable/usage/constraints.html#pow)
445
    MinionPow(Metadata, Moo<Atom>, Moo<Atom>, Moo<Atom>),
446

            
447
    /// `reify(constraint,r)` ensures that r=1 iff `constraint` is satisfied, where r is a 0/1
448
    /// variable.
449
    ///
450
    /// Low-level Minion constraint.
451
    ///
452
    /// # See also
453
    ///
454
    ///  + [Minion documentation](https://minion-solver.readthedocs.io/en/stable/usage/constraints.html#reify)
455
    #[compatible(Minion)]
456
    MinionReify(Metadata, Moo<Expression>, Atom),
457

            
458
    /// `reifyimply(constraint,r)` ensures that `r->constraint`, where r is a 0/1 variable.
459
    /// variable.
460
    ///
461
    /// Low-level Minion constraint.
462
    ///
463
    /// # See also
464
    ///
465
    ///  + [Minion documentation](https://minion-solver.readthedocs.io/en/stable/usage/constraints.html#reifyimply)
466
    #[compatible(Minion)]
467
    MinionReifyImply(Metadata, Moo<Expression>, Atom),
468

            
469
    /// `w-inintervalset(x, [a1,a2, b1,b2, … ])` ensures that the value of x belongs to one of the
470
    /// intervals {a1,…,a2}, {b1,…,b2} etc.
471
    ///
472
    /// The list of intervals must be given in numerical order.
473
    ///
474
    /// Low-level Minion constraint.
475
    ///
476
    /// # See also
477
    ///>
478
    ///  + [Minion documentation](https://minion-solver.readthedocs.io/en/stable/usage/constraints.html#w-inintervalset)
479
    #[compatible(Minion)]
480
    MinionWInIntervalSet(Metadata, Atom, Vec<i32>),
481

            
482
    /// `w-inset(x, [v1, v2, … ])` ensures that the value of `x` is one of the explicitly given values `v1`, `v2`, etc.
483
    ///
484
    /// This constraint enforces membership in a specific set of discrete values rather than intervals.
485
    ///
486
    /// The list of values must be given in numerical order.
487
    ///
488
    /// Low-level Minion constraint.
489
    ///
490
    /// # See also
491
    ///
492
    ///  + [Minion documentation](https://minion-solver.readthedocs.io/en/stable/usage/constraints.html#w-inset)
493
    #[compatible(Minion)]
494
    MinionWInSet(Metadata, Atom, Vec<i32>),
495

            
496
    /// `element_one(vec, i, e)` specifies that `vec[i] = e`. This implies that i is
497
    /// in the range `[1..len(vec)]`.
498
    ///
499
    /// Low-level Minion constraint.
500
    ///
501
    /// # See also
502
    ///
503
    ///  + [Minion documentation](https://minion-solver.readthedocs.io/en/stable/usage/constraints.html#element_one)
504
    #[compatible(Minion)]
505
    MinionElementOne(Metadata, Vec<Atom>, Moo<Atom>, Moo<Atom>),
506

            
507
    /// Declaration of an auxiliary variable.
508
    ///
509
    /// As with Savile Row, we semantically distinguish this from `Eq`.
510
    #[compatible(Minion)]
511
    #[polyquine_skip]
512
    AuxDeclaration(Metadata, Reference, Moo<Expression>),
513

            
514
    /// This expression is for encoding ints for the SAT solver, it stores the encoding type, the vector of booleans and the min/max for the int.
515
    #[compatible(SAT)]
516
    SATInt(Metadata, SATIntEncoding, Moo<Expression>, (i32, i32)),
517

            
518
    /// Addition over a pair of expressions (i.e. a + b) rather than a vec-expr like Expression::Sum.
519
    /// This is for compatibility with backends that do not support addition over vectors.
520
    #[compatible(SMT)]
521
    PairwiseSum(Metadata, Moo<Expression>, Moo<Expression>),
522

            
523
    /// Multiplication over a pair of expressions (i.e. a * b) rather than a vec-expr like Expression::Product.
524
    /// This is for compatibility with backends that do not support multiplication over vectors.
525
    #[compatible(SMT)]
526
    PairwiseProduct(Metadata, Moo<Expression>, Moo<Expression>),
527

            
528
    #[compatible(JsonInput)]
529
    Image(Metadata, Moo<Expression>, Moo<Expression>),
530

            
531
    #[compatible(JsonInput)]
532
    ImageSet(Metadata, Moo<Expression>, Moo<Expression>),
533

            
534
    #[compatible(JsonInput)]
535
    PreImage(Metadata, Moo<Expression>, Moo<Expression>),
536

            
537
    #[compatible(JsonInput)]
538
    Inverse(Metadata, Moo<Expression>, Moo<Expression>),
539

            
540
    #[compatible(JsonInput)]
541
    Restrict(Metadata, Moo<Expression>, Moo<Expression>),
542

            
543
    /// Lexicographical < between two matrices.
544
    ///
545
    /// A <lex B iff: A[i] < B[i] for some i /\ (A[j] > B[j] for some j -> i < j)
546
    /// I.e. A must be less than B at some index i, and if it is greater than B at another index j,
547
    /// then j comes after i.
548
    /// I.e. A must be greater than B at the first index where they differ.
549
    ///
550
    /// E.g. [1, 1] <lex [2, 1] and [1, 1] <lex [1, 2]
551
    LexLt(Metadata, Moo<Expression>, Moo<Expression>),
552

            
553
    /// Lexicographical <= between two matrices
554
    LexLeq(Metadata, Moo<Expression>, Moo<Expression>),
555

            
556
    /// Lexicographical > between two matrices
557
    /// This is a parser-level construct, and is immediately normalised to LexLt(b, a)
558
    LexGt(Metadata, Moo<Expression>, Moo<Expression>),
559

            
560
    /// Lexicographical >= between two matrices
561
    /// This is a parser-level construct, and is immediately normalised to LexLeq(b, a)
562
    LexGeq(Metadata, Moo<Expression>, Moo<Expression>),
563

            
564
    /// Low-level minion constraint. See Expression::LexLt
565
    FlatLexLt(Metadata, Vec<Atom>, Vec<Atom>),
566

            
567
    /// Low-level minion constraint. See Expression::LexLeq
568
    FlatLexLeq(Metadata, Vec<Atom>, Vec<Atom>),
569
}
570

            
571
// for the given matrix literal, return a bounded domain from the min to max of applying op to each
572
// child expression.
573
//
574
// Op must be monotonic.
575
//
576
// Returns none if unbounded
577
409785
fn bounded_i32_domain_for_matrix_literal_monotonic(
578
409785
    e: &Expression,
579
409785
    op: fn(i32, i32) -> Option<i32>,
580
409785
) -> Option<DomainPtr> {
581
    // only care about the elements, not the indices
582
409785
    let (mut exprs, _) = e.clone().unwrap_matrix_unchecked()?;
583

            
584
    // fold each element's domain into one using op.
585
    //
586
    // here, I assume that op is monotone. This means that the bounds of op([a1,a2],[b1,b2])  for
587
    // the ranges [a1,a2], [b1,b2] will be
588
    // [min(op(a1,b1),op(a2,b1),op(a1,b2),op(a2,b2)),max(op(a1,b1),op(a2,b1),op(a1,b2),op(a2,b2))].
589
    //
590
    // We used to not assume this, and work out the bounds by applying op on the Cartesian product
591
    // of A and B; however, this caused a combinatorial explosion and my computer to run out of
592
    // memory (on the hakank_eprime_xkcd test)...
593
    //Int
594
    // For example, to find the bounds of the intervals [1,4], [1,5] combined using op, we used to do
595
    //  [min(op(1,1), op(1,2),op(1,3),op(1,4),op(1,5),op(2,1)..
596
    //
597
    // +,-,/,* are all monotone, so this assumption should be fine for now...
598

            
599
398591
    let expr = exprs.pop()?;
600
398588
    let dom = expr.domain_of()?;
601
396674
    let resolved = dom.resolve()?;
602
396674
    let GroundDomain::Int(ranges) = resolved.as_ref() else {
603
3
        return None;
604
    };
605

            
606
396671
    let (mut current_min, mut current_max) = range_vec_bounds_i32(ranges)?;
607

            
608
462269
    for expr in exprs {
609
462269
        let dom = expr.domain_of()?;
610
462153
        let resolved = dom.resolve()?;
611
462153
        let GroundDomain::Int(ranges) = resolved.as_ref() else {
612
            return None;
613
        };
614

            
615
462153
        let (min, max) = range_vec_bounds_i32(ranges)?;
616

            
617
        // all the possible new values for current_min / current_max
618
462153
        let minmax = op(min, current_max)?;
619
462153
        let minmin = op(min, current_min)?;
620
462153
        let maxmin = op(max, current_min)?;
621
462153
        let maxmax = op(max, current_max)?;
622
462153
        let vals = [minmax, minmin, maxmin, maxmax];
623

            
624
462153
        current_min = *vals
625
462153
            .iter()
626
462153
            .min()
627
462153
            .expect("vals iterator should not be empty, and should have a minimum.");
628
462153
        current_max = *vals
629
462153
            .iter()
630
462153
            .max()
631
462153
            .expect("vals iterator should not be empty, and should have a maximum.");
632
    }
633

            
634
396555
    if current_min == current_max {
635
255493
        Some(Domain::int(vec![Range::Single(current_min)]))
636
    } else {
637
141062
        Some(Domain::int(vec![Range::Bounded(current_min, current_max)]))
638
    }
639
409785
}
640

            
641
// Returns none if unbounded
642
858824
fn range_vec_bounds_i32(ranges: &Vec<Range<i32>>) -> Option<(i32, i32)> {
643
858824
    let mut min = i32::MAX;
644
858824
    let mut max = i32::MIN;
645
867802
    for r in ranges {
646
867802
        match r {
647
551696
            Range::Single(i) => {
648
551696
                if *i < min {
649
542718
                    min = *i;
650
542758
                }
651
551696
                if *i > max {
652
551696
                    max = *i;
653
551696
                }
654
            }
655
316106
            Range::Bounded(i, j) => {
656
316106
                if *i < min {
657
316106
                    min = *i;
658
316106
                }
659
316106
                if *j > max {
660
316106
                    max = *j;
661
316106
                }
662
            }
663
            Range::UnboundedR(_) | Range::UnboundedL(_) | Range::Unbounded => return None,
664
        }
665
    }
666
858824
    Some((min, max))
667
858824
}
668

            
669
impl Expression {
670
    /// Returns the possible values of the expression, recursing to leaf expressions
671
2289637
    pub fn domain_of(&self) -> Option<DomainPtr> {
672
2289637
        match self {
673
15
            Expression::Union(_, a, b) => Some(Domain::set(
674
15
                SetAttr::<IntVal>::default(),
675
15
                a.domain_of()?.union(&b.domain_of()?).ok()?,
676
            )),
677
15
            Expression::Intersect(_, a, b) => Some(Domain::set(
678
15
                SetAttr::<IntVal>::default(),
679
15
                a.domain_of()?.intersect(&b.domain_of()?).ok()?,
680
            )),
681
            Expression::In(_, _, _) => Some(Domain::bool()),
682
            Expression::Supset(_, _, _) => Some(Domain::bool()),
683
            Expression::SupsetEq(_, _, _) => Some(Domain::bool()),
684
            Expression::Subset(_, _, _) => Some(Domain::bool()),
685
9
            Expression::SubsetEq(_, _, _) => Some(Domain::bool()),
686
31419
            Expression::AbstractLiteral(_, abslit) => abslit.domain_of(),
687
            Expression::DominanceRelation(_, _) => Some(Domain::bool()),
688
            Expression::FromSolution(_, expr) => Some(expr.domain_of()),
689
            Expression::Metavar(_, _) => None,
690
            Expression::Comprehension(_, comprehension) => comprehension.domain_of(),
691
            Expression::AbstractComprehension(_, comprehension) => comprehension.domain_of(),
692
171129
            Expression::UnsafeIndex(_, matrix, _) | Expression::SafeIndex(_, matrix, _) => {
693
183453
                let dom = matrix.domain_of()?;
694
183453
                if let Some((elem_domain, _)) = dom.as_matrix() {
695
183453
                    return Some(elem_domain);
696
                }
697

            
698
                // may actually use the value in the future
699
                #[allow(clippy::redundant_pattern_matching)]
700
                if let Some(_) = dom.as_tuple() {
701
                    // TODO: We can implement proper indexing for tuples
702
                    return None;
703
                }
704

            
705
                // may actually use the value in the future
706
                #[allow(clippy::redundant_pattern_matching)]
707
                if let Some(_) = dom.as_record() {
708
                    // TODO: We can implement proper indexing for records
709
                    return None;
710
                }
711

            
712
                bug!("subject of an index operation should support indexing")
713
            }
714
            Expression::UnsafeSlice(_, matrix, indices)
715
312
            | Expression::SafeSlice(_, matrix, indices) => {
716
312
                let sliced_dimension = indices.iter().position(Option::is_none);
717

            
718
312
                let dom = matrix.domain_of()?;
719
312
                let Some((elem_domain, index_domains)) = dom.as_matrix() else {
720
                    bug!("subject of an index operation should be a matrix");
721
                };
722

            
723
312
                match sliced_dimension {
724
312
                    Some(dimension) => Some(Domain::matrix(
725
312
                        elem_domain,
726
312
                        vec![index_domains[dimension].clone()],
727
312
                    )),
728

            
729
                    // same as index
730
                    None => Some(elem_domain),
731
                }
732
            }
733
            Expression::InDomain(_, _, _) => Some(Domain::bool()),
734
1525531
            Expression::Atomic(_, atom) => Some(atom.domain_of()),
735
71190
            Expression::Sum(_, e) => {
736
422648
                bounded_i32_domain_for_matrix_literal_monotonic(e, |x, y| Some(x + y))
737
            }
738
201552
            Expression::Product(_, e) => {
739
806520
                bounded_i32_domain_for_matrix_literal_monotonic(e, |x, y| Some(x * y))
740
            }
741
8268
            Expression::Min(_, e) => bounded_i32_domain_for_matrix_literal_monotonic(e, |x, y| {
742
8268
                Some(if x < y { x } else { y })
743
8268
            }),
744
5616
            Expression::Max(_, e) => bounded_i32_domain_for_matrix_literal_monotonic(e, |x, y| {
745
5616
                Some(if x > y { x } else { y })
746
5616
            }),
747
85722
            Expression::UnsafeDiv(_, a, b) => a
748
85722
                .domain_of()?
749
85722
                .resolve()?
750
85722
                .apply_i32(
751
                    // rust integer division is truncating; however, we want to always round down,
752
                    // including for negative numbers.
753
86931
                    |x, y| {
754
86931
                        if y != 0 {
755
86736
                            Some((x as f32 / y as f32).floor() as i32)
756
                        } else {
757
195
                            None
758
                        }
759
86931
                    },
760
85722
                    b.domain_of()?.resolve()?.as_ref(),
761
                )
762
85722
                .map(DomainPtr::from)
763
85722
                .ok(),
764
2769
            Expression::SafeDiv(_, a, b) => {
765
                // rust integer division is truncating; however, we want to always round down
766
                // including for negative numbers.
767
2769
                let domain = a
768
2769
                    .domain_of()?
769
2769
                    .resolve()?
770
2769
                    .apply_i32(
771
140049
                        |x, y| {
772
140049
                            if y != 0 {
773
131274
                                Some((x as f32 / y as f32).floor() as i32)
774
                            } else {
775
8775
                                None
776
                            }
777
140049
                        },
778
2769
                        b.domain_of()?.resolve()?.as_ref(),
779
                    )
780
                    .unwrap_or_else(|err| bug!("Got {err} when computing domain of {self}"));
781

            
782
2769
                if let GroundDomain::Int(ranges) = domain {
783
2769
                    let mut ranges = ranges;
784
2769
                    ranges.push(Range::Single(0));
785
2769
                    Some(Domain::int(ranges))
786
                } else {
787
                    bug!("Domain of {self} was not integer")
788
                }
789
            }
790
            Expression::UnsafeMod(_, a, b) => a
791
                .domain_of()?
792
                .resolve()?
793
                .apply_i32(
794
                    |x, y| if y != 0 { Some(x % y) } else { None },
795
                    b.domain_of()?.resolve()?.as_ref(),
796
                )
797
                .map(DomainPtr::from)
798
                .ok(),
799
858
            Expression::SafeMod(_, a, b) => {
800
858
                let domain = a
801
858
                    .domain_of()?
802
858
                    .resolve()?
803
858
                    .apply_i32(
804
23439
                        |x, y| if y != 0 { Some(x % y) } else { None },
805
858
                        b.domain_of()?.resolve()?.as_ref(),
806
                    )
807
                    .unwrap_or_else(|err| bug!("Got {err} when computing domain of {self}"));
808

            
809
858
                if let GroundDomain::Int(ranges) = domain {
810
858
                    let mut ranges = ranges;
811
858
                    ranges.push(Range::Single(0));
812
858
                    Some(Domain::int(ranges))
813
                } else {
814
                    bug!("Domain of {self} was not integer")
815
                }
816
            }
817
810
            Expression::SafePow(_, a, b) | Expression::UnsafePow(_, a, b) => a
818
810
                .domain_of()?
819
810
                .resolve()?
820
810
                .apply_i32(
821
20427
                    |x, y| {
822
20427
                        if (x != 0 || y != 0) && y >= 0 {
823
19647
                            Some(x.pow(y as u32))
824
                        } else {
825
780
                            None
826
                        }
827
20427
                    },
828
810
                    b.domain_of()?.resolve()?.as_ref(),
829
                )
830
810
                .map(DomainPtr::from)
831
810
                .ok(),
832
            Expression::Root(_, _) => None,
833
195
            Expression::Bubble(_, inner, _) => inner.domain_of(),
834
            Expression::AuxDeclaration(_, _, _) => Some(Domain::bool()),
835
2034
            Expression::And(_, _) => Some(Domain::bool()),
836
468
            Expression::Not(_, _) => Some(Domain::bool()),
837
39
            Expression::Or(_, _) => Some(Domain::bool()),
838
204
            Expression::Imply(_, _, _) => Some(Domain::bool()),
839
            Expression::Iff(_, _, _) => Some(Domain::bool()),
840
1176
            Expression::Eq(_, _, _) => Some(Domain::bool()),
841
            Expression::Neq(_, _, _) => Some(Domain::bool()),
842
39
            Expression::Geq(_, _, _) => Some(Domain::bool()),
843
1014
            Expression::Leq(_, _, _) => Some(Domain::bool()),
844
237
            Expression::Gt(_, _, _) => Some(Domain::bool()),
845
3
            Expression::Lt(_, _, _) => Some(Domain::bool()),
846
            Expression::FlatAbsEq(_, _, _) => Some(Domain::bool()),
847
39
            Expression::FlatSumGeq(_, _, _) => Some(Domain::bool()),
848
            Expression::FlatSumLeq(_, _, _) => Some(Domain::bool()),
849
            Expression::MinionDivEqUndefZero(_, _, _, _) => Some(Domain::bool()),
850
            Expression::MinionModuloEqUndefZero(_, _, _, _) => Some(Domain::bool()),
851
78
            Expression::FlatIneq(_, _, _, _) => Some(Domain::bool()),
852
390
            Expression::Flatten(_, n, m) => {
853
390
                if let Some(expr) = n {
854
                    if expr.return_type() == ReturnType::Int {
855
                        // TODO: handle flatten with depth argument
856
                        return None;
857
                    }
858
                } else {
859
                    // TODO: currently only works for matrices
860
390
                    let dom = m.domain_of()?.resolve()?;
861
390
                    let (val_dom, idx_doms) = match dom.as_ref() {
862
390
                        GroundDomain::Matrix(val, idx) => (val, idx),
863
                        _ => return None,
864
                    };
865
390
                    let num_elems = matrix::num_elements(idx_doms).ok()? as i32;
866

            
867
390
                    let new_index_domain = Domain::int(vec![Range::Bounded(1, num_elems)]);
868
390
                    return Some(Domain::matrix(
869
390
                        val_dom.clone().into(),
870
390
                        vec![new_index_domain],
871
390
                    ));
872
                }
873
                None
874
            }
875
            Expression::AllDiff(_, _) => Some(Domain::bool()),
876
            Expression::Table(_, _, _) => Some(Domain::bool()),
877
            Expression::NegativeTable(_, _, _) => Some(Domain::bool()),
878
            Expression::FlatWatchedLiteral(_, _, _) => Some(Domain::bool()),
879
            Expression::MinionReify(_, _, _) => Some(Domain::bool()),
880
39
            Expression::MinionReifyImply(_, _, _) => Some(Domain::bool()),
881
            Expression::MinionWInIntervalSet(_, _, _) => Some(Domain::bool()),
882
            Expression::MinionWInSet(_, _, _) => Some(Domain::bool()),
883
            Expression::MinionElementOne(_, _, _, _) => Some(Domain::bool()),
884
2379
            Expression::Neg(_, x) => {
885
2379
                let dom = x.domain_of()?;
886
2379
                let mut ranges = dom.as_int()?;
887

            
888
1014
                ranges = ranges
889
1014
                    .into_iter()
890
1014
                    .map(|r| match r {
891
156
                        Range::Single(x) => Range::Single(-x),
892
858
                        Range::Bounded(x, y) => Range::Bounded(-y, -x),
893
                        Range::UnboundedR(i) => Range::UnboundedL(-i),
894
                        Range::UnboundedL(i) => Range::UnboundedR(-i),
895
                        Range::Unbounded => Range::Unbounded,
896
1014
                    })
897
1014
                    .collect();
898

            
899
1014
                Some(Domain::int(ranges))
900
            }
901
171330
            Expression::Minus(_, a, b) => a
902
171330
                .domain_of()?
903
171330
                .resolve()?
904
187209
                .apply_i32(|x, y| Some(x - y), b.domain_of()?.resolve()?.as_ref())
905
171330
                .map(DomainPtr::from)
906
171330
                .ok(),
907
            Expression::FlatAllDiff(_, _) => Some(Domain::bool()),
908
            Expression::FlatMinusEq(_, _, _) => Some(Domain::bool()),
909
            Expression::FlatProductEq(_, _, _, _) => Some(Domain::bool()),
910
            Expression::FlatWeightedSumLeq(_, _, _, _) => Some(Domain::bool()),
911
            Expression::FlatWeightedSumGeq(_, _, _, _) => Some(Domain::bool()),
912
1131
            Expression::Abs(_, a) => a
913
1131
                .domain_of()?
914
1131
                .resolve()?
915
188877
                .apply_i32(|a, _| Some(a.abs()), a.domain_of()?.resolve()?.as_ref())
916
1131
                .map(DomainPtr::from)
917
1131
                .ok(),
918
            Expression::MinionPow(_, _, _, _) => Some(Domain::bool()),
919
819
            Expression::ToInt(_, _) => Some(Domain::int(vec![Range::Bounded(0, 1)])),
920
1560
            Expression::SATInt(_, _, _, (low, high)) => {
921
1560
                Some(Domain::int_ground(vec![Range::Bounded(*low, *high)]))
922
            }
923
            Expression::PairwiseSum(_, a, b) => a
924
                .domain_of()?
925
                .resolve()?
926
                .apply_i32(|a, b| Some(a + b), b.domain_of()?.resolve()?.as_ref())
927
                .map(DomainPtr::from)
928
                .ok(),
929
            Expression::PairwiseProduct(_, a, b) => a
930
                .domain_of()?
931
                .resolve()?
932
                .apply_i32(|a, b| Some(a * b), b.domain_of()?.resolve()?.as_ref())
933
                .map(DomainPtr::from)
934
                .ok(),
935
            Expression::Defined(_, function) => get_function_domain(function),
936
            Expression::Range(_, function) => get_function_codomain(function),
937
            Expression::Image(_, function, _) => get_function_codomain(function),
938
            Expression::ImageSet(_, function, _) => get_function_codomain(function),
939
            Expression::PreImage(_, function, _) => get_function_domain(function),
940
            Expression::Restrict(_, function, new_domain) => {
941
                let (attrs, _, codom) = function.domain_of()?.as_function()?;
942
                let new_dom = new_domain.domain_of()?;
943
                Some(Domain::function(attrs, new_dom, codom))
944
            }
945
            Expression::Inverse(..) => Some(Domain::bool()),
946
            Expression::LexLt(..) => Some(Domain::bool()),
947
            Expression::LexLeq(..) => Some(Domain::bool()),
948
            Expression::LexGt(..) => Some(Domain::bool()),
949
            Expression::LexGeq(..) => Some(Domain::bool()),
950
            Expression::FlatLexLt(..) => Some(Domain::bool()),
951
            Expression::FlatLexLeq(..) => Some(Domain::bool()),
952
        }
953
2289637
    }
954

            
955
    pub fn get_meta(&self) -> Metadata {
956
        let metas: VecDeque<Metadata> = self.children_bi();
957
        metas[0].clone()
958
    }
959

            
960
    pub fn set_meta(&self, meta: Metadata) {
961
        self.transform_bi(&|_| meta.clone());
962
    }
963

            
964
    /// Checks whether this expression is safe.
965
    ///
966
    /// An expression is unsafe if can be undefined, or if any of its children can be undefined.
967
    ///
968
    /// Unsafe expressions are (typically) prefixed with Unsafe in our AST, and can be made
969
    /// safe through the use of bubble rules.
970
95282
    pub fn is_safe(&self) -> bool {
971
        // TODO: memoise in Metadata
972
814962
        for expr in self.universe() {
973
814962
            match expr {
974
                Expression::UnsafeDiv(_, _, _)
975
                | Expression::UnsafeMod(_, _, _)
976
                | Expression::UnsafePow(_, _, _)
977
                | Expression::UnsafeIndex(_, _, _)
978
                | Expression::Bubble(_, _, _)
979
                | Expression::UnsafeSlice(_, _, _) => {
980
3886
                    return false;
981
                }
982
811076
                _ => {}
983
            }
984
        }
985
91396
        true
986
95282
    }
987

            
988
    pub fn is_clean(&self) -> bool {
989
        let metadata = self.get_meta();
990
        metadata.clean
991
    }
992

            
993
    pub fn set_clean(&mut self, bool_value: bool) {
994
        let mut metadata = self.get_meta();
995
        metadata.clean = bool_value;
996
        self.set_meta(metadata);
997
    }
998

            
999
    /// True if the expression is an associative and commutative operator
9193444
    pub fn is_associative_commutative_operator(&self) -> bool {
9193444
        TryInto::<ACOperatorKind>::try_into(self).is_ok()
9193444
    }
    /// True if the expression is a matrix literal.
    ///
    /// This is true for both forms of matrix literals: those with elements of type [`Literal`] and
    /// [`Expression`].
8700
    pub fn is_matrix_literal(&self) -> bool {
        matches!(
8700
            self,
            Expression::AbstractLiteral(_, AbstractLiteral::Matrix(_, _))
                | Expression::Atomic(
                    _,
                    Atom::Literal(Literal::AbstractLiteral(AbstractLiteral::Matrix(_, _))),
                )
        )
8700
    }
    /// True iff self and other are both atomic and identical.
    ///
    /// This method is useful to cheaply check equivalence. Assuming CSE is enabled, any unifiable
    /// expressions will be rewritten to a common variable. This is much cheaper than checking the
    /// entire subtrees of `self` and `other`.
75772
    pub fn identical_atom_to(&self, other: &Expression) -> bool {
75772
        let atom1: Result<&Atom, _> = self.try_into();
75772
        let atom2: Result<&Atom, _> = other.try_into();
75772
        if let (Ok(atom1), Ok(atom2)) = (atom1, atom2) {
5336
            atom2 == atom1
        } else {
70436
            false
        }
75772
    }
    /// If the expression is a list, returns a *copied* vector of the inner expressions.
    ///
    /// A list is any a matrix with the domain `int(1..)`. This includes matrix literals without
    /// any explicitly specified domain.
4478792
    pub fn unwrap_list(&self) -> Option<Vec<Expression>> {
3537720
        match self {
3537720
            Expression::AbstractLiteral(_, matrix @ AbstractLiteral::Matrix(_, _)) => {
3537720
                matrix.unwrap_list().cloned()
            }
            Expression::Atomic(
                _,
16804
                Atom::Literal(Literal::AbstractLiteral(matrix @ AbstractLiteral::Matrix(_, _))),
16804
            ) => matrix.unwrap_list().map(|elems| {
11932
                elems
11932
                    .clone()
11932
                    .into_iter()
36818
                    .map(|x: Literal| Expression::Atomic(Metadata::new(), Atom::Literal(x)))
11932
                    .collect_vec()
11932
            }),
924268
            _ => None,
        }
4478792
    }
    /// If the expression is a matrix, gets it elements and index domain.
    ///
    /// **Consider using the safer [`Expression::unwrap_list`] instead.**
    ///
    /// It is generally undefined to edit the length of a matrix unless it is a list (as defined by
    /// [`Expression::unwrap_list`]). Users of this function should ensure that, if the matrix is
    /// reconstructed, the index domain and the number of elements in the matrix remain the same.
4822385
    pub fn unwrap_matrix_unchecked(self) -> Option<(Vec<Expression>, DomainPtr)> {
2507689
        match self {
2507689
            Expression::AbstractLiteral(_, AbstractLiteral::Matrix(elems, domain)) => {
2507689
                Some((elems, domain))
            }
            Expression::Atomic(
                _,
390868
                Atom::Literal(Literal::AbstractLiteral(AbstractLiteral::Matrix(elems, domain))),
            ) => Some((
390868
                elems
390868
                    .into_iter()
775410
                    .map(|x: Literal| Expression::Atomic(Metadata::new(), Atom::Literal(x)))
390868
                    .collect_vec(),
390868
                domain.into(),
            )),
1923828
            _ => None,
        }
4822385
    }
    /// For a Root expression, extends the inner vec with the given vec.
    ///
    /// # Panics
    /// Panics if the expression is not Root.
    pub fn extend_root(self, exprs: Vec<Expression>) -> Expression {
        match self {
            Expression::Root(meta, mut children) => {
                children.extend(exprs);
                Expression::Root(meta, children)
            }
            _ => panic!("extend_root called on a non-Root expression"),
        }
    }
    /// Converts the expression to a literal, if possible.
129220
    pub fn into_literal(self) -> Option<Literal> {
120211
        match self {
111922
            Expression::Atomic(_, Atom::Literal(lit)) => Some(lit),
            Expression::AbstractLiteral(_, abslit) => {
                Some(Literal::AbstractLiteral(abslit.into_literals()?))
            }
5070
            Expression::Neg(_, e) => {
5070
                let Literal::Int(i) = Moo::unwrap_or_clone(e).into_literal()? else {
                    bug!("negated literal should be an int");
                };
5070
                Some(Literal::Int(-i))
            }
12228
            _ => None,
        }
129220
    }
    /// If this expression is an associative-commutative operator, return its [ACOperatorKind].
10320282
    pub fn to_ac_operator_kind(&self) -> Option<ACOperatorKind> {
10320282
        TryFrom::try_from(self).ok()
10320282
    }
    /// Returns the categories of all sub-expressions of self.
65026
    pub fn universe_categories(&self) -> HashSet<Category> {
65026
        self.universe()
65026
            .into_iter()
642536
            .map(|x| x.category_of())
65026
            .collect()
65026
    }
}
pub fn get_function_domain(function: &Moo<Expression>) -> Option<DomainPtr> {
    let function_domain = function.domain_of()?;
    match function_domain.resolve().as_ref() {
        Some(d) => {
            match d.as_ref() {
                GroundDomain::Function(_, domain, _) => Some(domain.clone().into()),
                // Not defined for anything other than a function
                _ => None,
            }
        }
        None => {
            match function_domain.as_unresolved()? {
                UnresolvedDomain::Function(_, domain, _) => Some(domain.clone()),
                // Not defined for anything other than a function
                _ => None,
            }
        }
    }
}
pub fn get_function_codomain(function: &Moo<Expression>) -> Option<DomainPtr> {
    let function_domain = function.domain_of()?;
    match function_domain.resolve().as_ref() {
        Some(d) => {
            match d.as_ref() {
                GroundDomain::Function(_, _, codomain) => Some(codomain.clone().into()),
                // Not defined for anything other than a function
                _ => None,
            }
        }
        None => {
            match function_domain.as_unresolved()? {
                UnresolvedDomain::Function(_, _, codomain) => Some(codomain.clone()),
                // Not defined for anything other than a function
                _ => None,
            }
        }
    }
}
impl TryFrom<&Expression> for i32 {
    type Error = ();
280444
    fn try_from(value: &Expression) -> Result<Self, Self::Error> {
280444
        let Expression::Atomic(_, atom) = value else {
186148
            return Err(());
        };
94296
        let Atom::Literal(lit) = atom else {
94296
            return Err(());
        };
        let Literal::Int(i) = lit else {
            return Err(());
        };
        Ok(*i)
280444
    }
}
impl TryFrom<Expression> for i32 {
    type Error = ();
    fn try_from(value: Expression) -> Result<Self, Self::Error> {
        TryFrom::<&Expression>::try_from(&value)
    }
}
impl From<i32> for Expression {
40282
    fn from(i: i32) -> Self {
40282
        Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Int(i)))
40282
    }
}
impl From<bool> for Expression {
21331
    fn from(b: bool) -> Self {
21331
        Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Bool(b)))
21331
    }
}
impl From<Atom> for Expression {
3494
    fn from(value: Atom) -> Self {
3494
        Expression::Atomic(Metadata::new(), value)
3494
    }
}
impl From<Literal> for Expression {
6264
    fn from(value: Literal) -> Self {
6264
        Expression::Atomic(Metadata::new(), value.into())
6264
    }
}
impl From<Moo<Expression>> for Expression {
51738
    fn from(val: Moo<Expression>) -> Self {
51738
        val.as_ref().clone()
51738
    }
}
impl CategoryOf for Expression {
829924
    fn category_of(&self) -> Category {
        // take highest category of all the expressions children
3364938
        let category = self.cata(&move |x,children| {
3364938
            if let Some(max_category) = children.iter().max() {
                // if this expression contains subexpressions, return the maximum category of the
                // subexpressions
1021368
                *max_category
            } else {
                // this expression has no children
2343570
                let mut max_category = Category::Bottom;
                // calculate the category by looking at all atoms, submodels, comprehensions, and
                // declarationptrs inside this expression
                // this should generically cover all leaf types we currently have in oxide.
                // if x contains submodels (including comprehensions)
2343570
                if !Biplate::<Model>::universe_bi(&x).is_empty() {
                    // assume that the category is decision
                    return Category::Decision;
2343570
                }
                // if x contains atoms
2345774
                if let Some(max_atom_category) = Biplate::<Atom>::universe_bi(&x).iter().map(|x| x.category_of()).max()
                // and those atoms have a higher category than we already know about
2343570
                && max_atom_category > max_category{
                    // update category
2343570
                    max_category = max_atom_category;
2343570
                }
                // if x contains declarationPtrs
2343570
                if let Some(max_declaration_category) = Biplate::<DeclarationPtr>::universe_bi(&x).iter().map(|x| x.category_of()).max()
                // and those pointers have a higher category than we already know about
1782576
                && max_declaration_category > max_category{
                    // update category
                    max_category = max_declaration_category;
2343570
                }
2343570
                max_category
            }
3364938
        });
829924
        if cfg!(debug_assertions) {
829924
            trace!(
                category= %category,
                expression= %self,
                "Called Expression::category_of()"
            );
        };
829924
        category
829924
    }
}
impl Display for Expression {
2448512884
    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
2448512884
        match &self {
1432
            Expression::Union(_, box1, box2) => {
1432
                write!(f, "({} union {})", box1.clone(), box2.clone())
            }
365961
            Expression::In(_, e1, e2) => {
365961
                write!(f, "{e1} in {e2}")
            }
1432
            Expression::Intersect(_, box1, box2) => {
1432
                write!(f, "({} intersect {})", box1.clone(), box2.clone())
            }
1856
            Expression::Supset(_, box1, box2) => {
1856
                write!(f, "({} supset {})", box1.clone(), box2.clone())
            }
1856
            Expression::SupsetEq(_, box1, box2) => {
1856
                write!(f, "({} supsetEq {})", box1.clone(), box2.clone())
            }
2320
            Expression::Subset(_, box1, box2) => {
2320
                write!(f, "({} subset {})", box1.clone(), box2.clone())
            }
4652
            Expression::SubsetEq(_, box1, box2) => {
4652
                write!(f, "({} subsetEq {})", box1.clone(), box2.clone())
            }
262090825
            Expression::AbstractLiteral(_, l) => l.fmt(f),
7056046
            Expression::Comprehension(_, c) => c.fmt(f),
17052
            Expression::AbstractComprehension(_, c) => c.fmt(f),
118423562
            Expression::UnsafeIndex(_, e1, e2) | Expression::SafeIndex(_, e1, e2) => {
158602084
                write!(f, "{e1}{}", pretty_vec(e2))
            }
10344300
            Expression::UnsafeSlice(_, e1, es) | Expression::SafeSlice(_, e1, es) => {
16161526
                let args = es
16161526
                    .iter()
32065822
                    .map(|x| match x {
15904296
                        Some(x) => format!("{x}"),
16161526
                        None => "..".into(),
32065822
                    })
16161526
                    .join(",");
16161526
                write!(f, "{e1}[{args}]")
            }
2858248
            Expression::InDomain(_, e, domain) => {
2858248
                write!(f, "__inDomain({e},{domain})")
            }
18209022
            Expression::Root(_, exprs) => {
18209022
                write!(f, "{}", pretty_expressions_as_top_level(exprs))
            }
            Expression::DominanceRelation(_, expr) => write!(f, "DominanceRelation({expr})"),
            Expression::FromSolution(_, expr) => write!(f, "FromSolution({expr})"),
            Expression::Metavar(_, name) => write!(f, "&{name}"),
1516306110
            Expression::Atomic(_, atom) => atom.fmt(f),
1097998
            Expression::Abs(_, a) => write!(f, "|{a}|"),
90716108
            Expression::Sum(_, e) => {
90716108
                write!(f, "sum({e})")
            }
32160304
            Expression::Product(_, e) => {
32160304
                write!(f, "product({e})")
            }
1733504
            Expression::Min(_, e) => {
1733504
                write!(f, "min({e})")
            }
1426684
            Expression::Max(_, e) => {
1426684
                write!(f, "max({e})")
            }
1930004
            Expression::Not(_, expr_box) => {
1930004
                write!(f, "!({})", expr_box.clone())
            }
9388790
            Expression::Or(_, e) => {
9388790
                write!(f, "or({e})")
            }
17255732
            Expression::And(_, e) => {
17255732
                write!(f, "and({e})")
            }
3958056
            Expression::Imply(_, box1, box2) => {
3958056
                write!(f, "({box1}) -> ({box2})")
            }
119016
            Expression::Iff(_, box1, box2) => {
119016
                write!(f, "({box1}) <-> ({box2})")
            }
64998487
            Expression::Eq(_, box1, box2) => {
64998487
                write!(f, "({} = {})", box1.clone(), box2.clone())
            }
17408038
            Expression::Neq(_, box1, box2) => {
17408038
                write!(f, "({} != {})", box1.clone(), box2.clone())
            }
9258168
            Expression::Geq(_, box1, box2) => {
9258168
                write!(f, "({} >= {})", box1.clone(), box2.clone())
            }
20126956
            Expression::Leq(_, box1, box2) => {
20126956
                write!(f, "({} <= {})", box1.clone(), box2.clone())
            }
628728
            Expression::Gt(_, box1, box2) => {
628728
                write!(f, "({} > {})", box1.clone(), box2.clone())
            }
5614292
            Expression::Lt(_, box1, box2) => {
5614292
                write!(f, "({} < {})", box1.clone(), box2.clone())
            }
19399562
            Expression::FlatSumGeq(_, box1, box2) => {
19399562
                write!(f, "SumGeq({}, {})", pretty_vec(box1), box2.clone())
            }
19228810
            Expression::FlatSumLeq(_, box1, box2) => {
19228810
                write!(f, "SumLeq({}, {})", pretty_vec(box1), box2.clone())
            }
11786232
            Expression::FlatIneq(_, box1, box2, box3) => write!(
11786232
                f,
                "Ineq({}, {}, {})",
11786232
                box1.clone(),
11786232
                box2.clone(),
11786232
                box3.clone()
            ),
648904
            Expression::Flatten(_, n, m) => {
648904
                if let Some(n) = n {
                    write!(f, "flatten({n}, {m})")
                } else {
648904
                    write!(f, "flatten({m})")
                }
            }
4286606
            Expression::AllDiff(_, e) => {
4286606
                write!(f, "allDiff({e})")
            }
135024
            Expression::Table(_, tuple_expr, rows_expr) => {
135024
                write!(f, "table({tuple_expr}, {rows_expr})")
            }
22504
            Expression::NegativeTable(_, tuple_expr, rows_expr) => {
22504
                write!(f, "negativeTable({tuple_expr}, {rows_expr})")
            }
823076
            Expression::Bubble(_, box1, box2) => {
823076
                write!(f, "{{{} @ {}}}", box1.clone(), box2.clone())
            }
3300896
            Expression::SafeDiv(_, box1, box2) => {
3300896
                write!(f, "SafeDiv({}, {})", box1.clone(), box2.clone())
            }
2222618
            Expression::UnsafeDiv(_, box1, box2) => {
2222618
                write!(f, "UnsafeDiv({}, {})", box1.clone(), box2.clone())
            }
477558
            Expression::UnsafePow(_, box1, box2) => {
477558
                write!(f, "UnsafePow({}, {})", box1.clone(), box2.clone())
            }
1674634
            Expression::SafePow(_, box1, box2) => {
1674634
                write!(f, "SafePow({}, {})", box1.clone(), box2.clone())
            }
870116
            Expression::MinionDivEqUndefZero(_, box1, box2, box3) => {
870116
                write!(
870116
                    f,
                    "DivEq({}, {}, {})",
870116
                    box1.clone(),
870116
                    box2.clone(),
870116
                    box3.clone()
                )
            }
330020
            Expression::MinionModuloEqUndefZero(_, box1, box2, box3) => {
330020
                write!(
330020
                    f,
                    "ModEq({}, {}, {})",
330020
                    box1.clone(),
330020
                    box2.clone(),
330020
                    box3.clone()
                )
            }
953752
            Expression::FlatWatchedLiteral(_, x, l) => {
953752
                write!(f, "WatchedLiteral({x},{l})")
            }
7996018
            Expression::MinionReify(_, box1, box2) => {
7996018
                write!(f, "Reify({}, {})", box1.clone(), box2.clone())
            }
3269262
            Expression::MinionReifyImply(_, box1, box2) => {
3269262
                write!(f, "ReifyImply({}, {})", box1.clone(), box2.clone())
            }
209438
            Expression::MinionWInIntervalSet(_, atom, intervals) => {
209438
                let intervals = intervals.iter().join(",");
209438
                write!(f, "__minion_w_inintervalset({atom},[{intervals}])")
            }
89088
            Expression::MinionWInSet(_, atom, values) => {
89088
                let values = values.iter().join(",");
89088
                write!(f, "__minion_w_inset({atom},[{values}])")
            }
8901212
            Expression::AuxDeclaration(_, reference, e) => {
8901212
                write!(f, "{} =aux {}", reference, e.clone())
            }
334138
            Expression::UnsafeMod(_, a, b) => {
334138
                write!(f, "{} % {}", a.clone(), b.clone())
            }
1004560
            Expression::SafeMod(_, a, b) => {
1004560
                write!(f, "SafeMod({},{})", a.clone(), b.clone())
            }
8195206
            Expression::Neg(_, a) => {
8195206
                write!(f, "-({})", a.clone())
            }
14344244
            Expression::Minus(_, a, b) => {
14344244
                write!(f, "({} - {})", a.clone(), b.clone())
            }
2834054
            Expression::FlatAllDiff(_, es) => {
2834054
                write!(f, "__flat_alldiff({})", pretty_vec(es))
            }
190414
            Expression::FlatAbsEq(_, a, b) => {
190414
                write!(f, "AbsEq({},{})", a.clone(), b.clone())
            }
400954
            Expression::FlatMinusEq(_, a, b) => {
400954
                write!(f, "MinusEq({},{})", a.clone(), b.clone())
            }
300498
            Expression::FlatProductEq(_, a, b, c) => {
300498
                write!(
300498
                    f,
                    "FlatProductEq({},{},{})",
300498
                    a.clone(),
300498
                    b.clone(),
300498
                    c.clone()
                )
            }
3869702
            Expression::FlatWeightedSumLeq(_, cs, vs, total) => {
3869702
                write!(
3869702
                    f,
                    "FlatWeightedSumLeq({},{},{})",
3869702
                    pretty_vec(cs),
3869702
                    pretty_vec(vs),
3869702
                    total.clone()
                )
            }
3985470
            Expression::FlatWeightedSumGeq(_, cs, vs, total) => {
3985470
                write!(
3985470
                    f,
                    "FlatWeightedSumGeq({},{},{})",
3985470
                    pretty_vec(cs),
3985470
                    pretty_vec(vs),
3985470
                    total.clone()
                )
            }
417404
            Expression::MinionPow(_, atom, atom1, atom2) => {
417404
                write!(f, "MinionPow({atom},{atom1},{atom2})")
            }
4352812
            Expression::MinionElementOne(_, atoms, atom, atom1) => {
4352812
                let atoms = atoms.iter().join(",");
4352812
                write!(f, "__minion_element_one([{atoms}],{atom},{atom1})")
            }
284664
            Expression::ToInt(_, expr) => {
284664
                write!(f, "toInt({expr})")
            }
55901240
            Expression::SATInt(_, encoding, bits, (min, max)) => {
55901240
                write!(f, "SATInt({encoding:?}, {bits} [{min}, {max}])")
            }
            Expression::PairwiseSum(_, a, b) => write!(f, "PairwiseSum({a}, {b})"),
            Expression::PairwiseProduct(_, a, b) => write!(f, "PairwiseProduct({a}, {b})"),
116
            Expression::Defined(_, function) => write!(f, "defined({function})"),
116
            Expression::Range(_, function) => write!(f, "range({function})"),
116
            Expression::Image(_, function, elems) => write!(f, "image({function},{elems})"),
116
            Expression::ImageSet(_, function, elems) => write!(f, "imageSet({function},{elems})"),
116
            Expression::PreImage(_, function, elems) => write!(f, "preImage({function},{elems})"),
116
            Expression::Inverse(_, a, b) => write!(f, "inverse({a},{b})"),
97
            Expression::Restrict(_, function, domain) => write!(f, "restrict({function},{domain})"),
88624
            Expression::LexLt(_, a, b) => write!(f, "({a} <lex {b})"),
5814674
            Expression::LexLeq(_, a, b) => write!(f, "({a} <=lex {b})"),
            Expression::LexGt(_, a, b) => write!(f, "({a} >lex {b})"),
            Expression::LexGeq(_, a, b) => write!(f, "({a} >=lex {b})"),
22272
            Expression::FlatLexLt(_, a, b) => {
22272
                write!(f, "FlatLexLt({}, {})", pretty_vec(a), pretty_vec(b))
            }
44544
            Expression::FlatLexLeq(_, a, b) => {
44544
                write!(f, "FlatLexLeq({}, {})", pretty_vec(a), pretty_vec(b))
            }
        }
2448512884
    }
}
impl Typeable for Expression {
711751
    fn return_type(&self) -> ReturnType {
711751
        match self {
            Expression::Union(_, subject, _) => ReturnType::Set(Box::new(subject.return_type())),
            Expression::Intersect(_, subject, _) => {
                ReturnType::Set(Box::new(subject.return_type()))
            }
390
            Expression::In(_, _, _) => ReturnType::Bool,
            Expression::Supset(_, _, _) => ReturnType::Bool,
            Expression::SupsetEq(_, _, _) => ReturnType::Bool,
            Expression::Subset(_, _, _) => ReturnType::Bool,
            Expression::SubsetEq(_, _, _) => ReturnType::Bool,
24540
            Expression::AbstractLiteral(_, lit) => lit.return_type(),
68409
            Expression::UnsafeIndex(_, subject, idx) | Expression::SafeIndex(_, subject, idx) => {
120567
                let subject_ty = subject.return_type();
120567
                match subject_ty {
                    ReturnType::Matrix(_) => {
                        // For n-dimensional matrices, unwrap the element type until
                        // we either get to the innermost element type or the last index
117525
                        let mut elem_typ = subject_ty;
117525
                        let mut idx_len = idx.len();
235284
                        while idx_len > 0
139989
                            && let ReturnType::Matrix(new_elem_typ) = &elem_typ
117759
                        {
117759
                            elem_typ = *new_elem_typ.clone();
117759
                            idx_len -= 1;
117759
                        }
117525
                        elem_typ
                    }
                    // TODO: We can implement indexing for these eventually
3042
                    ReturnType::Record(_) | ReturnType::Tuple(_) => ReturnType::Unknown,
                    _ => bug!(
                        "Invalid indexing operation: expected the operand to be a collection, got {self}: {subject_ty}"
                    ),
                }
            }
1872
            Expression::UnsafeSlice(_, subject, _) | Expression::SafeSlice(_, subject, _) => {
1872
                ReturnType::Matrix(Box::new(subject.return_type()))
            }
            Expression::InDomain(_, _, _) => ReturnType::Bool,
            Expression::Comprehension(_, comp) => comp.return_type(),
            Expression::AbstractComprehension(_, comp) => comp.return_type(),
            Expression::Root(_, _) => ReturnType::Bool,
            Expression::DominanceRelation(_, _) => ReturnType::Bool,
            Expression::FromSolution(_, expr) => expr.return_type(),
            Expression::Metavar(_, _) => ReturnType::Unknown,
426220
            Expression::Atomic(_, atom) => atom.return_type(),
429
            Expression::Abs(_, _) => ReturnType::Int,
81009
            Expression::Sum(_, _) => ReturnType::Int,
6552
            Expression::Product(_, _) => ReturnType::Int,
858
            Expression::Min(_, _) => ReturnType::Int,
819
            Expression::Max(_, _) => ReturnType::Int,
117
            Expression::Not(_, _) => ReturnType::Bool,
366
            Expression::Or(_, _) => ReturnType::Bool,
870
            Expression::Imply(_, _, _) => ReturnType::Bool,
            Expression::Iff(_, _, _) => ReturnType::Bool,
2385
            Expression::And(_, _) => ReturnType::Bool,
6240
            Expression::Eq(_, _, _) => ReturnType::Bool,
390
            Expression::Neq(_, _, _) => ReturnType::Bool,
78
            Expression::Geq(_, _, _) => ReturnType::Bool,
1872
            Expression::Leq(_, _, _) => ReturnType::Bool,
78
            Expression::Gt(_, _, _) => ReturnType::Bool,
            Expression::Lt(_, _, _) => ReturnType::Bool,
7371
            Expression::SafeDiv(_, _, _) => ReturnType::Int,
5148
            Expression::UnsafeDiv(_, _, _) => ReturnType::Int,
            Expression::FlatAllDiff(_, _) => ReturnType::Bool,
39
            Expression::FlatSumGeq(_, _, _) => ReturnType::Bool,
            Expression::FlatSumLeq(_, _, _) => ReturnType::Bool,
            Expression::MinionDivEqUndefZero(_, _, _, _) => ReturnType::Bool,
78
            Expression::FlatIneq(_, _, _, _) => ReturnType::Bool,
936
            Expression::Flatten(_, _, matrix) => {
936
                let matrix_type = matrix.return_type();
936
                match matrix_type {
                    ReturnType::Matrix(_) => {
                        // unwrap until we get to innermost element
936
                        let mut elem_type = matrix_type;
1872
                        while let ReturnType::Matrix(new_elem_type) = &elem_type {
936
                            elem_type = *new_elem_type.clone();
936
                        }
936
                        ReturnType::Matrix(Box::new(elem_type))
                    }
                    _ => bug!(
                        "Invalid indexing operation: expected the operand to be a collection, got {self}: {matrix_type}"
                    ),
                }
            }
195
            Expression::AllDiff(_, _) => ReturnType::Bool,
            Expression::Table(_, _, _) => ReturnType::Bool,
            Expression::NegativeTable(_, _, _) => ReturnType::Bool,
2028
            Expression::Bubble(_, inner, _) => inner.return_type(),
            Expression::FlatWatchedLiteral(_, _, _) => ReturnType::Bool,
            Expression::MinionReify(_, _, _) => ReturnType::Bool,
39
            Expression::MinionReifyImply(_, _, _) => ReturnType::Bool,
            Expression::MinionWInIntervalSet(_, _, _) => ReturnType::Bool,
            Expression::MinionWInSet(_, _, _) => ReturnType::Bool,
            Expression::MinionElementOne(_, _, _, _) => ReturnType::Bool,
            Expression::AuxDeclaration(_, _, _) => ReturnType::Bool,
819
            Expression::UnsafeMod(_, _, _) => ReturnType::Int,
3315
            Expression::SafeMod(_, _, _) => ReturnType::Int,
            Expression::MinionModuloEqUndefZero(_, _, _, _) => ReturnType::Bool,
3783
            Expression::Neg(_, _) => ReturnType::Int,
468
            Expression::UnsafePow(_, _, _) => ReturnType::Int,
2676
            Expression::SafePow(_, _, _) => ReturnType::Int,
741
            Expression::Minus(_, _, _) => ReturnType::Int,
            Expression::FlatAbsEq(_, _, _) => ReturnType::Bool,
            Expression::FlatMinusEq(_, _, _) => ReturnType::Bool,
            Expression::FlatProductEq(_, _, _, _) => ReturnType::Bool,
            Expression::FlatWeightedSumLeq(_, _, _, _) => ReturnType::Bool,
            Expression::FlatWeightedSumGeq(_, _, _, _) => ReturnType::Bool,
            Expression::MinionPow(_, _, _, _) => ReturnType::Bool,
156
            Expression::ToInt(_, _) => ReturnType::Int,
6513
            Expression::SATInt(..) => ReturnType::Int,
            Expression::PairwiseSum(_, _, _) => ReturnType::Int,
            Expression::PairwiseProduct(_, _, _) => ReturnType::Int,
            Expression::Defined(_, function) => {
                let subject = function.return_type();
                match subject {
                    ReturnType::Function(domain, _) => *domain,
                    _ => bug!(
                        "Invalid defined operation: expected the operand to be a function, got {self}: {subject}"
                    ),
                }
            }
            Expression::Range(_, function) => {
                let subject = function.return_type();
                match subject {
                    ReturnType::Function(_, codomain) => *codomain,
                    _ => bug!(
                        "Invalid range operation: expected the operand to be a function, got {self}: {subject}"
                    ),
                }
            }
            Expression::Image(_, function, _) => {
                let subject = function.return_type();
                match subject {
                    ReturnType::Function(_, codomain) => *codomain,
                    _ => bug!(
                        "Invalid image operation: expected the operand to be a function, got {self}: {subject}"
                    ),
                }
            }
            Expression::ImageSet(_, function, _) => {
                let subject = function.return_type();
                match subject {
                    ReturnType::Function(_, codomain) => *codomain,
                    _ => bug!(
                        "Invalid imageSet operation: expected the operand to be a function, got {self}: {subject}"
                    ),
                }
            }
            Expression::PreImage(_, function, _) => {
                let subject = function.return_type();
                match subject {
                    ReturnType::Function(domain, _) => *domain,
                    _ => bug!(
                        "Invalid preImage operation: expected the operand to be a function, got {self}: {subject}"
                    ),
                }
            }
            Expression::Restrict(_, function, new_domain) => {
                let subject = function.return_type();
                match subject {
                    ReturnType::Function(_, codomain) => {
                        ReturnType::Function(Box::new(new_domain.return_type()), codomain)
                    }
                    _ => bug!(
                        "Invalid preImage operation: expected the operand to be a function, got {self}: {subject}"
                    ),
                }
            }
            Expression::Inverse(..) => ReturnType::Bool,
            Expression::LexLt(..) => ReturnType::Bool,
            Expression::LexGt(..) => ReturnType::Bool,
1794
            Expression::LexLeq(..) => ReturnType::Bool,
            Expression::LexGeq(..) => ReturnType::Bool,
            Expression::FlatLexLt(..) => ReturnType::Bool,
            Expression::FlatLexLeq(..) => ReturnType::Bool,
        }
711751
    }
}
#[cfg(test)]
mod tests {
    use crate::matrix_expr;
    use super::*;
    #[test]
3
    fn test_domain_of_constant_sum() {
3
        let c1 = Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Int(1)));
3
        let c2 = Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Int(2)));
3
        let sum = Expression::Sum(Metadata::new(), Moo::new(matrix_expr![c1, c2]));
3
        assert_eq!(sum.domain_of(), Some(Domain::int(vec![Range::Single(3)])));
3
    }
    #[test]
3
    fn test_domain_of_constant_invalid_type() {
3
        let c1 = Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Int(1)));
3
        let c2 = Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Bool(true)));
3
        let sum = Expression::Sum(Metadata::new(), Moo::new(matrix_expr![c1, c2]));
3
        assert_eq!(sum.domain_of(), None);
3
    }
    #[test]
3
    fn test_domain_of_empty_sum() {
3
        let sum = Expression::Sum(Metadata::new(), Moo::new(matrix_expr![]));
3
        assert_eq!(sum.domain_of(), None);
3
    }
}