1
use std::collections::{HashSet, VecDeque};
2
use std::fmt::{Display, Formatter};
3
use std::hash::{DefaultHasher, Hash, Hasher};
4
use std::sync::atomic::{AtomicU64, Ordering};
5

            
6
static HASH_HITS: AtomicU64 = AtomicU64::new(0);
7
static HASH_MISSES: AtomicU64 = AtomicU64::new(0);
8

            
9
pub fn print_hash_stats() {
10
    println!(
11
        "Expression hash stats: hits={}, misses={}",
12
        HASH_HITS.load(Ordering::Relaxed),
13
        HASH_MISSES.load(Ordering::Relaxed)
14
    );
15
}
16
use tracing::trace;
17

            
18
use conjure_cp_enum_compatibility_macro::{document_compatibility, generate_discriminants};
19
use itertools::Itertools;
20
use serde::{Deserialize, Serialize};
21
use tree_morph::cache::CacheHashable;
22
use ustr::Ustr;
23

            
24
use polyquine::Quine;
25
use uniplate::{Biplate, Uniplate};
26

            
27
use crate::ast::metadata::NO_HASH;
28
use crate::bug;
29

            
30
use super::abstract_comprehension::AbstractComprehension;
31
use super::ac_operators::ACOperatorKind;
32
use super::categories::{Category, CategoryOf};
33
use super::comprehension::Comprehension;
34
use super::domains::HasDomain as _;
35
use super::pretty::{pretty_expressions_as_top_level, pretty_vec};
36
use super::records::RecordValue;
37
use super::sat_encoding::SATIntEncoding;
38
use super::{
39
    AbstractLiteral, Atom, DeclarationPtr, Domain, DomainPtr, GroundDomain, IntVal, Literal,
40
    Metadata, Model, Moo, Name, Range, Reference, ReturnType, SetAttr, SymbolTable, SymbolTablePtr,
41
    Typeable, UnresolvedDomain, matrix,
42
};
43

            
44
// Ensure that this type doesn't get too big
45
//
46
// If you triggered this assertion, you either made a variant of this enum that is too big, or you
47
// made Name,Literal,AbstractLiteral,Atom bigger, which made this bigger! To fix this, put some
48
// stuff in boxes.
49
//
50
// Enums take the size of their largest variant, so an enum with mostly small variants and a few
51
// large ones wastes memory... A larger Expression type also slows down Oxide.
52
//
53
// For more information, and more details on type sizes and how to measure them, see the commit
54
// message for 6012de809 (perf: reduce size of AST types, 2025-06-18).
55
//
56
// You can also see type sizes in the rustdoc documentation, generated by ./tools/gen_docs.sh
57
//
58
// https://github.com/conjure-cp/conjure-oxide/commit/6012de8096ca491ded91ecec61352fdf4e994f2e
59

            
60
// TODO: box all usages of Metadata to bring this down a bit more - I have added variants to
61
// ReturnType, and Metadata contains ReturnType, so Metadata has got bigger. Metadata will get a
62
// lot bigger still when we start using it for memoisation, so it should really be
63
// boxed ~niklasdewally
64

            
65
// expect size of Expression to be 112 bytes
66
static_assertions::assert_eq_size!([u8; 112], Expression);
67

            
68
/// Represents different types of expressions used to define rules and constraints in the model.
69
///
70
/// The `Expression` enum includes operations, constants, and variable references
71
/// used to build rules and conditions for the model.
72
#[generate_discriminants]
73
#[document_compatibility]
74
#[derive(Clone, Debug, Hash, PartialEq, Eq, Serialize, Deserialize, Uniplate, Quine)]
75
#[biplate(to=AbstractComprehension)]
76
#[biplate(to=AbstractLiteral<Expression>)]
77
#[biplate(to=AbstractLiteral<Literal>)]
78
#[biplate(to=Atom)]
79
#[biplate(to=Comprehension)]
80
#[biplate(to=DeclarationPtr)]
81
#[biplate(to=DomainPtr)]
82
#[biplate(to=Literal)]
83
#[biplate(to=Metadata)]
84
#[biplate(to=Name)]
85
#[biplate(to=Option<Expression>)]
86
#[biplate(to=RecordValue<Expression>)]
87
#[biplate(to=RecordValue<Literal>)]
88
#[biplate(to=Reference)]
89
#[biplate(to=Model)]
90
#[biplate(to=SymbolTable)]
91
#[biplate(to=SymbolTablePtr)]
92
#[biplate(to=Vec<Expression>)]
93
#[path_prefix(conjure_cp::ast)]
94
pub enum Expression {
95
    AbstractLiteral(Metadata, AbstractLiteral<Expression>),
96
    /// The top of the model
97
    Root(Metadata, Vec<Expression>),
98

            
99
    /// An expression representing "A is valid as long as B is true"
100
    /// Turns into a conjunction when it reaches a boolean context
101
    Bubble(Metadata, Moo<Expression>, Moo<Expression>),
102

            
103
    /// A comprehension.
104
    ///
105
    /// The inside of the comprehension opens a new scope.
106
    // todo (gskorokhod): Comprehension contains a symbol table which contains a bunch of pointers.
107
    // This makes implementing Quine tricky (it doesnt support Rc, by design). Skip it for now.
108
    #[polyquine_skip]
109
    Comprehension(Metadata, Moo<Comprehension>),
110

            
111
    /// Higher-level abstract comprehension
112
    #[polyquine_skip] // no idea what this is lol but it stops rustc screaming at me
113
    AbstractComprehension(Metadata, Moo<AbstractComprehension>),
114

            
115
    /// Defines dominance ("Solution A is preferred over Solution B")
116
    DominanceRelation(Metadata, Moo<Expression>),
117
    /// `fromSolution(name)` - Used in dominance relation definitions
118
    FromSolution(Metadata, Moo<Atom>),
119

            
120
    #[polyquine_with(arm = (_, name) => {
121
        let ident = proc_macro2::Ident::new(name.as_str(), proc_macro2::Span::call_site());
122
        quote::quote! { #ident.clone().into() }
123
    })]
124
    Metavar(Metadata, Ustr),
125

            
126
    Atomic(Metadata, Atom),
127

            
128
    /// A matrix index.
129
    ///
130
    /// Defined iff the indices are within their respective index domains.
131
    #[compatible(JsonInput)]
132
    UnsafeIndex(Metadata, Moo<Expression>, Vec<Expression>),
133

            
134
    /// A safe matrix index.
135
    ///
136
    /// See [`Expression::UnsafeIndex`]
137
    #[compatible(SMT)]
138
    SafeIndex(Metadata, Moo<Expression>, Vec<Expression>),
139

            
140
    /// A matrix slice: `a[indices]`.
141
    ///
142
    /// One of the indicies may be `None`, representing the dimension of the matrix we want to take
143
    /// a slice of. For example, for some 3d matrix a, `a[1,..,2]` has the indices
144
    /// `Some(1),None,Some(2)`.
145
    ///
146
    /// It is assumed that the slice only has one "wild-card" dimension and thus is 1 dimensional.
147
    ///
148
    /// Defined iff the defined indices are within their respective index domains.
149
    #[compatible(JsonInput)]
150
    UnsafeSlice(Metadata, Moo<Expression>, Vec<Option<Expression>>),
151

            
152
    /// A safe matrix slice: `a[indices]`.
153
    ///
154
    /// See [`Expression::UnsafeSlice`].
155
    SafeSlice(Metadata, Moo<Expression>, Vec<Option<Expression>>),
156

            
157
    /// `inDomain(x,domain)` iff `x` is in the domain `domain`.
158
    ///
159
    /// This cannot be constructed from Essence input, nor passed to a solver: this expression is
160
    /// mainly used during the conversion of `UnsafeIndex` and `UnsafeSlice` to `SafeIndex` and
161
    /// `SafeSlice` respectively.
162
    InDomain(Metadata, Moo<Expression>, DomainPtr),
163

            
164
    /// `toInt(b)` casts boolean expression b to an integer.
165
    ///
166
    /// - If b is false, then `toInt(b) == 0`
167
    ///
168
    /// - If b is true, then `toInt(b) == 1`
169
    #[compatible(SMT)]
170
    ToInt(Metadata, Moo<Expression>),
171

            
172
    /// `|x|` - absolute value of `x`
173
    #[compatible(JsonInput, SMT)]
174
    Abs(Metadata, Moo<Expression>),
175

            
176
    /// `sum(<vec_expr>)`
177
    #[compatible(JsonInput, SMT)]
178
    Sum(Metadata, Moo<Expression>),
179

            
180
    /// `a * b * c * ...`
181
    #[compatible(JsonInput, SMT)]
182
    Product(Metadata, Moo<Expression>),
183

            
184
    /// `min(<vec_expr>)`
185
    #[compatible(JsonInput, SMT)]
186
    Min(Metadata, Moo<Expression>),
187

            
188
    /// `max(<vec_expr>)`
189
    #[compatible(JsonInput, SMT)]
190
    Max(Metadata, Moo<Expression>),
191

            
192
    /// `not(a)`
193
    #[compatible(JsonInput, SAT, SMT)]
194
    Not(Metadata, Moo<Expression>),
195

            
196
    /// `or(<vec_expr>)`
197
    #[compatible(JsonInput, SAT, SMT)]
198
    Or(Metadata, Moo<Expression>),
199

            
200
    /// `and(<vec_expr>)`
201
    #[compatible(JsonInput, SAT, SMT)]
202
    And(Metadata, Moo<Expression>),
203

            
204
    /// Ensures that `a->b` (material implication).
205
    #[compatible(JsonInput, SMT)]
206
    Imply(Metadata, Moo<Expression>, Moo<Expression>),
207

            
208
    /// `iff(a, b)` a <-> b
209
    #[compatible(JsonInput, SMT)]
210
    Iff(Metadata, Moo<Expression>, Moo<Expression>),
211

            
212
    #[compatible(JsonInput)]
213
    Union(Metadata, Moo<Expression>, Moo<Expression>),
214

            
215
    #[compatible(JsonInput)]
216
    In(Metadata, Moo<Expression>, Moo<Expression>),
217

            
218
    #[compatible(JsonInput)]
219
    Intersect(Metadata, Moo<Expression>, Moo<Expression>),
220

            
221
    #[compatible(JsonInput)]
222
    Supset(Metadata, Moo<Expression>, Moo<Expression>),
223

            
224
    #[compatible(JsonInput)]
225
    SupsetEq(Metadata, Moo<Expression>, Moo<Expression>),
226

            
227
    #[compatible(JsonInput)]
228
    Subset(Metadata, Moo<Expression>, Moo<Expression>),
229

            
230
    #[compatible(JsonInput)]
231
    SubsetEq(Metadata, Moo<Expression>, Moo<Expression>),
232

            
233
    #[compatible(JsonInput, SMT)]
234
    Eq(Metadata, Moo<Expression>, Moo<Expression>),
235

            
236
    #[compatible(JsonInput, SMT)]
237
    Neq(Metadata, Moo<Expression>, Moo<Expression>),
238

            
239
    #[compatible(JsonInput, SMT)]
240
    Geq(Metadata, Moo<Expression>, Moo<Expression>),
241

            
242
    #[compatible(JsonInput, SMT)]
243
    Leq(Metadata, Moo<Expression>, Moo<Expression>),
244

            
245
    #[compatible(JsonInput, SMT)]
246
    Gt(Metadata, Moo<Expression>, Moo<Expression>),
247

            
248
    #[compatible(JsonInput, SMT)]
249
    Lt(Metadata, Moo<Expression>, Moo<Expression>),
250

            
251
    /// Division after preventing division by zero, usually with a bubble
252
    #[compatible(SMT)]
253
    SafeDiv(Metadata, Moo<Expression>, Moo<Expression>),
254

            
255
    /// Division with a possibly undefined value (division by 0)
256
    #[compatible(JsonInput)]
257
    UnsafeDiv(Metadata, Moo<Expression>, Moo<Expression>),
258

            
259
    /// Modulo after preventing mod 0, usually with a bubble
260
    #[compatible(SMT)]
261
    SafeMod(Metadata, Moo<Expression>, Moo<Expression>),
262

            
263
    /// Modulo with a possibly undefined value (mod 0)
264
    #[compatible(JsonInput)]
