1
use std::collections::{HashSet, VecDeque};
2
use std::fmt::{Display, Formatter};
3
use tracing::trace;
4

            
5
use crate::ast::Name;
6
use crate::ast::ReturnType;
7
use crate::ast::SetAttr;
8
use crate::ast::literals::AbstractLiteral;
9
use crate::ast::literals::Literal;
10
use crate::ast::pretty::{pretty_expressions_as_top_level, pretty_vec};
11
use crate::ast::{Atom, DomainPtr};
12
use crate::ast::{GroundDomain, Metadata, UnresolvedDomain};
13
use crate::ast::{IntVal, Moo};
14
use crate::bug;
15
use conjure_cp_enum_compatibility_macro::document_compatibility;
16
use itertools::Itertools;
17
use serde::{Deserialize, Serialize};
18
use ustr::Ustr;
19

            
20
use polyquine::Quine;
21
use uniplate::{Biplate, Uniplate};
22

            
23
use super::ac_operators::ACOperatorKind;
24
use super::categories::{Category, CategoryOf};
25
use super::comprehension::Comprehension;
26
use super::domains::HasDomain as _;
27
use super::records::RecordValue;
28
use super::{DeclarationPtr, Domain, Range, Reference, SubModel, Typeable};
29

            
30
// Ensure that this type doesn't get too big
31
//
32
// If you triggered this assertion, you either made a variant of this enum that is too big, or you
33
// made Name,Literal,AbstractLiteral,Atom bigger, which made this bigger! To fix this, put some
34
// stuff in boxes.
35
//
36
// Enums take the size of their largest variant, so an enum with mostly small variants and a few
37
// large ones wastes memory... A larger Expression type also slows down Oxide.
38
//
39
// For more information, and more details on type sizes and how to measure them, see the commit
40
// message for 6012de809 (perf: reduce size of AST types, 2025-06-18).
41
//
42
// You can also see type sizes in the rustdoc documentation, generated by ./tools/gen_docs.sh
43
//
44
// https://github.com/conjure-cp/conjure-oxide/commit/6012de8096ca491ded91ecec61352fdf4e994f2e
45

            
46
// TODO: box all usages of Metadata to bring this down a bit more - I have added variants to
47
// ReturnType, and Metadata contains ReturnType, so Metadata has got bigger. Metadata will get a
48
// lot bigger still when we start using it for memoisation, so it should really be
49
// boxed ~niklasdewally
50

            
51
// expect size of Expression to be 112 bytes
52
static_assertions::assert_eq_size!([u8; 104], Expression);
53

            
54
/// Represents different types of expressions used to define rules and constraints in the model.
55
///
56
/// The `Expression` enum includes operations, constants, and variable references
57
/// used to build rules and conditions for the model.
58
#[document_compatibility]
59
#[derive(Clone, Debug, Hash, PartialEq, Eq, Serialize, Deserialize, Uniplate, Quine)]
60
#[biplate(to=Metadata)]
61
#[biplate(to=Atom)]
62
#[biplate(to=DeclarationPtr)]
63
#[biplate(to=Name)]
64
#[biplate(to=Reference)]
65
#[biplate(to=Vec<Expression>)]
66
#[biplate(to=Option<Expression>)]
67
#[biplate(to=SubModel)]
68
#[biplate(to=Comprehension)]
69
#[biplate(to=AbstractLiteral<Expression>)]
70
#[biplate(to=AbstractLiteral<Literal>)]
71
#[biplate(to=RecordValue<Expression>)]
72
#[biplate(to=RecordValue<Literal>)]
73
#[biplate(to=Literal)]
74
#[biplate(to=DomainPtr)]
75
#[path_prefix(conjure_cp::ast)]
76
pub enum Expression {
77
    AbstractLiteral(Metadata, AbstractLiteral<Expression>),
78
    /// The top of the model
79
    Root(Metadata, Vec<Expression>),
80

            
81
    /// An expression representing "A is valid as long as B is true"
82
    /// Turns into a conjunction when it reaches a boolean context
83
    Bubble(Metadata, Moo<Expression>, Moo<Expression>),
84

            
85
    /// A comprehension.
86
    ///
87
    /// The inside of the comprehension opens a new scope.
88
    // todo (gskorokhod): Comprehension contains a SubModel which contains a bunch of Rc pointers.
89
    // This makes implementing Quine tricky (it doesnt support Rc, by design). Skip it for now.
90
    #[polyquine_skip]
91
    Comprehension(Metadata, Moo<Comprehension>),
92

            
93
    /// Defines dominance ("Solution A is preferred over Solution B")
94
    DominanceRelation(Metadata, Moo<Expression>),
95
    /// `fromSolution(name)` - Used in dominance relation definitions
96
    FromSolution(Metadata, Moo<Atom>),
97

            
98
    #[polyquine_with(arm = (_, name) => {
99
        let ident = proc_macro2::Ident::new(name.as_str(), proc_macro2::Span::call_site());
100
        quote::quote! { #ident.clone().into() }
101
    })]
102
    Metavar(Metadata, Ustr),
103

            
104
    Atomic(Metadata, Atom),
105

            
106
    /// A matrix index.
107
    ///
108
    /// Defined iff the indices are within their respective index domains.
109
    #[compatible(JsonInput)]
110
    UnsafeIndex(Metadata, Moo<Expression>, Vec<Expression>),
111

            
112
    /// A safe matrix index.
113
    ///
114
    /// See [`Expression::UnsafeIndex`]
115
    #[compatible(SMT)]
116
    SafeIndex(Metadata, Moo<Expression>, Vec<Expression>),
117

            
118
    /// A matrix slice: `a[indices]`.
119
    ///
120
    /// One of the indicies may be `None`, representing the dimension of the matrix we want to take
121
    /// a slice of. For example, for some 3d matrix a, `a[1,..,2]` has the indices
122
    /// `Some(1),None,Some(2)`.
123
    ///
124
    /// It is assumed that the slice only has one "wild-card" dimension and thus is 1 dimensional.
125
    ///
126
    /// Defined iff the defined indices are within their respective index domains.
127
    #[compatible(JsonInput)]
128
    UnsafeSlice(Metadata, Moo<Expression>, Vec<Option<Expression>>),
129

            
130
    /// A safe matrix slice: `a[indices]`.
131
    ///
132
    /// See [`Expression::UnsafeSlice`].
133
    SafeSlice(Metadata, Moo<Expression>, Vec<Option<Expression>>),
134

            
135
    /// `inDomain(x,domain)` iff `x` is in the domain `domain`.
136
    ///
137
    /// This cannot be constructed from Essence input, nor passed to a solver: this expression is
138
    /// mainly used during the conversion of `UnsafeIndex` and `UnsafeSlice` to `SafeIndex` and
139
    /// `SafeSlice` respectively.
140
    InDomain(Metadata, Moo<Expression>, DomainPtr),
141

            
142
    /// `toInt(b)` casts boolean expression b to an integer.
143
    ///
144
    /// - If b is false, then `toInt(b) == 0`
145
    ///
146
    /// - If b is true, then `toInt(b) == 1`
147
    #[compatible(SMT)]
148
    ToInt(Metadata, Moo<Expression>),
149

            
150
    // todo (gskorokhod): Same reason as for Comprehension
151
    #[polyquine_skip]
152
    Scope(Metadata, Moo<SubModel>),
153

            
154
    /// `|x|` - absolute value of `x`
155
    #[compatible(JsonInput, SMT)]
156
    Abs(Metadata, Moo<Expression>),
157

            
158
    /// `sum(<vec_expr>)`
159
    #[compatible(JsonInput, SMT)]
160
    Sum(Metadata, Moo<Expression>),
161

            
162
    /// `a * b * c * ...`
163
    #[compatible(JsonInput, SMT)]
164
    Product(Metadata, Moo<Expression>),
165

            
166
    /// `min(<vec_expr>)`
167
    #[compatible(JsonInput, SMT)]
168
    Min(Metadata, Moo<Expression>),
169

            
170
    /// `max(<vec_expr>)`
171
    #[compatible(JsonInput, SMT)]
172
    Max(Metadata, Moo<Expression>),
173

            
174
    /// `not(a)`
175
    #[compatible(JsonInput, SAT, SMT)]
176
    Not(Metadata, Moo<Expression>),
177

            
178
    /// `or(<vec_expr>)`
179
    #[compatible(JsonInput, SAT, SMT)]
180
    Or(Metadata, Moo<Expression>),
181

            
182
    /// `and(<vec_expr>)`
183
    #[compatible(JsonInput, SAT, SMT)]
184
    And(Metadata, Moo<Expression>),
185

            
186
    /// Ensures that `a->b` (material implication).
187
    #[compatible(JsonInput, SMT)]
188
    Imply(Metadata, Moo<Expression>, Moo<Expression>),
189

            
190
    /// `iff(a, b)` a <-> b
191
    #[compatible(JsonInput, SMT)]
192
    Iff(Metadata, Moo<Expression>, Moo<Expression>),
193

            
194
    #[compatible(JsonInput)]
195
    Union(Metadata, Moo<Expression>, Moo<Expression>),
196

            
197
    #[compatible(JsonInput)]
198
    In(Metadata, Moo<Expression>, Moo<Expression>),
199

            
200
    #[compatible(JsonInput)]
201
    Intersect(Metadata, Moo<Expression>, Moo<Expression>),
