conjure_core/ast/
expressions.rs

1use std::collections::VecDeque;
2use std::fmt::{Display, Formatter};
3use std::sync::Arc;
4
5use itertools::Itertools;
6use serde::{Deserialize, Serialize};
7
8use crate::ast::literals::AbstractLiteral;
9use crate::ast::literals::Literal;
10use crate::ast::pretty::{pretty_expressions_as_top_level, pretty_vec};
11use crate::ast::symbol_table::SymbolTable;
12use crate::ast::Atom;
13use crate::ast::Name;
14use crate::ast::ReturnType;
15use crate::bug;
16use crate::metadata::Metadata;
17use enum_compatability_macro::document_compatibility;
18use uniplate::derive::Uniplate;
19use uniplate::{Biplate, Uniplate as _};
20
21use super::{Domain, Range, SubModel, Typeable};
22
23/// Represents different types of expressions used to define rules and constraints in the model.
24///
25/// The `Expression` enum includes operations, constants, and variable references
26/// used to build rules and conditions for the model.
27#[document_compatibility]
28#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize, Uniplate)]
29#[uniplate(walk_into=[Atom,SubModel,AbstractLiteral<Expression>])]
30#[biplate(to=Metadata)]
31#[biplate(to=Atom)]
32#[biplate(to=Name)]
33#[biplate(to=Vec<Expression>)]
34#[biplate(to=Option<Expression>)]
35#[biplate(to=SubModel)]
36#[biplate(to=AbstractLiteral<Expression>)]
37#[biplate(to=AbstractLiteral<Literal>,walk_into=[Atom])]
38#[biplate(to=Literal,walk_into=[Atom])]
39pub enum Expression {
40    AbstractLiteral(Metadata, AbstractLiteral<Expression>),
41    /// The top of the model
42    Root(Metadata, Vec<Expression>),
43
44    /// An expression representing "A is valid as long as B is true"
45    /// Turns into a conjunction when it reaches a boolean context
46    Bubble(Metadata, Box<Expression>, Box<Expression>),
47
48    /// Defines dominance ("Solution A is preferred over Solution B")
49    DominanceRelation(Metadata, Box<Expression>),
50    /// `fromSolution(name)` - Used in dominance relation definitions
51    FromSolution(Metadata, Box<Expression>),
52
53    Atomic(Metadata, Atom),
54
55    /// A matrix index.
56    ///
57    /// Defined iff the indices are within their respective index domains.
58    #[compatible(JsonInput)]
59    UnsafeIndex(Metadata, Box<Expression>, Vec<Expression>),
60
61    /// A safe matrix index.
62    ///
63    /// See [`Expression::UnsafeIndex`]
64    SafeIndex(Metadata, Box<Expression>, Vec<Expression>),
65
66    /// A matrix slice: `a[indices]`.
67    ///
68    /// One of the indicies may be `None`, representing the dimension of the matrix we want to take
69    /// a slice of. For example, for some 3d matrix a, `a[1,..,2]` has the indices
70    /// `Some(1),None,Some(2)`.
71    ///
72    /// It is assumed that the slice only has one "wild-card" dimension and thus is 1 dimensional.
73    ///
74    /// Defined iff the defined indices are within their respective index domains.
75    #[compatible(JsonInput)]
76    UnsafeSlice(Metadata, Box<Expression>, Vec<Option<Expression>>),
77
78    /// A safe matrix slice: `a[indices]`.
79    ///
80    /// See [`Expression::UnsafeSlice`].
81    SafeSlice(Metadata, Box<Expression>, Vec<Option<Expression>>),
82
83    /// `inDomain(x,domain)` iff `x` is in the domain `domain`.
84    ///
85    /// This cannot be constructed from Essence input, nor passed to a solver: this expression is
86    /// mainly used during the conversion of `UnsafeIndex` and `UnsafeSlice` to `SafeIndex` and
87    /// `SafeSlice` respectively.
88    InDomain(Metadata, Box<Expression>, Domain),
89
90    Scope(Metadata, Box<SubModel>),
91
92    /// `|x|` - absolute value of `x`
93    #[compatible(JsonInput)]
94    Abs(Metadata, Box<Expression>),
95
96    /// `a + b + c + ...`
97    #[compatible(JsonInput)]
98    Sum(Metadata, Vec<Expression>),
99
100    /// `a * b * c * ...`
101    #[compatible(JsonInput)]
102    Product(Metadata, Vec<Expression>),
103
104    /// `min(<vec_expr>)`
105    #[compatible(JsonInput)]
106    Min(Metadata, Box<Expression>),
107
108    /// `max(<vec_expr>)`
109    #[compatible(JsonInput)]
110    Max(Metadata, Box<Expression>),
111
112    /// `not(a)`
113    #[compatible(JsonInput, SAT)]
114    Not(Metadata, Box<Expression>),
115
116    /// `or(<vec_expr>)`
117    #[compatible(JsonInput, SAT)]
118    Or(Metadata, Box<Expression>),
119
120    /// `and(<vec_expr>)`
121    #[compatible(JsonInput, SAT)]
122    And(Metadata, Box<Expression>),
123
124    /// Ensures that `a->b` (material implication).
125    #[compatible(JsonInput)]
126    Imply(Metadata, Box<Expression>, Box<Expression>),
127
128    #[compatible(JsonInput)]
129    Eq(Metadata, Box<Expression>, Box<Expression>),
130
131    #[compatible(JsonInput)]
132    Neq(Metadata, Box<Expression>, Box<Expression>),
133
134    #[compatible(JsonInput)]
135    Geq(Metadata, Box<Expression>, Box<Expression>),
136
137    #[compatible(JsonInput)]
138    Leq(Metadata, Box<Expression>, Box<Expression>),
139
140    #[compatible(JsonInput)]
141    Gt(Metadata, Box<Expression>, Box<Expression>),
142
143    #[compatible(JsonInput)]
144    Lt(Metadata, Box<Expression>, Box<Expression>),
145
146    /// Division after preventing division by zero, usually with a bubble
147    SafeDiv(Metadata, Box<Expression>, Box<Expression>),
148
149    /// Division with a possibly undefined value (division by 0)
150    #[compatible(JsonInput)]
151    UnsafeDiv(Metadata, Box<Expression>, Box<Expression>),
152
153    /// Modulo after preventing mod 0, usually with a bubble
154    SafeMod(Metadata, Box<Expression>, Box<Expression>),
155
156    /// Modulo with a possibly undefined value (mod 0)
157    #[compatible(JsonInput)]
158    UnsafeMod(Metadata, Box<Expression>, Box<Expression>),
159
160    /// Negation: `-x`
161    #[compatible(JsonInput)]
162    Neg(Metadata, Box<Expression>),
163
164    /// Unsafe power`x**y` (possibly undefined)
165    ///
166    /// Defined when (X!=0 \\/ Y!=0) /\ Y>=0
167    #[compatible(JsonInput)]
168    UnsafePow(Metadata, Box<Expression>, Box<Expression>),
169
170    /// `UnsafePow` after preventing undefinedness
171    SafePow(Metadata, Box<Expression>, Box<Expression>),
172
173    /// `allDiff(<vec_expr>)`
174    #[compatible(JsonInput)]
175    AllDiff(Metadata, Box<Expression>),
176
177    /// Binary subtraction operator
178    ///
179    /// This is a parser-level construct, and is immediately normalised to `Sum([a,-b])`.
180    #[compatible(JsonInput)]
181    Minus(Metadata, Box<Expression>, Box<Expression>),
182
183    /// Ensures that x=|y| i.e. x is the absolute value of y.
184    ///
185    /// Low-level Minion constraint.
186    ///
187    /// # See also
188    ///
189    /// + [Minion documentation](https://minion-solver.readthedocs.io/en/stable/usage/constraints.html#abs)
190    #[compatible(Minion)]
191    FlatAbsEq(Metadata, Atom, Atom),
192
193    /// Ensures that sum(vec) >= x.
194    ///
195    /// Low-level Minion constraint.
196    ///
197    /// # See also
198    ///
199    /// + [Minion documentation](https://minion-solver.readthedocs.io/en/stable/usage/constraints.html#sumgeq)
200    #[compatible(Minion)]
201    FlatSumGeq(Metadata, Vec<Atom>, Atom),
202
203    /// Ensures that sum(vec) <= x.
204    ///
205    /// Low-level Minion constraint.
206    ///
207    /// # See also
208    ///
209    /// + [Minion documentation](https://minion-solver.readthedocs.io/en/stable/usage/constraints.html#sumleq)
210    #[compatible(Minion)]
211    FlatSumLeq(Metadata, Vec<Atom>, Atom),
212
213    /// `ineq(x,y,k)` ensures that x <= y + k.
214    ///
215    /// Low-level Minion constraint.
216    ///
217    /// # See also
218    ///
219    /// + [Minion documentation](https://minion-solver.readthedocs.io/en/stable/usage/constraints.html#ineq)
220    #[compatible(Minion)]
221    FlatIneq(Metadata, Atom, Atom, Literal),
222
223    /// `w-literal(x,k)` ensures that x == k, where x is a variable and k a constant.
224    ///
225    /// Low-level Minion constraint.
226    ///
227    /// This is a low-level Minion constraint and you should probably use Eq instead. The main use
228    /// of w-literal is to convert boolean variables to constraints so that they can be used inside
229    /// watched-and and watched-or.
230    ///
231    /// # See also
232    ///
233    /// + [Minion documentation](https://minion-solver.readthedocs.io/en/stable/usage/constraints.html#minuseq)
234    /// + `rules::minion::boolean_literal_to_wliteral`.
235    #[compatible(Minion)]
236    FlatWatchedLiteral(Metadata, Name, Literal),
237
238    /// `weightedsumleq(cs,xs,total)` ensures that cs.xs <= total, where cs.xs is the scalar dot
239    /// product of cs and xs.
240    ///
241    /// Low-level Minion constraint.
242    ///
243    /// Represents a weighted sum of the form `ax + by + cz + ...`
244    ///
245    /// # See also
246    ///
247    /// + [Minion
248    /// documentation](https://minion-solver.readthedocs.io/en/stable/usage/constraints.html#weightedsumleq)
249    FlatWeightedSumLeq(Metadata, Vec<Literal>, Vec<Atom>, Atom),
250
251    /// `weightedsumgeq(cs,xs,total)` ensures that cs.xs >= total, where cs.xs is the scalar dot
252    /// product of cs and xs.
253    ///
254    /// Low-level Minion constraint.
255    ///
256    /// Represents a weighted sum of the form `ax + by + cz + ...`
257    ///
258    /// # See also
259    ///
260    /// + [Minion
261    /// documentation](https://minion-solver.readthedocs.io/en/stable/usage/constraints.html#weightedsumleq)
262    FlatWeightedSumGeq(Metadata, Vec<Literal>, Vec<Atom>, Atom),
263
264    /// Ensures that x =-y, where x and y are atoms.
265    ///
266    /// Low-level Minion constraint.
267    ///
268    /// # See also
269    ///
270    /// + [Minion documentation](https://minion-solver.readthedocs.io/en/stable/usage/constraints.html#minuseq)
271    #[compatible(Minion)]
272    FlatMinusEq(Metadata, Atom, Atom),
273
274    /// Ensures that x*y=z.
275    ///
276    /// Low-level Minion constraint.
277    ///
278    /// # See also
279    ///
280    /// + [Minion documentation](https://minion-solver.readthedocs.io/en/stable/usage/constraints.html#product)
281    #[compatible(Minion)]
282    FlatProductEq(Metadata, Atom, Atom, Atom),
283
284    /// Ensures that floor(x/y)=z. Always true when y=0.
285    ///
286    /// Low-level Minion constraint.
287    ///
288    /// # See also
289    ///
290    /// + [Minion documentation](https://minion-solver.readthedocs.io/en/stable/usage/constraints.html#div_undefzero)
291    #[compatible(Minion)]
292    MinionDivEqUndefZero(Metadata, Atom, Atom, Atom),
293
294    /// Ensures that x%y=z. Always true when y=0.
295    ///
296    /// Low-level Minion constraint.
297    ///
298    /// # See also
299    ///
300    /// + [Minion documentation](https://minion-solver.readthedocs.io/en/stable/usage/constraints.html#mod_undefzero)
301    #[compatible(Minion)]
302    MinionModuloEqUndefZero(Metadata, Atom, Atom, Atom),
303
304    /// Ensures that `x**y = z`.
305    ///
306    /// Low-level Minion constraint.
307    ///
308    /// This constraint is false when `y<0` except for `1**y=1` and `(-1)**y=z` (where z is 1 if y
309    /// is odd and z is -1 if y is even).
310    ///
311    /// # See also
312    ///
313    /// + [Github comment about `pow` semantics](https://github.com/minion/minion/issues/40#issuecomment-2595914891)
314    /// + [Minion documentation](https://minion-solver.readthedocs.io/en/stable/usage/constraints.html#pow)
315    MinionPow(Metadata, Atom, Atom, Atom),
316
317    /// `reify(constraint,r)` ensures that r=1 iff `constraint` is satisfied, where r is a 0/1
318    /// variable.
319    ///
320    /// Low-level Minion constraint.
321    ///
322    /// # See also
323    ///
324    ///  + [Minion documentation](https://minion-solver.readthedocs.io/en/stable/usage/constraints.html#reify)
325    #[compatible(Minion)]
326    MinionReify(Metadata, Box<Expression>, Atom),
327
328    /// `reifyimply(constraint,r)` ensures that `r->constraint`, where r is a 0/1 variable.
329    /// variable.
330    ///
331    /// Low-level Minion constraint.
332    ///
333    /// # See also
334    ///
335    ///  + [Minion documentation](https://minion-solver.readthedocs.io/en/stable/usage/constraints.html#reifyimply)
336    #[compatible(Minion)]
337    MinionReifyImply(Metadata, Box<Expression>, Atom),
338
339    /// Declaration of an auxiliary variable.
340    ///
341    /// As with Savile Row, we semantically distinguish this from `Eq`.
342    #[compatible(Minion)]
343    AuxDeclaration(Metadata, Name, Box<Expression>),
344}
345
346fn expr_vec_to_domain_i32(
347    exprs: &[Expression],
348    op: fn(i32, i32) -> Option<i32>,
349    vars: &SymbolTable,
350) -> Option<Domain> {
351    let domains: Vec<Option<_>> = exprs.iter().map(|e| e.domain_of(vars)).collect();
352    domains
353        .into_iter()
354        .reduce(|a, b| a.and_then(|x| b.and_then(|y| x.apply_i32(op, &y))))
355        .flatten()
356}
357fn expr_vec_lit_to_domain_i32(
358    e: &Expression,
359    op: fn(i32, i32) -> Option<i32>,
360    vars: &SymbolTable,
361) -> Option<Domain> {
362    let exprs = e.clone().unwrap_list()?;
363    expr_vec_to_domain_i32(&exprs, op, vars)
364}
365
366// Returns none if unbounded
367fn range_vec_bounds_i32(ranges: &Vec<Range<i32>>) -> Option<(i32, i32)> {
368    let mut min = i32::MAX;
369    let mut max = i32::MIN;
370    for r in ranges {
371        match r {
372            Range::Single(i) => {
373                if *i < min {
374                    min = *i;
375                }
376                if *i > max {
377                    max = *i;
378                }
379            }
380            Range::Bounded(i, j) => {
381                if *i < min {
382                    min = *i;
383                }
384                if *j > max {
385                    max = *j;
386                }
387            }
388            Range::UnboundedR(_) | Range::UnboundedL(_) => return None,
389        }
390    }
391    Some((min, max))
392}
393
394impl Expression {
395    /// Returns the possible values of the expression, recursing to leaf expressions
396    pub fn domain_of(&self, syms: &SymbolTable) -> Option<Domain> {
397        let ret = match self {
398            //todo
399            Expression::AbstractLiteral(_, _) => None,
400            Expression::DominanceRelation(_, _) => Some(Domain::BoolDomain),
401            Expression::FromSolution(_, expr) => expr.domain_of(syms),
402            Expression::UnsafeIndex(_, matrix, _) | Expression::SafeIndex(_, matrix, _) => {
403                let Domain::DomainMatrix(elem_domain, _) = matrix.domain_of(syms)? else {
404                    bug!("subject of an index operation should be a matrix");
405                };
406
407                Some(*elem_domain)
408            }
409            Expression::UnsafeSlice(_, matrix, indices)
410            | Expression::SafeSlice(_, matrix, indices) => {
411                let sliced_dimension = indices.iter().position(Option::is_none);
412
413                let Domain::DomainMatrix(elem_domain, index_domains) = matrix.domain_of(syms)?
414                else {
415                    bug!("subject of an index operation should be a matrix");
416                };
417
418                match sliced_dimension {
419                    Some(dimension) => Some(Domain::DomainMatrix(
420                        elem_domain,
421                        vec![index_domains[dimension].clone()],
422                    )),
423
424                    // same as index
425                    None => Some(*elem_domain),
426                }
427            }
428            Expression::InDomain(_, _, _) => Some(Domain::BoolDomain),
429            Expression::Atomic(_, Atom::Reference(name)) => Some(syms.resolve_domain(name)?),
430            Expression::Atomic(_, Atom::Literal(Literal::Int(n))) => {
431                Some(Domain::IntDomain(vec![Range::Single(*n)]))
432            }
433            Expression::Atomic(_, Atom::Literal(Literal::Bool(_))) => Some(Domain::BoolDomain),
434            Expression::Atomic(_, Atom::Literal(Literal::AbstractLiteral(_))) => None,
435            Expression::Scope(_, _) => Some(Domain::BoolDomain),
436            Expression::Sum(_, exprs) => expr_vec_to_domain_i32(exprs, |x, y| Some(x + y), syms),
437            Expression::Product(_, exprs) => {
438                expr_vec_to_domain_i32(exprs, |x, y| Some(x * y), syms)
439            }
440            Expression::Min(_, e) => {
441                expr_vec_lit_to_domain_i32(e, |x, y| Some(if x < y { x } else { y }), syms)
442            }
443            Expression::Max(_, e) => {
444                expr_vec_lit_to_domain_i32(e, |x, y| Some(if x > y { x } else { y }), syms)
445            }
446            Expression::UnsafeDiv(_, a, b) => a.domain_of(syms)?.apply_i32(
447                // rust integer division is truncating; however, we want to always round down,
448                // including for negative numbers.
449                |x, y| {
450                    if y != 0 {
451                        Some((x as f32 / y as f32).floor() as i32)
452                    } else {
453                        None
454                    }
455                },
456                &b.domain_of(syms)?,
457            ),
458            Expression::SafeDiv(_, a, b) => {
459                // rust integer division is truncating; however, we want to always round down
460                // including for negative numbers.
461                let domain = a.domain_of(syms)?.apply_i32(
462                    |x, y| {
463                        if y != 0 {
464                            Some((x as f32 / y as f32).floor() as i32)
465                        } else {
466                            None
467                        }
468                    },
469                    &b.domain_of(syms)?,
470                );
471
472                match domain {
473                    Some(Domain::IntDomain(ranges)) => {
474                        let mut ranges = ranges;
475                        ranges.push(Range::Single(0));
476                        Some(Domain::IntDomain(ranges))
477                    }
478                    None => Some(Domain::IntDomain(vec![Range::Single(0)])),
479                    _ => None,
480                }
481            }
482            Expression::UnsafeMod(_, a, b) => a.domain_of(syms)?.apply_i32(
483                |x, y| if y != 0 { Some(x % y) } else { None },
484                &b.domain_of(syms)?,
485            ),
486
487            Expression::SafeMod(_, a, b) => {
488                let domain = a.domain_of(syms)?.apply_i32(
489                    |x, y| if y != 0 { Some(x % y) } else { None },
490                    &b.domain_of(syms)?,
491                );
492
493                match domain {
494                    Some(Domain::IntDomain(ranges)) => {
495                        let mut ranges = ranges;
496                        ranges.push(Range::Single(0));
497                        Some(Domain::IntDomain(ranges))
498                    }
499                    None => Some(Domain::IntDomain(vec![Range::Single(0)])),
500                    _ => None,
501                }
502            }
503
504            Expression::SafePow(_, a, b) | Expression::UnsafePow(_, a, b) => {
505                a.domain_of(syms)?.apply_i32(
506                    |x, y| {
507                        if (x != 0 || y != 0) && y >= 0 {
508                            Some(x ^ y)
509                        } else {
510                            None
511                        }
512                    },
513                    &b.domain_of(syms)?,
514                )
515            }
516
517            Expression::Root(_, _) => None,
518            Expression::Bubble(_, _, _) => None,
519            Expression::AuxDeclaration(_, _, _) => Some(Domain::BoolDomain),
520            Expression::And(_, _) => Some(Domain::BoolDomain),
521            Expression::Not(_, _) => Some(Domain::BoolDomain),
522            Expression::Or(_, _) => Some(Domain::BoolDomain),
523            Expression::Imply(_, _, _) => Some(Domain::BoolDomain),
524            Expression::Eq(_, _, _) => Some(Domain::BoolDomain),
525            Expression::Neq(_, _, _) => Some(Domain::BoolDomain),
526            Expression::Geq(_, _, _) => Some(Domain::BoolDomain),
527            Expression::Leq(_, _, _) => Some(Domain::BoolDomain),
528            Expression::Gt(_, _, _) => Some(Domain::BoolDomain),
529            Expression::Lt(_, _, _) => Some(Domain::BoolDomain),
530            Expression::FlatAbsEq(_, _, _) => Some(Domain::BoolDomain),
531            Expression::FlatSumGeq(_, _, _) => Some(Domain::BoolDomain),
532            Expression::FlatSumLeq(_, _, _) => Some(Domain::BoolDomain),
533            Expression::MinionDivEqUndefZero(_, _, _, _) => Some(Domain::BoolDomain),
534            Expression::MinionModuloEqUndefZero(_, _, _, _) => Some(Domain::BoolDomain),
535            Expression::FlatIneq(_, _, _, _) => Some(Domain::BoolDomain),
536            Expression::AllDiff(_, _) => Some(Domain::BoolDomain),
537            Expression::FlatWatchedLiteral(_, _, _) => Some(Domain::BoolDomain),
538            Expression::MinionReify(_, _, _) => Some(Domain::BoolDomain),
539            Expression::MinionReifyImply(_, _, _) => Some(Domain::BoolDomain),
540            Expression::Neg(_, x) => {
541                let Some(Domain::IntDomain(mut ranges)) = x.domain_of(syms) else {
542                    return None;
543                };
544
545                for range in ranges.iter_mut() {
546                    *range = match range {
547                        Range::Single(x) => Range::Single(-*x),
548                        Range::Bounded(x, y) => Range::Bounded(-*y, -*x),
549                        Range::UnboundedR(i) => Range::UnboundedL(-*i),
550                        Range::UnboundedL(i) => Range::UnboundedR(-*i),
551                    };
552                }
553
554                Some(Domain::IntDomain(ranges))
555            }
556            Expression::Minus(_, a, b) => a
557                .domain_of(syms)?
558                .apply_i32(|x, y| Some(x - y), &b.domain_of(syms)?),
559
560            Expression::FlatMinusEq(_, _, _) => Some(Domain::BoolDomain),
561            Expression::FlatProductEq(_, _, _, _) => Some(Domain::BoolDomain),
562            Expression::FlatWeightedSumLeq(_, _, _, _) => Some(Domain::BoolDomain),
563            Expression::FlatWeightedSumGeq(_, _, _, _) => Some(Domain::BoolDomain),
564            Expression::Abs(_, a) => a
565                .domain_of(syms)?
566                .apply_i32(|a, _| Some(a.abs()), &a.domain_of(syms)?),
567            Expression::MinionPow(_, _, _, _) => Some(Domain::BoolDomain),
568        };
569        match ret {
570            // TODO: (flm8) the Minion bindings currently only support single ranges for domains, so we use the min/max bounds
571            // Once they support a full domain as we define it, we can remove this conversion
572            Some(Domain::IntDomain(ranges)) if ranges.len() > 1 => {
573                let (min, max) = range_vec_bounds_i32(&ranges)?;
574                Some(Domain::IntDomain(vec![Range::Bounded(min, max)]))
575            }
576            _ => ret,
577        }
578    }
579
580    pub fn get_meta(&self) -> Metadata {
581        let metas: VecDeque<Metadata> = self.children_bi();
582        metas[0].clone()
583    }
584
585    pub fn set_meta(&self, meta: Metadata) {
586        self.transform_bi(Arc::new(move |_| meta.clone()));
587    }
588
589    /// Checks whether this expression is safe.
590    ///
591    /// An expression is unsafe if can be undefined, or if any of its children can be undefined.
592    ///
593    /// Unsafe expressions are (typically) prefixed with Unsafe in our AST, and can be made
594    /// safe through the use of bubble rules.
595    pub fn is_safe(&self) -> bool {
596        // TODO: memoise in Metadata
597        for expr in self.universe() {
598            match expr {
599                Expression::UnsafeDiv(_, _, _)
600                | Expression::UnsafeMod(_, _, _)
601                | Expression::UnsafePow(_, _, _)
602                | Expression::UnsafeIndex(_, _, _)
603                | Expression::UnsafeSlice(_, _, _) => {
604                    return false;
605                }
606                _ => {}
607            }
608        }
609        true
610    }
611
612    pub fn return_type(&self) -> Option<ReturnType> {
613        match self {
614            Expression::AbstractLiteral(_, _) => None,
615            Expression::UnsafeIndex(_, subject, _) | Expression::SafeIndex(_, subject, _) => {
616                Some(subject.return_type()?)
617            }
618            Expression::UnsafeSlice(_, subject, _) | Expression::SafeSlice(_, subject, _) => {
619                Some(ReturnType::Matrix(Box::new(subject.return_type()?)))
620            }
621            Expression::InDomain(_, _, _) => Some(ReturnType::Bool),
622            Expression::Root(_, _) => Some(ReturnType::Bool),
623            Expression::DominanceRelation(_, _) => Some(ReturnType::Bool),
624            Expression::FromSolution(_, expr) => expr.return_type(),
625            Expression::Atomic(_, Atom::Literal(Literal::Int(_))) => Some(ReturnType::Int),
626            Expression::Atomic(_, Atom::Literal(Literal::Bool(_))) => Some(ReturnType::Bool),
627            Expression::Atomic(_, Atom::Literal(Literal::AbstractLiteral(_))) => None,
628            Expression::Atomic(_, Atom::Reference(_)) => None,
629            Expression::Scope(_, scope) => scope.return_type(),
630            Expression::Abs(_, _) => Some(ReturnType::Int),
631            Expression::Sum(_, _) => Some(ReturnType::Int),
632            Expression::Product(_, _) => Some(ReturnType::Int),
633            Expression::Min(_, _) => Some(ReturnType::Int),
634            Expression::Max(_, _) => Some(ReturnType::Int),
635            Expression::Not(_, _) => Some(ReturnType::Bool),
636            Expression::Or(_, _) => Some(ReturnType::Bool),
637            Expression::Imply(_, _, _) => Some(ReturnType::Bool),
638            Expression::And(_, _) => Some(ReturnType::Bool),
639            Expression::Eq(_, _, _) => Some(ReturnType::Bool),
640            Expression::Neq(_, _, _) => Some(ReturnType::Bool),
641            Expression::Geq(_, _, _) => Some(ReturnType::Bool),
642            Expression::Leq(_, _, _) => Some(ReturnType::Bool),
643            Expression::Gt(_, _, _) => Some(ReturnType::Bool),
644            Expression::Lt(_, _, _) => Some(ReturnType::Bool),
645            Expression::SafeDiv(_, _, _) => Some(ReturnType::Int),
646            Expression::UnsafeDiv(_, _, _) => Some(ReturnType::Int),
647            Expression::FlatSumGeq(_, _, _) => Some(ReturnType::Bool),
648            Expression::FlatSumLeq(_, _, _) => Some(ReturnType::Bool),
649            Expression::MinionDivEqUndefZero(_, _, _, _) => Some(ReturnType::Bool),
650            Expression::FlatIneq(_, _, _, _) => Some(ReturnType::Bool),
651            Expression::AllDiff(_, _) => Some(ReturnType::Bool),
652            Expression::Bubble(_, _, _) => None, // TODO: (flm8) should this be a bool?
653            Expression::FlatWatchedLiteral(_, _, _) => Some(ReturnType::Bool),
654            Expression::MinionReify(_, _, _) => Some(ReturnType::Bool),
655            Expression::MinionReifyImply(_, _, _) => Some(ReturnType::Bool),
656            Expression::AuxDeclaration(_, _, _) => Some(ReturnType::Bool),
657            Expression::UnsafeMod(_, _, _) => Some(ReturnType::Int),
658            Expression::SafeMod(_, _, _) => Some(ReturnType::Int),
659            Expression::MinionModuloEqUndefZero(_, _, _, _) => Some(ReturnType::Bool),
660            Expression::Neg(_, _) => Some(ReturnType::Int),
661            Expression::UnsafePow(_, _, _) => Some(ReturnType::Int),
662            Expression::SafePow(_, _, _) => Some(ReturnType::Int),
663            Expression::Minus(_, _, _) => Some(ReturnType::Int),
664            Expression::FlatAbsEq(_, _, _) => Some(ReturnType::Bool),
665            Expression::FlatMinusEq(_, _, _) => Some(ReturnType::Bool),
666            Expression::FlatProductEq(_, _, _, _) => Some(ReturnType::Bool),
667            Expression::FlatWeightedSumLeq(_, _, _, _) => Some(ReturnType::Bool),
668            Expression::FlatWeightedSumGeq(_, _, _, _) => Some(ReturnType::Bool),
669            Expression::MinionPow(_, _, _, _) => Some(ReturnType::Bool),
670        }
671    }
672
673    pub fn is_clean(&self) -> bool {
674        let metadata = self.get_meta();
675        metadata.clean
676    }
677
678    pub fn set_clean(&mut self, bool_value: bool) {
679        let mut metadata = self.get_meta();
680        metadata.clean = bool_value;
681        self.set_meta(metadata);
682    }
683
684    /// True if the expression is an associative and commutative operator
685    pub fn is_associative_commutative_operator(&self) -> bool {
686        matches!(
687            self,
688            Expression::Sum(_, _)
689                | Expression::Or(_, _)
690                | Expression::And(_, _)
691                | Expression::Product(_, _)
692        )
693    }
694
695    /// True iff self and other are both atomic and identical.
696    ///
697    /// This method is useful to cheaply check equivalence. Assuming CSE is enabled, any unifiable
698    /// expressions will be rewritten to a common variable. This is much cheaper than checking the
699    /// entire subtrees of `self` and `other`.
700    pub fn identical_atom_to(&self, other: &Expression) -> bool {
701        let atom1: Result<&Atom, _> = self.try_into();
702        let atom2: Result<&Atom, _> = other.try_into();
703
704        if let (Ok(atom1), Ok(atom2)) = (atom1, atom2) {
705            atom2 == atom1
706        } else {
707            false
708        }
709    }
710
711    /// If the expression is a list, returns the inner expressions.
712    ///
713    /// A list is any a matrix with the domain `int(1..)`. This includes matrix literals without
714    /// any explicitly specified domain.
715    pub fn unwrap_list(self) -> Option<Vec<Expression>> {
716        match self {
717            Expression::AbstractLiteral(_, matrix @ AbstractLiteral::Matrix(_, _)) => {
718                matrix.unwrap_list().cloned()
719            }
720            Expression::Atomic(
721                _,
722                Atom::Literal(Literal::AbstractLiteral(matrix @ AbstractLiteral::Matrix(_, _))),
723            ) => matrix.unwrap_list().map(|elems| {
724                elems
725                    .clone()
726                    .into_iter()
727                    .map(|x: Literal| Expression::Atomic(Metadata::new(), Atom::Literal(x)))
728                    .collect_vec()
729            }),
730            _ => None,
731        }
732    }
733
734    /// If the expression is a matrix, gets it elements and index domain.
735    ///
736    /// **Consider using the safer [`Expression::unwrap_list`] instead.**
737    ///
738    /// It is generally undefined to edit the length of a matrix unless it is a list (as defined by
739    /// [`Expression::unwrap_list`]). Users of this function should ensure that, if the matrix is
740    /// reconstructed, the index domain and the number of elements in the matrix remain the same.
741    pub fn unwrap_matrix_unchecked(self) -> Option<(Vec<Expression>, Domain)> {
742        match self {
743            Expression::AbstractLiteral(_, AbstractLiteral::Matrix(elems, domain)) => {
744                Some((elems.clone(), domain))
745            }
746            Expression::Atomic(
747                _,
748                Atom::Literal(Literal::AbstractLiteral(AbstractLiteral::Matrix(elems, domain))),
749            ) => Some((
750                elems
751                    .clone()
752                    .into_iter()
753                    .map(|x: Literal| Expression::Atomic(Metadata::new(), Atom::Literal(x)))
754                    .collect_vec(),
755                domain,
756            )),
757
758            _ => None,
759        }
760    }
761}
762
763impl From<i32> for Expression {
764    fn from(i: i32) -> Self {
765        Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Int(i)))
766    }
767}
768
769impl From<bool> for Expression {
770    fn from(b: bool) -> Self {
771        Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Bool(b)))
772    }
773}
774
775impl From<Atom> for Expression {
776    fn from(value: Atom) -> Self {
777        Expression::Atomic(Metadata::new(), value)
778    }
779}
780impl Display for Expression {
781    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
782        match &self {
783            Expression::AbstractLiteral(_, l) => l.fmt(f),
784            Expression::UnsafeIndex(_, e1, e2) | Expression::SafeIndex(_, e1, e2) => {
785                write!(f, "{e1}{}", pretty_vec(e2))
786            }
787            Expression::UnsafeSlice(_, e1, es) | Expression::SafeSlice(_, e1, es) => {
788                let args = es
789                    .iter()
790                    .map(|x| match x {
791                        Some(x) => format!("{}", x),
792                        None => "..".into(),
793                    })
794                    .join(",");
795
796                write!(f, "{e1}[{args}]")
797            }
798
799            Expression::InDomain(_, e, domain) => {
800                write!(f, "__inDomain({e},{domain})")
801            }
802            Expression::Root(_, exprs) => {
803                write!(f, "{}", pretty_expressions_as_top_level(exprs))
804            }
805            Expression::DominanceRelation(_, expr) => write!(f, "DominanceRelation({})", expr),
806            Expression::FromSolution(_, expr) => write!(f, "FromSolution({})", expr),
807            Expression::Atomic(_, atom) => atom.fmt(f),
808            Expression::Scope(_, submodel) => write!(f, "{{\n{submodel}\n}}"),
809            Expression::Abs(_, a) => write!(f, "|{}|", a),
810            Expression::Sum(_, expressions) => {
811                write!(f, "Sum({})", pretty_vec(expressions))
812            }
813            Expression::Product(_, expressions) => {
814                write!(f, "Product({})", pretty_vec(expressions))
815            }
816            Expression::Min(_, e) => {
817                write!(f, "min({e})")
818            }
819            Expression::Max(_, e) => {
820                write!(f, "max({e})")
821            }
822            Expression::Not(_, expr_box) => {
823                write!(f, "Not({})", expr_box.clone())
824            }
825            Expression::Or(_, e) => {
826                write!(f, "or({e})")
827            }
828            Expression::And(_, e) => {
829                write!(f, "and({e})")
830            }
831            Expression::Imply(_, box1, box2) => {
832                write!(f, "({}) -> ({})", box1, box2)
833            }
834            Expression::Eq(_, box1, box2) => {
835                write!(f, "({} = {})", box1.clone(), box2.clone())
836            }
837            Expression::Neq(_, box1, box2) => {
838                write!(f, "({} != {})", box1.clone(), box2.clone())
839            }
840            Expression::Geq(_, box1, box2) => {
841                write!(f, "({} >= {})", box1.clone(), box2.clone())
842            }
843            Expression::Leq(_, box1, box2) => {
844                write!(f, "({} <= {})", box1.clone(), box2.clone())
845            }
846            Expression::Gt(_, box1, box2) => {
847                write!(f, "({} > {})", box1.clone(), box2.clone())
848            }
849            Expression::Lt(_, box1, box2) => {
850                write!(f, "({} < {})", box1.clone(), box2.clone())
851            }
852            Expression::FlatSumGeq(_, box1, box2) => {
853                write!(f, "SumGeq({}, {})", pretty_vec(box1), box2.clone())
854            }
855            Expression::FlatSumLeq(_, box1, box2) => {
856                write!(f, "SumLeq({}, {})", pretty_vec(box1), box2.clone())
857            }
858            Expression::FlatIneq(_, box1, box2, box3) => write!(
859                f,
860                "Ineq({}, {}, {})",
861                box1.clone(),
862                box2.clone(),
863                box3.clone()
864            ),
865            Expression::AllDiff(_, e) => {
866                write!(f, "allDiff({e})")
867            }
868            Expression::Bubble(_, box1, box2) => {
869                write!(f, "{{{} @ {}}}", box1.clone(), box2.clone())
870            }
871            Expression::SafeDiv(_, box1, box2) => {
872                write!(f, "SafeDiv({}, {})", box1.clone(), box2.clone())
873            }
874            Expression::UnsafeDiv(_, box1, box2) => {
875                write!(f, "UnsafeDiv({}, {})", box1.clone(), box2.clone())
876            }
877            Expression::UnsafePow(_, box1, box2) => {
878                write!(f, "UnsafePow({}, {})", box1.clone(), box2.clone())
879            }
880            Expression::SafePow(_, box1, box2) => {
881                write!(f, "SafePow({}, {})", box1.clone(), box2.clone())
882            }
883            Expression::MinionDivEqUndefZero(_, box1, box2, box3) => {
884                write!(
885                    f,
886                    "DivEq({}, {}, {})",
887                    box1.clone(),
888                    box2.clone(),
889                    box3.clone()
890                )
891            }
892            Expression::MinionModuloEqUndefZero(_, box1, box2, box3) => {
893                write!(
894                    f,
895                    "ModEq({}, {}, {})",
896                    box1.clone(),
897                    box2.clone(),
898                    box3.clone()
899                )
900            }
901
902            Expression::FlatWatchedLiteral(_, x, l) => {
903                write!(f, "WatchedLiteral({},{})", x, l)
904            }
905            Expression::MinionReify(_, box1, box2) => {
906                write!(f, "Reify({}, {})", box1.clone(), box2.clone())
907            }
908            Expression::MinionReifyImply(_, box1, box2) => {
909                write!(f, "ReifyImply({}, {})", box1.clone(), box2.clone())
910            }
911            Expression::AuxDeclaration(_, n, e) => {
912                write!(f, "{} =aux {}", n, e.clone())
913            }
914            Expression::UnsafeMod(_, a, b) => {
915                write!(f, "{} % {}", a.clone(), b.clone())
916            }
917            Expression::SafeMod(_, a, b) => {
918                write!(f, "SafeMod({},{})", a.clone(), b.clone())
919            }
920            Expression::Neg(_, a) => {
921                write!(f, "-({})", a.clone())
922            }
923            Expression::Minus(_, a, b) => {
924                write!(f, "({} - {})", a.clone(), b.clone())
925            }
926            Expression::FlatAbsEq(_, a, b) => {
927                write!(f, "AbsEq({},{})", a.clone(), b.clone())
928            }
929            Expression::FlatMinusEq(_, a, b) => {
930                write!(f, "MinusEq({},{})", a.clone(), b.clone())
931            }
932            Expression::FlatProductEq(_, a, b, c) => {
933                write!(
934                    f,
935                    "FlatProductEq({},{},{})",
936                    a.clone(),
937                    b.clone(),
938                    c.clone()
939                )
940            }
941            Expression::FlatWeightedSumLeq(_, cs, vs, total) => {
942                write!(
943                    f,
944                    "FlatWeightedSumLeq({},{},{})",
945                    pretty_vec(cs),
946                    pretty_vec(vs),
947                    total.clone()
948                )
949            }
950
951            Expression::FlatWeightedSumGeq(_, cs, vs, total) => {
952                write!(
953                    f,
954                    "FlatWeightedSumGeq({},{},{})",
955                    pretty_vec(cs),
956                    pretty_vec(vs),
957                    total.clone()
958                )
959            }
960            Expression::MinionPow(_, atom, atom1, atom2) => {
961                write!(f, "MinionPow({},{},{})", atom, atom1, atom2)
962            }
963        }
964    }
965}
966
967#[cfg(test)]
968mod tests {
969    use std::rc::Rc;
970
971    use crate::ast::declaration::Declaration;
972
973    use super::*;
974
975    #[test]
976    fn test_domain_of_constant_sum() {
977        let c1 = Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Int(1)));
978        let c2 = Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Int(2)));
979        let sum = Expression::Sum(Metadata::new(), vec![c1.clone(), c2.clone()]);
980        assert_eq!(
981            sum.domain_of(&SymbolTable::new()),
982            Some(Domain::IntDomain(vec![Range::Single(3)]))
983        );
984    }
985
986    #[test]
987    fn test_domain_of_constant_invalid_type() {
988        let c1 = Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Int(1)));
989        let c2 = Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Bool(true)));
990        let sum = Expression::Sum(Metadata::new(), vec![c1.clone(), c2.clone()]);
991        assert_eq!(sum.domain_of(&SymbolTable::new()), None);
992    }
993
994    #[test]
995    fn test_domain_of_empty_sum() {
996        let sum = Expression::Sum(Metadata::new(), vec![]);
997        assert_eq!(sum.domain_of(&SymbolTable::new()), None);
998    }
999
1000    #[test]
1001    fn test_domain_of_reference() {
1002        let reference = Expression::Atomic(Metadata::new(), Atom::Reference(Name::MachineName(0)));
1003        let mut vars = SymbolTable::new();
1004        vars.insert(Rc::new(Declaration::new_var(
1005            Name::MachineName(0),
1006            Domain::IntDomain(vec![Range::Single(1)]),
1007        )))
1008        .unwrap();
1009        assert_eq!(
1010            reference.domain_of(&vars),
1011            Some(Domain::IntDomain(vec![Range::Single(1)]))
1012        );
1013    }
1014
1015    #[test]
1016    fn test_domain_of_reference_not_found() {
1017        let reference = Expression::Atomic(Metadata::new(), Atom::Reference(Name::MachineName(0)));
1018        assert_eq!(reference.domain_of(&SymbolTable::new()), None);
1019    }
1020
1021    #[test]
1022    fn test_domain_of_reference_sum_single() {
1023        let reference = Expression::Atomic(Metadata::new(), Atom::Reference(Name::MachineName(0)));
1024        let mut vars = SymbolTable::new();
1025        vars.insert(Rc::new(Declaration::new_var(
1026            Name::MachineName(0),
1027            Domain::IntDomain(vec![Range::Single(1)]),
1028        )))
1029        .unwrap();
1030        let sum = Expression::Sum(Metadata::new(), vec![reference.clone(), reference.clone()]);
1031        assert_eq!(
1032            sum.domain_of(&vars),
1033            Some(Domain::IntDomain(vec![Range::Single(2)]))
1034        );
1035    }
1036
1037    #[test]
1038    fn test_domain_of_reference_sum_bounded() {
1039        let reference = Expression::Atomic(Metadata::new(), Atom::Reference(Name::MachineName(0)));
1040        let mut vars = SymbolTable::new();
1041        vars.insert(Rc::new(Declaration::new_var(
1042            Name::MachineName(0),
1043            Domain::IntDomain(vec![Range::Bounded(1, 2)]),
1044        )));
1045        let sum = Expression::Sum(Metadata::new(), vec![reference.clone(), reference.clone()]);
1046        assert_eq!(
1047            sum.domain_of(&vars),
1048            Some(Domain::IntDomain(vec![Range::Bounded(2, 4)]))
1049        );
1050    }
1051}