265
    UnsafeMod(Metadata, Moo<Expression>, Moo<Expression>),
266

            
267
    /// Negation: `-x`
268
    #[compatible(JsonInput, SMT)]
269
    Neg(Metadata, Moo<Expression>),
270

            
271
    /// Factorial: `x!` or 'factorial(x)`
272
    #[compatible(JsonInput)]
273
    Factorial(Metadata, Moo<Expression>),
274

            
275
    /// Set of domain values function is defined for
276
    #[compatible(JsonInput)]
277
    Defined(Metadata, Moo<Expression>),
278

            
279
    /// Set of codomain values function is defined for
280
    #[compatible(JsonInput)]
281
    Range(Metadata, Moo<Expression>),
282

            
283
    /// Unsafe power`x**y` (possibly undefined)
284
    ///
285
    /// Defined when (X!=0 \\/ Y!=0) /\ Y>=0
286
    #[compatible(JsonInput)]
287
    UnsafePow(Metadata, Moo<Expression>, Moo<Expression>),
288

            
289
    /// `UnsafePow` after preventing undefinedness
290
    SafePow(Metadata, Moo<Expression>, Moo<Expression>),
291

            
292
    /// Flatten matrix operator
293
    /// `flatten(M)` or `flatten(n, M)`
294
    /// where M is a matrix and n is an optional integer argument indicating depth of flattening
295
    Flatten(Metadata, Option<Moo<Expression>>, Moo<Expression>),
296

            
297
    /// `allDiff(<vec_expr>)`
298
    #[compatible(JsonInput)]
299
    AllDiff(Metadata, Moo<Expression>),
300

            
301
    /// `table([x1, x2, ...], [[r11, r12, ...], [r21, r22, ...], ...])`
302
    ///
303
    /// Represents a positive table constraint: the tuple `[x1, x2, ...]` must match one of the
304
    /// allowed rows.
305
    #[compatible(JsonInput)]
306
    Table(Metadata, Moo<Expression>, Moo<Expression>),
307

            
308
    /// `negativeTable([x1, x2, ...], [[r11, r12, ...], [r21, r22, ...], ...])`
309
    ///
310
    /// Represents a negative table constraint: the tuple `[x1, x2, ...]` must NOT match any of the
311
    /// forbidden rows.
312
    #[compatible(JsonInput)]
313
    NegativeTable(Metadata, Moo<Expression>, Moo<Expression>),
314
    /// Binary subtraction operator
315
    ///
316
    /// This is a parser-level construct, and is immediately normalised to `Sum([a,-b])`.
317
    /// TODO: make this compatible with Set Difference calculations - need to change return type and domain for this expression and write a set comprehension rule.
318
    /// have already edited minus_to_sum to prevent this from applying to sets
319
    #[compatible(JsonInput)]
320
    Minus(Metadata, Moo<Expression>, Moo<Expression>),
321

            
322
    /// Ensures that x=|y| i.e. x is the absolute value of y.
323
    ///
324
    /// Low-level Minion constraint.
325
    ///
326
    /// # See also
327
    ///
328
    /// + [Minion documentation](https://minion-solver.readthedocs.io/en/stable/usage/constraints.html#abs)
329
    #[compatible(Minion)]
330
    FlatAbsEq(Metadata, Moo<Atom>, Moo<Atom>),
331

            
332
    /// Ensures that `alldiff([a,b,...])`.
333
    ///
334
    /// Low-level Minion constraint.
335
    ///
336
    /// # See also
337
    ///
338
    /// + [Minion documentation](https://minion-solver.readthedocs.io/en/stable/usage/constraints.html#alldiff)
339
    #[compatible(Minion)]
340
    FlatAllDiff(Metadata, Vec<Atom>),
341

            
342
    /// Ensures that sum(vec) >= x.
343
    ///
344
    /// Low-level Minion constraint.
345
    ///
346
    /// # See also
347
    ///
348
    /// + [Minion documentation](https://minion-solver.readthedocs.io/en/stable/usage/constraints.html#sumgeq)
349
    #[compatible(Minion)]
350
    FlatSumGeq(Metadata, Vec<Atom>, Atom),
351

            
352
    /// Ensures that sum(vec) <= x.
353
    ///
354
    /// Low-level Minion constraint.
355
    ///
356
    /// # See also
357
    ///
358
    /// + [Minion documentation](https://minion-solver.readthedocs.io/en/stable/usage/constraints.html#sumleq)
359
    #[compatible(Minion)]
360
    FlatSumLeq(Metadata, Vec<Atom>, Atom),
361

            
362
    /// `ineq(x,y,k)` ensures that x <= y + k.
363
    ///
364
    /// Low-level Minion constraint.
365
    ///
366
    /// # See also
367
    ///
368
    /// + [Minion documentation](https://minion-solver.readthedocs.io/en/stable/usage/constraints.html#ineq)
369
    #[compatible(Minion)]
370
    FlatIneq(Metadata, Moo<Atom>, Moo<Atom>, Box<Literal>),
371

            
372
    /// `w-literal(x,k)` ensures that x == k, where x is a variable and k a constant.
373
    ///
374
    /// Low-level Minion constraint.
375
    ///
376
    /// This is a low-level Minion constraint and you should probably use Eq instead. The main use
377
    /// of w-literal is to convert boolean variables to constraints so that they can be used inside
378
    /// watched-and and watched-or.
379
    ///
380
    /// # See also
381
    ///
382
    /// + [Minion documentation](https://minion-solver.readthedocs.io/en/stable/usage/constraints.html#minuseq)
383
    /// + `rules::minion::boolean_literal_to_wliteral`.
384
    #[compatible(Minion)]
385
    #[polyquine_skip]
386
    FlatWatchedLiteral(Metadata, Reference, Literal),
387

            
388
    /// `weightedsumleq(cs,xs,total)` ensures that cs.xs <= total, where cs.xs is the scalar dot
389
    /// product of cs and xs.
390
    ///
391
    /// Low-level Minion constraint.
392
    ///
393
    /// Represents a weighted sum of the form `ax + by + cz + ...`
394
    ///
395
    /// # See also
396
    ///
397
    /// + [Minion
398
    /// documentation](https://minion-solver.readthedocs.io/en/stable/usage/constraints.html#weightedsumleq)
399
    FlatWeightedSumLeq(Metadata, Vec<Literal>, Vec<Atom>, Moo<Atom>),
400

            
401
    /// `weightedsumgeq(cs,xs,total)` ensures that cs.xs >= total, where cs.xs is the scalar dot
402
    /// product of cs and xs.
403
    ///
404
    /// Low-level Minion constraint.
405
    ///
406
    /// Represents a weighted sum of the form `ax + by + cz + ...`
407
    ///
408
    /// # See also
409
    ///
410
    /// + [Minion
411
    /// documentation](https://minion-solver.readthedocs.io/en/stable/usage/constraints.html#weightedsumleq)
412
    FlatWeightedSumGeq(Metadata, Vec<Literal>, Vec<Atom>, Moo<Atom>),
413

            
414
    /// Ensures that x =-y, where x and y are atoms.
415
    ///
416
    /// Low-level Minion constraint.
417
    ///
418
    /// # See also
419
    ///
420
    /// + [Minion documentation](https://minion-solver.readthedocs.io/en/stable/usage/constraints.html#minuseq)
421
    #[compatible(Minion)]
422
    FlatMinusEq(Metadata, Moo<Atom>, Moo<Atom>),
423

            
424
    /// Ensures that x*y=z.
425
    ///
426
    /// Low-level Minion constraint.
427
    ///
428
    /// # See also
429
    ///
430
    /// + [Minion documentation](https://minion-solver.readthedocs.io/en/stable/usage/constraints.html#product)
431
    #[compatible(Minion)]
432
    FlatProductEq(Metadata, Moo<Atom>, Moo<Atom>, Moo<Atom>),
433

            
434
    /// Ensures that floor(x/y)=z. Always true when y=0.
435
    ///
436
    /// Low-level Minion constraint.
437
    ///
438
    /// # See also
439
    ///
440
    /// + [Minion documentation](https://minion-solver.readthedocs.io/en/stable/usage/constraints.html#div_undefzero)
441
    #[compatible(Minion)]
442
    MinionDivEqUndefZero(Metadata, Moo<Atom>, Moo<Atom>, Moo<Atom>),
443

            
444
    /// Ensures that x%y=z. Always true when y=0.
445
    ///
446
    /// Low-level Minion constraint.
447
    ///
448
    /// # See also
449
    ///
450
    /// + [Minion documentation](https://minion-solver.readthedocs.io/en/stable/usage/constraints.html#mod_undefzero)
451
    #[compatible(Minion)]
452
    MinionModuloEqUndefZero(Metadata, Moo<Atom>, Moo<Atom>, Moo<Atom>),
453

            
454
    /// Ensures that `x**y = z`.
455
    ///
456
    /// Low-level Minion constraint.
457
    ///
458
    /// This constraint is false when `y<0` except for `1**y=1` and `(-1)**y=z` (where z is 1 if y
459
    /// is odd and z is -1 if y is even).
460
    ///
461
    /// # See also
462
    ///
463
    /// + [Github comment about `pow` semantics](https://github.com/minion/minion/issues/40#issuecomment-2595914891)
464
    /// + [Minion documentation](https://minion-solver.readthedocs.io/en/stable/usage/constraints.html#pow)
465
    MinionPow(Metadata, Moo<Atom>, Moo<Atom>, Moo<Atom>),
466

            
467
    /// `reify(constraint,r)` ensures that r=1 iff `constraint` is satisfied, where r is a 0/1
468
    /// variable.
469
    ///
470
    /// Low-level Minion constraint.
471
    ///
472
    /// # See also
473
    ///
474
    ///  + [Minion documentation](https://minion-solver.readthedocs.io/en/stable/usage/constraints.html#reify)
475
    #[compatible(Minion)]
476
    MinionReify(Metadata, Moo<Expression>, Atom),
477

            
478
    /// `reifyimply(constraint,r)` ensures that `r->constraint`, where r is a 0/1 variable.
479
    /// variable.
480
    ///
481
    /// Low-level Minion constraint.
482
    ///
483
    /// # See also
484
    ///
485
    ///  + [Minion documentation](https://minion-solver.readthedocs.io/en/stable/usage/constraints.html#reifyimply)
486
    #[compatible(Minion)]
487
    MinionReifyImply(Metadata, Moo<Expression>, Atom),
488

            
489
    /// `w-inintervalset(x, [a1,a2, b1,b2, … ])` ensures that the value of x belongs to one of the
490
    /// intervals {a1,…,a2}, {b1,…,b2} etc.
491
    ///
492
    /// The list of intervals must be given in numerical order.
493
    ///
494
    /// Low-level Minion constraint.
495
    ///
496
    /// # See also
497
    ///>
498
    ///  + [Minion documentation](https://minion-solver.readthedocs.io/en/stable/usage/constraints.html#w-inintervalset)
499
    #[compatible(Minion)]
500
    MinionWInIntervalSet(Metadata, Atom, Vec<i32>),
501

            
502
    /// `w-inset(x, [v1, v2, … ])` ensures that the value of `x` is one of the explicitly given values `v1`, `v2`, etc.
503
    ///
504
    /// This constraint enforces membership in a specific set of discrete values rather than intervals.
505
    ///
506
    /// The list of values must be given in numerical order.
507
    ///
508
    /// Low-level Minion constraint.
509
    ///
510
    /// # See also
511
    ///
512
    ///  + [Minion documentation](https://minion-solver.readthedocs.io/en/stable/usage/constraints.html#w-inset)
513
    #[compatible(Minion)]
514
    MinionWInSet(Metadata, Atom, Vec<i32>),
515

            
516
    /// `element_one(vec, i, e)` specifies that `vec[i] = e`. This implies that i is
517
    /// in the range `[1..len(vec)]`.
518
    ///
519
    /// Low-level Minion constraint.
520
    ///
521
    /// # See also
522
    ///
523
    ///  + [Minion documentation](https://minion-solver.readthedocs.io/en/stable/usage/constraints.html#element_one)
524
    #[compatible(Minion)]
525
    MinionElementOne(Metadata, Vec<Atom>, Moo<Atom>, Moo<Atom>),
526

            
527
    /// Declaration of an auxiliary variable.