202

            
203
    #[compatible(JsonInput)]
204
    Supset(Metadata, Moo<Expression>, Moo<Expression>),
205

            
206
    #[compatible(JsonInput)]
207
    SupsetEq(Metadata, Moo<Expression>, Moo<Expression>),
208

            
209
    #[compatible(JsonInput)]
210
    Subset(Metadata, Moo<Expression>, Moo<Expression>),
211

            
212
    #[compatible(JsonInput)]
213
    SubsetEq(Metadata, Moo<Expression>, Moo<Expression>),
214

            
215
    #[compatible(JsonInput, SMT)]
216
    Eq(Metadata, Moo<Expression>, Moo<Expression>),
217

            
218
    #[compatible(JsonInput, SMT)]
219
    Neq(Metadata, Moo<Expression>, Moo<Expression>),
220

            
221
    #[compatible(JsonInput, SMT)]
222
    Geq(Metadata, Moo<Expression>, Moo<Expression>),
223

            
224
    #[compatible(JsonInput, SMT)]
225
    Leq(Metadata, Moo<Expression>, Moo<Expression>),
226

            
227
    #[compatible(JsonInput, SMT)]
228
    Gt(Metadata, Moo<Expression>, Moo<Expression>),
229

            
230
    #[compatible(JsonInput, SMT)]
231
    Lt(Metadata, Moo<Expression>, Moo<Expression>),
232

            
233
    /// Division after preventing division by zero, usually with a bubble
234
    #[compatible(SMT)]
235
    SafeDiv(Metadata, Moo<Expression>, Moo<Expression>),
236

            
237
    /// Division with a possibly undefined value (division by 0)
238
    #[compatible(JsonInput)]
239
    UnsafeDiv(Metadata, Moo<Expression>, Moo<Expression>),
240

            
241
    /// Modulo after preventing mod 0, usually with a bubble
242
    #[compatible(SMT)]
243
    SafeMod(Metadata, Moo<Expression>, Moo<Expression>),
244

            
245
    /// Modulo with a possibly undefined value (mod 0)
246
    #[compatible(JsonInput)]
247
    UnsafeMod(Metadata, Moo<Expression>, Moo<Expression>),
248

            
249
    /// Negation: `-x`
250
    #[compatible(JsonInput, SMT)]
251
    Neg(Metadata, Moo<Expression>),
252

            
253
    /// Set of domain values function is defined for
254
    #[compatible(JsonInput)]
255
    Defined(Metadata, Moo<Expression>),
256

            
257
    /// Set of codomain values function is defined for
258
    #[compatible(JsonInput)]
259
    Range(Metadata, Moo<Expression>),
260

            
261
    /// Unsafe power`x**y` (possibly undefined)
262
    ///
263
    /// Defined when (X!=0 \\/ Y!=0) /\ Y>=0
264
    #[compatible(JsonInput)]
265
    UnsafePow(Metadata, Moo<Expression>, Moo<Expression>),
266

            
267
    /// `UnsafePow` after preventing undefinedness
268
    SafePow(Metadata, Moo<Expression>, Moo<Expression>),
269

            
270
    /// Flatten matrix operator
271
    /// `flatten(M)` or `flatten(n, M)`
272
    /// where M is a matrix and n is an optional integer argument indicating depth of flattening
273
    Flatten(Metadata, Option<Moo<Expression>>, Moo<Expression>),
274

            
275
    /// `allDiff(<vec_expr>)`
276
    #[compatible(JsonInput)]
277
    AllDiff(Metadata, Moo<Expression>),
278

            
279
    /// Binary subtraction operator
280
    ///
281
    /// This is a parser-level construct, and is immediately normalised to `Sum([a,-b])`.
282
    /// TODO: make this compatible with Set Difference calculations - need to change return type and domain for this expression and write a set comprehension rule.
283
    /// have already edited minus_to_sum to prevent this from applying to sets
284
    #[compatible(JsonInput)]
285
    Minus(Metadata, Moo<Expression>, Moo<Expression>),
286

            
287
    /// Ensures that x=|y| i.e. x is the absolute value of y.
288
    ///
289
    /// Low-level Minion constraint.
290
    ///
291
    /// # See also
292
    ///
293
    /// + [Minion documentation](https://minion-solver.readthedocs.io/en/stable/usage/constraints.html#abs)
294
    #[compatible(Minion)]
295
    FlatAbsEq(Metadata, Moo<Atom>, Moo<Atom>),
296

            
297
    /// Ensures that `alldiff([a,b,...])`.
298
    ///
299
    /// Low-level Minion constraint.
300
    ///
301
    /// # See also
302
    ///
303
    /// + [Minion documentation](https://minion-solver.readthedocs.io/en/stable/usage/constraints.html#alldiff)
304
    #[compatible(Minion)]
305
    FlatAllDiff(Metadata, Vec<Atom>),
306

            
307
    /// Ensures that sum(vec) >= x.
308
    ///
309
    /// Low-level Minion constraint.
310
    ///
311
    /// # See also
312
    ///
313
    /// + [Minion documentation](https://minion-solver.readthedocs.io/en/stable/usage/constraints.html#sumgeq)
314
    #[compatible(Minion)]
315
    FlatSumGeq(Metadata, Vec<Atom>, Atom),
316

            
317
    /// Ensures that sum(vec) <= x.
318
    ///
319
    /// Low-level Minion constraint.
320
    ///
321
    /// # See also
322
    ///
323
    /// + [Minion documentation](https://minion-solver.readthedocs.io/en/stable/usage/constraints.html#sumleq)
324
    #[compatible(Minion)]
325
    FlatSumLeq(Metadata, Vec<Atom>, Atom),
326

            
327
    /// `ineq(x,y,k)` ensures that x <= y + k.
328
    ///
329
    /// Low-level Minion constraint.
330
    ///
331
    /// # See also
332
    ///
333
    /// + [Minion documentation](https://minion-solver.readthedocs.io/en/stable/usage/constraints.html#ineq)
334
    #[compatible(Minion)]
335
    FlatIneq(Metadata, Moo<Atom>, Moo<Atom>, Box<Literal>),
336

            
337
    /// `w-literal(x,k)` ensures that x == k, where x is a variable and k a constant.
338
    ///
339
    /// Low-level Minion constraint.
340
    ///
341
    /// This is a low-level Minion constraint and you should probably use Eq instead. The main use
342
    /// of w-literal is to convert boolean variables to constraints so that they can be used inside
343
    /// watched-and and watched-or.
344
    ///
345
    /// # See also
346
    ///
347
    /// + [Minion documentation](https://minion-solver.readthedocs.io/en/stable/usage/constraints.html#minuseq)
348
    /// + `rules::minion::boolean_literal_to_wliteral`.
349
    #[compatible(Minion)]
350
    #[polyquine_skip]
351
    FlatWatchedLiteral(Metadata, Reference, Literal),
352

            
353
    /// `weightedsumleq(cs,xs,total)` ensures that cs.xs <= total, where cs.xs is the scalar dot
354
    /// product of cs and xs.
355
    ///
356
    /// Low-level Minion constraint.
357
    ///
358
    /// Represents a weighted sum of the form `ax + by + cz + ...`
359
    ///
360
    /// # See also
361
    ///
362
    /// + [Minion
363
    /// documentation](https://minion-solver.readthedocs.io/en/stable/usage/constraints.html#weightedsumleq)
364
    FlatWeightedSumLeq(Metadata, Vec<Literal>, Vec<Atom>, Moo<Atom>),
365

            
366
    /// `weightedsumgeq(cs,xs,total)` ensures that cs.xs >= total, where cs.xs is the scalar dot
367
    /// product of cs and xs.
368
    ///
369
    /// Low-level Minion constraint.
370
    ///
371
    /// Represents a weighted sum of the form `ax + by + cz + ...`
372
    ///
373
    /// # See also
374
    ///
375
    /// + [Minion
376
    /// documentation](https://minion-solver.readthedocs.io/en/stable/usage/constraints.html#weightedsumleq)
377
    FlatWeightedSumGeq(Metadata, Vec<Literal>, Vec<Atom>, Moo<Atom>),
378

            
379
    /// Ensures that x =-y, where x and y are atoms.
380
    ///
381
    /// Low-level Minion constraint.
382
    ///
383
    /// # See also
384
    ///
385
    /// + [Minion documentation](https://minion-solver.readthedocs.io/en/stable/usage/constraints.html#minuseq)
386
    #[compatible(Minion)]
387
    FlatMinusEq(Metadata, Moo<Atom>, Moo<Atom>),
388

            
389
    /// Ensures that x*y=z.
390
    ///
391
    /// Low-level Minion constraint.
392
    ///
393
    /// # See also
394
    ///
395
    /// + [Minion documentation](https://minion-solver.readthedocs.io/en/stable/usage/constraints.html#product)
396
    #[compatible(Minion)]
397
    FlatProductEq(Metadata, Moo<Atom>, Moo<Atom>, Moo<Atom>),
398

            
399
    /// Ensures that floor(x/y)=z. Always true when y=0.
400
    ///
401
    /// Low-level Minion constraint.
402
    ///
403
    /// # See also
404
    ///
405
    /// + [Minion documentation](https://minion-solver.readthedocs.io/en/stable/usage/constraints.html#div_undefzero)
406
    #[compatible(Minion)]
