conjure_core/ast/
expressions.rs

1use crate::ast::declaration::serde::DeclarationPtrAsId;
2use serde_with::serde_as;
3use std::collections::{HashSet, VecDeque};
4use std::fmt::{Display, Formatter};
5use tracing::trace;
6
7use crate::ast::Atom;
8use crate::ast::Moo;
9use crate::ast::Name;
10use crate::ast::ReturnType;
11use crate::ast::SetAttr;
12use crate::ast::literals::AbstractLiteral;
13use crate::ast::literals::Literal;
14use crate::ast::pretty::{pretty_expressions_as_top_level, pretty_vec};
15use crate::bug;
16use crate::metadata::Metadata;
17use enum_compatability_macro::document_compatibility;
18use itertools::Itertools;
19use serde::{Deserialize, Serialize};
20
21use uniplate::{Biplate, Uniplate};
22
23use super::ac_operators::ACOperatorKind;
24use super::categories::{Category, CategoryOf};
25use super::comprehension::Comprehension;
26use super::domains::HasDomain as _;
27use super::records::RecordValue;
28use super::{DeclarationPtr, Domain, Range, SubModel, Typeable};
29
30// Ensure that this type doesn't get too big
31//
32// If you triggered this assertion, you either made a variant of this enum that is too big, or you
33// made Name,Literal,AbstractLiteral,Atom bigger, which made this bigger! To fix this, put some
34// stuff in boxes.
35//
36// Enums take the size of their largest variant, so an enum with mostly small variants and a few
37// large ones wastes memory... A larger Expression type also slows down Oxide.
38//
39// For more information, and more details on type sizes and how to measure them, see the commit
40// message for 6012de809 (perf: reduce size of AST types, 2025-06-18).
41//
42// You can also see type sizes in the rustdoc documentation, generated by ./tools/gen_docs.sh
43//
44// https://github.com/conjure-cp/conjure-oxide/commit/6012de8096ca491ded91ecec61352fdf4e994f2e
45
46// TODO: box all usages of Metadata to bring this down a bit more - I have added variants to
47// ReturnType, and Metadata contains ReturnType, so Metadata has got bigger. Metadata will get a
48// lot bigger still when we start using it for memoisation, so it should really be
49// boxed ~niklasdewally
50
51// expect size of Expression to be 112 bytes
52static_assertions::assert_eq_size!([u8; 104], Expression);
53
54/// Represents different types of expressions used to define rules and constraints in the model.
55///
56/// The `Expression` enum includes operations, constants, and variable references
57/// used to build rules and conditions for the model.
58#[document_compatibility]
59#[serde_as]
60#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize, Uniplate)]
61#[biplate(to=Metadata)]
62#[biplate(to=Atom)]
63#[biplate(to=DeclarationPtr)]
64#[biplate(to=Name)]
65#[biplate(to=Vec<Expression>)]
66#[biplate(to=Option<Expression>)]
67#[biplate(to=SubModel)]
68#[biplate(to=Comprehension)]
69#[biplate(to=AbstractLiteral<Expression>)]
70#[biplate(to=AbstractLiteral<Literal>)]
71#[biplate(to=RecordValue<Expression>)]
72#[biplate(to=RecordValue<Literal>)]
73#[biplate(to=Literal)]
74pub enum Expression {
75    AbstractLiteral(Metadata, AbstractLiteral<Expression>),
76    /// The top of the model
77    Root(Metadata, Vec<Expression>),
78
79    /// An expression representing "A is valid as long as B is true"
80    /// Turns into a conjunction when it reaches a boolean context
81    Bubble(Metadata, Moo<Expression>, Moo<Expression>),
82
83    /// A comprehension.
84    ///
85    /// The inside of the comprehension opens a new scope.
86    Comprehension(Metadata, Moo<Comprehension>),
87
88    /// Defines dominance ("Solution A is preferred over Solution B")
89    DominanceRelation(Metadata, Moo<Expression>),
90    /// `fromSolution(name)` - Used in dominance relation definitions
91    FromSolution(Metadata, Moo<Expression>),
92
93    Atomic(Metadata, Atom),
94
95    /// A matrix index.
96    ///
97    /// Defined iff the indices are within their respective index domains.
98    #[compatible(JsonInput)]
99    UnsafeIndex(Metadata, Moo<Expression>, Vec<Expression>),
100
101    /// A safe matrix index.
102    ///
103    /// See [`Expression::UnsafeIndex`]
104    SafeIndex(Metadata, Moo<Expression>, Vec<Expression>),
105
106    /// A matrix slice: `a[indices]`.
107    ///
108    /// One of the indicies may be `None`, representing the dimension of the matrix we want to take
109    /// a slice of. For example, for some 3d matrix a, `a[1,..,2]` has the indices
110    /// `Some(1),None,Some(2)`.
111    ///
112    /// It is assumed that the slice only has one "wild-card" dimension and thus is 1 dimensional.
113    ///
114    /// Defined iff the defined indices are within their respective index domains.
115    #[compatible(JsonInput)]
116    UnsafeSlice(Metadata, Moo<Expression>, Vec<Option<Expression>>),
117
118    /// A safe matrix slice: `a[indices]`.
119    ///
120    /// See [`Expression::UnsafeSlice`].
121    SafeSlice(Metadata, Moo<Expression>, Vec<Option<Expression>>),
122
123    /// `inDomain(x,domain)` iff `x` is in the domain `domain`.
124    ///
125    /// This cannot be constructed from Essence input, nor passed to a solver: this expression is
126    /// mainly used during the conversion of `UnsafeIndex` and `UnsafeSlice` to `SafeIndex` and
127    /// `SafeSlice` respectively.
128    InDomain(Metadata, Moo<Expression>, Domain),
129
130    /// `toInt(b)` casts boolean expression b to an integer.
131    ///
132    /// - If b is false, then `toInt(b) == 0`
133    ///
134    /// - If b is true, then `toInt(b) == 1`
135    ToInt(Metadata, Moo<Expression>),
136
137    Scope(Metadata, Moo<SubModel>),
138
139    /// `|x|` - absolute value of `x`
140    #[compatible(JsonInput)]
141    Abs(Metadata, Moo<Expression>),
142
143    /// `sum(<vec_expr>)`
144    #[compatible(JsonInput)]
145    Sum(Metadata, Moo<Expression>),
146
147    /// `a * b * c * ...`
148    #[compatible(JsonInput)]
149    Product(Metadata, Moo<Expression>),
150
151    /// `min(<vec_expr>)`
152    #[compatible(JsonInput)]
153    Min(Metadata, Moo<Expression>),
154
155    /// `max(<vec_expr>)`
156    #[compatible(JsonInput)]
157    Max(Metadata, Moo<Expression>),
158
159    /// `not(a)`
160    #[compatible(JsonInput, SAT)]
161    Not(Metadata, Moo<Expression>),
162
163    /// `or(<vec_expr>)`
164    #[compatible(JsonInput, SAT)]
165    Or(Metadata, Moo<Expression>),
166
167    /// `and(<vec_expr>)`
168    #[compatible(JsonInput, SAT)]
169    And(Metadata, Moo<Expression>),
170
171    /// Ensures that `a->b` (material implication).
172    #[compatible(JsonInput)]
173    Imply(Metadata, Moo<Expression>, Moo<Expression>),
174
175    /// `iff(a, b)` a <-> b
176    #[compatible(JsonInput)]
177    Iff(Metadata, Moo<Expression>, Moo<Expression>),
178
179    #[compatible(JsonInput)]
180    Union(Metadata, Moo<Expression>, Moo<Expression>),
181
182    #[compatible(JsonInput)]
183    In(Metadata, Moo<Expression>, Moo<Expression>),
184
185    #[compatible(JsonInput)]
186    Intersect(Metadata, Moo<Expression>, Moo<Expression>),
187
188    #[compatible(JsonInput)]
189    Supset(Metadata, Moo<Expression>, Moo<Expression>),
190
191    #[compatible(JsonInput)]
192    SupsetEq(Metadata, Moo<Expression>, Moo<Expression>),
193
194    #[compatible(JsonInput)]
195    Subset(Metadata, Moo<Expression>, Moo<Expression>),
196
197    #[compatible(JsonInput)]
198    SubsetEq(Metadata, Moo<Expression>, Moo<Expression>),
199
200    #[compatible(JsonInput)]
201    Eq(Metadata, Moo<Expression>, Moo<Expression>),
202
203    #[compatible(JsonInput)]
204    Neq(Metadata, Moo<Expression>, Moo<Expression>),
205
206    #[compatible(JsonInput)]
207    Geq(Metadata, Moo<Expression>, Moo<Expression>),
208
209    #[compatible(JsonInput)]
210    Leq(Metadata, Moo<Expression>, Moo<Expression>),
211
212    #[compatible(JsonInput)]
213    Gt(Metadata, Moo<Expression>, Moo<Expression>),
214
215    #[compatible(JsonInput)]
216    Lt(Metadata, Moo<Expression>, Moo<Expression>),
217
218    /// Division after preventing division by zero, usually with a bubble
219    SafeDiv(Metadata, Moo<Expression>, Moo<Expression>),
220
221    /// Division with a possibly undefined value (division by 0)
222    #[compatible(JsonInput)]
223    UnsafeDiv(Metadata, Moo<Expression>, Moo<Expression>),
224
225    /// Modulo after preventing mod 0, usually with a bubble
226    SafeMod(Metadata, Moo<Expression>, Moo<Expression>),
227
228    /// Modulo with a possibly undefined value (mod 0)
229    #[compatible(JsonInput)]
230    UnsafeMod(Metadata, Moo<Expression>, Moo<Expression>),
231
232    /// Negation: `-x`
233    #[compatible(JsonInput)]
234    Neg(Metadata, Moo<Expression>),
235
236    /// Unsafe power`x**y` (possibly undefined)
237    ///
238    /// Defined when (X!=0 \\/ Y!=0) /\ Y>=0
239    #[compatible(JsonInput)]
240    UnsafePow(Metadata, Moo<Expression>, Moo<Expression>),
241
242    /// `UnsafePow` after preventing undefinedness
243    SafePow(Metadata, Moo<Expression>, Moo<Expression>),
244
245    /// `allDiff(<vec_expr>)`
246    #[compatible(JsonInput)]
247    AllDiff(Metadata, Moo<Expression>),
248
249    /// Binary subtraction operator
250    ///
251    /// This is a parser-level construct, and is immediately normalised to `Sum([a,-b])`.
252    /// TODO: make this compatible with Set Difference calculations - need to change return type and domain for this expression and write a set comprehension rule.
253    /// have already edited minus_to_sum to prevent this from applying to sets
254    #[compatible(JsonInput)]
255    Minus(Metadata, Moo<Expression>, Moo<Expression>),
256
257    /// Ensures that x=|y| i.e. x is the absolute value of y.
258    ///
259    /// Low-level Minion constraint.
260    ///
261    /// # See also
262    ///
263    /// + [Minion documentation](https://minion-solver.readthedocs.io/en/stable/usage/constraints.html#abs)
264    #[compatible(Minion)]
265    FlatAbsEq(Metadata, Moo<Atom>, Moo<Atom>),
266
267    /// Ensures that `alldiff([a,b,...])`.
268    ///
269    /// Low-level Minion constraint.
270    ///
271    /// # See also
272    ///
273    /// + [Minion documentation](https://minion-solver.readthedocs.io/en/stable/usage/constraints.html#alldiff)
274    #[compatible(Minion)]
275    FlatAllDiff(Metadata, Vec<Atom>),
276
277    /// Ensures that sum(vec) >= x.
278    ///
279    /// Low-level Minion constraint.
280    ///
281    /// # See also
282    ///
283    /// + [Minion documentation](https://minion-solver.readthedocs.io/en/stable/usage/constraints.html#sumgeq)
284    #[compatible(Minion)]
285    FlatSumGeq(Metadata, Vec<Atom>, Atom),
286
287    /// Ensures that sum(vec) <= x.
288    ///
289    /// Low-level Minion constraint.
290    ///
291    /// # See also
292    ///
293    /// + [Minion documentation](https://minion-solver.readthedocs.io/en/stable/usage/constraints.html#sumleq)
294    #[compatible(Minion)]
295    FlatSumLeq(Metadata, Vec<Atom>, Atom),
296
297    /// `ineq(x,y,k)` ensures that x <= y + k.
298    ///
299    /// Low-level Minion constraint.
300    ///
301    /// # See also
302    ///
303    /// + [Minion documentation](https://minion-solver.readthedocs.io/en/stable/usage/constraints.html#ineq)
304    #[compatible(Minion)]
305    FlatIneq(Metadata, Moo<Atom>, Moo<Atom>, Box<Literal>),
306
307    /// `w-literal(x,k)` ensures that x == k, where x is a variable and k a constant.
308    ///
309    /// Low-level Minion constraint.
310    ///
311    /// This is a low-level Minion constraint and you should probably use Eq instead. The main use
312    /// of w-literal is to convert boolean variables to constraints so that they can be used inside
313    /// watched-and and watched-or.
314    ///
315    /// # See also
316    ///
317    /// + [Minion documentation](https://minion-solver.readthedocs.io/en/stable/usage/constraints.html#minuseq)
318    /// + `rules::minion::boolean_literal_to_wliteral`.
319    #[compatible(Minion)]
320    FlatWatchedLiteral(
321        Metadata,
322        #[serde_as(as = "DeclarationPtrAsId")] DeclarationPtr,
323        Literal,
324    ),
325
326    /// `weightedsumleq(cs,xs,total)` ensures that cs.xs <= total, where cs.xs is the scalar dot
327    /// product of cs and xs.
328    ///
329    /// Low-level Minion constraint.
330    ///
331    /// Represents a weighted sum of the form `ax + by + cz + ...`
332    ///
333    /// # See also
334    ///
335    /// + [Minion
336    /// documentation](https://minion-solver.readthedocs.io/en/stable/usage/constraints.html#weightedsumleq)
337    FlatWeightedSumLeq(Metadata, Vec<Literal>, Vec<Atom>, Moo<Atom>),
338
339    /// `weightedsumgeq(cs,xs,total)` ensures that cs.xs >= total, where cs.xs is the scalar dot
340    /// product of cs and xs.
341    ///
342    /// Low-level Minion constraint.
343    ///
344    /// Represents a weighted sum of the form `ax + by + cz + ...`
345    ///
346    /// # See also
347    ///
348    /// + [Minion
349    /// documentation](https://minion-solver.readthedocs.io/en/stable/usage/constraints.html#weightedsumleq)
350    FlatWeightedSumGeq(Metadata, Vec<Literal>, Vec<Atom>, Moo<Atom>),
351
352    /// Ensures that x =-y, where x and y are atoms.
353    ///
354    /// Low-level Minion constraint.
355    ///
356    /// # See also
357    ///
358    /// + [Minion documentation](https://minion-solver.readthedocs.io/en/stable/usage/constraints.html#minuseq)
359    #[compatible(Minion)]
360    FlatMinusEq(Metadata, Moo<Atom>, Moo<Atom>),
361
362    /// Ensures that x*y=z.
363    ///
364    /// Low-level Minion constraint.
365    ///
366    /// # See also
367    ///
368    /// + [Minion documentation](https://minion-solver.readthedocs.io/en/stable/usage/constraints.html#product)
369    #[compatible(Minion)]
370    FlatProductEq(Metadata, Moo<Atom>, Moo<Atom>, Moo<Atom>),
371
372    /// Ensures that floor(x/y)=z. Always true when y=0.
373    ///
374    /// Low-level Minion constraint.
375    ///
376    /// # See also
377    ///
378    /// + [Minion documentation](https://minion-solver.readthedocs.io/en/stable/usage/constraints.html#div_undefzero)
379    #[compatible(Minion)]
380    MinionDivEqUndefZero(Metadata, Moo<Atom>, Moo<Atom>, Moo<Atom>),
381
382    /// Ensures that x%y=z. Always true when y=0.
383    ///
384    /// Low-level Minion constraint.
385    ///
386    /// # See also
387    ///
388    /// + [Minion documentation](https://minion-solver.readthedocs.io/en/stable/usage/constraints.html#mod_undefzero)
389    #[compatible(Minion)]
390    MinionModuloEqUndefZero(Metadata, Moo<Atom>, Moo<Atom>, Moo<Atom>),
391
392    /// Ensures that `x**y = z`.
393    ///
394    /// Low-level Minion constraint.
395    ///
396    /// This constraint is false when `y<0` except for `1**y=1` and `(-1)**y=z` (where z is 1 if y
397    /// is odd and z is -1 if y is even).
398    ///
399    /// # See also
400    ///
401    /// + [Github comment about `pow` semantics](https://github.com/minion/minion/issues/40#issuecomment-2595914891)
402    /// + [Minion documentation](https://minion-solver.readthedocs.io/en/stable/usage/constraints.html#pow)
403    MinionPow(Metadata, Moo<Atom>, Moo<Atom>, Moo<Atom>),
404
405    /// `reify(constraint,r)` ensures that r=1 iff `constraint` is satisfied, where r is a 0/1
406    /// variable.
407    ///
408    /// Low-level Minion constraint.
409    ///
410    /// # See also
411    ///
412    ///  + [Minion documentation](https://minion-solver.readthedocs.io/en/stable/usage/constraints.html#reify)
413    #[compatible(Minion)]
414    MinionReify(Metadata, Moo<Expression>, Atom),
415
416    /// `reifyimply(constraint,r)` ensures that `r->constraint`, where r is a 0/1 variable.
417    /// variable.
418    ///
419    /// Low-level Minion constraint.
420    ///
421    /// # See also
422    ///
423    ///  + [Minion documentation](https://minion-solver.readthedocs.io/en/stable/usage/constraints.html#reifyimply)
424    #[compatible(Minion)]
425    MinionReifyImply(Metadata, Moo<Expression>, Atom),
426
427    /// `w-inintervalset(x, [a1,a2, b1,b2, … ])` ensures that the value of x belongs to one of the
428    /// intervals {a1,…,a2}, {b1,…,b2} etc.
429    ///
430    /// The list of intervals must be given in numerical order.
431    ///
432    /// Low-level Minion constraint.
433    ///
434    /// # See also
435    ///>
436    ///  + [Minion documentation](https://minion-solver.readthedocs.io/en/stable/usage/constraints.html#w-inintervalset)
437    #[compatible(Minion)]
438    MinionWInIntervalSet(Metadata, Atom, Vec<i32>),
439
440    /// `w-inset(x, [v1, v2, … ])` ensures that the value of `x` is one of the explicitly given values `v1`, `v2`, etc.
441    ///
442    /// This constraint enforces membership in a specific set of discrete values rather than intervals.
443    ///
444    /// The list of values must be given in numerical order.
445    ///
446    /// Low-level Minion constraint.
447    ///
448    /// # See also
449    ///
450    ///  + [Minion documentation](https://minion-solver.readthedocs.io/en/stable/usage/constraints.html#w-inset)
451    #[compatible(Minion)]
452    MinionWInSet(Metadata, Atom, Vec<i32>),
453
454    /// `element_one(vec, i, e)` specifies that `vec[i] = e`. This implies that i is
455    /// in the range `[1..len(vec)]`.
456    ///
457    /// Low-level Minion constraint.
458    ///
459    /// # See also
460    ///
461    ///  + [Minion documentation](https://minion-solver.readthedocs.io/en/stable/usage/constraints.html#element_one)
462    #[compatible(Minion)]
463    MinionElementOne(Metadata, Vec<Atom>, Moo<Atom>, Moo<Atom>),
464
465    /// Declaration of an auxiliary variable.
466    ///
467    /// As with Savile Row, we semantically distinguish this from `Eq`.
468    #[compatible(Minion)]
469    AuxDeclaration(
470        Metadata,
471        #[serde_as(as = "DeclarationPtrAsId")] DeclarationPtr,
472        Moo<Expression>,
473    ),
474}
475
476// for the given matrix literal, return a bounded domain from the min to max of applying op to each
477// child expression.
478//
479// Op must be monotonic.
480//
481// Returns none if unbounded
482fn bounded_i32_domain_for_matrix_literal_monotonic(
483    e: &Expression,
484    op: fn(i32, i32) -> Option<i32>,
485) -> Option<Domain> {
486    // only care about the elements, not the indices
487    let (mut exprs, _) = e.clone().unwrap_matrix_unchecked()?;
488
489    // fold each element's domain into one using op.
490    //
491    // here, I assume that op is monotone. This means that the bounds of op([a1,a2],[b1,b2])  for
492    // the ranges [a1,a2], [b1,b2] will be
493    // [min(op(a1,b1),op(a2,b1),op(a1,b2),op(a2,b2)),max(op(a1,b1),op(a2,b1),op(a1,b2),op(a2,b2))].
494    //
495    // We used to not assume this, and work out the bounds by applying op on the Cartesian product
496    // of A and B; however, this caused a combinatorial explosion and my computer to run out of
497    // memory (on the hakank_eprime_xkcd test)...
498    //
499    // For example, to find the bounds of the intervals [1,4], [1,5] combined using op, we used to do
500    //  [min(op(1,1), op(1,2),op(1,3),op(1,4),op(1,5),op(2,1)..
501    //
502    // +,-,/,* are all monotone, so this assumption should be fine for now...
503
504    let expr = exprs.pop()?;
505    let Some(Domain::Int(ranges)) = expr.domain_of() else {
506        return None;
507    };
508
509    let (mut current_min, mut current_max) = range_vec_bounds_i32(&ranges)?;
510
511    for expr in exprs {
512        let Some(Domain::Int(ranges)) = expr.domain_of() else {
513            return None;
514        };
515
516        let (min, max) = range_vec_bounds_i32(&ranges)?;
517
518        // all the possible new values for current_min / current_max
519        let minmax = op(min, current_max)?;
520        let minmin = op(min, current_min)?;
521        let maxmin = op(max, current_min)?;
522        let maxmax = op(max, current_max)?;
523        let vals = [minmax, minmin, maxmin, maxmax];
524
525        current_min = *vals
526            .iter()
527            .min()
528            .expect("vals iterator should not be empty, and should have a minimum.");
529        current_max = *vals
530            .iter()
531            .max()
532            .expect("vals iterator should not be empty, and should have a maximum.");
533    }
534
535    if current_min == current_max {
536        Some(Domain::Int(vec![Range::Single(current_min)]))
537    } else {
538        Some(Domain::Int(vec![Range::Bounded(current_min, current_max)]))
539    }
540}
541
542// Returns none if unbounded
543fn range_vec_bounds_i32(ranges: &Vec<Range<i32>>) -> Option<(i32, i32)> {
544    let mut min = i32::MAX;
545    let mut max = i32::MIN;
546    for r in ranges {
547        match r {
548            Range::Single(i) => {
549                if *i < min {
550                    min = *i;
551                }
552                if *i > max {
553                    max = *i;
554                }
555            }
556            Range::Bounded(i, j) => {
557                if *i < min {
558                    min = *i;
559                }
560                if *j > max {
561                    max = *j;
562                }
563            }
564            Range::UnboundedR(_) | Range::UnboundedL(_) => return None,
565        }
566    }
567    Some((min, max))
568}
569
570impl Expression {
571    /// Returns the possible values of the expression, recursing to leaf expressions
572    pub fn domain_of(&self) -> Option<Domain> {
573        let ret = match self {
574            Expression::Union(_, a, b) => Some(Domain::Set(
575                SetAttr::None,
576                Box::new(a.domain_of()?.union(&b.domain_of()?).ok()?),
577            )),
578            Expression::Intersect(_, a, b) => Some(Domain::Set(
579                SetAttr::None,
580                Box::new(a.domain_of()?.intersect(&b.domain_of()?).ok()?),
581            )),
582            Expression::In(_, _, _) => Some(Domain::Bool),
583            Expression::Supset(_, _, _) => Some(Domain::Bool),
584            Expression::SupsetEq(_, _, _) => Some(Domain::Bool),
585            Expression::Subset(_, _, _) => Some(Domain::Bool),
586            Expression::SubsetEq(_, _, _) => Some(Domain::Bool),
587            Expression::AbstractLiteral(_, abslit) => abslit.domain_of(),
588            Expression::DominanceRelation(_, _) => Some(Domain::Bool),
589            Expression::FromSolution(_, expr) => expr.domain_of(),
590            Expression::Comprehension(_, comprehension) => comprehension.domain_of(),
591            Expression::UnsafeIndex(_, matrix, _) | Expression::SafeIndex(_, matrix, _) => {
592                match matrix.domain_of()? {
593                    Domain::Matrix(elem_domain, _) => Some(*elem_domain),
594                    Domain::Tuple(_) => None,
595                    Domain::Record(_) => None,
596                    _ => {
597                        bug!("subject of an index operation should support indexing")
598                    }
599                }
600            }
601            Expression::UnsafeSlice(_, matrix, indices)
602            | Expression::SafeSlice(_, matrix, indices) => {
603                let sliced_dimension = indices.iter().position(Option::is_none);
604
605                let Domain::Matrix(elem_domain, index_domains) = matrix.domain_of()? else {
606                    bug!("subject of an index operation should be a matrix");
607                };
608
609                match sliced_dimension {
610                    Some(dimension) => Some(Domain::Matrix(
611                        elem_domain,
612                        vec![index_domains[dimension].clone()],
613                    )),
614
615                    // same as index
616                    None => Some(*elem_domain),
617                }
618            }
619            Expression::InDomain(_, _, _) => Some(Domain::Bool),
620            Expression::Atomic(_, Atom::Reference(ptr)) => ptr.domain(),
621            Expression::Atomic(_, atom) => Some(atom.domain_of()),
622            Expression::Scope(_, _) => Some(Domain::Bool),
623            Expression::Sum(_, e) => {
624                bounded_i32_domain_for_matrix_literal_monotonic(e, |x, y| Some(x + y))
625            }
626            Expression::Product(_, e) => {
627                bounded_i32_domain_for_matrix_literal_monotonic(e, |x, y| Some(x * y))
628            }
629            Expression::Min(_, e) => bounded_i32_domain_for_matrix_literal_monotonic(e, |x, y| {
630                Some(if x < y { x } else { y })
631            }),
632            Expression::Max(_, e) => bounded_i32_domain_for_matrix_literal_monotonic(e, |x, y| {
633                Some(if x > y { x } else { y })
634            }),
635            Expression::UnsafeDiv(_, a, b) => a
636                .domain_of()?
637                .apply_i32(
638                    // rust integer division is truncating; however, we want to always round down,
639                    // including for negative numbers.
640                    |x, y| {
641                        if y != 0 {
642                            Some((x as f32 / y as f32).floor() as i32)
643                        } else {
644                            None
645                        }
646                    },
647                    &b.domain_of()?,
648                )
649                .ok(),
650            Expression::SafeDiv(_, a, b) => {
651                // rust integer division is truncating; however, we want to always round down
652                // including for negative numbers.
653                let domain = a.domain_of()?.apply_i32(
654                    |x, y| {
655                        if y != 0 {
656                            Some((x as f32 / y as f32).floor() as i32)
657                        } else {
658                            None
659                        }
660                    },
661                    &b.domain_of()?,
662                );
663
664                match domain {
665                    Ok(Domain::Int(ranges)) => {
666                        let mut ranges = ranges;
667                        ranges.push(Range::Single(0));
668                        Some(Domain::Int(ranges))
669                    }
670                    Err(_) => todo!(),
671                    _ => unreachable!(),
672                }
673            }
674            Expression::UnsafeMod(_, a, b) => a
675                .domain_of()?
676                .apply_i32(
677                    |x, y| if y != 0 { Some(x % y) } else { None },
678                    &b.domain_of()?,
679                )
680                .ok(),
681            Expression::SafeMod(_, a, b) => {
682                let domain = a.domain_of()?.apply_i32(
683                    |x, y| if y != 0 { Some(x % y) } else { None },
684                    &b.domain_of()?,
685                );
686
687                match domain {
688                    Ok(Domain::Int(ranges)) => {
689                        let mut ranges = ranges;
690                        ranges.push(Range::Single(0));
691                        Some(Domain::Int(ranges))
692                    }
693                    Err(_) => todo!(),
694                    _ => unreachable!(),
695                }
696            }
697            Expression::SafePow(_, a, b) | Expression::UnsafePow(_, a, b) => a
698                .domain_of()?
699                .apply_i32(
700                    |x, y| {
701                        if (x != 0 || y != 0) && y >= 0 {
702                            Some(x.pow(y as u32))
703                        } else {
704                            None
705                        }
706                    },
707                    &b.domain_of()?,
708                )
709                .ok(),
710            Expression::Root(_, _) => None,
711            Expression::Bubble(_, inner, _) => inner.domain_of(),
712            Expression::AuxDeclaration(_, _, _) => Some(Domain::Bool),
713            Expression::And(_, _) => Some(Domain::Bool),
714            Expression::Not(_, _) => Some(Domain::Bool),
715            Expression::Or(_, _) => Some(Domain::Bool),
716            Expression::Imply(_, _, _) => Some(Domain::Bool),
717            Expression::Iff(_, _, _) => Some(Domain::Bool),
718            Expression::Eq(_, _, _) => Some(Domain::Bool),
719            Expression::Neq(_, _, _) => Some(Domain::Bool),
720            Expression::Geq(_, _, _) => Some(Domain::Bool),
721            Expression::Leq(_, _, _) => Some(Domain::Bool),
722            Expression::Gt(_, _, _) => Some(Domain::Bool),
723            Expression::Lt(_, _, _) => Some(Domain::Bool),
724            Expression::FlatAbsEq(_, _, _) => Some(Domain::Bool),
725            Expression::FlatSumGeq(_, _, _) => Some(Domain::Bool),
726            Expression::FlatSumLeq(_, _, _) => Some(Domain::Bool),
727            Expression::MinionDivEqUndefZero(_, _, _, _) => Some(Domain::Bool),
728            Expression::MinionModuloEqUndefZero(_, _, _, _) => Some(Domain::Bool),
729            Expression::FlatIneq(_, _, _, _) => Some(Domain::Bool),
730            Expression::AllDiff(_, _) => Some(Domain::Bool),
731            Expression::FlatWatchedLiteral(_, _, _) => Some(Domain::Bool),
732            Expression::MinionReify(_, _, _) => Some(Domain::Bool),
733            Expression::MinionReifyImply(_, _, _) => Some(Domain::Bool),
734            Expression::MinionWInIntervalSet(_, _, _) => Some(Domain::Bool),
735            Expression::MinionWInSet(_, _, _) => Some(Domain::Bool),
736            Expression::MinionElementOne(_, _, _, _) => Some(Domain::Bool),
737            Expression::Neg(_, x) => {
738                let Some(Domain::Int(mut ranges)) = x.domain_of() else {
739                    return None;
740                };
741
742                for range in ranges.iter_mut() {
743                    *range = match range {
744                        Range::Single(x) => Range::Single(-*x),
745                        Range::Bounded(x, y) => Range::Bounded(-*y, -*x),
746                        Range::UnboundedR(i) => Range::UnboundedL(-*i),
747                        Range::UnboundedL(i) => Range::UnboundedR(-*i),
748                    };
749                }
750
751                Some(Domain::Int(ranges))
752            }
753            Expression::Minus(_, a, b) => a
754                .domain_of()?
755                .apply_i32(|x, y| Some(x - y), &b.domain_of()?)
756                .ok(),
757            Expression::FlatAllDiff(_, _) => Some(Domain::Bool),
758            Expression::FlatMinusEq(_, _, _) => Some(Domain::Bool),
759            Expression::FlatProductEq(_, _, _, _) => Some(Domain::Bool),
760            Expression::FlatWeightedSumLeq(_, _, _, _) => Some(Domain::Bool),
761            Expression::FlatWeightedSumGeq(_, _, _, _) => Some(Domain::Bool),
762            Expression::Abs(_, a) => a
763                .domain_of()?
764                .apply_i32(|a, _| Some(a.abs()), &a.domain_of()?)
765                .ok(),
766            Expression::MinionPow(_, _, _, _) => Some(Domain::Bool),
767            Expression::ToInt(_, _) => Some(Domain::Int(vec![Range::Bounded(0, 1)])),
768        };
769        match ret {
770            // TODO: (flm8) the Minion bindings currently only support single ranges for domains, so we use the min/max bounds
771            // Once they support a full domain as we define it, we can remove this conversion
772            Some(Domain::Int(ranges)) if ranges.len() > 1 => {
773                let (min, max) = range_vec_bounds_i32(&ranges)?;
774                Some(Domain::Int(vec![Range::Bounded(min, max)]))
775            }
776            _ => ret,
777        }
778    }
779
780    pub fn get_meta(&self) -> Metadata {
781        let metas: VecDeque<Metadata> = self.children_bi();
782        metas[0].clone()
783    }
784
785    pub fn set_meta(&self, meta: Metadata) {
786        self.transform_bi(&|_| meta.clone());
787    }
788
789    /// Checks whether this expression is safe.
790    ///
791    /// An expression is unsafe if can be undefined, or if any of its children can be undefined.
792    ///
793    /// Unsafe expressions are (typically) prefixed with Unsafe in our AST, and can be made
794    /// safe through the use of bubble rules.
795    pub fn is_safe(&self) -> bool {
796        // TODO: memoise in Metadata
797        for expr in self.universe() {
798            match expr {
799                Expression::UnsafeDiv(_, _, _)
800                | Expression::UnsafeMod(_, _, _)
801                | Expression::UnsafePow(_, _, _)
802                | Expression::UnsafeIndex(_, _, _)
803                | Expression::Bubble(_, _, _)
804                | Expression::UnsafeSlice(_, _, _) => {
805                    return false;
806                }
807                _ => {}
808            }
809        }
810        true
811    }
812
813    pub fn is_clean(&self) -> bool {
814        let metadata = self.get_meta();
815        metadata.clean
816    }
817
818    pub fn set_clean(&mut self, bool_value: bool) {
819        let mut metadata = self.get_meta();
820        metadata.clean = bool_value;
821        self.set_meta(metadata);
822    }
823
824    /// True if the expression is an associative and commutative operator
825    pub fn is_associative_commutative_operator(&self) -> bool {
826        TryInto::<ACOperatorKind>::try_into(self).is_ok()
827    }
828
829    /// True if the expression is a matrix literal.
830    ///
831    /// This is true for both forms of matrix literals: those with elements of type [`Literal`] and
832    /// [`Expression`].
833    pub fn is_matrix_literal(&self) -> bool {
834        matches!(
835            self,
836            Expression::AbstractLiteral(_, AbstractLiteral::Matrix(_, _))
837                | Expression::Atomic(
838                    _,
839                    Atom::Literal(Literal::AbstractLiteral(AbstractLiteral::Matrix(_, _))),
840                )
841        )
842    }
843
844    /// True iff self and other are both atomic and identical.
845    ///
846    /// This method is useful to cheaply check equivalence. Assuming CSE is enabled, any unifiable
847    /// expressions will be rewritten to a common variable. This is much cheaper than checking the
848    /// entire subtrees of `self` and `other`.
849    pub fn identical_atom_to(&self, other: &Expression) -> bool {
850        let atom1: Result<&Atom, _> = self.try_into();
851        let atom2: Result<&Atom, _> = other.try_into();
852
853        if let (Ok(atom1), Ok(atom2)) = (atom1, atom2) {
854            atom2 == atom1
855        } else {
856            false
857        }
858    }
859
860    /// If the expression is a list, returns the inner expressions.
861    ///
862    /// A list is any a matrix with the domain `int(1..)`. This includes matrix literals without
863    /// any explicitly specified domain.
864    pub fn unwrap_list(self) -> Option<Vec<Expression>> {
865        match self {
866            Expression::AbstractLiteral(_, matrix @ AbstractLiteral::Matrix(_, _)) => {
867                matrix.unwrap_list().cloned()
868            }
869            Expression::Atomic(
870                _,
871                Atom::Literal(Literal::AbstractLiteral(matrix @ AbstractLiteral::Matrix(_, _))),
872            ) => matrix.unwrap_list().map(|elems| {
873                elems
874                    .clone()
875                    .into_iter()
876                    .map(|x: Literal| Expression::Atomic(Metadata::new(), Atom::Literal(x)))
877                    .collect_vec()
878            }),
879            _ => None,
880        }
881    }
882
883    /// If the expression is a matrix, gets it elements and index domain.
884    ///
885    /// **Consider using the safer [`Expression::unwrap_list`] instead.**
886    ///
887    /// It is generally undefined to edit the length of a matrix unless it is a list (as defined by
888    /// [`Expression::unwrap_list`]). Users of this function should ensure that, if the matrix is
889    /// reconstructed, the index domain and the number of elements in the matrix remain the same.
890    pub fn unwrap_matrix_unchecked(self) -> Option<(Vec<Expression>, Domain)> {
891        match self {
892            Expression::AbstractLiteral(_, AbstractLiteral::Matrix(elems, domain)) => {
893                Some((elems.clone(), *domain))
894            }
895            Expression::Atomic(
896                _,
897                Atom::Literal(Literal::AbstractLiteral(AbstractLiteral::Matrix(elems, domain))),
898            ) => Some((
899                elems
900                    .clone()
901                    .into_iter()
902                    .map(|x: Literal| Expression::Atomic(Metadata::new(), Atom::Literal(x)))
903                    .collect_vec(),
904                *domain,
905            )),
906
907            _ => None,
908        }
909    }
910
911    /// For a Root expression, extends the inner vec with the given vec.
912    ///
913    /// # Panics
914    /// Panics if the expression is not Root.
915    pub fn extend_root(self, exprs: Vec<Expression>) -> Expression {
916        match self {
917            Expression::Root(meta, mut children) => {
918                children.extend(exprs);
919                Expression::Root(meta, children)
920            }
921            _ => panic!("extend_root called on a non-Root expression"),
922        }
923    }
924
925    /// Converts the expression to a literal, if possible.
926    pub fn into_literal(self) -> Option<Literal> {
927        match self {
928            Expression::Atomic(_, Atom::Literal(lit)) => Some(lit),
929            Expression::AbstractLiteral(_, abslit) => {
930                Some(Literal::AbstractLiteral(abslit.clone().into_literals()?))
931            }
932            Expression::Neg(_, e) => {
933                let Literal::Int(i) = Moo::unwrap_or_clone(e).into_literal()? else {
934                    bug!("negated literal should be an int");
935                };
936
937                Some(Literal::Int(-i))
938            }
939
940            _ => None,
941        }
942    }
943
944    /// If this expression is an associative-commutative operator, return its [ACOperatorKind].
945    pub fn to_ac_operator_kind(&self) -> Option<ACOperatorKind> {
946        TryFrom::try_from(self).ok()
947    }
948
949    /// Returns the categories of all sub-expressions of self.
950    pub fn universe_categories(&self) -> HashSet<Category> {
951        self.universe()
952            .into_iter()
953            .map(|x| x.category_of())
954            .collect()
955    }
956}
957
958impl TryFrom<&Expression> for i32 {
959    type Error = ();
960
961    fn try_from(value: &Expression) -> Result<Self, Self::Error> {
962        let Expression::Atomic(_, atom) = value else {
963            return Err(());
964        };
965
966        let Atom::Literal(lit) = atom else {
967            return Err(());
968        };
969
970        let Literal::Int(i) = lit else {
971            return Err(());
972        };
973
974        Ok(*i)
975    }
976}
977
978impl TryFrom<Expression> for i32 {
979    type Error = ();
980
981    fn try_from(value: Expression) -> Result<Self, Self::Error> {
982        TryFrom::<&Expression>::try_from(&value)
983    }
984}
985impl From<i32> for Expression {
986    fn from(i: i32) -> Self {
987        Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Int(i)))
988    }
989}
990
991impl From<bool> for Expression {
992    fn from(b: bool) -> Self {
993        Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Bool(b)))
994    }
995}
996
997impl From<Atom> for Expression {
998    fn from(value: Atom) -> Self {
999        Expression::Atomic(Metadata::new(), value)
1000    }
1001}
1002
1003impl From<Moo<Expression>> for Expression {
1004    fn from(val: Moo<Expression>) -> Self {
1005        val.as_ref().clone()
1006    }
1007}
1008
1009impl CategoryOf for Expression {
1010    fn category_of(&self) -> Category {
1011        // take highest category of all the expressions children
1012        let category = self.cata(&move |x,children| {
1013
1014            if let Some(max_category) = children.iter().max() {
1015                // if this expression contains subexpressions, return the maximum category of the
1016                // subexpressions
1017                *max_category
1018            } else {
1019                // this expression has no children
1020                let mut max_category = Category::Bottom;
1021
1022                // calculate the category by looking at all atoms, submodels, comprehensions, and
1023                // declarationptrs inside this expression
1024
1025                // this should generically cover all leaf types we currently have in oxide.
1026
1027                // if x contains submodels (including comprehensions)
1028                if !Biplate::<SubModel>::universe_bi(&x).is_empty() {
1029                    // assume that the category is decision
1030                    return Category::Decision;
1031                }
1032
1033                // if x contains atoms
1034                if let Some(max_atom_category) = Biplate::<Atom>::universe_bi(&x).iter().map(|x| x.category_of()).max()
1035                // and those atoms have a higher category than we already know about
1036                && max_atom_category > max_category{
1037                    // update category 
1038                    max_category = max_atom_category;
1039                }
1040
1041                // if x contains declarationPtrs
1042                if let Some(max_declaration_category) = Biplate::<DeclarationPtr>::universe_bi(&x).iter().map(|x| x.category_of()).max()
1043                // and those pointers have a higher category than we already know about
1044                && max_declaration_category > max_category{
1045                    // update category 
1046                    max_category = max_declaration_category;
1047                }
1048                max_category
1049
1050            }
1051        });
1052
1053        if cfg!(debug_assertions) {
1054            trace!(
1055                category= %category,
1056                expression= %self,
1057                "Called Expression::category_of()"
1058            );
1059        };
1060        category
1061    }
1062}
1063
1064impl Display for Expression {
1065    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
1066        match &self {
1067            Expression::Union(_, box1, box2) => {
1068                write!(f, "({} union {})", box1.clone(), box2.clone())
1069            }
1070            Expression::In(_, e1, e2) => {
1071                write!(f, "{e1} in {e2}")
1072            }
1073            Expression::Intersect(_, box1, box2) => {
1074                write!(f, "({} intersect {})", box1.clone(), box2.clone())
1075            }
1076            Expression::Supset(_, box1, box2) => {
1077                write!(f, "({} supset {})", box1.clone(), box2.clone())
1078            }
1079            Expression::SupsetEq(_, box1, box2) => {
1080                write!(f, "({} supsetEq {})", box1.clone(), box2.clone())
1081            }
1082            Expression::Subset(_, box1, box2) => {
1083                write!(f, "({} subset {})", box1.clone(), box2.clone())
1084            }
1085            Expression::SubsetEq(_, box1, box2) => {
1086                write!(f, "({} subsetEq {})", box1.clone(), box2.clone())
1087            }
1088
1089            Expression::AbstractLiteral(_, l) => l.fmt(f),
1090            Expression::Comprehension(_, c) => c.fmt(f),
1091            Expression::UnsafeIndex(_, e1, e2) | Expression::SafeIndex(_, e1, e2) => {
1092                write!(f, "{e1}{}", pretty_vec(e2))
1093            }
1094            Expression::UnsafeSlice(_, e1, es) | Expression::SafeSlice(_, e1, es) => {
1095                let args = es
1096                    .iter()
1097                    .map(|x| match x {
1098                        Some(x) => format!("{x}"),
1099                        None => "..".into(),
1100                    })
1101                    .join(",");
1102
1103                write!(f, "{e1}[{args}]")
1104            }
1105            Expression::InDomain(_, e, domain) => {
1106                write!(f, "__inDomain({e},{domain})")
1107            }
1108            Expression::Root(_, exprs) => {
1109                write!(f, "{}", pretty_expressions_as_top_level(exprs))
1110            }
1111            Expression::DominanceRelation(_, expr) => write!(f, "DominanceRelation({expr})"),
1112            Expression::FromSolution(_, expr) => write!(f, "FromSolution({expr})"),
1113            Expression::Atomic(_, atom) => atom.fmt(f),
1114            Expression::Scope(_, submodel) => write!(f, "{{\n{submodel}\n}}"),
1115            Expression::Abs(_, a) => write!(f, "|{a}|"),
1116            Expression::Sum(_, e) => {
1117                write!(f, "sum({e})")
1118            }
1119            Expression::Product(_, e) => {
1120                write!(f, "product({e})")
1121            }
1122            Expression::Min(_, e) => {
1123                write!(f, "min({e})")
1124            }
1125            Expression::Max(_, e) => {
1126                write!(f, "max({e})")
1127            }
1128            Expression::Not(_, expr_box) => {
1129                write!(f, "!({})", expr_box.clone())
1130            }
1131            Expression::Or(_, e) => {
1132                write!(f, "or({e})")
1133            }
1134            Expression::And(_, e) => {
1135                write!(f, "and({e})")
1136            }
1137            Expression::Imply(_, box1, box2) => {
1138                write!(f, "({box1}) -> ({box2})")
1139            }
1140            Expression::Iff(_, box1, box2) => {
1141                write!(f, "({box1}) <-> ({box2})")
1142            }
1143            Expression::Eq(_, box1, box2) => {
1144                write!(f, "({} = {})", box1.clone(), box2.clone())
1145            }
1146            Expression::Neq(_, box1, box2) => {
1147                write!(f, "({} != {})", box1.clone(), box2.clone())
1148            }
1149            Expression::Geq(_, box1, box2) => {
1150                write!(f, "({} >= {})", box1.clone(), box2.clone())
1151            }
1152            Expression::Leq(_, box1, box2) => {
1153                write!(f, "({} <= {})", box1.clone(), box2.clone())
1154            }
1155            Expression::Gt(_, box1, box2) => {
1156                write!(f, "({} > {})", box1.clone(), box2.clone())
1157            }
1158            Expression::Lt(_, box1, box2) => {
1159                write!(f, "({} < {})", box1.clone(), box2.clone())
1160            }
1161            Expression::FlatSumGeq(_, box1, box2) => {
1162                write!(f, "SumGeq({}, {})", pretty_vec(box1), box2.clone())
1163            }
1164            Expression::FlatSumLeq(_, box1, box2) => {
1165                write!(f, "SumLeq({}, {})", pretty_vec(box1), box2.clone())
1166            }
1167            Expression::FlatIneq(_, box1, box2, box3) => write!(
1168                f,
1169                "Ineq({}, {}, {})",
1170                box1.clone(),
1171                box2.clone(),
1172                box3.clone()
1173            ),
1174            Expression::AllDiff(_, e) => {
1175                write!(f, "allDiff({e})")
1176            }
1177            Expression::Bubble(_, box1, box2) => {
1178                write!(f, "{{{} @ {}}}", box1.clone(), box2.clone())
1179            }
1180            Expression::SafeDiv(_, box1, box2) => {
1181                write!(f, "SafeDiv({}, {})", box1.clone(), box2.clone())
1182            }
1183            Expression::UnsafeDiv(_, box1, box2) => {
1184                write!(f, "UnsafeDiv({}, {})", box1.clone(), box2.clone())
1185            }
1186            Expression::UnsafePow(_, box1, box2) => {
1187                write!(f, "UnsafePow({}, {})", box1.clone(), box2.clone())
1188            }
1189            Expression::SafePow(_, box1, box2) => {
1190                write!(f, "SafePow({}, {})", box1.clone(), box2.clone())
1191            }
1192            Expression::MinionDivEqUndefZero(_, box1, box2, box3) => {
1193                write!(
1194                    f,
1195                    "DivEq({}, {}, {})",
1196                    box1.clone(),
1197                    box2.clone(),
1198                    box3.clone()
1199                )
1200            }
1201            Expression::MinionModuloEqUndefZero(_, box1, box2, box3) => {
1202                write!(
1203                    f,
1204                    "ModEq({}, {}, {})",
1205                    box1.clone(),
1206                    box2.clone(),
1207                    box3.clone()
1208                )
1209            }
1210            Expression::FlatWatchedLiteral(_, x, l) => {
1211                write!(f, "WatchedLiteral({x},{l})", x = &x.name() as &Name)
1212            }
1213            Expression::MinionReify(_, box1, box2) => {
1214                write!(f, "Reify({}, {})", box1.clone(), box2.clone())
1215            }
1216            Expression::MinionReifyImply(_, box1, box2) => {
1217                write!(f, "ReifyImply({}, {})", box1.clone(), box2.clone())
1218            }
1219            Expression::MinionWInIntervalSet(_, atom, intervals) => {
1220                let intervals = intervals.iter().join(",");
1221                write!(f, "__minion_w_inintervalset({atom},[{intervals}])")
1222            }
1223            Expression::MinionWInSet(_, atom, values) => {
1224                let values = values.iter().join(",");
1225                write!(f, "__minion_w_inset({atom},{values})")
1226            }
1227            Expression::AuxDeclaration(_, decl, e) => {
1228                write!(f, "{} =aux {}", &decl.name() as &Name, e.clone())
1229            }
1230            Expression::UnsafeMod(_, a, b) => {
1231                write!(f, "{} % {}", a.clone(), b.clone())
1232            }
1233            Expression::SafeMod(_, a, b) => {
1234                write!(f, "SafeMod({},{})", a.clone(), b.clone())
1235            }
1236            Expression::Neg(_, a) => {
1237                write!(f, "-({})", a.clone())
1238            }
1239            Expression::Minus(_, a, b) => {
1240                write!(f, "({} - {})", a.clone(), b.clone())
1241            }
1242            Expression::FlatAllDiff(_, es) => {
1243                write!(f, "__flat_alldiff({})", pretty_vec(es))
1244            }
1245            Expression::FlatAbsEq(_, a, b) => {
1246                write!(f, "AbsEq({},{})", a.clone(), b.clone())
1247            }
1248            Expression::FlatMinusEq(_, a, b) => {
1249                write!(f, "MinusEq({},{})", a.clone(), b.clone())
1250            }
1251            Expression::FlatProductEq(_, a, b, c) => {
1252                write!(
1253                    f,
1254                    "FlatProductEq({},{},{})",
1255                    a.clone(),
1256                    b.clone(),
1257                    c.clone()
1258                )
1259            }
1260            Expression::FlatWeightedSumLeq(_, cs, vs, total) => {
1261                write!(
1262                    f,
1263                    "FlatWeightedSumLeq({},{},{})",
1264                    pretty_vec(cs),
1265                    pretty_vec(vs),
1266                    total.clone()
1267                )
1268            }
1269            Expression::FlatWeightedSumGeq(_, cs, vs, total) => {
1270                write!(
1271                    f,
1272                    "FlatWeightedSumGeq({},{},{})",
1273                    pretty_vec(cs),
1274                    pretty_vec(vs),
1275                    total.clone()
1276                )
1277            }
1278            Expression::MinionPow(_, atom, atom1, atom2) => {
1279                write!(f, "MinionPow({atom},{atom1},{atom2})")
1280            }
1281            Expression::MinionElementOne(_, atoms, atom, atom1) => {
1282                let atoms = atoms.iter().join(",");
1283                write!(f, "__minion_element_one([{atoms}],{atom},{atom1})")
1284            }
1285
1286            Expression::ToInt(_, expr) => {
1287                write!(f, "toInt({expr})")
1288            }
1289        }
1290    }
1291}
1292
1293impl Typeable for Expression {
1294    fn return_type(&self) -> Option<ReturnType> {
1295        match self {
1296            Expression::Union(_, subject, _) => {
1297                Some(ReturnType::Set(Box::new(subject.return_type()?)))
1298            }
1299            Expression::Intersect(_, subject, _) => {
1300                Some(ReturnType::Set(Box::new(subject.return_type()?)))
1301            }
1302            Expression::In(_, _, _) => Some(ReturnType::Bool),
1303            Expression::Supset(_, _, _) => Some(ReturnType::Bool),
1304            Expression::SupsetEq(_, _, _) => Some(ReturnType::Bool),
1305            Expression::Subset(_, _, _) => Some(ReturnType::Bool),
1306            Expression::SubsetEq(_, _, _) => Some(ReturnType::Bool),
1307            Expression::AbstractLiteral(_, lit) => lit.return_type(),
1308            Expression::UnsafeIndex(_, subject, _) | Expression::SafeIndex(_, subject, _) => {
1309                let mut elem_typ = subject.return_type()?;
1310                let ReturnType::Matrix(_) = elem_typ else {
1311                    return None;
1312                };
1313
1314                // unwrap the return types of n-d matrices to get to the real element typetype.
1315                while let ReturnType::Matrix(new_elem_typ) = elem_typ {
1316                    elem_typ = *new_elem_typ;
1317                }
1318
1319                Some(elem_typ)
1320            }
1321            Expression::UnsafeSlice(_, subject, _) | Expression::SafeSlice(_, subject, _) => {
1322                Some(ReturnType::Matrix(Box::new(subject.return_type()?)))
1323            }
1324            Expression::InDomain(_, _, _) => Some(ReturnType::Bool),
1325            Expression::Comprehension(_, _) => None,
1326            Expression::Root(_, _) => Some(ReturnType::Bool),
1327            Expression::DominanceRelation(_, _) => Some(ReturnType::Bool),
1328            Expression::FromSolution(_, expr) => expr.return_type(),
1329            Expression::Atomic(_, atom) => atom.return_type(),
1330            Expression::Scope(_, scope) => scope.return_type(),
1331            Expression::Abs(_, _) => Some(ReturnType::Int),
1332            Expression::Sum(_, _) => Some(ReturnType::Int),
1333            Expression::Product(_, _) => Some(ReturnType::Int),
1334            Expression::Min(_, _) => Some(ReturnType::Int),
1335            Expression::Max(_, _) => Some(ReturnType::Int),
1336            Expression::Not(_, _) => Some(ReturnType::Bool),
1337            Expression::Or(_, _) => Some(ReturnType::Bool),
1338            Expression::Imply(_, _, _) => Some(ReturnType::Bool),
1339            Expression::Iff(_, _, _) => Some(ReturnType::Bool),
1340            Expression::And(_, _) => Some(ReturnType::Bool),
1341            Expression::Eq(_, _, _) => Some(ReturnType::Bool),
1342            Expression::Neq(_, _, _) => Some(ReturnType::Bool),
1343            Expression::Geq(_, _, _) => Some(ReturnType::Bool),
1344            Expression::Leq(_, _, _) => Some(ReturnType::Bool),
1345            Expression::Gt(_, _, _) => Some(ReturnType::Bool),
1346            Expression::Lt(_, _, _) => Some(ReturnType::Bool),
1347            Expression::SafeDiv(_, _, _) => Some(ReturnType::Int),
1348            Expression::UnsafeDiv(_, _, _) => Some(ReturnType::Int),
1349            Expression::FlatAllDiff(_, _) => Some(ReturnType::Bool),
1350            Expression::FlatSumGeq(_, _, _) => Some(ReturnType::Bool),
1351            Expression::FlatSumLeq(_, _, _) => Some(ReturnType::Bool),
1352            Expression::MinionDivEqUndefZero(_, _, _, _) => Some(ReturnType::Bool),
1353            Expression::FlatIneq(_, _, _, _) => Some(ReturnType::Bool),
1354            Expression::AllDiff(_, _) => Some(ReturnType::Bool),
1355            Expression::Bubble(_, inner, _) => inner.return_type(),
1356            Expression::FlatWatchedLiteral(_, _, _) => Some(ReturnType::Bool),
1357            Expression::MinionReify(_, _, _) => Some(ReturnType::Bool),
1358            Expression::MinionReifyImply(_, _, _) => Some(ReturnType::Bool),
1359            Expression::MinionWInIntervalSet(_, _, _) => Some(ReturnType::Bool),
1360            Expression::MinionWInSet(_, _, _) => Some(ReturnType::Bool),
1361            Expression::MinionElementOne(_, _, _, _) => Some(ReturnType::Bool),
1362            Expression::AuxDeclaration(_, _, _) => Some(ReturnType::Bool),
1363            Expression::UnsafeMod(_, _, _) => Some(ReturnType::Int),
1364            Expression::SafeMod(_, _, _) => Some(ReturnType::Int),
1365            Expression::MinionModuloEqUndefZero(_, _, _, _) => Some(ReturnType::Bool),
1366            Expression::Neg(_, _) => Some(ReturnType::Int),
1367            Expression::UnsafePow(_, _, _) => Some(ReturnType::Int),
1368            Expression::SafePow(_, _, _) => Some(ReturnType::Int),
1369            Expression::Minus(_, _, _) => Some(ReturnType::Int),
1370            Expression::FlatAbsEq(_, _, _) => Some(ReturnType::Bool),
1371            Expression::FlatMinusEq(_, _, _) => Some(ReturnType::Bool),
1372            Expression::FlatProductEq(_, _, _, _) => Some(ReturnType::Bool),
1373            Expression::FlatWeightedSumLeq(_, _, _, _) => Some(ReturnType::Bool),
1374            Expression::FlatWeightedSumGeq(_, _, _, _) => Some(ReturnType::Bool),
1375            Expression::MinionPow(_, _, _, _) => Some(ReturnType::Bool),
1376            Expression::ToInt(_, _) => Some(ReturnType::Int),
1377        }
1378    }
1379}
1380
1381#[cfg(test)]
1382mod tests {
1383
1384    use crate::matrix_expr;
1385
1386    use super::*;
1387
1388    #[test]
1389    fn test_domain_of_constant_sum() {
1390        let c1 = Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Int(1)));
1391        let c2 = Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Int(2)));
1392        let sum = Expression::Sum(
1393            Metadata::new(),
1394            Moo::new(matrix_expr![c1.clone(), c2.clone()]),
1395        );
1396        assert_eq!(sum.domain_of(), Some(Domain::Int(vec![Range::Single(3)])));
1397    }
1398
1399    #[test]
1400    fn test_domain_of_constant_invalid_type() {
1401        let c1 = Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Int(1)));
1402        let c2 = Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Bool(true)));
1403        let sum = Expression::Sum(
1404            Metadata::new(),
1405            Moo::new(matrix_expr![c1.clone(), c2.clone()]),
1406        );
1407        assert_eq!(sum.domain_of(), None);
1408    }
1409
1410    #[test]
1411    fn test_domain_of_empty_sum() {
1412        let sum = Expression::Sum(Metadata::new(), Moo::new(matrix_expr![]));
1413        assert_eq!(sum.domain_of(), None);
1414    }
1415}