528
    ///
529
    /// As with Savile Row, we semantically distinguish this from `Eq`.
530
    #[compatible(Minion)]
531
    #[polyquine_skip]
532
    AuxDeclaration(Metadata, Reference, Moo<Expression>),
533

            
534
    /// This expression is for encoding ints for the SAT solver, it stores the encoding type, the vector of booleans and the min/max for the int.
535
    #[compatible(SAT)]
536
    SATInt(Metadata, SATIntEncoding, Moo<Expression>, (i32, i32)),
537

            
538
    /// Addition over a pair of expressions (i.e. a + b) rather than a vec-expr like Expression::Sum.
539
    /// This is for compatibility with backends that do not support addition over vectors.
540
    #[compatible(SMT)]
541
    PairwiseSum(Metadata, Moo<Expression>, Moo<Expression>),
542

            
543
    /// Multiplication over a pair of expressions (i.e. a * b) rather than a vec-expr like Expression::Product.
544
    /// This is for compatibility with backends that do not support multiplication over vectors.
545
    #[compatible(SMT)]
546
    PairwiseProduct(Metadata, Moo<Expression>, Moo<Expression>),
547

            
548
    #[compatible(JsonInput)]
549
    Image(Metadata, Moo<Expression>, Moo<Expression>),
550

            
551
    #[compatible(JsonInput)]
552
    ImageSet(Metadata, Moo<Expression>, Moo<Expression>),
553

            
554
    #[compatible(JsonInput)]
555
    PreImage(Metadata, Moo<Expression>, Moo<Expression>),
556

            
557
    #[compatible(JsonInput)]
558
    Inverse(Metadata, Moo<Expression>, Moo<Expression>),
559

            
560
    #[compatible(JsonInput)]
561
    Restrict(Metadata, Moo<Expression>, Moo<Expression>),
562

            
563
    /// Lexicographical < between two matrices.
564
    ///
565
    /// A <lex B iff: A[i] < B[i] for some i /\ (A[j] > B[j] for some j -> i < j)
566
    /// I.e. A must be less than B at some index i, and if it is greater than B at another index j,
567
    /// then j comes after i.
568
    /// I.e. A must be greater than B at the first index where they differ.
569
    ///
570
    /// E.g. [1, 1] <lex [2, 1] and [1, 1] <lex [1, 2]
571
    LexLt(Metadata, Moo<Expression>, Moo<Expression>),
572

            
573
    /// Lexicographical <= between two matrices
574
    LexLeq(Metadata, Moo<Expression>, Moo<Expression>),
575

            
576
    /// Lexicographical > between two matrices
577
    /// This is a parser-level construct, and is immediately normalised to LexLt(b, a)
578
    LexGt(Metadata, Moo<Expression>, Moo<Expression>),
579

            
580
    /// Lexicographical >= between two matrices
581
    /// This is a parser-level construct, and is immediately normalised to LexLeq(b, a)
582
    LexGeq(Metadata, Moo<Expression>, Moo<Expression>),
583

            
584
    /// Low-level minion constraint. See Expression::LexLt
585
    FlatLexLt(Metadata, Vec<Atom>, Vec<Atom>),
586

            
587
    /// Low-level minion constraint. See Expression::LexLeq
588
    FlatLexLeq(Metadata, Vec<Atom>, Vec<Atom>),