407
    MinionDivEqUndefZero(Metadata, Moo<Atom>, Moo<Atom>, Moo<Atom>),
408

            
409
    /// Ensures that x%y=z. Always true when y=0.
410
    ///
411
    /// Low-level Minion constraint.
412
    ///
413
    /// # See also
414
    ///
415
    /// + [Minion documentation](https://minion-solver.readthedocs.io/en/stable/usage/constraints.html#mod_undefzero)
416
    #[compatible(Minion)]
417
    MinionModuloEqUndefZero(Metadata, Moo<Atom>, Moo<Atom>, Moo<Atom>),
418

            
419
    /// Ensures that `x**y = z`.
420
    ///
421
    /// Low-level Minion constraint.
422
    ///
423
    /// This constraint is false when `y<0` except for `1**y=1` and `(-1)**y=z` (where z is 1 if y
424
    /// is odd and z is -1 if y is even).
425
    ///
426
    /// # See also
427
    ///
428
    /// + [Github comment about `pow` semantics](https://github.com/minion/minion/issues/40#issuecomment-2595914891)
429
    /// + [Minion documentation](https://minion-solver.readthedocs.io/en/stable/usage/constraints.html#pow)
430
    MinionPow(Metadata, Moo<Atom>, Moo<Atom>, Moo<Atom>),
431

            
432
    /// `reify(constraint,r)` ensures that r=1 iff `constraint` is satisfied, where r is a 0/1
433
    /// variable.
434
    ///
435
    /// Low-level Minion constraint.
436
    ///
437
    /// # See also
438
    ///
439
    ///  + [Minion documentation](https://minion-solver.readthedocs.io/en/stable/usage/constraints.html#reify)
440
    #[compatible(Minion)]
441
    MinionReify(Metadata, Moo<Expression>, Atom),
442

            
443
    /// `reifyimply(constraint,r)` ensures that `r->constraint`, where r is a 0/1 variable.
444
    /// variable.
445
    ///
446
    /// Low-level Minion constraint.
447
    ///
448
    /// # See also
449
    ///
450
    ///  + [Minion documentation](https://minion-solver.readthedocs.io/en/stable/usage/constraints.html#reifyimply)
451
    #[compatible(Minion)]
452
    MinionReifyImply(Metadata, Moo<Expression>, Atom),
453

            
454
    /// `w-inintervalset(x, [a1,a2, b1,b2, … ])` ensures that the value of x belongs to one of the
455
    /// intervals {a1,…,a2}, {b1,…,b2} etc.
456
    ///
457
    /// The list of intervals must be given in numerical order.
458
    ///
459
    /// Low-level Minion constraint.
460
    ///
461
    /// # See also
462
    ///>
463
    ///  + [Minion documentation](https://minion-solver.readthedocs.io/en/stable/usage/constraints.html#w-inintervalset)
464
    #[compatible(Minion)]
465
    MinionWInIntervalSet(Metadata, Atom, Vec<i32>),
466

            
467
    /// `w-inset(x, [v1, v2, … ])` ensures that the value of `x` is one of the explicitly given values `v1`, `v2`, etc.
468
    ///
469
    /// This constraint enforces membership in a specific set of discrete values rather than intervals.
470
    ///
471
    /// The list of values must be given in numerical order.
472
    ///
473
    /// Low-level Minion constraint.
474
    ///
475
    /// # See also
476
    ///
477
    ///  + [Minion documentation](https://minion-solver.readthedocs.io/en/stable/usage/constraints.html#w-inset)
478
    #[compatible(Minion)]
479
    MinionWInSet(Metadata, Atom, Vec<i32>),
480

            
481
    /// `element_one(vec, i, e)` specifies that `vec[i] = e`. This implies that i is
482
    /// in the range `[1..len(vec)]`.
483
    ///
484
    /// Low-level Minion constraint.
485
    ///
486
    /// # See also
487
    ///
488
    ///  + [Minion documentation](https://minion-solver.readthedocs.io/en/stable/usage/constraints.html#element_one)
489
    #[compatible(Minion)]
490
    MinionElementOne(Metadata, Vec<Atom>, Moo<Atom>, Moo<Atom>),
491

            
492
    /// Declaration of an auxiliary variable.
493
    ///
494
    /// As with Savile Row, we semantically distinguish this from `Eq`.
495
    #[compatible(Minion)]
496
    #[polyquine_skip]
497
    AuxDeclaration(Metadata, Reference, Moo<Expression>),
498

            
499
    /// This expression is for encoding i32 ints as a vector of boolean expressions for cnf - using 2s complement
500
    #[compatible(SAT)]
501
    SATInt(Metadata, Moo<Expression>),
502

            
503
    /// Addition over a pair of expressions (i.e. a + b) rather than a vec-expr like Expression::Sum.
504
    /// This is for compatibility with backends that do not support addition over vectors.
505
    #[compatible(SMT)]
506
    PairwiseSum(Metadata, Moo<Expression>, Moo<Expression>),
507

            
508
    /// Multiplication over a pair of expressions (i.e. a * b) rather than a vec-expr like Expression::Product.
509
    /// This is for compatibility with backends that do not support multiplication over vectors.
510
    #[compatible(SMT)]
511
    PairwiseProduct(Metadata, Moo<Expression>, Moo<Expression>),
512

            
513
    #[compatible(JsonInput)]
514
    Image(Metadata, Moo<Expression>, Moo<Expression>),
515

            
516
    #[compatible(JsonInput)]
517
    ImageSet(Metadata, Moo<Expression>, Moo<Expression>),
518

            
519
    #[compatible(JsonInput)]
520
    PreImage(Metadata, Moo<Expression>, Moo<Expression>),
521

            
522
    #[compatible(JsonInput)]
523
    Inverse(Metadata, Moo<Expression>, Moo<Expression>),
524

            
525
    #[compatible(JsonInput)]
526
    Restrict(Metadata, Moo<Expression>, Moo<Expression>),
527

            
528
    /// Lexicographical < between two matrices.
529
    ///
530
    /// A <lex B iff: A[i] < B[i] for some i /\ (A[j] > B[j] for some j -> i < j)
531
    /// I.e. A must be less than B at some index i, and if it is greater than B at another index j,
532
    /// then j comes after i.
533
    /// I.e. A must be greater than B at the first index where they differ.
534
    ///
535
    /// E.g. [1, 1] <lex [2, 1] and [1, 1] <lex [1, 2]
536
    LexLt(Metadata, Moo<Expression>, Moo<Expression>),
537

            
538
    /// Lexicographical <= between two matrices
539
    LexLeq(Metadata, Moo<Expression>, Moo<Expression>),
540

            
541
    /// Lexicographical > between two matrices
542
    /// This is a parser-level construct, and is immediately normalised to LexLt(b, a)
543
    LexGt(Metadata, Moo<Expression>, Moo<Expression>),
544

            
545
    /// Lexicographical >= between two matrices
546
    /// This is a parser-level construct, and is immediately normalised to LexLeq(b, a)
547
    LexGeq(Metadata, Moo<Expression>, Moo<Expression>),
548

            
549
    /// Low-level minion constraint. See Expression::LexLt
550
    FlatLexLt(Metadata, Vec<Atom>, Vec<Atom>),
551

            
552
    /// Low-level minion constraint. See Expression::LexLeq
553
    FlatLexLeq(Metadata, Vec<Atom>, Vec<Atom>),
554
}
555

            
556
// for the given matrix literal, return a bounded domain from the min to max of applying op to each
557
// child expression.
558
//
559
// Op must be monotonic.
560
//
561
// Returns none if unbounded
562
fn bounded_i32_domain_for_matrix_literal_monotonic(
563
    e: &Expression,
564
    op: fn(i32, i32) -> Option<i32>,
565
) -> Option<DomainPtr> {
566
    // only care about the elements, not the indices
567
    let (mut exprs, _) = e.clone().unwrap_matrix_unchecked()?;
568

            
569
    // fold each element's domain into one using op.
570
    //
571
    // here, I assume that op is monotone. This means that the bounds of op([a1,a2],[b1,b2])  for
572
    // the ranges [a1,a2], [b1,b2] will be
573
    // [min(op(a1,b1),op(a2,b1),op(a1,b2),op(a2,b2)),max(op(a1,b1),op(a2,b1),op(a1,b2),op(a2,b2))].
574
    //
575
    // We used to not assume this, and work out the bounds by applying op on the Cartesian product
576
    // of A and B; however, this caused a combinatorial explosion and my computer to run out of
577
    // memory (on the hakank_eprime_xkcd test)...
578
    //Int
579
    // For example, to find the bounds of the intervals [1,4], [1,5] combined using op, we used to do
580
    //  [min(op(1,1), op(1,2),op(1,3),op(1,4),op(1,5),op(2,1)..
581
    //
582
    // +,-,/,* are all monotone, so this assumption should be fine for now...
583

            
584
    let expr = exprs.pop()?;
585
    let dom = expr.domain_of()?;
586
    let Some(GroundDomain::Int(ranges)) = dom.as_ground() else {
587
        return None;
588
    };
589

            
590
    let (mut current_min, mut current_max) = range_vec_bounds_i32(ranges)?;
591

            
592
    for expr in exprs {
593
        let dom = expr.domain_of()?;
594
        let Some(GroundDomain::Int(ranges)) = dom.as_ground() else {
595
            return None;
596
        };
597

            
598
        let (min, max) = range_vec_bounds_i32(ranges)?;
599

            
600
        // all the possible new values for current_min / current_max
601
        let minmax = op(min, current_max)?;
602
        let minmin = op(min, current_min)?;
603
        let maxmin = op(max, current_min)?;
604
        let maxmax = op(max, current_max)?;
605
        let vals = [minmax, minmin, maxmin, maxmax];
606

            
607
        current_min = *vals
608
            .iter()
609
            .min()
610
            .expect("vals iterator should not be empty, and should have a minimum.");
611
        current_max = *vals
612
            .iter()
613
            .max()
614
            .expect("vals iterator should not be empty, and should have a maximum.");
615
    }
616

            
617
    if current_min == current_max {
618
        Some(Domain::int(vec![Range::Single(current_min)]))
619
    } else {
620
        Some(Domain::int(vec![Range::Bounded(current_min, current_max)]))
621
    }
622
}
623

            
624
// Returns none if unbounded
625
fn range_vec_bounds_i32(ranges: &Vec<Range<i32>>) -> Option<(i32, i32)> {
626
    let mut min = i32::MAX;
627
    let mut max = i32::MIN;
628
    for r in ranges {
629
        match r {
630
            Range::Single(i) => {
631
                if *i < min {
632
                    min = *i;
633
                }
634
                if *i > max {
635
                    max = *i;
636
                }
637
            }
638
            Range::Bounded(i, j) => {
639
                if *i < min {
640
                    min = *i;
641
                }
642
                if *j > max {
643
                    max = *j;
644
                }
645
            }
646
            Range::UnboundedR(_) | Range::UnboundedL(_) | Range::Unbounded => return None,
647
        }
648
    }
649
    Some((min, max))
650
}
651

            
652
impl Expression {
653
    /// Returns the possible values of the expression, recursing to leaf expressions
654
    pub fn domain_of(&self) -> Option<DomainPtr> {
655
        //println!("domain_of {self}");
656
        let ret = match self {
657
            Expression::Union(_, a, b) => Some(Domain::set(
658
                SetAttr::<IntVal>::default(),
659
                a.domain_of()?.union(&b.domain_of()?).ok()?,
660
            )),
661
            Expression::Intersect(_, a, b) => Some(Domain::set(
662
                SetAttr::<IntVal>::default(),
663
                a.domain_of()?.intersect(&b.domain_of()?).ok()?,
664
            )),
665
            Expression::In(_, _, _) => Some(Domain::bool()),
666
            Expression::Supset(_, _, _) => Some(Domain::bool()),
667
            Expression::SupsetEq(_, _, _) => Some(Domain::bool()),
668
            Expression::Subset(_, _, _) => Some(Domain::bool()),
669
            Expression::SubsetEq(_, _, _) => Some(Domain::bool()),
670
            Expression::AbstractLiteral(_, abslit) => abslit.domain_of(),
671
            Expression::DominanceRelation(_, _) => Some(Domain::bool()),
672
            Expression::FromSolution(_, expr) => Some(expr.domain_of()),
673
            Expression::Metavar(_, _) => None,
674
            Expression::Comprehension(_, comprehension) => comprehension.domain_of(),
675
            Expression::UnsafeIndex(_, matrix, _) | Expression::SafeIndex(_, matrix, _) => {
676
                let dom = matrix.domain_of()?;
677
                if let Some((elem_domain, _)) = dom.as_matrix() {
678
                    return Some(elem_domain);
679
                }
680

            
681
                // may actually use the value in the future
682
                #[allow(clippy::redundant_pattern_matching)]
683
                if let Some(_) = dom.as_tuple() {
684
                    // TODO: We can implement proper indexing for tuples
685
                    return None;
686
                }
687

            
688
                // may actually use the value in the future
689
                #[allow(clippy::redundant_pattern_matching)]
690
                if let Some(_) = dom.as_record() {
691
                    // TODO: We can implement proper indexing for records
692
                    return None;
693
                }
694

            
695
                bug!("subject of an index operation should support indexing")
696
            }
697
            Expression::UnsafeSlice(_, matrix, indices)
698
            | Expression::SafeSlice(_, matrix, indices) => {
699
                let sliced_dimension = indices.iter().position(Option::is_none);
700

            
701
                let dom = matrix.domain_of()?;
702
                let Some((elem_domain, index_domains)) = dom.as_matrix() else {
703
                    bug!("subject of an index operation should be a matrix");
704
                };
705

            
706
                match sliced_dimension {
707
                    Some(dimension) => Some(Domain::matrix(
708
                        elem_domain,
709
                        vec![index_domains[dimension].clone()],
710
                    )),
711

            
712
                    // same as index
713
                    None => Some(elem_domain),
714
                }
715
            }
716
            Expression::InDomain(_, _, _) => Some(Domain::bool()),
717
            Expression::Atomic(_, atom) => Some(atom.domain_of()),
718
            Expression::Scope(_, _) => Some(Domain::bool()),
719
            Expression::Sum(_, e) => {
720
                bounded_i32_domain_for_matrix_literal_monotonic(e, |x, y| Some(x + y))
721
            }
722
            Expression::Product(_, e) => {
723
                bounded_i32_domain_for_matrix_literal_monotonic(e, |x, y| Some(x * y))
724
            }
725
            Expression::Min(_, e) => bounded_i32_domain_for_matrix_literal_monotonic(e, |x, y| {
726
                Some(if x < y { x } else { y })
727
            }),
728
            Expression::Max(_, e) => bounded_i32_domain_for_matrix_literal_monotonic(e, |x, y| {
729
                Some(if x > y { x } else { y })
730
            }),
731
            Expression::UnsafeDiv(_, a, b) => a
732
                .domain_of()?
733
                .resolve()?
734
                .apply_i32(
735
                    // rust integer division is truncating; however, we want to always round down,
736
                    // including for negative numbers.
737
                    |x, y| {
738
                        if y != 0 {
739
                            Some((x as f32 / y as f32).floor() as i32)
740
                        } else {
741
                            None
742
                        }
743
                    },
744
                    b.domain_of()?.resolve()?.as_ref(),
745
                )
746
                .map(DomainPtr::from)
747
                .ok(),
748
            Expression::SafeDiv(_, a, b) => {
749
                // rust integer division is truncating; however, we want to always round down
750
                // including for negative numbers.
751
                let domain = a
752
                    .domain_of()?
753
                    .resolve()?
754
                    .apply_i32(
755
                        |x, y| {
756
                            if y != 0 {
757
                                Some((x as f32 / y as f32).floor() as i32)
758
                            } else {
759
                                None
760
                            }
761
                        },
762
                        b.domain_of()?.resolve()?.as_ref(),
763
                    )
764
                    .unwrap_or_else(|err| bug!("Got {err} when computing domain of {self}"));
765

            
766
                if let GroundDomain::Int(ranges) = domain {
767
                    let mut ranges = ranges;
768
                    ranges.push(Range::Single(0));
769
                    return Some(Domain::int(ranges));
770
                } else {
771
                    bug!("Domain of {self} was not integer")
772
                }
773
            }
774
            Expression::UnsafeMod(_, a, b) => a
775
                .domain_of()?
776
                .resolve()?
777
                .apply_i32(
778
                    |x, y| if y != 0 { Some(x % y) } else { None },
779
                    b.domain_of()?.resolve()?.as_ref(),
780
                )
781
                .map(DomainPtr::from)
782
                .ok(),
783
            Expression::SafeMod(_, a, b) => {
784
                let domain = a
785
                    .domain_of()?
786
                    .resolve()?
787
                    .apply_i32(
788
                        |x, y| if y != 0 { Some(x % y) } else { None },
789
                        b.domain_of()?.resolve()?.as_ref(),
790
                    )
791
                    .unwrap_or_else(|err| bug!("Got {err} when computing domain of {self}"));
792

            
793
                if let GroundDomain::Int(ranges) = domain {
794
                    let mut ranges = ranges;
795
                    ranges.push(Range::Single(0));
796
                    return Some(Domain::int(ranges));
797
                } else {
798
                    bug!("Domain of {self} was not integer")
799
                }
800
            }
801
            Expression::SafePow(_, a, b) | Expression::UnsafePow(_, a, b) => a
802
                .domain_of()?
803
                .resolve()?
804
                .apply_i32(
805
                    |x, y| {
806
                        if (x != 0 || y != 0) && y >= 0 {
807
                            Some(x.pow(y as u32))
808
                        } else {
809
                            None
810
                        }
811
                    },
812
                    b.domain_of()?.resolve()?.as_ref(),
813
                )
814
                .map(DomainPtr::from)
815
                .ok(),
816
            Expression::Root(_, _) => None,
817
            Expression::Bubble(_, inner, _) => inner.domain_of(),
818
            Expression::AuxDeclaration(_, _, _) => Some(Domain::bool()),
819
            Expression::And(_, _) => Some(Domain::bool()),
820
            Expression::Not(_, _) => Some(Domain::bool()),
821
            Expression::Or(_, _) => Some(Domain::bool()),
822
            Expression::Imply(_, _, _) => Some(Domain::bool()),
823
            Expression::Iff(_, _, _) => Some(Domain::bool()),
824
            Expression::Eq(_, _, _) => Some(Domain::bool()),
825
            Expression::Neq(_, _, _) => Some(Domain::bool()),
826
            Expression::Geq(_, _, _) => Some(Domain::bool()),
827
            Expression::Leq(_, _, _) => Some(Domain::bool()),
828
            Expression::Gt(_, _, _) => Some(Domain::bool()),
829
            Expression::Lt(_, _, _) => Some(Domain::bool()),
830
            Expression::FlatAbsEq(_, _, _) => Some(Domain::bool()),
831
            Expression::FlatSumGeq(_, _, _) => Some(Domain::bool()),
832
            Expression::FlatSumLeq(_, _, _) => Some(Domain::bool()),
833
            Expression::MinionDivEqUndefZero(_, _, _, _) => Some(Domain::bool()),
834
            Expression::MinionModuloEqUndefZero(_, _, _, _) => Some(Domain::bool()),
835
            Expression::FlatIneq(_, _, _, _) => Some(Domain::bool()),
836
            Expression::Flatten(_, n, m) => {
837
                if let Some(expr) = n {
838
                    if expr.return_type() == ReturnType::Int {
839
                        // TODO: handle flatten with depth argument
840
                        return None;
841
                    }
842
                } else {
843
                    let domain = m.domain_of()?;
844
                    let mut total_size = 1;
845
                    let index_domains: Vec<Domain> = Vec::new();
846

            
847
                    // calculate total flattened size
848
                    for i in &index_domains {
849
                        total_size *= i.length().ok()?;
850
                    }
851
                    let new_index_domain =
852
                        Domain::int(vec![Range::Bounded(1, total_size.try_into().unwrap())]);
853
                    return Some(Domain::matrix(domain, vec![new_index_domain]));
854
                }
855
                None
856
            }
857
            Expression::AllDiff(_, _) => Some(Domain::bool()),
858
            Expression::FlatWatchedLiteral(_, _, _) => Some(Domain::bool()),
859
            Expression::MinionReify(_, _, _) => Some(Domain::bool()),
860
            Expression::MinionReifyImply(_, _, _) => Some(Domain::bool()),
861
            Expression::MinionWInIntervalSet(_, _, _) => Some(Domain::bool()),
862
            Expression::MinionWInSet(_, _, _) => Some(Domain::bool()),
863
            Expression::MinionElementOne(_, _, _, _) => Some(Domain::bool()),
864
            Expression::Neg(_, x) => {
865
                let dom = x.domain_of()?;
866
                let mut ranges = dom.as_int()?;
867

            
868
                ranges = ranges
869
                    .into_iter()
870
                    .map(|r| match r {
871
                        Range::Single(x) => Range::Single(-x),
872
                        Range::Bounded(x, y) => Range::Bounded(-y, -x),
873
                        Range::UnboundedR(i) => Range::UnboundedL(-i),
874
                        Range::UnboundedL(i) => Range::UnboundedR(-i),
875
                        Range::Unbounded => Range::Unbounded,
876
                    })
877
                    .collect();
878

            
879
                Some(Domain::int(ranges))
880
            }
881
            Expression::Minus(_, a, b) => a
882
                .domain_of()?
883
                .resolve()?
884
                .apply_i32(|x, y| Some(x - y), b.domain_of()?.resolve()?.as_ref())
885
                .map(DomainPtr::from)
886
                .ok(),
887
            Expression::FlatAllDiff(_, _) => Some(Domain::bool()),
888
            Expression::FlatMinusEq(_, _, _) => Some(Domain::bool()),
889
            Expression::FlatProductEq(_, _, _, _) => Some(Domain::bool()),
890
            Expression::FlatWeightedSumLeq(_, _, _, _) => Some(Domain::bool()),
891
            Expression::FlatWeightedSumGeq(_, _, _, _) => Some(Domain::bool()),
892
            Expression::Abs(_, a) => a
893
                .domain_of()?
894
                .resolve()?
895
                .apply_i32(|a, _| Some(a.abs()), a.domain_of()?.resolve()?.as_ref())
896
                .map(DomainPtr::from)
897
                .ok(),
898
            Expression::MinionPow(_, _, _, _) => Some(Domain::bool()),
899
            Expression::ToInt(_, _) => Some(Domain::int(vec![Range::Bounded(0, 1)])),
900
            Expression::SATInt(_, _) => {
901
                Some(Domain::int_ground(vec![Range::Bounded(
902
                    i8::MIN.into(),
903
                    i8::MAX.into(),
904
                )])) // BITS
905
            } // A CnfInt can represent any i8 integer at the moment
906
            // A CnfInt contains multiple boolean expressions and represents the integer
907
            // formed when these booleans are treated as the bits in an integer encoding.
908
            // So the 'domain of' should be an integer
909
            Expression::PairwiseSum(_, a, b) => a
910
                .domain_of()?
911
                .resolve()?
912
                .apply_i32(|a, b| Some(a + b), b.domain_of()?.resolve()?.as_ref())
913
                .map(DomainPtr::from)
914
                .ok(),
915
            Expression::PairwiseProduct(_, a, b) => a
916
                .domain_of()?
917
                .resolve()?
918
                .apply_i32(|a, b| Some(a * b), b.domain_of()?.resolve()?.as_ref())
919
                .map(DomainPtr::from)
920
                .ok(),
921
            Expression::Defined(_, function) => get_function_domain(function),
922
            Expression::Range(_, function) => get_function_codomain(function),
923
            Expression::Image(_, function, _) => get_function_codomain(function),
924
            Expression::ImageSet(_, function, _) => get_function_codomain(function),
925
            Expression::PreImage(_, function, _) => get_function_domain(function),
926
            Expression::Restrict(_, function, new_domain) => {
927
                let function_domain = function.domain_of()?;
928
                match function_domain.resolve().as_ref() {
929
                    Some(d) => {
930
                        match d.as_ref() {
931
                            GroundDomain::Function(attrs, _, codomain) => Some(Domain::function(
932
                                attrs.clone(),
933
                                new_domain.domain_of()?,
934
                                codomain.clone().into(),
935
                            )),
936
                            // Not defined for anything other than a function
937
                            _ => None,
938
                        }
939
                    }
940
                    None => {
941
                        match function_domain.as_unresolved()? {
942
                            UnresolvedDomain::Function(attrs, _, codomain) => {
943
                                Some(Domain::function(
944
                                    attrs.clone(),
945
                                    new_domain.domain_of()?,
946
                                    codomain.clone(),
947
                                ))
948
                            }
949
                            // Not defined for anything other than a function
950
                            _ => None,
951
                        }
952
                    }
953
                }
954
            }
955
            Expression::Inverse(..) => Some(Domain::bool()),
956
            Expression::LexLt(..) => Some(Domain::bool()),
957
            Expression::LexLeq(..) => Some(Domain::bool()),
958
            Expression::LexGt(..) => Some(Domain::bool()),
959
            Expression::LexGeq(..) => Some(Domain::bool()),
960
            Expression::FlatLexLt(..) => Some(Domain::bool()),
961
            Expression::FlatLexLeq(..) => Some(Domain::bool()),
962
        };
963
        if let Some(dom) = &ret
964
            && let Some(ranges) = dom.as_int_ground()
965
            && ranges.len() > 1
966
        {
967
            // TODO: (flm8) the Minion bindings currently only support single ranges for domains, so we use the min/max bounds
968
            // Once they support a full domain as we define it, we can remove this conversion
969
            let (min, max) = range_vec_bounds_i32(ranges)?;
970
            return Some(Domain::int(vec![Range::Bounded(min, max)]));
971
        }
972
        ret
973
    }
974

            
975
    pub fn get_meta(&self) -> Metadata {
976
        let metas: VecDeque<Metadata> = self.children_bi();
977
        metas[0].clone()
978
    }
979

            
980
    pub fn set_meta(&self, meta: Metadata) {
981
        self.transform_bi(&|_| meta.clone());
982
    }
983

            
984
    /// Checks whether this expression is safe.
985
    ///
986
    /// An expression is unsafe if can be undefined, or if any of its children can be undefined.
987
    ///
988
    /// Unsafe expressions are (typically) prefixed with Unsafe in our AST, and can be made
989
    /// safe through the use of bubble rules.
990
    pub fn is_safe(&self) -> bool {
991
        // TODO: memoise in Metadata
992
        for expr in self.universe() {
993
            match expr {
994
                Expression::UnsafeDiv(_, _, _)
995
                | Expression::UnsafeMod(_, _, _)
996
                | Expression::UnsafePow(_, _, _)
997
                | Expression::UnsafeIndex(_, _, _)
998
                | Expression::Bubble(_, _, _)
999
                | Expression::UnsafeSlice(_, _, _) => {
                    return false;
                }
                _ => {}
            }
        }
        true
    }
    pub fn is_clean(&self) -> bool {
        let metadata = self.get_meta();
        metadata.clean
    }
    pub fn set_clean(&mut self, bool_value: bool) {
        let mut metadata = self.get_meta();
        metadata.clean = bool_value;
        self.set_meta(metadata);
    }
    /// True if the expression is an associative and commutative operator
    pub fn is_associative_commutative_operator(&self) -> bool {
        TryInto::<ACOperatorKind>::try_into(self).is_ok()
    }
    /// True if the expression is a matrix literal.
    ///
    /// This is true for both forms of matrix literals: those with elements of type [`Literal`] and
    /// [`Expression`].
    pub fn is_matrix_literal(&self) -> bool {
        matches!(
            self,
            Expression::AbstractLiteral(_, AbstractLiteral::Matrix(_, _))
                | Expression::Atomic(
                    _,
                    Atom::Literal(Literal::AbstractLiteral(AbstractLiteral::Matrix(_, _))),
                )
        )
    }
    /// True iff self and other are both atomic and identical.
    ///
    /// This method is useful to cheaply check equivalence. Assuming CSE is enabled, any unifiable
    /// expressions will be rewritten to a common variable. This is much cheaper than checking the
    /// entire subtrees of `self` and `other`.
    pub fn identical_atom_to(&self, other: &Expression) -> bool {
        let atom1: Result<&Atom, _> = self.try_into();
        let atom2: Result<&Atom, _> = other.try_into();
        if let (Ok(atom1), Ok(atom2)) = (atom1, atom2) {
            atom2 == atom1
        } else {
            false
        }
    }
    /// If the expression is a list, returns the inner expressions.
    ///
    /// A list is any a matrix with the domain `int(1..)`. This includes matrix literals without
    /// any explicitly specified domain.
    pub fn unwrap_list(self) -> Option<Vec<Expression>> {
        match self {
            Expression::AbstractLiteral(_, matrix @ AbstractLiteral::Matrix(_, _)) => {
                matrix.unwrap_list().cloned()
            }
            Expression::Atomic(
                _,
                Atom::Literal(Literal::AbstractLiteral(matrix @ AbstractLiteral::Matrix(_, _))),
            ) => matrix.unwrap_list().map(|elems| {
                elems
                    .clone()
                    .into_iter()
                    .map(|x: Literal| Expression::Atomic(Metadata::new(), Atom::Literal(x)))
                    .collect_vec()
            }),
            _ => None,
        }
    }
    /// If the expression is a matrix, gets it elements and index domain.
    ///
    /// **Consider using the safer [`Expression::unwrap_list`] instead.**
    ///
    /// It is generally undefined to edit the length of a matrix unless it is a list (as defined by
    /// [`Expression::unwrap_list`]). Users of this function should ensure that, if the matrix is
    /// reconstructed, the index domain and the number of elements in the matrix remain the same.
    pub fn unwrap_matrix_unchecked(self) -> Option<(Vec<Expression>, DomainPtr)> {
        match self {
            Expression::AbstractLiteral(_, AbstractLiteral::Matrix(elems, domain)) => {
                Some((elems, domain))
            }
            Expression::Atomic(
                _,
                Atom::Literal(Literal::AbstractLiteral(AbstractLiteral::Matrix(elems, domain))),
            ) => Some((
                elems
                    .into_iter()
                    .map(|x: Literal| Expression::Atomic(Metadata::new(), Atom::Literal(x)))
                    .collect_vec(),
                domain.into(),
            )),
            _ => None,
        }
    }
    /// For a Root expression, extends the inner vec with the given vec.
    ///
    /// # Panics
    /// Panics if the expression is not Root.
    pub fn extend_root(self, exprs: Vec<Expression>) -> Expression {
        match self {
            Expression::Root(meta, mut children) => {
                children.extend(exprs);
                Expression::Root(meta, children)
            }
            _ => panic!("extend_root called on a non-Root expression"),
        }
    }
    /// Converts the expression to a literal, if possible.
    pub fn into_literal(self) -> Option<Literal> {
        match self {
            Expression::Atomic(_, Atom::Literal(lit)) => Some(lit),
            Expression::AbstractLiteral(_, abslit) => {
                Some(Literal::AbstractLiteral(abslit.into_literals()?))
            }
            Expression::Neg(_, e) => {
                let Literal::Int(i) = Moo::unwrap_or_clone(e).into_literal()? else {
                    bug!("negated literal should be an int");
                };
                Some(Literal::Int(-i))
            }
            _ => None,
        }
    }
    /// If this expression is an associative-commutative operator, return its [ACOperatorKind].
    pub fn to_ac_operator_kind(&self) -> Option<ACOperatorKind> {
        TryFrom::try_from(self).ok()
    }
    /// Returns the categories of all sub-expressions of self.
    pub fn universe_categories(&self) -> HashSet<Category> {
        self.universe()
            .into_iter()
            .map(|x| x.category_of())
            .collect()
    }
}
pub fn get_function_domain(function: &Moo<Expression>) -> Option<DomainPtr> {
    let function_domain = function.domain_of()?;
    match function_domain.resolve().as_ref() {
        Some(d) => {
            match d.as_ref() {
                GroundDomain::Function(_, domain, _) => Some(domain.clone().into()),
                // Not defined for anything other than a function
                _ => None,
            }
        }
        None => {
            match function_domain.as_unresolved()? {
                UnresolvedDomain::Function(_, domain, _) => Some(domain.clone()),
                // Not defined for anything other than a function
                _ => None,
            }
        }
    }
}
pub fn get_function_codomain(function: &Moo<Expression>) -> Option<DomainPtr> {
    let function_domain = function.domain_of()?;
    match function_domain.resolve().as_ref() {
        Some(d) => {
            match d.as_ref() {
                GroundDomain::Function(_, _, codomain) => Some(codomain.clone().into()),
                // Not defined for anything other than a function
                _ => None,
            }
        }
        None => {
            match function_domain.as_unresolved()? {
                UnresolvedDomain::Function(_, _, codomain) => Some(codomain.clone()),
                // Not defined for anything other than a function
                _ => None,
            }
        }
    }
}
impl TryFrom<&Expression> for i32 {
    type Error = ();
    fn try_from(value: &Expression) -> Result<Self, Self::Error> {
        let Expression::Atomic(_, atom) = value else {
            return Err(());
        };
        let Atom::Literal(lit) = atom else {
            return Err(());
        };
        let Literal::Int(i) = lit else {
            return Err(());
        };
        Ok(*i)
    }
}
impl TryFrom<Expression> for i32 {
    type Error = ();
    fn try_from(value: Expression) -> Result<Self, Self::Error> {
        TryFrom::<&Expression>::try_from(&value)
    }
}
impl From<i32> for Expression {
    fn from(i: i32) -> Self {
        Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Int(i)))
    }
}
impl From<bool> for Expression {
    fn from(b: bool) -> Self {
        Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Bool(b)))
    }
}
impl From<Atom> for Expression {
    fn from(value: Atom) -> Self {
        Expression::Atomic(Metadata::new(), value)
    }
}
impl From<Literal> for Expression {
    fn from(value: Literal) -> Self {
        Expression::Atomic(Metadata::new(), value.into())
    }
}
impl From<Moo<Expression>> for Expression {
    fn from(val: Moo<Expression>) -> Self {
        val.as_ref().clone()
    }
}
impl CategoryOf for Expression {
    fn category_of(&self) -> Category {
        // take highest category of all the expressions children
        let category = self.cata(&move |x,children| {
            if let Some(max_category) = children.iter().max() {
                // if this expression contains subexpressions, return the maximum category of the
                // subexpressions
                *max_category
            } else {
                // this expression has no children
                let mut max_category = Category::Bottom;
                // calculate the category by looking at all atoms, submodels, comprehensions, and
                // declarationptrs inside this expression
                // this should generically cover all leaf types we currently have in oxide.
                // if x contains submodels (including comprehensions)
                if !Biplate::<SubModel>::universe_bi(&x).is_empty() {
                    // assume that the category is decision
                    return Category::Decision;
                }
                // if x contains atoms
                if let Some(max_atom_category) = Biplate::<Atom>::universe_bi(&x).iter().map(|x| x.category_of()).max()
                // and those atoms have a higher category than we already know about
                && max_atom_category > max_category{
                    // update category 
                    max_category = max_atom_category;
                }
                // if x contains declarationPtrs
                if let Some(max_declaration_category) = Biplate::<DeclarationPtr>::universe_bi(&x).iter().map(|x| x.category_of()).max()
                // and those pointers have a higher category than we already know about
                && max_declaration_category > max_category{
                    // update category 
                    max_category = max_declaration_category;
                }
                max_category
            }
        });
        if cfg!(debug_assertions) {
            trace!(
                category= %category,
                expression= %self,
                "Called Expression::category_of()"
            );
        };
        category
    }
}
impl Display for Expression {
    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
        match &self {
            Expression::Union(_, box1, box2) => {
                write!(f, "({} union {})", box1.clone(), box2.clone())
            }
            Expression::In(_, e1, e2) => {
                write!(f, "{e1} in {e2}")
            }
            Expression::Intersect(_, box1, box2) => {
                write!(f, "({} intersect {})", box1.clone(), box2.clone())
            }
            Expression::Supset(_, box1, box2) => {
                write!(f, "({} supset {})", box1.clone(), box2.clone())
            }
            Expression::SupsetEq(_, box1, box2) => {
                write!(f, "({} supsetEq {})", box1.clone(), box2.clone())
            }
            Expression::Subset(_, box1, box2) => {
                write!(f, "({} subset {})", box1.clone(), box2.clone())
            }
            Expression::SubsetEq(_, box1, box2) => {
                write!(f, "({} subsetEq {})", box1.clone(), box2.clone())
            }
            Expression::AbstractLiteral(_, l) => l.fmt(f),
            Expression::Comprehension(_, c) => c.fmt(f),
            Expression::UnsafeIndex(_, e1, e2) | Expression::SafeIndex(_, e1, e2) => {
                write!(f, "{e1}{}", pretty_vec(e2))
            }
            Expression::UnsafeSlice(_, e1, es) | Expression::SafeSlice(_, e1, es) => {
                let args = es
                    .iter()
                    .map(|x| match x {
                        Some(x) => format!("{x}"),
                        None => "..".into(),
                    })
                    .join(",");
                write!(f, "{e1}[{args}]")
            }
            Expression::InDomain(_, e, domain) => {
                write!(f, "__inDomain({e},{domain})")
            }
            Expression::Root(_, exprs) => {
                write!(f, "{}", pretty_expressions_as_top_level(exprs))
            }
            Expression::DominanceRelation(_, expr) => write!(f, "DominanceRelation({expr})"),
            Expression::FromSolution(_, expr) => write!(f, "FromSolution({expr})"),
            Expression::Metavar(_, name) => write!(f, "&{name}"),
            Expression::Atomic(_, atom) => atom.fmt(f),
            Expression::Scope(_, submodel) => write!(f, "{{\n{submodel}\n}}"),
            Expression::Abs(_, a) => write!(f, "|{a}|"),
            Expression::Sum(_, e) => {
                write!(f, "sum({e})")
            }
            Expression::Product(_, e) => {
                write!(f, "product({e})")
            }
            Expression::Min(_, e) => {
                write!(f, "min({e})")
            }
            Expression::Max(_, e) => {
                write!(f, "max({e})")
            }
            Expression::Not(_, expr_box) => {
                write!(f, "!({})", expr_box.clone())
            }
            Expression::Or(_, e) => {
                write!(f, "or({e})")
            }
            Expression::And(_, e) => {
                write!(f, "and({e})")
            }
            Expression::Imply(_, box1, box2) => {
                write!(f, "({box1}) -> ({box2})")
            }
            Expression::Iff(_, box1, box2) => {
                write!(f, "({box1}) <-> ({box2})")
            }
            Expression::Eq(_, box1, box2) => {
                write!(f, "({} = {})", box1.clone(), box2.clone())
            }
            Expression::Neq(_, box1, box2) => {
                write!(f, "({} != {})", box1.clone(), box2.clone())
            }
            Expression::Geq(_, box1, box2) => {
                write!(f, "({} >= {})", box1.clone(), box2.clone())
            }
            Expression::Leq(_, box1, box2) => {
                write!(f, "({} <= {})", box1.clone(), box2.clone())
            }
            Expression::Gt(_, box1, box2) => {
                write!(f, "({} > {})", box1.clone(), box2.clone())
            }
            Expression::Lt(_, box1, box2) => {
                write!(f, "({} < {})", box1.clone(), box2.clone())
            }
            Expression::FlatSumGeq(_, box1, box2) => {
                write!(f, "SumGeq({}, {})", pretty_vec(box1), box2.clone())
            }
            Expression::FlatSumLeq(_, box1, box2) => {
                write!(f, "SumLeq({}, {})", pretty_vec(box1), box2.clone())
            }
            Expression::FlatIneq(_, box1, box2, box3) => write!(
                f,
                "Ineq({}, {}, {})",
                box1.clone(),
                box2.clone(),
                box3.clone()
            ),
            Expression::Flatten(_, n, m) => {
                if let Some(n) = n {
                    write!(f, "flatten({n}, {m})")
                } else {
                    write!(f, "flatten({m})")
                }
            }
            Expression::AllDiff(_, e) => {
                write!(f, "allDiff({e})")
            }
            Expression::Bubble(_, box1, box2) => {
                write!(f, "{{{} @ {}}}", box1.clone(), box2.clone())
            }
            Expression::SafeDiv(_, box1, box2) => {
                write!(f, "SafeDiv({}, {})", box1.clone(), box2.clone())
            }
            Expression::UnsafeDiv(_, box1, box2) => {
                write!(f, "UnsafeDiv({}, {})", box1.clone(), box2.clone())
            }
            Expression::UnsafePow(_, box1, box2) => {
                write!(f, "UnsafePow({}, {})", box1.clone(), box2.clone())
            }
            Expression::SafePow(_, box1, box2) => {
                write!(f, "SafePow({}, {})", box1.clone(), box2.clone())
            }
            Expression::MinionDivEqUndefZero(_, box1, box2, box3) => {
                write!(
                    f,
                    "DivEq({}, {}, {})",
                    box1.clone(),
                    box2.clone(),
                    box3.clone()
                )
            }
            Expression::MinionModuloEqUndefZero(_, box1, box2, box3) => {
                write!(
                    f,
                    "ModEq({}, {}, {})",
                    box1.clone(),
                    box2.clone(),
                    box3.clone()
                )
            }
            Expression::FlatWatchedLiteral(_, x, l) => {
                write!(f, "WatchedLiteral({x},{l})")
            }
            Expression::MinionReify(_, box1, box2) => {
                write!(f, "Reify({}, {})", box1.clone(), box2.clone())
            }
            Expression::MinionReifyImply(_, box1, box2) => {
                write!(f, "ReifyImply({}, {})", box1.clone(), box2.clone())
            }
            Expression::MinionWInIntervalSet(_, atom, intervals) => {
                let intervals = intervals.iter().join(",");
                write!(f, "__minion_w_inintervalset({atom},[{intervals}])")
            }
            Expression::MinionWInSet(_, atom, values) => {
                let values = values.iter().join(",");
                write!(f, "__minion_w_inset({atom},{values})")
            }
            Expression::AuxDeclaration(_, reference, e) => {
                write!(f, "{} =aux {}", reference, e.clone())
            }
            Expression::UnsafeMod(_, a, b) => {
                write!(f, "{} % {}", a.clone(), b.clone())
            }
            Expression::SafeMod(_, a, b) => {
                write!(f, "SafeMod({},{})", a.clone(), b.clone())
            }
            Expression::Neg(_, a) => {
                write!(f, "-({})", a.clone())
            }
            Expression::Minus(_, a, b) => {
                write!(f, "({} - {})", a.clone(), b.clone())
            }
            Expression::FlatAllDiff(_, es) => {
                write!(f, "__flat_alldiff({})", pretty_vec(es))
            }
            Expression::FlatAbsEq(_, a, b) => {
                write!(f, "AbsEq({},{})", a.clone(), b.clone())
            }
            Expression::FlatMinusEq(_, a, b) => {
                write!(f, "MinusEq({},{})", a.clone(), b.clone())
            }
            Expression::FlatProductEq(_, a, b, c) => {
                write!(
                    f,
                    "FlatProductEq({},{},{})",
                    a.clone(),
                    b.clone(),
                    c.clone()
                )
            }
            Expression::FlatWeightedSumLeq(_, cs, vs, total) => {
                write!(
                    f,
                    "FlatWeightedSumLeq({},{},{})",
                    pretty_vec(cs),
                    pretty_vec(vs),
                    total.clone()
                )
            }
            Expression::FlatWeightedSumGeq(_, cs, vs, total) => {
                write!(
                    f,
                    "FlatWeightedSumGeq({},{},{})",
                    pretty_vec(cs),
                    pretty_vec(vs),
                    total.clone()
                )
            }
            Expression::MinionPow(_, atom, atom1, atom2) => {
                write!(f, "MinionPow({atom},{atom1},{atom2})")
            }
            Expression::MinionElementOne(_, atoms, atom, atom1) => {
                let atoms = atoms.iter().join(",");
                write!(f, "__minion_element_one([{atoms}],{atom},{atom1})")
            }
            Expression::ToInt(_, expr) => {
                write!(f, "toInt({expr})")
            }
            Expression::SATInt(_, e) => {
                write!(f, "SATInt({e})")
            }
            Expression::PairwiseSum(_, a, b) => write!(f, "PairwiseSum({a}, {b})"),
            Expression::PairwiseProduct(_, a, b) => write!(f, "PairwiseProduct({a}, {b})"),
            Expression::Defined(_, function) => write!(f, "defined({function})"),
            Expression::Range(_, function) => write!(f, "range({function})"),
            Expression::Image(_, function, elems) => write!(f, "image({function},{elems})"),
            Expression::ImageSet(_, function, elems) => write!(f, "imageSet({function},{elems})"),
            Expression::PreImage(_, function, elems) => write!(f, "preImage({function},{elems})"),
            Expression::Inverse(_, a, b) => write!(f, "inverse({a},{b})"),
            Expression::Restrict(_, function, domain) => write!(f, "restrict({function},{domain})"),
            Expression::LexLt(_, a, b) => write!(f, "({a} <lex {b})"),
            Expression::LexLeq(_, a, b) => write!(f, "({a} <=lex {b})"),
            Expression::LexGt(_, a, b) => write!(f, "({a} >lex {b})"),
            Expression::LexGeq(_, a, b) => write!(f, "({a} >=lex {b})"),
            Expression::FlatLexLt(_, a, b) => {
                write!(f, "FlatLexLt({}, {})", pretty_vec(a), pretty_vec(b))
            }
            Expression::FlatLexLeq(_, a, b) => {
                write!(f, "FlatLexLeq({}, {})", pretty_vec(a), pretty_vec(b))
            }
        }
    }
}
impl Typeable for Expression {
    fn return_type(&self) -> ReturnType {
        match self {
            Expression::Union(_, subject, _) => ReturnType::Set(Box::new(subject.return_type())),
            Expression::Intersect(_, subject, _) => {
                ReturnType::Set(Box::new(subject.return_type()))
            }
            Expression::In(_, _, _) => ReturnType::Bool,
            Expression::Supset(_, _, _) => ReturnType::Bool,
            Expression::SupsetEq(_, _, _) => ReturnType::Bool,
            Expression::Subset(_, _, _) => ReturnType::Bool,
            Expression::SubsetEq(_, _, _) => ReturnType::Bool,
            Expression::AbstractLiteral(_, lit) => lit.return_type(),
            Expression::UnsafeIndex(_, subject, idx) | Expression::SafeIndex(_, subject, idx) => {
                let subject_ty = subject.return_type();
                match subject_ty {
                    ReturnType::Matrix(_) => {
                        // For n-dimensional matrices, unwrap the element type until
                        // we either get to the innermost element type or the last index
                        let mut elem_typ = subject_ty;
                        let mut idx_len = idx.len();
                        while idx_len > 0
                            && let ReturnType::Matrix(new_elem_typ) = &elem_typ
                        {
                            elem_typ = *new_elem_typ.clone();
                            idx_len -= 1;
                        }
                        elem_typ
                    }
                    // TODO: We can implement indexing for these eventually
                    ReturnType::Record(_) | ReturnType::Tuple(_) => ReturnType::Unknown,
                    _ => bug!(
                        "Invalid indexing operation: expected the operand to be a collection, got {self}: {subject_ty}"
                    ),
                }
            }
            Expression::UnsafeSlice(_, subject, _) | Expression::SafeSlice(_, subject, _) => {
                ReturnType::Matrix(Box::new(subject.return_type()))
            }
            Expression::InDomain(_, _, _) => ReturnType::Bool,
            Expression::Comprehension(_, comp) => comp.return_type(),
            Expression::Root(_, _) => ReturnType::Bool,
            Expression::DominanceRelation(_, _) => ReturnType::Bool,
            Expression::FromSolution(_, expr) => expr.return_type(),
            Expression::Metavar(_, _) => ReturnType::Unknown,
            Expression::Atomic(_, atom) => atom.return_type(),
            Expression::Scope(_, scope) => scope.return_type(),
            Expression::Abs(_, _) => ReturnType::Int,
            Expression::Sum(_, _) => ReturnType::Int,
            Expression::Product(_, _) => ReturnType::Int,
            Expression::Min(_, _) => ReturnType::Int,
            Expression::Max(_, _) => ReturnType::Int,
            Expression::Not(_, _) => ReturnType::Bool,
            Expression::Or(_, _) => ReturnType::Bool,
            Expression::Imply(_, _, _) => ReturnType::Bool,
            Expression::Iff(_, _, _) => ReturnType::Bool,
            Expression::And(_, _) => ReturnType::Bool,
            Expression::Eq(_, _, _) => ReturnType::Bool,
            Expression::Neq(_, _, _) => ReturnType::Bool,
            Expression::Geq(_, _, _) => ReturnType::Bool,
            Expression::Leq(_, _, _) => ReturnType::Bool,
            Expression::Gt(_, _, _) => ReturnType::Bool,
            Expression::Lt(_, _, _) => ReturnType::Bool,
            Expression::SafeDiv(_, _, _) => ReturnType::Int,
            Expression::UnsafeDiv(_, _, _) => ReturnType::Int,
            Expression::FlatAllDiff(_, _) => ReturnType::Bool,
            Expression::FlatSumGeq(_, _, _) => ReturnType::Bool,
            Expression::FlatSumLeq(_, _, _) => ReturnType::Bool,
            Expression::MinionDivEqUndefZero(_, _, _, _) => ReturnType::Bool,
            Expression::FlatIneq(_, _, _, _) => ReturnType::Bool,
            Expression::Flatten(_, _, matrix) => {
                let matrix_type = matrix.return_type();
                match matrix_type {
                    ReturnType::Matrix(_) => {
                        // unwrap until we get to innermost element
                        let mut elem_type = matrix_type;
                        while let ReturnType::Matrix(new_elem_type) = &elem_type {
                            elem_type = *new_elem_type.clone();
                        }
                        ReturnType::Matrix(Box::new(elem_type))
                    }
                    _ => bug!(
                        "Invalid indexing operation: expected the operand to be a collection, got {self}: {matrix_type}"
                    ),
                }
            }
            Expression::AllDiff(_, _) => ReturnType::Bool,
            Expression::Bubble(_, inner, _) => inner.return_type(),
            Expression::FlatWatchedLiteral(_, _, _) => ReturnType::Bool,
            Expression::MinionReify(_, _, _) => ReturnType::Bool,
            Expression::MinionReifyImply(_, _, _) => ReturnType::Bool,
            Expression::MinionWInIntervalSet(_, _, _) => ReturnType::Bool,
            Expression::MinionWInSet(_, _, _) => ReturnType::Bool,
            Expression::MinionElementOne(_, _, _, _) => ReturnType::Bool,
            Expression::AuxDeclaration(_, _, _) => ReturnType::Bool,
            Expression::UnsafeMod(_, _, _) => ReturnType::Int,
            Expression::SafeMod(_, _, _) => ReturnType::Int,
            Expression::MinionModuloEqUndefZero(_, _, _, _) => ReturnType::Bool,
            Expression::Neg(_, _) => ReturnType::Int,
            Expression::UnsafePow(_, _, _) => ReturnType::Int,
            Expression::SafePow(_, _, _) => ReturnType::Int,
            Expression::Minus(_, _, _) => ReturnType::Int,
            Expression::FlatAbsEq(_, _, _) => ReturnType::Bool,
            Expression::FlatMinusEq(_, _, _) => ReturnType::Bool,
            Expression::FlatProductEq(_, _, _, _) => ReturnType::Bool,
            Expression::FlatWeightedSumLeq(_, _, _, _) => ReturnType::Bool,
            Expression::FlatWeightedSumGeq(_, _, _, _) => ReturnType::Bool,
            Expression::MinionPow(_, _, _, _) => ReturnType::Bool,
            Expression::ToInt(_, _) => ReturnType::Int,
            Expression::SATInt(_, _) => ReturnType::Int,
            Expression::PairwiseSum(_, _, _) => ReturnType::Int,
            Expression::PairwiseProduct(_, _, _) => ReturnType::Int,
            Expression::Defined(_, function) => {
                let subject = function.return_type();
                match subject {
                    ReturnType::Function(domain, _) => *domain,
                    _ => bug!(
                        "Invalid defined operation: expected the operand to be a function, got {self}: {subject}"
                    ),
                }
            }
            Expression::Range(_, function) => {
                let subject = function.return_type();
                match subject {
                    ReturnType::Function(_, codomain) => *codomain,
                    _ => bug!(
                        "Invalid range operation: expected the operand to be a function, got {self}: {subject}"
                    ),
                }
            }
            Expression::Image(_, function, _) => {
                let subject = function.return_type();
                match subject {
                    ReturnType::Function(_, codomain) => *codomain,
                    _ => bug!(
                        "Invalid image operation: expected the operand to be a function, got {self}: {subject}"
                    ),
                }
            }
            Expression::ImageSet(_, function, _) => {
                let subject = function.return_type();
                match subject {
                    ReturnType::Function(_, codomain) => *codomain,
                    _ => bug!(
                        "Invalid imageSet operation: expected the operand to be a function, got {self}: {subject}"
                    ),
                }
            }
            Expression::PreImage(_, function, _) => {
                let subject = function.return_type();
                match subject {
                    ReturnType::Function(domain, _) => *domain,
                    _ => bug!(
                        "Invalid preImage operation: expected the operand to be a function, got {self}: {subject}"
                    ),
                }
            }
            Expression::Restrict(_, function, new_domain) => {
                let subject = function.return_type();
                match subject {
                    ReturnType::Function(_, codomain) => {
                        ReturnType::Function(Box::new(new_domain.return_type()), codomain)
                    }
                    _ => bug!(
                        "Invalid preImage operation: expected the operand to be a function, got {self}: {subject}"
                    ),
                }
            }
            Expression::Inverse(..) => ReturnType::Bool,
            Expression::LexLt(..) => ReturnType::Bool,
            Expression::LexGt(..) => ReturnType::Bool,
            Expression::LexLeq(..) => ReturnType::Bool,
            Expression::LexGeq(..) => ReturnType::Bool,
            Expression::FlatLexLt(..) => ReturnType::Bool,
            Expression::FlatLexLeq(..) => ReturnType::Bool,
        }
    }
}
#[cfg(test)]
mod tests {
    use crate::matrix_expr;
    use super::*;
    #[test]
    fn test_domain_of_constant_sum() {
        let c1 = Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Int(1)));
        let c2 = Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Int(2)));
        let sum = Expression::Sum(Metadata::new(), Moo::new(matrix_expr![c1, c2]));
        assert_eq!(sum.domain_of(), Some(Domain::int(vec![Range::Single(3)])));
    }
    #[test]
    fn test_domain_of_constant_invalid_type() {
        let c1 = Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Int(1)));
        let c2 = Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Bool(true)));
        let sum = Expression::Sum(Metadata::new(), Moo::new(matrix_expr![c1, c2]));
        assert_eq!(sum.domain_of(), None);
    }
    #[test]
    fn test_domain_of_empty_sum() {
        let sum = Expression::Sum(Metadata::new(), Moo::new(matrix_expr![]));
        assert_eq!(sum.domain_of(), None);
    }
}