589
}
590

            
591
// for the given matrix literal, return a bounded domain from the min to max of applying op to each
592
// child expression.
593
//
594
// Op must be monotonic.
595
//
596
// Returns none if unbounded
597
1824455
fn bounded_i32_domain_for_matrix_literal_monotonic(
598
1824455
    e: &Expression,
599
1824455
    op: fn(i32, i32) -> Option<i32>,
600
1824455
) -> Option<DomainPtr> {
601
    // only care about the elements, not the indices
602
1824455
    let (mut exprs, _) = e.clone().unwrap_matrix_unchecked()?;
603

            
604
    // fold each element's domain into one using op.
605
    //
606
    // here, I assume that op is monotone. This means that the bounds of op([a1,a2],[b1,b2])  for
607
    // the ranges [a1,a2], [b1,b2] will be
608
    // [min(op(a1,b1),op(a2,b1),op(a1,b2),op(a2,b2)),max(op(a1,b1),op(a2,b1),op(a1,b2),op(a2,b2))].
609
    //
610
    // We used to not assume this, and work out the bounds by applying op on the Cartesian product
611
    // of A and B; however, this caused a combinatorial explosion and my computer to run out of
612
    // memory (on the hakank_eprime_xkcd test)...
613
    //Int
614
    // For example, to find the bounds of the intervals [1,4], [1,5] combined using op, we used to do
615
    //  [min(op(1,1), op(1,2),op(1,3),op(1,4),op(1,5),op(2,1)..
616
    //
617
    // +,-,/,* are all monotone, so this assumption should be fine for now...
618

            
619
1802957
    let expr = exprs.pop()?;
620
1802956
    let dom = expr.domain_of()?;
621
1800196
    let resolved = dom.resolve()?;
622
1799876
    let GroundDomain::Int(ranges) = resolved.as_ref() else {
623
13
        return None;
624
    };
625

            
626
1799863
    let (mut current_min, mut current_max) = range_vec_bounds_i32(ranges)?;
627

            
628
2048357
    for expr in exprs {
629
2048357
        let dom = expr.domain_of()?;
630
2048037
        let resolved = dom.resolve()?;
631
2047917
        let GroundDomain::Int(ranges) = resolved.as_ref() else {
632
            return None;
633
        };
634

            
635
2047917
        let (min, max) = range_vec_bounds_i32(ranges)?;
636

            
637
        // all the possible new values for current_min / current_max
638
2047917
        let minmax = op(min, current_max)?;
639
2047917
        let minmin = op(min, current_min)?;
640
2047917
        let maxmin = op(max, current_min)?;
641
2047917
        let maxmax = op(max, current_max)?;
642
2047917
        let vals = [minmax, minmin, maxmin, maxmax];
643

            
644
2047917
        current_min = *vals
645
2047917
            .iter()
646
2047917
            .min()
647
2047917
            .expect("vals iterator should not be empty, and should have a minimum.");
648
2047917
        current_max = *vals
649
2047917
            .iter()
650
2047917
            .max()
651
2047917
            .expect("vals iterator should not be empty, and should have a maximum.");
652
    }
653

            
654
1799423
    if current_min == current_max {
655
941
        Some(Domain::int(vec![Range::Single(current_min)]))
656
    } else {
657
1798482
        Some(Domain::int(vec![Range::Bounded(current_min, current_max)]))
658
    }
659
1824455
}
660

            
661
10
fn matrix_element_domain(e: &Expression) -> Option<DomainPtr> {
662
10
    let (elem_domain, _) = e.domain_of()?.as_matrix()?;
663
10
    elem_domain.as_ref().as_int()?;
664
10
    Some(elem_domain)
665
10
}
666

            
667
// Returns none if unbounded
668
3847780
fn range_vec_bounds_i32(ranges: &Vec<Range<i32>>) -> Option<(i32, i32)> {
669
3847780
    let mut min = i32::MAX;
670
3847780
    let mut max = i32::MIN;
671
3865488
    for r in ranges {
672
3865488
        match r {
673
1381226
            Range::Single(i) => {
674
1381226
                if *i < min {
675
1363518
                    min = *i;
676
1363518
                }
677
1381226
                if *i > max {
678
1381226
                    max = *i;
679
1381226
                }
680
            }
681
2484262
            Range::Bounded(i, j) => {
682
2484262
                if *i < min {
683
2484262
                    min = *i;
684
2484262
                }
685
2484262
                if *j > max {
686
2484262
                    max = *j;
687
2484262
                }
688
            }
689
            Range::UnboundedR(_) | Range::UnboundedL(_) | Range::Unbounded => return None,
690
        }
691
    }
692
3847780
    Some((min, max))
693
3847780
}
694

            
695
impl Expression {
696
    /// Returns the possible values of the expression, recursing to leaf expressions
697
9064618
    pub fn domain_of(&self) -> Option<DomainPtr> {
698
9064618
        match self {
699
96
            Expression::Union(_, a, b) => Some(Domain::set(
700
96
                SetAttr::<IntVal>::default(),
701
96
                a.domain_of()?.union(&b.domain_of()?).ok()?,
702
            )),
703
90
            Expression::Intersect(_, a, b) => Some(Domain::set(
704
90
                SetAttr::<IntVal>::default(),
705
90
                a.domain_of()?.intersect(&b.domain_of()?).ok()?,
706
            )),
707
            Expression::In(_, _, _) => Some(Domain::bool()),
708
            Expression::Supset(_, _, _) => Some(Domain::bool()),
709
            Expression::SupsetEq(_, _, _) => Some(Domain::bool()),
710
            Expression::Subset(_, _, _) => Some(Domain::bool()),
711
6
            Expression::SubsetEq(_, _, _) => Some(Domain::bool()),
712
197266
            Expression::AbstractLiteral(_, abslit) => abslit.domain_of(),
713
            Expression::DominanceRelation(_, _) => Some(Domain::bool()),
714
            Expression::FromSolution(_, expr) => Some(expr.domain_of()),
715
            Expression::Metavar(_, _) => None,
716
26
            Expression::Comprehension(_, comprehension) => comprehension.domain_of(),
717
            Expression::AbstractComprehension(_, comprehension) => comprehension.domain_of(),
718
729804
            Expression::UnsafeIndex(_, matrix, _) | Expression::SafeIndex(_, matrix, _) => {
719
750322
                let dom = matrix.domain_of()?;
720
750322
                if let Some((elem_domain, _)) = dom.as_matrix() {
721
750322
                    return Some(elem_domain);
722
                }
723

            
724
                // may actually use the value in the future
725
                #[allow(clippy::redundant_pattern_matching)]
726
                if let Some(_) = dom.as_tuple() {
727
                    // TODO: We can implement proper indexing for tuples
728
                    return None;
729
                }
730

            
731
                // may actually use the value in the future
732
                #[allow(clippy::redundant_pattern_matching)]
733
                if let Some(_) = dom.as_record() {
734
                    // TODO: We can implement proper indexing for records
735
                    return None;
736
                }
737

            
738
                bug!("subject of an index operation should support indexing")
739
            }
740
            Expression::UnsafeSlice(_, matrix, indices)
741
480
            | Expression::SafeSlice(_, matrix, indices) => {
742
480
                let sliced_dimension = indices.iter().position(Option::is_none);
743

            
744
480
                let dom = matrix.domain_of()?;
745
480
                let Some((elem_domain, index_domains)) = dom.as_matrix() else {
746
                    bug!("subject of an index operation should be a matrix");
747
                };
748

            
749
480
                match sliced_dimension {
750
480
                    Some(dimension) => Some(Domain::matrix(
751
480
                        elem_domain,
752
480
                        vec![index_domains[dimension].clone()],
753
480
                    )),
754

            
755
                    // same as index
756
                    None => Some(elem_domain),
757
                }
758
            }
759
            Expression::InDomain(_, _, _) => Some(Domain::bool()),
760
5801705
            Expression::Atomic(_, atom) => Some(atom.domain_of()),
761
1225971
            Expression::Sum(_, e) => {
762
5831236
                bounded_i32_domain_for_matrix_literal_monotonic(e, |x, y| Some(x + y))
763
            }
764
591114
            Expression::Product(_, e) => {
765
2319792
                bounded_i32_domain_for_matrix_literal_monotonic(e, |x, y| Some(x * y))
766
            }
767
24640
            Expression::Min(_, e) => bounded_i32_domain_for_matrix_literal_monotonic(e, |x, y| {
768
24640
                Some(if x < y { x } else { y })
769
24640
            })
770
4200
            .or_else(|| matrix_element_domain(e)),
771
16010
            Expression::Max(_, e) => bounded_i32_domain_for_matrix_literal_monotonic(e, |x, y| {
772
16000
                Some(if x > y { x } else { y })
773
16000
            })
774
3170
            .or_else(|| matrix_element_domain(e)),
775
520
            Expression::UnsafeDiv(_, a, b) => a
776
520
                .domain_of()?
777
520
                .resolve()?
778
520
                .apply_i32(
779
                    // rust integer division is truncating; however, we want to always round down,
780
                    // including for negative numbers.
781
3000
                    |x, y| {
782
3000
                        if y != 0 {
783
2600
                            Some((x as f32 / y as f32).floor() as i32)
784
                        } else {
785
400
                            None
786
                        }
787
3000
                    },
788
520
                    b.domain_of()?.resolve()?.as_ref(),
789
                )
790
520
                .map(DomainPtr::from)
791
520
                .ok(),
792
21720
            Expression::SafeDiv(_, a, b) => {
793
                // rust integer division is truncating; however, we want to always round down
794
                // including for negative numbers.
795
21720
                let domain = a
796
21720
                    .domain_of()?
797
21720
                    .resolve()?
798
21720
                    .apply_i32(
799
897680
                        |x, y| {
800
897680
                            if y != 0 {
801
780040
                                Some((x as f32 / y as f32).floor() as i32)
802
                            } else {
803
117640
                                None
804
                            }
805
897680
                        },
806
21720
                        b.domain_of()?.resolve()?.as_ref(),
807
                    )
808
                    .unwrap_or_else(|err| bug!("Got {err} when computing domain of {self}"));
809

            
810
21720
                if let GroundDomain::Int(ranges) = domain {
811
21720
                    let mut ranges = ranges;
812
21720
                    ranges.push(Range::Single(0));
813
21720
                    Some(Domain::int(ranges))
814
                } else {
815
                    bug!("Domain of {self} was not integer")
816
                }
817
            }
818
            Expression::UnsafeMod(_, a, b) => a
819
                .domain_of()?
820
                .resolve()?
821
                .apply_i32(
822
                    |x, y| if y != 0 { Some(x % y) } else { None },
823
                    b.domain_of()?.resolve()?.as_ref(),
824
                )
825
                .map(DomainPtr::from)
826
                .ok(),
827
6960
            Expression::SafeMod(_, a, b) => {
828
6960
                let domain = a
829
6960
                    .domain_of()?
830
6960
                    .resolve()?
831
6960
                    .apply_i32(
832
250320
                        |x, y| if y != 0 { Some(x % y) } else { None },
833
6960
                        b.domain_of()?.resolve()?.as_ref(),
834
                    )
835
                    .unwrap_or_else(|err| bug!("Got {err} when computing domain of {self}"));
836

            
837
6960
                if let GroundDomain::Int(ranges) = domain {
838
6960
                    let mut ranges = ranges;
839
6960
                    ranges.push(Range::Single(0));
840
6960
                    Some(Domain::int(ranges))
841
                } else {
842
                    bug!("Domain of {self} was not integer")
843
                }
844
            }
845
4270
            Expression::SafePow(_, a, b) | Expression::UnsafePow(_, a, b) => a
846
4270
                .domain_of()?
847
4270
                .resolve()?
848
4270
                .apply_i32(
849
154360
                    |x, y| {
850
154360
                        if (x != 0 || y != 0) && y >= 0 {
851
139960
                            Some(x.pow(y as u32))
852
                        } else {
853
14400
                            None
854
                        }
855
154360
                    },
856
4270
                    b.domain_of()?.resolve()?.as_ref(),
857
                )
858
4270
                .map(DomainPtr::from)
859
4270
                .ok(),
860
            Expression::Root(_, _) => None,
861
320
            Expression::Bubble(_, inner, _) => inner.domain_of(),
862
            Expression::AuxDeclaration(_, _, _) => Some(Domain::bool()),
863
29286
            Expression::And(_, _) => Some(Domain::bool()),
864
400
            Expression::Not(_, _) => Some(Domain::bool()),
865
320
            Expression::Or(_, _) => Some(Domain::bool()),
866
4228
            Expression::Imply(_, _, _) => Some(Domain::bool()),
867
            Expression::Iff(_, _, _) => Some(Domain::bool()),
868
16052
            Expression::Eq(_, _, _) => Some(Domain::bool()),
869
            Expression::Neq(_, _, _) => Some(Domain::bool()),
870
            Expression::Geq(_, _, _) => Some(Domain::bool()),
871
1720
            Expression::Leq(_, _, _) => Some(Domain::bool()),
872
82
            Expression::Gt(_, _, _) => Some(Domain::bool()),
873
2
            Expression::Lt(_, _, _) => Some(Domain::bool()),
874
            Expression::Factorial(_, _) => None, // not implemented
875
            Expression::FlatAbsEq(_, _, _) => Some(Domain::bool()),
876
80
            Expression::FlatSumGeq(_, _, _) => Some(Domain::bool()),
877
            Expression::FlatSumLeq(_, _, _) => Some(Domain::bool()),
878
            Expression::MinionDivEqUndefZero(_, _, _, _) => Some(Domain::bool()),
879
            Expression::MinionModuloEqUndefZero(_, _, _, _) => Some(Domain::bool()),
880
360
            Expression::FlatIneq(_, _, _, _) => Some(Domain::bool()),
881
804
            Expression::Flatten(_, n, m) => {
882
804
                if let Some(expr) = n {
883
                    if expr.return_type() == ReturnType::Int {
884
                        // TODO: handle flatten with depth argument
885
                        return None;
886
                    }
887
                } else {
888
                    // TODO: currently only works for matrices
889
804
                    let dom = m.domain_of()?.resolve()?;
890
800
                    let (val_dom, idx_doms) = match dom.as_ref() {
891
800
                        GroundDomain::Matrix(val, idx) => (val, idx),
892
                        _ => return None,
893
                    };
894
800
                    let num_elems = matrix::num_elements(idx_doms).ok()? as i32;
895

            
896
800
                    let new_index_domain = Domain::int(vec![Range::Bounded(1, num_elems)]);
897
800
                    return Some(Domain::matrix(
898
800
                        val_dom.clone().into(),
899
800
                        vec![new_index_domain],
900
800
                    ));
901
                }
902
                None
903
            }
904
            Expression::AllDiff(_, _) => Some(Domain::bool()),
905
            Expression::Table(_, _, _) => Some(Domain::bool()),
906
            Expression::NegativeTable(_, _, _) => Some(Domain::bool()),
907
            Expression::FlatWatchedLiteral(_, _, _) => Some(Domain::bool()),
908
            Expression::MinionReify(_, _, _) => Some(Domain::bool()),
909
4820
            Expression::MinionReifyImply(_, _, _) => Some(Domain::bool()),
910
            Expression::MinionWInIntervalSet(_, _, _) => Some(Domain::bool()),
911
            Expression::MinionWInSet(_, _, _) => Some(Domain::bool()),
912
            Expression::MinionElementOne(_, _, _, _) => Some(Domain::bool()),
913
6408
            Expression::Neg(_, x) => {
914
6408
                let dom = x.domain_of()?;
915
6408
                let mut ranges = dom.as_int()?;
916

            
917
3648
                ranges = ranges
918
3648
                    .into_iter()
919
3648
                    .map(|r| match r {
920
320
                        Range::Single(x) => Range::Single(-x),
921
3328
                        Range::Bounded(x, y) => Range::Bounded(-y, -x),
922
                        Range::UnboundedR(i) => Range::UnboundedL(-i),
923
                        Range::UnboundedL(i) => Range::UnboundedR(-i),
924
                        Range::Unbounded => Range::Unbounded,
925
3648
                    })
926
3648
                    .collect();
927

            
928
3648
                Some(Domain::int(ranges))
929
            }
930
375652
            Expression::Minus(_, a, b) => a
931
375652
                .domain_of()?
932
375652
                .resolve()?
933
1159350
                .apply_i32(|x, y| Some(x - y), b.domain_of()?.resolve()?.as_ref())
934
375652
                .map(DomainPtr::from)
935
375652
                .ok(),
936
            Expression::FlatAllDiff(_, _) => Some(Domain::bool()),
937
            Expression::FlatMinusEq(_, _, _) => Some(Domain::bool()),
938
            Expression::FlatProductEq(_, _, _, _) => Some(Domain::bool()),
939
            Expression::FlatWeightedSumLeq(_, _, _, _) => Some(Domain::bool()),
940
            Expression::FlatWeightedSumGeq(_, _, _, _) => Some(Domain::bool()),
941
5320
            Expression::Abs(_, a) => a
942
5320
                .domain_of()?
943
5320
                .resolve()?
944
914120
                .apply_i32(|a, _| Some(a.abs()), a.domain_of()?.resolve()?.as_ref())
945
5320
                .map(DomainPtr::from)
946
5320
                .ok(),
947
            Expression::MinionPow(_, _, _, _) => Some(Domain::bool()),
948
3928
            Expression::ToInt(_, _) => Some(Domain::int(vec![Range::Bounded(0, 1)])),
949
3840
            Expression::SATInt(_, _, _, (low, high)) => {
950
3840
                Some(Domain::int_ground(vec![Range::Bounded(*low, *high)]))
951
            }
952
            Expression::PairwiseSum(_, a, b) => a
953
                .domain_of()?
954
                .resolve()?
955
                .apply_i32(|a, b| Some(a + b), b.domain_of()?.resolve()?.as_ref())
956
                .map(DomainPtr::from)
957
                .ok(),
958
            Expression::PairwiseProduct(_, a, b) => a
959
                .domain_of()?
960
                .resolve()?
961
                .apply_i32(|a, b| Some(a * b), b.domain_of()?.resolve()?.as_ref())
962
                .map(DomainPtr::from)
963
                .ok(),
964
            Expression::Defined(_, function) => get_function_domain(function),
965
            Expression::Range(_, function) => get_function_codomain(function),
966
            Expression::Image(_, function, _) => get_function_codomain(function),
967
            Expression::ImageSet(_, function, _) => get_function_codomain(function),
968
            Expression::PreImage(_, function, _) => get_function_domain(function),
969
            Expression::Restrict(_, function, new_domain) => {
970
                let (attrs, _, codom) = function.domain_of()?.as_function()?;
971
                let new_dom = new_domain.domain_of()?;
972
                Some(Domain::function(attrs, new_dom, codom))
973
            }
974
            Expression::Inverse(..) => Some(Domain::bool()),
975
            Expression::LexLt(..) => Some(Domain::bool()),
976
3080
            Expression::LexLeq(..) => Some(Domain::bool()),
977
            Expression::LexGt(..) => Some(Domain::bool()),
978
            Expression::LexGeq(..) => Some(Domain::bool()),
979
            Expression::FlatLexLt(..) => Some(Domain::bool()),
980
            Expression::FlatLexLeq(..) => Some(Domain::bool()),
981
        }
982
9064618
    }
983

            
984
    /// Returns a reference to this expression's metadata without cloning.
985
    pub fn meta_ref(&self) -> &Metadata {
986
        macro_rules! match_meta_ref {
987
            ($($variant:ident),* $(,)?) => {
988
                match self {
989
                    $(Expression::$variant(meta, ..) => meta,)*
990
                }
991
            };
992
        }
993
        match_meta_ref!(
994
            AbstractLiteral,
995
            Root,
996
            Bubble,
997
            Comprehension,
998
            AbstractComprehension,
999
            DominanceRelation,
            FromSolution,
            Metavar,
            Atomic,
            UnsafeIndex,
            SafeIndex,
            UnsafeSlice,
            SafeSlice,
            InDomain,
            ToInt,
            Abs,
            Sum,
            Product,
            Min,
            Max,
            Not,
            Or,
            And,
            Imply,
            Iff,
            Union,
            In,
            Intersect,
            Supset,
            SupsetEq,
            Subset,
            SubsetEq,
            Eq,
            Neq,
            Geq,
            Leq,
            Gt,
            Lt,
            SafeDiv,
            UnsafeDiv,
            SafeMod,
            UnsafeMod,
            Neg,
            Defined,
            Range,
            UnsafePow,
            SafePow,
            Flatten,
            AllDiff,
            Minus,
            Factorial,
            FlatAbsEq,
            FlatAllDiff,
            FlatSumGeq,
            FlatSumLeq,
            FlatIneq,
            FlatWatchedLiteral,
            FlatWeightedSumLeq,
            FlatWeightedSumGeq,
            FlatMinusEq,
            FlatProductEq,
            MinionDivEqUndefZero,
            MinionModuloEqUndefZero,
            MinionPow,
            MinionReify,
            MinionReifyImply,
            MinionWInIntervalSet,
            MinionWInSet,
            MinionElementOne,
            AuxDeclaration,
            SATInt,
            PairwiseSum,
            PairwiseProduct,
            Image,
            ImageSet,
            PreImage,
            Inverse,
            Restrict,
            LexLt,
            LexLeq,
            LexGt,
            LexGeq,
            FlatLexLt,
            FlatLexLeq,
            NegativeTable,
            Table
        )
    }
    pub fn get_meta(&self) -> Metadata {
        let metas: VecDeque<Metadata> = self.children_bi();
        metas[0].clone()
    }
    pub fn set_meta(&self, meta: Metadata) {
        self.transform_bi(&|_| meta.clone());
    }
    /// Checks whether this expression is safe.
    ///
    /// An expression is unsafe if can be undefined, or if any of its children can be undefined.
    ///
    /// Unsafe expressions are (typically) prefixed with Unsafe in our AST, and can be made
    /// safe through the use of bubble rules.
5250494
    pub fn is_safe(&self) -> bool {
        // TODO: memoise in Metadata
42660686
        for expr in self.universe() {
42660686
            match expr {
                Expression::UnsafeDiv(_, _, _)
                | Expression::UnsafeMod(_, _, _)
                | Expression::UnsafePow(_, _, _)
                | Expression::UnsafeIndex(_, _, _)
                | Expression::Bubble(_, _, _)
                | Expression::UnsafeSlice(_, _, _) => {
807154
                    return false;
                }
41853532
                _ => {}
            }
        }
4443340
        true
5250494
    }
    /// True if the expression is an associative and commutative operator
14649122
    pub fn is_associative_commutative_operator(&self) -> bool {
14649122
        TryInto::<ACOperatorKind>::try_into(self).is_ok()
14649122
    }
    /// True if the expression is a matrix literal.
    ///
    /// This is true for both forms of matrix literals: those with elements of type [`Literal`] and
    /// [`Expression`].
8280
    pub fn is_matrix_literal(&self) -> bool {
        matches!(
8280
            self,
            Expression::AbstractLiteral(_, AbstractLiteral::Matrix(_, _))
                | Expression::Atomic(
                    _,
                    Atom::Literal(Literal::AbstractLiteral(AbstractLiteral::Matrix(_, _))),
                )
        )
8280
    }
    /// True iff self and other are both atomic and identical.
    ///
    /// This method is useful to cheaply check equivalence. Assuming CSE is enabled, any unifiable
    /// expressions will be rewritten to a common variable. This is much cheaper than checking the
    /// entire subtrees of `self` and `other`.
2517236
    pub fn identical_atom_to(&self, other: &Expression) -> bool {
2517236
        let atom1: Result<&Atom, _> = self.try_into();
2517236
        let atom2: Result<&Atom, _> = other.try_into();
2517236
        if let (Ok(atom1), Ok(atom2)) = (atom1, atom2) {
561538
            atom2 == atom1
        } else {
1955698
            false
        }
2517236
    }
    /// If the expression is a list, returns a *copied* vector of the inner expressions.
    ///
    /// A list is any a matrix with the domain `int(1..)`. This includes matrix literals without
    /// any explicitly specified domain.
8662760
    pub fn unwrap_list(&self) -> Option<Vec<Expression>> {
7453768
        match self {
7453768
            Expression::AbstractLiteral(_, matrix @ AbstractLiteral::Matrix(_, _)) => {
7453768
                matrix.unwrap_list().cloned()
            }
            Expression::Atomic(
                _,
27344
                Atom::Literal(Literal::AbstractLiteral(matrix @ AbstractLiteral::Matrix(_, _))),
27344
            ) => matrix.unwrap_list().map(|elems| {
23164
                elems
23164
                    .clone()
23164
                    .into_iter()
63998
                    .map(|x: Literal| Expression::Atomic(Metadata::new(), Atom::Literal(x)))
23164
                    .collect_vec()
23164
            }),
1181648
            _ => None,
        }
8662760
    }
    /// If the expression is a matrix, gets it elements and index domain.
    ///
    /// **Consider using the safer [`Expression::unwrap_list`] instead.**
    ///
    /// It is generally undefined to edit the length of a matrix unless it is a list (as defined by
    /// [`Expression::unwrap_list`]). Users of this function should ensure that, if the matrix is
    /// reconstructed, the index domain and the number of elements in the matrix remain the same.
11931913
    pub fn unwrap_matrix_unchecked(self) -> Option<(Vec<Expression>, DomainPtr)> {
8274123
        match self {
8274123
            Expression::AbstractLiteral(_, AbstractLiteral::Matrix(elems, domain)) => {
8274123
                Some((elems, domain))
            }
            Expression::Atomic(
                _,
440356
                Atom::Literal(Literal::AbstractLiteral(AbstractLiteral::Matrix(elems, domain))),
            ) => Some((
440356
                elems
440356
                    .into_iter()
1017948
                    .map(|x: Literal| Expression::Atomic(Metadata::new(), Atom::Literal(x)))
440356
                    .collect_vec(),
440356
                domain.into(),
            )),
3217434
            _ => None,
        }
11931913
    }
    /// For a Root expression, extends the inner vec with the given vec.
    ///
    /// # Panics
    /// Panics if the expression is not Root.
26420
    pub fn extend_root(self, exprs: Vec<Expression>) -> Expression {
26420
        match self {
26420
            Expression::Root(meta, mut children) => {
26420
                children.extend(exprs);
26420
                Expression::Root(meta, children)
            }
            _ => panic!("extend_root called on a non-Root expression"),
        }
26420
    }
    /// Converts the expression to a literal, if possible.
149138
    pub fn into_literal(self) -> Option<Literal> {
139588
        match self {
129834
            Expression::Atomic(_, Atom::Literal(lit)) => Some(lit),
            Expression::AbstractLiteral(_, abslit) => {
                Some(Literal::AbstractLiteral(abslit.into_literals()?))
            }
6122
            Expression::Neg(_, e) => {
6122
                let Literal::Int(i) = Moo::unwrap_or_clone(e).into_literal()? else {
                    bug!("negated literal should be an int");
                };
6120
                Some(Literal::Int(-i))
            }
13182
            _ => None,
        }
149138
    }
    /// If this expression is an associative-commutative operator, return its [ACOperatorKind].
15888662
    pub fn to_ac_operator_kind(&self) -> Option<ACOperatorKind> {
15888662
        TryFrom::try_from(self).ok()
15888662
    }
    /// Returns the categories of all sub-expressions of self.
87994
    pub fn universe_categories(&self) -> HashSet<Category> {
87994
        self.universe()
87994
            .into_iter()
1027604
            .map(|x| x.category_of())
87994
            .collect()
87994
    }
}
pub fn get_function_domain(function: &Moo<Expression>) -> Option<DomainPtr> {
    let function_domain = function.domain_of()?;
    match function_domain.resolve().as_ref() {
        Some(d) => {
            match d.as_ref() {
                GroundDomain::Function(_, domain, _) => Some(domain.clone().into()),
                // Not defined for anything other than a function
                _ => None,
            }
        }
        None => {
            match function_domain.as_unresolved()? {
                UnresolvedDomain::Function(_, domain, _) => Some(domain.clone()),
                // Not defined for anything other than a function
                _ => None,
            }
        }
    }
}
pub fn get_function_codomain(function: &Moo<Expression>) -> Option<DomainPtr> {
    let function_domain = function.domain_of()?;
    match function_domain.resolve().as_ref() {
        Some(d) => {
            match d.as_ref() {
                GroundDomain::Function(_, _, codomain) => Some(codomain.clone().into()),
                // Not defined for anything other than a function
                _ => None,
            }
        }
        None => {
            match function_domain.as_unresolved()? {
                UnresolvedDomain::Function(_, _, codomain) => Some(codomain.clone()),
                // Not defined for anything other than a function
                _ => None,
            }
        }
    }
}
impl TryFrom<&Expression> for i32 {
    type Error = ();
957154
    fn try_from(value: &Expression) -> Result<Self, Self::Error> {
957154
        let Expression::Atomic(_, atom) = value else {
740038
            return Err(());
        };
217116
        let Atom::Literal(lit) = atom else {
216456
            return Err(());
        };
660
        let Literal::Int(i) = lit else {
            return Err(());
        };
660
        Ok(*i)
957154
    }
}
impl TryFrom<Expression> for i32 {
    type Error = ();
    fn try_from(value: Expression) -> Result<Self, Self::Error> {
        TryFrom::<&Expression>::try_from(&value)
    }
}
impl From<i32> for Expression {
69816
    fn from(i: i32) -> Self {
69816
        Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Int(i)))
69816
    }
}
impl From<bool> for Expression {
35382
    fn from(b: bool) -> Self {
35382
        Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Bool(b)))
35382
    }
}
impl From<Atom> for Expression {
5494
    fn from(value: Atom) -> Self {
5494
        Expression::Atomic(Metadata::new(), value)
5494
    }
}
impl From<Literal> for Expression {
7040
    fn from(value: Literal) -> Self {
7040
        Expression::Atomic(Metadata::new(), value.into())
7040
    }
}
impl From<Moo<Expression>> for Expression {
82512
    fn from(val: Moo<Expression>) -> Self {
82512
        val.as_ref().clone()
82512
    }
}
impl CategoryOf for Expression {
1323070
    fn category_of(&self) -> Category {
        // take highest category of all the expressions children
4958140
        let category = self.cata(&move |x,children| {
4958140
            if let Some(max_category) = children.iter().max() {
                // if this expression contains subexpressions, return the maximum category of the
                // subexpressions
1420438
                *max_category
            } else {
                // this expression has no children
3537702
                let mut max_category = Category::Bottom;
                // calculate the category by looking at all atoms, submodels, comprehensions, and
                // declarationptrs inside this expression
                // this should generically cover all leaf types we currently have in oxide.
                // if x contains submodels (including comprehensions)
3537702
                if !Biplate::<Model>::universe_bi(&x).is_empty() {
                    // assume that the category is decision
                    return Category::Decision;
3537702
                }
                // if x contains atoms
3542102
                if let Some(max_atom_category) = Biplate::<Atom>::universe_bi(&x).iter().map(|x| x.category_of()).max()
                // and those atoms have a higher category than we already know about
3537682
                && max_atom_category > max_category{
                    // update category
3537682
                    max_category = max_atom_category;
3537682
                }
                // if x contains declarationPtrs
3537702
                if let Some(max_declaration_category) = Biplate::<DeclarationPtr>::universe_bi(&x).iter().map(|x| x.category_of()).max()
                // and those pointers have a higher category than we already know about
2728008
                && max_declaration_category > max_category{
                    // update category
598
                    max_category = max_declaration_category;
3537104
                }
3537702
                max_category
            }
4958140
        });
1323070
        if cfg!(debug_assertions) {
1323070
            trace!(
                category= %category,
                expression= %self,
                "Called Expression::category_of()"
            );
        };
1323070
        category
1323070
    }
}
impl Display for Expression {
72902230
    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
72902230
        match &self {
404
            Expression::Union(_, box1, box2) => {
404
                write!(f, "({} union {})", box1.clone(), box2.clone())
            }
3044
            Expression::In(_, e1, e2) => {
3044
                write!(f, "{e1} in {e2}")
            }
380
            Expression::Intersect(_, box1, box2) => {
380
                write!(f, "({} intersect {})", box1.clone(), box2.clone())
            }
480
            Expression::Supset(_, box1, box2) => {
480
                write!(f, "({} supset {})", box1.clone(), box2.clone())
            }
480
            Expression::SupsetEq(_, box1, box2) => {
480
                write!(f, "({} supsetEq {})", box1.clone(), box2.clone())
            }
600
            Expression::Subset(_, box1, box2) => {
600
                write!(f, "({} subset {})", box1.clone(), box2.clone())
            }
510
            Expression::SubsetEq(_, box1, box2) => {
510
                write!(f, "({} subsetEq {})", box1.clone(), box2.clone())
            }
2210788
            Expression::AbstractLiteral(_, l) => l.fmt(f),
30252
            Expression::Comprehension(_, c) => c.fmt(f),
            Expression::AbstractComprehension(_, c) => c.fmt(f),
271394
            Expression::UnsafeIndex(_, e1, e2) => write!(f, "{e1}{}", pretty_vec(e2)),
647904
            Expression::SafeIndex(_, e1, e2) => write!(f, "SafeIndex({e1},{})", pretty_vec(e2)),
20320
            Expression::UnsafeSlice(_, e1, es) => {
20320
                let args = es
20320
                    .iter()
40400
                    .map(|x| match x {
20080
                        Some(x) => format!("{x}"),
20320
                        None => "..".into(),
40400
                    })
20320
                    .join(",");
20320
                write!(f, "{e1}[{args}]")
            }
13760
            Expression::SafeSlice(_, e1, es) => {
13760
                let args = es
13760
                    .iter()
26960
                    .map(|x| match x {
13200
                        Some(x) => format!("{x}"),
13760
                        None => "..".into(),
26960
                    })
13760
                    .join(",");
13760
                write!(f, "SafeSlice({e1},[{args}])")
            }
3480
            Expression::InDomain(_, e, domain) => {
3480
                write!(f, "__inDomain({e},{domain})")
            }
203784
            Expression::Root(_, exprs) => {
203784
                write!(f, "{}", pretty_expressions_as_top_level(exprs))
            }
            Expression::DominanceRelation(_, expr) => write!(f, "DominanceRelation({expr})"),
            Expression::FromSolution(_, expr) => write!(f, "FromSolution({expr})"),
            Expression::Metavar(_, name) => write!(f, "&{name}"),
64814400
            Expression::Atomic(_, atom) => atom.fmt(f),
5760
            Expression::Abs(_, a) => write!(f, "|{a}|"),
630146
            Expression::Sum(_, e) => {
630146
                write!(f, "sum({e})")
            }
101810
            Expression::Product(_, e) => {
101810
                write!(f, "product({e})")
            }
14840
            Expression::Min(_, e) => {
14840
                write!(f, "min({e})")
            }
14964
            Expression::Max(_, e) => {
14964
                write!(f, "max({e})")
            }
40704
            Expression::Not(_, expr_box) => {
40704
                write!(f, "!({})", expr_box.clone())
            }
249864
            Expression::Or(_, e) => {
249864
                write!(f, "or({e})")
            }
244304
            Expression::And(_, e) => {
244304
                write!(f, "and({e})")
            }
45600
            Expression::Imply(_, box1, box2) => {
45600
                write!(f, "({box1}) -> ({box2})")
            }
1680
            Expression::Iff(_, box1, box2) => {
1680
                write!(f, "({box1}) <-> ({box2})")
            }
350870
            Expression::Eq(_, box1, box2) => {
350870
                write!(f, "({} = {})", box1.clone(), box2.clone())
            }
458418
            Expression::Neq(_, box1, box2) => {
458418
                write!(f, "({} != {})", box1.clone(), box2.clone())
            }
138244
            Expression::Geq(_, box1, box2) => {
138244
                write!(f, "({} >= {})", box1.clone(), box2.clone())
            }
313496
            Expression::Leq(_, box1, box2) => {
313496
                write!(f, "({} <= {})", box1.clone(), box2.clone())
            }
8004
            Expression::Gt(_, box1, box2) => {
8004
                write!(f, "({} > {})", box1.clone(), box2.clone())
            }
45924
            Expression::Lt(_, box1, box2) => {
45924
                write!(f, "({} < {})", box1.clone(), box2.clone())
            }
159220
            Expression::FlatSumGeq(_, box1, box2) => {
159220
                write!(f, "SumGeq({}, {})", pretty_vec(box1), box2.clone())
            }
155780
            Expression::FlatSumLeq(_, box1, box2) => {
155780
                write!(f, "SumLeq({}, {})", pretty_vec(box1), box2.clone())
            }
56312
            Expression::FlatIneq(_, box1, box2, box3) => write!(
56312
                f,
                "Ineq({}, {}, {})",
56312
                box1.clone(),
56312
                box2.clone(),
56312
                box3.clone()
            ),
1994
            Expression::Flatten(_, n, m) => {
1994
                if let Some(n) = n {
                    write!(f, "flatten({n}, {m})")
                } else {
1994
                    write!(f, "flatten({m})")
                }
            }
14876
            Expression::AllDiff(_, e) => {
14876
                write!(f, "allDiff({e})")
            }
1200
            Expression::Table(_, tuple_expr, rows_expr) => {
1200
                write!(f, "table({tuple_expr}, {rows_expr})")
            }
200
            Expression::NegativeTable(_, tuple_expr, rows_expr) => {
200
                write!(f, "negativeTable({tuple_expr}, {rows_expr})")
            }
13784
            Expression::Bubble(_, box1, box2) => {
13784
                write!(f, "{{{} @ {}}}", box1.clone(), box2.clone())
            }
19960
            Expression::SafeDiv(_, box1, box2) => {
19960
                write!(f, "SafeDiv({}, {})", box1.clone(), box2.clone())
            }
21560
            Expression::UnsafeDiv(_, box1, box2) => {
21560
                write!(f, "({} / {})", box1.clone(), box2.clone())
            }
26220
            Expression::UnsafePow(_, box1, box2) => {
26220
                write!(f, "({} ** {})", box1.clone(), box2.clone())
            }
279722
            Expression::SafePow(_, box1, box2) => {
279722
                write!(f, "SafePow({}, {})", box1.clone(), box2.clone())
            }
4220
            Expression::MinionDivEqUndefZero(_, box1, box2, box3) => {
4220
                write!(
4220
                    f,
                    "DivEq({}, {}, {})",
4220
                    box1.clone(),
4220
                    box2.clone(),
4220
                    box3.clone()
                )
            }
1320
            Expression::MinionModuloEqUndefZero(_, box1, box2, box3) => {
1320
                write!(
1320
                    f,
                    "ModEq({}, {}, {})",
1320
                    box1.clone(),
1320
                    box2.clone(),
1320
                    box3.clone()
                )
            }
4344
            Expression::FlatWatchedLiteral(_, x, l) => {
4344
                write!(f, "WatchedLiteral({x},{l})")
            }
66776
            Expression::MinionReify(_, box1, box2) => {
66776
                write!(f, "Reify({}, {})", box1.clone(), box2.clone())
            }
39488
            Expression::MinionReifyImply(_, box1, box2) => {
39488
                write!(f, "ReifyImply({}, {})", box1.clone(), box2.clone())
            }
760
            Expression::MinionWInIntervalSet(_, atom, intervals) => {
760
                let intervals = intervals.iter().join(",");
760
                write!(f, "__minion_w_inintervalset({atom},[{intervals}])")
            }
480
            Expression::MinionWInSet(_, atom, values) => {
480
                let values = values.iter().join(",");
480
                write!(f, "__minion_w_inset({atom},[{values}])")
            }
87042
            Expression::AuxDeclaration(_, reference, e) => {
87042
                write!(f, "{} =aux {}", reference, e.clone())
            }
3880
            Expression::UnsafeMod(_, a, b) => {
3880
                write!(f, "{} % {}", a.clone(), b.clone())
            }
6720
            Expression::SafeMod(_, a, b) => {
6720
                write!(f, "SafeMod({},{})", a.clone(), b.clone())
            }
48992
            Expression::Neg(_, a) => {
48992
                write!(f, "-({})", a.clone())
            }
            Expression::Factorial(_, a) => {
                write!(f, "({})!", a.clone())
            }
192814
            Expression::Minus(_, a, b) => {
192814
                write!(f, "({} - {})", a.clone(), b.clone())
            }
10164
            Expression::FlatAllDiff(_, es) => {
10164
                write!(f, "__flat_alldiff({})", pretty_vec(es))
            }
1280
            Expression::FlatAbsEq(_, a, b) => {
1280
                write!(f, "AbsEq({},{})", a.clone(), b.clone())
            }
720
            Expression::FlatMinusEq(_, a, b) => {
720
                write!(f, "MinusEq({},{})", a.clone(), b.clone())
            }
2300
            Expression::FlatProductEq(_, a, b, c) => {
2300
                write!(
2300
                    f,
                    "FlatProductEq({},{},{})",
2300
                    a.clone(),
2300
                    b.clone(),
2300
                    c.clone()
                )
            }
18700
            Expression::FlatWeightedSumLeq(_, cs, vs, total) => {
18700
                write!(
18700
                    f,
                    "FlatWeightedSumLeq({},{},{})",
18700
                    pretty_vec(cs),
18700
                    pretty_vec(vs),
18700
                    total.clone()
                )
            }
18900
            Expression::FlatWeightedSumGeq(_, cs, vs, total) => {
18900
                write!(
18900
                    f,
                    "FlatWeightedSumGeq({},{},{})",
18900
                    pretty_vec(cs),
18900
                    pretty_vec(vs),
18900
                    total.clone()
                )
            }
9160
            Expression::MinionPow(_, atom, atom1, atom2) => {
9160
                write!(f, "MinionPow({atom},{atom1},{atom2})")
            }
93568
            Expression::MinionElementOne(_, atoms, atom, atom1) => {
93568
                let atoms = atoms.iter().join(",");
93568
                write!(f, "__minion_element_one([{atoms}],{atom},{atom1})")
            }
1126
            Expression::ToInt(_, expr) => {
1126
                write!(f, "toInt({expr})")
            }
636640
            Expression::SATInt(_, encoding, bits, (min, max)) => {
636640
                write!(f, "SATInt({encoding:?}, {bits} [{min}, {max}])")
            }
            Expression::PairwiseSum(_, a, b) => write!(f, "PairwiseSum({a}, {b})"),
            Expression::PairwiseProduct(_, a, b) => write!(f, "PairwiseProduct({a}, {b})"),
40
            Expression::Defined(_, function) => write!(f, "defined({function})"),
40
            Expression::Range(_, function) => write!(f, "range({function})"),
40
            Expression::Image(_, function, elems) => write!(f, "image({function},{elems})"),
40
            Expression::ImageSet(_, function, elems) => write!(f, "imageSet({function},{elems})"),
40
            Expression::PreImage(_, function, elems) => write!(f, "preImage({function},{elems})"),
40
            Expression::Inverse(_, a, b) => write!(f, "inverse({a},{b})"),
40
            Expression::Restrict(_, function, domain) => write!(f, "restrict({function},{domain})"),
1360
            Expression::LexLt(_, a, b) => write!(f, "({a} <lex {b})"),
12820
            Expression::LexLeq(_, a, b) => write!(f, "({a} <=lex {b})"),
120
            Expression::LexGt(_, a, b) => write!(f, "({a} >lex {b})"),
180
            Expression::LexGeq(_, a, b) => write!(f, "({a} >=lex {b})"),
240
            Expression::FlatLexLt(_, a, b) => {
240
                write!(f, "FlatLexLt({}, {})", pretty_vec(a), pretty_vec(b))
            }
400
            Expression::FlatLexLeq(_, a, b) => {
400
                write!(f, "FlatLexLeq({}, {})", pretty_vec(a), pretty_vec(b))
            }
        }
72902230
    }
}
impl Typeable for Expression {
2214426
    fn return_type(&self) -> ReturnType {
2214426
        match self {
            Expression::Union(_, subject, _) => ReturnType::Set(Box::new(subject.return_type())),
            Expression::Intersect(_, subject, _) => {
                ReturnType::Set(Box::new(subject.return_type()))
            }
600
            Expression::In(_, _, _) => ReturnType::Bool,
            Expression::Supset(_, _, _) => ReturnType::Bool,
            Expression::SupsetEq(_, _, _) => ReturnType::Bool,
            Expression::Subset(_, _, _) => ReturnType::Bool,
            Expression::SubsetEq(_, _, _) => ReturnType::Bool,
133590
            Expression::AbstractLiteral(_, lit) => lit.return_type(),
229078
            Expression::UnsafeIndex(_, subject, idx) | Expression::SafeIndex(_, subject, idx) => {
301998
                let subject_ty = subject.return_type();
301998
                match subject_ty {
                    ReturnType::Matrix(_) => {
                        // For n-dimensional matrices, unwrap the element type until
                        // we either get to the innermost element type or the last index
292158
                        let mut elem_typ = subject_ty;
292158
                        let mut idx_len = idx.len();
584630
                        while idx_len > 0
342946
                            && let ReturnType::Matrix(new_elem_typ) = &elem_typ
292472
                        {
292472
                            elem_typ = *new_elem_typ.clone();
292472
                            idx_len -= 1;
292472
                        }
292158
                        elem_typ
                    }
                    // TODO: We can implement indexing for these eventually
9840
                    ReturnType::Record(_) | ReturnType::Tuple(_) => ReturnType::Unknown,
                    _ => bug!(
                        "Invalid indexing operation: expected the operand to be a collection, got {self}: {subject_ty}"
                    ),
                }
            }
160
            Expression::UnsafeSlice(_, subject, _) | Expression::SafeSlice(_, subject, _) => {
160
                ReturnType::Matrix(Box::new(subject.return_type()))
            }
            Expression::InDomain(_, _, _) => ReturnType::Bool,
4
            Expression::Comprehension(_, comp) => comp.return_type(),
            Expression::AbstractComprehension(_, comp) => comp.return_type(),
            Expression::Root(_, _) => ReturnType::Bool,
            Expression::DominanceRelation(_, _) => ReturnType::Bool,
            Expression::FromSolution(_, expr) => expr.return_type(),
            Expression::Metavar(_, _) => ReturnType::Unknown,
1572338
            Expression::Atomic(_, atom) => atom.return_type(),
960
            Expression::Abs(_, _) => ReturnType::Int,
105200
            Expression::Sum(_, _) => ReturnType::Int,
10762
            Expression::Product(_, _) => ReturnType::Int,
1920
            Expression::Min(_, _) => ReturnType::Int,
1866
            Expression::Max(_, _) => ReturnType::Int,
160
            Expression::Not(_, _) => ReturnType::Bool,
414
            Expression::Or(_, _) => ReturnType::Bool,
1292
            Expression::Imply(_, _, _) => ReturnType::Bool,
            Expression::Iff(_, _, _) => ReturnType::Bool,
4096
            Expression::And(_, _) => ReturnType::Bool,
8662
            Expression::Eq(_, _, _) => ReturnType::Bool,
1080
            Expression::Neq(_, _, _) => ReturnType::Bool,
            Expression::Geq(_, _, _) => ReturnType::Bool,
2200
            Expression::Leq(_, _, _) => ReturnType::Bool,
            Expression::Gt(_, _, _) => ReturnType::Bool,
            Expression::Lt(_, _, _) => ReturnType::Bool,
13200
            Expression::SafeDiv(_, _, _) => ReturnType::Int,
11160
            Expression::UnsafeDiv(_, _, _) => ReturnType::Int,
            Expression::FlatAllDiff(_, _) => ReturnType::Bool,
80
            Expression::FlatSumGeq(_, _, _) => ReturnType::Bool,
            Expression::FlatSumLeq(_, _, _) => ReturnType::Bool,
            Expression::MinionDivEqUndefZero(_, _, _, _) => ReturnType::Bool,
680
            Expression::FlatIneq(_, _, _, _) => ReturnType::Bool,
1920
            Expression::Flatten(_, _, matrix) => {
1920
                let matrix_type = matrix.return_type();
1920
                match matrix_type {
                    ReturnType::Matrix(_) => {
                        // unwrap until we get to innermost element
1920
                        let mut elem_type = matrix_type;
3840
                        while let ReturnType::Matrix(new_elem_type) = &elem_type {
1920
                            elem_type = *new_elem_type.clone();
1920
                        }
1920
                        ReturnType::Matrix(Box::new(elem_type))
                    }
                    _ => bug!(
                        "Invalid indexing operation: expected the operand to be a collection, got {self}: {matrix_type}"
                    ),
                }
            }
160
            Expression::AllDiff(_, _) => ReturnType::Bool,
            Expression::Table(_, _, _) => ReturnType::Bool,
            Expression::NegativeTable(_, _, _) => ReturnType::Bool,
3120
            Expression::Bubble(_, inner, _) => inner.return_type(),
            Expression::FlatWatchedLiteral(_, _, _) => ReturnType::Bool,
            Expression::MinionReify(_, _, _) => ReturnType::Bool,
240
            Expression::MinionReifyImply(_, _, _) => ReturnType::Bool,
            Expression::MinionWInIntervalSet(_, _, _) => ReturnType::Bool,
            Expression::MinionWInSet(_, _, _) => ReturnType::Bool,
            Expression::MinionElementOne(_, _, _, _) => ReturnType::Bool,
            Expression::AuxDeclaration(_, _, _) => ReturnType::Bool,
1520
            Expression::UnsafeMod(_, _, _) => ReturnType::Int,
4320
            Expression::SafeMod(_, _, _) => ReturnType::Int,
            Expression::MinionModuloEqUndefZero(_, _, _, _) => ReturnType::Bool,
2802
            Expression::Neg(_, _) => ReturnType::Int,
            Expression::Factorial(_, _) => ReturnType::Int,
960
            Expression::UnsafePow(_, _, _) => ReturnType::Int,
2784
            Expression::SafePow(_, _, _) => ReturnType::Int,
1478
            Expression::Minus(_, _, _) => ReturnType::Int,
            Expression::FlatAbsEq(_, _, _) => ReturnType::Bool,
            Expression::FlatMinusEq(_, _, _) => ReturnType::Bool,
            Expression::FlatProductEq(_, _, _, _) => ReturnType::Bool,
            Expression::FlatWeightedSumLeq(_, _, _, _) => ReturnType::Bool,
160
            Expression::FlatWeightedSumGeq(_, _, _, _) => ReturnType::Bool,
            Expression::MinionPow(_, _, _, _) => ReturnType::Bool,
1340
            Expression::ToInt(_, _) => ReturnType::Int,
19840
            Expression::SATInt(..) => ReturnType::Int,
            Expression::PairwiseSum(_, _, _) => ReturnType::Int,
            Expression::PairwiseProduct(_, _, _) => ReturnType::Int,
            Expression::Defined(_, function) => {
                let subject = function.return_type();
                match subject {
                    ReturnType::Function(domain, _) => *domain,
                    _ => bug!(
                        "Invalid defined operation: expected the operand to be a function, got {self}: {subject}"
                    ),
                }
            }
            Expression::Range(_, function) => {
                let subject = function.return_type();
                match subject {
                    ReturnType::Function(_, codomain) => *codomain,
                    _ => bug!(
                        "Invalid range operation: expected the operand to be a function, got {self}: {subject}"
                    ),
                }
            }
            Expression::Image(_, function, _) => {
                let subject = function.return_type();
                match subject {
                    ReturnType::Function(_, codomain) => *codomain,
                    _ => bug!(
                        "Invalid image operation: expected the operand to be a function, got {self}: {subject}"
                    ),
                }
            }
            Expression::ImageSet(_, function, _) => {
                let subject = function.return_type();
                match subject {
                    ReturnType::Function(_, codomain) => *codomain,
                    _ => bug!(
                        "Invalid imageSet operation: expected the operand to be a function, got {self}: {subject}"
                    ),
                }
            }
            Expression::PreImage(_, function, _) => {
                let subject = function.return_type();
                match subject {
                    ReturnType::Function(domain, _) => *domain,
                    _ => bug!(
                        "Invalid preImage operation: expected the operand to be a function, got {self}: {subject}"
                    ),
                }
            }
            Expression::Restrict(_, function, new_domain) => {
                let subject = function.return_type();
                match subject {
                    ReturnType::Function(_, codomain) => {
                        ReturnType::Function(Box::new(new_domain.return_type()), codomain)
                    }
                    _ => bug!(
                        "Invalid preImage operation: expected the operand to be a function, got {self}: {subject}"
                    ),
                }
            }
            Expression::Inverse(..) => ReturnType::Bool,
            Expression::LexLt(..) => ReturnType::Bool,
            Expression::LexGt(..) => ReturnType::Bool,
1360
            Expression::LexLeq(..) => ReturnType::Bool,
            Expression::LexGeq(..) => ReturnType::Bool,
            Expression::FlatLexLt(..) => ReturnType::Bool,
            Expression::FlatLexLeq(..) => ReturnType::Bool,
        }
2214426
    }
}
impl Expression {
    /// Visit each direct `Expression` child by reference, without cloning.
    fn for_each_expr_child(&self, f: &mut impl FnMut(&Expression)) {
        match self {
            // Special Case
            Expression::AbstractLiteral(_, alit) => match alit {
                AbstractLiteral::Set(v) | AbstractLiteral::MSet(v) | AbstractLiteral::Tuple(v) => {
                    for expr in v {
                        f(expr);
                    }
                }
                AbstractLiteral::Matrix(v, _domain) => {
                    for expr in v {
                        f(expr);
                    }
                }
                AbstractLiteral::Record(rs) => {
                    for r in rs {
                        f(&r.value);
                    }
                }
                AbstractLiteral::Function(vs) => {
                    for (a, b) in vs {
                        f(a);
                        f(b);
                    }
                }
            },
            Expression::Root(_, vs) => {
                for expr in vs {
                    f(expr);
                }
            }
            // Moo<Expression>
            Expression::DominanceRelation(_, m1)
            | Expression::ToInt(_, m1)
            | Expression::Abs(_, m1)
            | Expression::Sum(_, m1)
            | Expression::Product(_, m1)
            | Expression::Min(_, m1)
            | Expression::Max(_, m1)
            | Expression::Not(_, m1)
            | Expression::Or(_, m1)
            | Expression::And(_, m1)
            | Expression::Neg(_, m1)
            | Expression::Defined(_, m1)
            | Expression::AllDiff(_, m1)
            | Expression::Factorial(_, m1)
            | Expression::Range(_, m1) => {
                f(m1);
            }
            // Moo<Expression> + Moo<Expression>
            Expression::Table(_, m1, m2)
            | Expression::NegativeTable(_, m1, m2)
            | Expression::Bubble(_, m1, m2)
            | Expression::Imply(_, m1, m2)
            | Expression::Iff(_, m1, m2)
            | Expression::Union(_, m1, m2)
            | Expression::In(_, m1, m2)
            | Expression::Intersect(_, m1, m2)
            | Expression::Supset(_, m1, m2)
            | Expression::SupsetEq(_, m1, m2)
            | Expression::Subset(_, m1, m2)
            | Expression::SubsetEq(_, m1, m2)
            | Expression::Eq(_, m1, m2)
            | Expression::Neq(_, m1, m2)
            | Expression::Geq(_, m1, m2)
            | Expression::Leq(_, m1, m2)
            | Expression::Gt(_, m1, m2)
            | Expression::Lt(_, m1, m2)
            | Expression::SafeDiv(_, m1, m2)
            | Expression::UnsafeDiv(_, m1, m2)
            | Expression::SafeMod(_, m1, m2)
            | Expression::UnsafeMod(_, m1, m2)
            | Expression::UnsafePow(_, m1, m2)
            | Expression::SafePow(_, m1, m2)
            | Expression::Minus(_, m1, m2)
            | Expression::PairwiseSum(_, m1, m2)
            | Expression::PairwiseProduct(_, m1, m2)
            | Expression::Image(_, m1, m2)
            | Expression::ImageSet(_, m1, m2)
            | Expression::PreImage(_, m1, m2)
            | Expression::Inverse(_, m1, m2)
            | Expression::Restrict(_, m1, m2)
            | Expression::LexLt(_, m1, m2)
            | Expression::LexLeq(_, m1, m2)
            | Expression::LexGt(_, m1, m2)
            | Expression::LexGeq(_, m1, m2) => {
                f(m1);
                f(m2);
            }
            // Moo<Expression> + Vec<Expression>
            Expression::UnsafeIndex(_, m, vs) | Expression::SafeIndex(_, m, vs) => {
                f(m);
                for v in vs {
                    f(v);
                }
            }
            // Moo<Expression> + Vec<Option<Expression>>
            Expression::UnsafeSlice(_, m, vs) | Expression::SafeSlice(_, m, vs) => {
                f(m);
                for e in vs.iter().flatten() {
                    f(e);
                }
            }
            // Moo<Expression> + DomainPtr
            Expression::InDomain(_, m, _) => {
                f(m);
            }
            // Option<Moo<Expression>> + Moo<Expression>
            Expression::Flatten(_, opt, m) => {
                if let Some(e) = opt {
                    f(e);
                }
                f(m);
            }
            // Moo<Expression> + Atom
            Expression::MinionReify(_, m, _) | Expression::MinionReifyImply(_, m, _) => {
                f(m);
            }
            // Reference + Moo<Expression>
            Expression::AuxDeclaration(_, _, m) => {
                f(m);
            }
            // SATIntEncoding + Moo<Expression> + (i32, i32)
            Expression::SATInt(_, _, m, _) => {
                f(m);
            }
            // No Expression children
            Expression::Comprehension(_, _)
            | Expression::AbstractComprehension(_, _)
            | Expression::Atomic(_, _)
            | Expression::FromSolution(_, _)
            | Expression::Metavar(_, _)
            | Expression::FlatAbsEq(_, _, _)
            | Expression::FlatMinusEq(_, _, _)
            | Expression::FlatProductEq(_, _, _, _)
            | Expression::MinionDivEqUndefZero(_, _, _, _)
            | Expression::MinionModuloEqUndefZero(_, _, _, _)
            | Expression::MinionPow(_, _, _, _)
            | Expression::FlatAllDiff(_, _)
            | Expression::FlatSumGeq(_, _, _)
            | Expression::FlatSumLeq(_, _, _)
            | Expression::FlatIneq(_, _, _, _)
            | Expression::FlatWatchedLiteral(_, _, _)
            | Expression::FlatWeightedSumLeq(_, _, _, _)
            | Expression::FlatWeightedSumGeq(_, _, _, _)
            | Expression::MinionWInIntervalSet(_, _, _)
            | Expression::MinionWInSet(_, _, _)
            | Expression::MinionElementOne(_, _, _, _)
            | Expression::FlatLexLt(_, _, _)
            | Expression::FlatLexLeq(_, _, _) => {}
        }
    }
}
impl CacheHashable for Expression {
    fn invalidate_cache(&self) {
        self.meta_ref()
            .stored_hash
            .store(NO_HASH, Ordering::Relaxed);
    }
    fn invalidate_cache_recursive(&self) {
        self.invalidate_cache();
        self.for_each_expr_child(&mut |child| {
            child.invalidate_cache_recursive();
        });
    }
    fn get_cached_hash(&self) -> u64 {
        let stored = self.meta_ref().stored_hash.load(Ordering::Relaxed);
        if stored != NO_HASH {
            HASH_HITS.fetch_add(1, Ordering::Relaxed);
            return stored;
        }
        HASH_MISSES.fetch_add(1, Ordering::Relaxed);
        self.calculate_hash()
    }
    fn calculate_hash(&self) -> u64 {
        let mut hasher = DefaultHasher::new();
        std::mem::discriminant(self).hash(&mut hasher);
        match self {
            // Special Case
            Expression::AbstractLiteral(_, alit) => match alit {
                AbstractLiteral::Set(v) | AbstractLiteral::MSet(v) | AbstractLiteral::Tuple(v) => {
                    for expr in v {
                        expr.get_cached_hash().hash(&mut hasher);
                    }
                }
                AbstractLiteral::Matrix(v, domain) => {
                    domain.hash(&mut hasher);
                    for expr in v {
                        expr.get_cached_hash().hash(&mut hasher);
                    }
                }
                AbstractLiteral::Record(rs) => {
                    for r in rs {
                        r.name.hash(&mut hasher);
                        r.value.get_cached_hash().hash(&mut hasher);
                    }
                }
                AbstractLiteral::Function(vs) => {
                    for (a, b) in vs {
                        a.get_cached_hash().hash(&mut hasher);
                        b.get_cached_hash().hash(&mut hasher);
                    }
                }
            },
            Expression::Root(_, vs) => {
                for expr in vs {
                    expr.get_cached_hash().hash(&mut hasher);
                }
            }
            // Moo<Expression>
            Expression::DominanceRelation(_, m1)
            | Expression::ToInt(_, m1)
            | Expression::Abs(_, m1)
            | Expression::Sum(_, m1)
            | Expression::Product(_, m1)
            | Expression::Min(_, m1)
            | Expression::Max(_, m1)
            | Expression::Not(_, m1)
            | Expression::Or(_, m1)
            | Expression::And(_, m1)
            | Expression::Neg(_, m1)
            | Expression::Defined(_, m1)
            | Expression::AllDiff(_, m1)
            | Expression::Factorial(_, m1)
            | Expression::Range(_, m1) => {
                m1.get_cached_hash().hash(&mut hasher);
            }
            // Moo<Expression> + Moo<Expression>
            Expression::Table(_, m1, m2)
            | Expression::NegativeTable(_, m1, m2)
            | Expression::Bubble(_, m1, m2)
            | Expression::Imply(_, m1, m2)
            | Expression::Iff(_, m1, m2)
            | Expression::Union(_, m1, m2)
            | Expression::In(_, m1, m2)
            | Expression::Intersect(_, m1, m2)
            | Expression::Supset(_, m1, m2)
            | Expression::SupsetEq(_, m1, m2)
            | Expression::Subset(_, m1, m2)
            | Expression::SubsetEq(_, m1, m2)
            | Expression::Eq(_, m1, m2)
            | Expression::Neq(_, m1, m2)
            | Expression::Geq(_, m1, m2)
            | Expression::Leq(_, m1, m2)
            | Expression::Gt(_, m1, m2)
            | Expression::Lt(_, m1, m2)
            | Expression::SafeDiv(_, m1, m2)
            | Expression::UnsafeDiv(_, m1, m2)
            | Expression::SafeMod(_, m1, m2)
            | Expression::UnsafeMod(_, m1, m2)
            | Expression::UnsafePow(_, m1, m2)
            | Expression::SafePow(_, m1, m2)
            | Expression::Minus(_, m1, m2)
            | Expression::PairwiseSum(_, m1, m2)
            | Expression::PairwiseProduct(_, m1, m2)
            | Expression::Image(_, m1, m2)
            | Expression::ImageSet(_, m1, m2)
            | Expression::PreImage(_, m1, m2)
            | Expression::Inverse(_, m1, m2)
            | Expression::Restrict(_, m1, m2)
            | Expression::LexLt(_, m1, m2)
            | Expression::LexLeq(_, m1, m2)
            | Expression::LexGt(_, m1, m2)
            | Expression::LexGeq(_, m1, m2) => {
                m1.get_cached_hash().hash(&mut hasher);
                m2.get_cached_hash().hash(&mut hasher);
            }
            // Moo<Expression> + Vec<Expression>
            Expression::UnsafeIndex(_, m, vs) | Expression::SafeIndex(_, m, vs) => {
                m.get_cached_hash().hash(&mut hasher);
                for v in vs {
                    v.get_cached_hash().hash(&mut hasher);
                }
            }
            // Moo<Expression> + Vec<Option<Expression>>
            Expression::UnsafeSlice(_, m, vs) | Expression::SafeSlice(_, m, vs) => {
                m.get_cached_hash().hash(&mut hasher);
                for v in vs {
                    match v {
                        Some(e) => e.get_cached_hash().hash(&mut hasher),
                        None => 0u64.hash(&mut hasher),
                    }
                }
            }
            // Moo<Expression> + DomainPtr
            Expression::InDomain(_, m, d) => {
                m.get_cached_hash().hash(&mut hasher);
                d.hash(&mut hasher);
            }
            // Option<Moo<Expression>> + Moo<Expression>
            Expression::Flatten(_, opt, m) => {
                if let Some(e) = opt {
                    e.get_cached_hash().hash(&mut hasher);
                }
                m.get_cached_hash().hash(&mut hasher);
            }
            // Moo<Expression> + Atom
            Expression::MinionReify(_, m, a) | Expression::MinionReifyImply(_, m, a) => {
                m.get_cached_hash().hash(&mut hasher);
                a.hash(&mut hasher);
            }
            // Reference + Moo<Expression>
            Expression::AuxDeclaration(_, r, m) => {
                r.hash(&mut hasher);
                m.get_cached_hash().hash(&mut hasher);
            }
            // SATIntEncoding + Moo<Expression> + (i32, i32)
            Expression::SATInt(_, enc, m, bounds) => {
                enc.hash(&mut hasher);
                m.get_cached_hash().hash(&mut hasher);
                bounds.hash(&mut hasher);
            }
            // Non-Expression Moo types - hash normally
            Expression::Comprehension(_, c) => c.hash(&mut hasher),
            Expression::AbstractComprehension(_, c) => c.hash(&mut hasher),
            // Leaf types - no Expression children
            Expression::Atomic(_, a) => a.hash(&mut hasher),
            Expression::FromSolution(_, a) => a.hash(&mut hasher),
            Expression::Metavar(_, u) => u.hash(&mut hasher),
            // Two Moo<Atom>
            Expression::FlatAbsEq(_, a1, a2) | Expression::FlatMinusEq(_, a1, a2) => {
                a1.hash(&mut hasher);
                a2.hash(&mut hasher);
            }
            // Three Moo<Atom>
            Expression::FlatProductEq(_, a1, a2, a3)
            | Expression::MinionDivEqUndefZero(_, a1, a2, a3)
            | Expression::MinionModuloEqUndefZero(_, a1, a2, a3)
            | Expression::MinionPow(_, a1, a2, a3) => {
                a1.hash(&mut hasher);
                a2.hash(&mut hasher);
                a3.hash(&mut hasher);
            }
            // Vec<Atom>
            Expression::FlatAllDiff(_, vs) => {
                for v in vs {
                    v.hash(&mut hasher);
                }
            }
            // Vec<Atom> + Atom
            Expression::FlatSumGeq(_, vs, a) | Expression::FlatSumLeq(_, vs, a) => {
                for v in vs {
                    v.hash(&mut hasher);
                }
                a.hash(&mut hasher);
            }
            // Moo<Atom> + Moo<Atom> + Box<Literal>
            Expression::FlatIneq(_, a1, a2, lit) => {
                a1.hash(&mut hasher);
                a2.hash(&mut hasher);
                lit.hash(&mut hasher);
            }
            // Reference + Literal
            Expression::FlatWatchedLiteral(_, r, l) => {
                r.hash(&mut hasher);
                l.hash(&mut hasher);
            }
            // Vec<Literal> + Vec<Atom> + Moo<Atom>
            Expression::FlatWeightedSumLeq(_, lits, atoms, a)
            | Expression::FlatWeightedSumGeq(_, lits, atoms, a) => {
                for l in lits {
                    l.hash(&mut hasher);
                }
                for at in atoms {
                    at.hash(&mut hasher);
                }
                a.hash(&mut hasher);
            }
            // Atom + Vec<i32>
            Expression::MinionWInIntervalSet(_, a, vs) | Expression::MinionWInSet(_, a, vs) => {
                a.hash(&mut hasher);
                for v in vs {
                    v.hash(&mut hasher);
                }
            }
            // Vec<Atom> + Moo<Atom> + Moo<Atom>
            Expression::MinionElementOne(_, vs, a1, a2) => {
                for v in vs {
                    v.hash(&mut hasher);
                }
                a1.hash(&mut hasher);
                a2.hash(&mut hasher);
            }
            // Vec<Atom> + Vec<Atom>
            Expression::FlatLexLt(_, v1, v2) | Expression::FlatLexLeq(_, v1, v2) => {
                for v in v1 {
                    v.hash(&mut hasher);
                }
                for v in v2 {
                    v.hash(&mut hasher);
                }
            }
        };
        let result = hasher.finish();
        self.meta_ref().stored_hash.swap(result, Ordering::Relaxed);
        result
    }
}
#[cfg(test)]
mod tests {
    use crate::matrix_expr;
    use super::*;
    #[test]
1
    fn test_domain_of_constant_sum() {
1
        let c1 = Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Int(1)));
1
        let c2 = Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Int(2)));
1
        let sum = Expression::Sum(Metadata::new(), Moo::new(matrix_expr![c1, c2]));
1
        assert_eq!(sum.domain_of(), Some(Domain::int(vec![Range::Single(3)])));
1
    }
    #[test]
1
    fn test_domain_of_constant_invalid_type() {
1
        let c1 = Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Int(1)));
1
        let c2 = Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Bool(true)));
1
        let sum = Expression::Sum(Metadata::new(), Moo::new(matrix_expr![c1, c2]));
1
        assert_eq!(sum.domain_of(), None);
1
    }
    #[test]
1
    fn test_domain_of_empty_sum() {
1
        let sum = Expression::Sum(Metadata::new(), Moo::new(matrix_expr![]));
1
        assert_eq!(sum.domain_of(), None);
1
    }
}