1use std::collections::{HashSet, VecDeque};
2use std::fmt::{Display, Formatter};
3use tracing::trace;
4
5use crate::ast::ReturnType;
6use crate::ast::SetAttr;
7use crate::ast::SymbolTable;
8use crate::ast::literals::AbstractLiteral;
9use crate::ast::literals::Literal;
10use crate::ast::pretty::{pretty_expressions_as_top_level, pretty_vec};
11use crate::ast::sat_encoding::SATIntEncoding;
12use crate::ast::{Atom, DomainPtr};
13use crate::ast::{GroundDomain, Metadata, UnresolvedDomain};
14use crate::ast::{IntVal, Moo};
15use crate::ast::{Name, matrix};
16use crate::bug;
17use conjure_cp_enum_compatibility_macro::document_compatibility;
18use itertools::Itertools;
19use serde::{Deserialize, Serialize};
20use ustr::Ustr;
21
22use polyquine::Quine;
23use uniplate::{Biplate, Uniplate};
24
25use super::abstract_comprehension::AbstractComprehension;
26use super::ac_operators::ACOperatorKind;
27use super::categories::{Category, CategoryOf};
28use super::comprehension::Comprehension;
29use super::domains::HasDomain as _;
30use super::records::RecordValue;
31use super::{DeclarationPtr, Domain, Range, Reference, SubModel, Typeable};
32
33static_assertions::assert_eq_size!([u8; 104], Expression);
56
57#[document_compatibility]
62#[derive(Clone, Debug, Hash, PartialEq, Eq, Serialize, Deserialize, Uniplate, Quine)]
63#[biplate(to=AbstractComprehension)]
64#[biplate(to=AbstractLiteral<Expression>)]
65#[biplate(to=AbstractLiteral<Literal>)]
66#[biplate(to=Atom)]
67#[biplate(to=Comprehension)]
68#[biplate(to=DeclarationPtr)]
69#[biplate(to=DomainPtr)]
70#[biplate(to=Literal)]
71#[biplate(to=Metadata)]
72#[biplate(to=Name)]
73#[biplate(to=Option<Expression>)]
74#[biplate(to=RecordValue<Expression>)]
75#[biplate(to=RecordValue<Literal>)]
76#[biplate(to=Reference)]
77#[biplate(to=SubModel)]
78#[biplate(to=SymbolTable)]
79#[biplate(to=Vec<Expression>)]
80#[path_prefix(conjure_cp::ast)]
81pub enum Expression {
82 AbstractLiteral(Metadata, AbstractLiteral<Expression>),
83 Root(Metadata, Vec<Expression>),
85
86 Bubble(Metadata, Moo<Expression>, Moo<Expression>),
89
90 #[polyquine_skip]
96 Comprehension(Metadata, Moo<Comprehension>),
97
98 #[polyquine_skip] AbstractComprehension(Metadata, Moo<AbstractComprehension>),
101
102 DominanceRelation(Metadata, Moo<Expression>),
104 FromSolution(Metadata, Moo<Atom>),
106
107 #[polyquine_with(arm = (_, name) => {
108 let ident = proc_macro2::Ident::new(name.as_str(), proc_macro2::Span::call_site());
109 quote::quote! { #ident.clone().into() }
110 })]
111 Metavar(Metadata, Ustr),
112
113 Atomic(Metadata, Atom),
114
115 #[compatible(JsonInput)]
119 UnsafeIndex(Metadata, Moo<Expression>, Vec<Expression>),
120
121 #[compatible(SMT)]
125 SafeIndex(Metadata, Moo<Expression>, Vec<Expression>),
126
127 #[compatible(JsonInput)]
137 UnsafeSlice(Metadata, Moo<Expression>, Vec<Option<Expression>>),
138
139 SafeSlice(Metadata, Moo<Expression>, Vec<Option<Expression>>),
143
144 InDomain(Metadata, Moo<Expression>, DomainPtr),
150
151 #[compatible(SMT)]
157 ToInt(Metadata, Moo<Expression>),
158
159 #[polyquine_skip]
161 Scope(Metadata, Moo<SubModel>),
162
163 #[compatible(JsonInput, SMT)]
165 Abs(Metadata, Moo<Expression>),
166
167 #[compatible(JsonInput, SMT)]
169 Sum(Metadata, Moo<Expression>),
170
171 #[compatible(JsonInput, SMT)]
173 Product(Metadata, Moo<Expression>),
174
175 #[compatible(JsonInput, SMT)]
177 Min(Metadata, Moo<Expression>),
178
179 #[compatible(JsonInput, SMT)]
181 Max(Metadata, Moo<Expression>),
182
183 #[compatible(JsonInput, SAT, SMT)]
185 Not(Metadata, Moo<Expression>),
186
187 #[compatible(JsonInput, SAT, SMT)]
189 Or(Metadata, Moo<Expression>),
190
191 #[compatible(JsonInput, SAT, SMT)]
193 And(Metadata, Moo<Expression>),
194
195 #[compatible(JsonInput, SMT)]
197 Imply(Metadata, Moo<Expression>, Moo<Expression>),
198
199 #[compatible(JsonInput, SMT)]
201 Iff(Metadata, Moo<Expression>, Moo<Expression>),
202
203 #[compatible(JsonInput)]
204 Union(Metadata, Moo<Expression>, Moo<Expression>),
205
206 #[compatible(JsonInput)]
207 In(Metadata, Moo<Expression>, Moo<Expression>),
208
209 #[compatible(JsonInput)]
210 Intersect(Metadata, Moo<Expression>, Moo<Expression>),
211
212 #[compatible(JsonInput)]
213 Supset(Metadata, Moo<Expression>, Moo<Expression>),
214
215 #[compatible(JsonInput)]
216 SupsetEq(Metadata, Moo<Expression>, Moo<Expression>),
217
218 #[compatible(JsonInput)]
219 Subset(Metadata, Moo<Expression>, Moo<Expression>),
220
221 #[compatible(JsonInput)]
222 SubsetEq(Metadata, Moo<Expression>, Moo<Expression>),
223
224 #[compatible(JsonInput, SMT)]
225 Eq(Metadata, Moo<Expression>, Moo<Expression>),
226
227 #[compatible(JsonInput, SMT)]
228 Neq(Metadata, Moo<Expression>, Moo<Expression>),
229
230 #[compatible(JsonInput, SMT)]
231 Geq(Metadata, Moo<Expression>, Moo<Expression>),
232
233 #[compatible(JsonInput, SMT)]
234 Leq(Metadata, Moo<Expression>, Moo<Expression>),
235
236 #[compatible(JsonInput, SMT)]
237 Gt(Metadata, Moo<Expression>, Moo<Expression>),
238
239 #[compatible(JsonInput, SMT)]
240 Lt(Metadata, Moo<Expression>, Moo<Expression>),
241
242 #[compatible(SMT)]
244 SafeDiv(Metadata, Moo<Expression>, Moo<Expression>),
245
246 #[compatible(JsonInput)]
248 UnsafeDiv(Metadata, Moo<Expression>, Moo<Expression>),
249
250 #[compatible(SMT)]
252 SafeMod(Metadata, Moo<Expression>, Moo<Expression>),
253
254 #[compatible(JsonInput)]
256 UnsafeMod(Metadata, Moo<Expression>, Moo<Expression>),
257
258 #[compatible(JsonInput, SMT)]
260 Neg(Metadata, Moo<Expression>),
261
262 #[compatible(JsonInput)]
264 Defined(Metadata, Moo<Expression>),
265
266 #[compatible(JsonInput)]
268 Range(Metadata, Moo<Expression>),
269
270 #[compatible(JsonInput)]
274 UnsafePow(Metadata, Moo<Expression>, Moo<Expression>),
275
276 SafePow(Metadata, Moo<Expression>, Moo<Expression>),
278
279 Flatten(Metadata, Option<Moo<Expression>>, Moo<Expression>),
283
284 #[compatible(JsonInput)]
286 AllDiff(Metadata, Moo<Expression>),
287
288 #[compatible(JsonInput)]
294 Minus(Metadata, Moo<Expression>, Moo<Expression>),
295
296 #[compatible(Minion)]
304 FlatAbsEq(Metadata, Moo<Atom>, Moo<Atom>),
305
306 #[compatible(Minion)]
314 FlatAllDiff(Metadata, Vec<Atom>),
315
316 #[compatible(Minion)]
324 FlatSumGeq(Metadata, Vec<Atom>, Atom),
325
326 #[compatible(Minion)]
334 FlatSumLeq(Metadata, Vec<Atom>, Atom),
335
336 #[compatible(Minion)]
344 FlatIneq(Metadata, Moo<Atom>, Moo<Atom>, Box<Literal>),
345
346 #[compatible(Minion)]
359 #[polyquine_skip]
360 FlatWatchedLiteral(Metadata, Reference, Literal),
361
362 FlatWeightedSumLeq(Metadata, Vec<Literal>, Vec<Atom>, Moo<Atom>),
374
375 FlatWeightedSumGeq(Metadata, Vec<Literal>, Vec<Atom>, Moo<Atom>),
387
388 #[compatible(Minion)]
396 FlatMinusEq(Metadata, Moo<Atom>, Moo<Atom>),
397
398 #[compatible(Minion)]
406 FlatProductEq(Metadata, Moo<Atom>, Moo<Atom>, Moo<Atom>),
407
408 #[compatible(Minion)]
416 MinionDivEqUndefZero(Metadata, Moo<Atom>, Moo<Atom>, Moo<Atom>),
417
418 #[compatible(Minion)]
426 MinionModuloEqUndefZero(Metadata, Moo<Atom>, Moo<Atom>, Moo<Atom>),
427
428 MinionPow(Metadata, Moo<Atom>, Moo<Atom>, Moo<Atom>),
440
441 #[compatible(Minion)]
450 MinionReify(Metadata, Moo<Expression>, Atom),
451
452 #[compatible(Minion)]
461 MinionReifyImply(Metadata, Moo<Expression>, Atom),
462
463 #[compatible(Minion)]
474 MinionWInIntervalSet(Metadata, Atom, Vec<i32>),
475
476 #[compatible(Minion)]
488 MinionWInSet(Metadata, Atom, Vec<i32>),
489
490 #[compatible(Minion)]
499 MinionElementOne(Metadata, Vec<Atom>, Moo<Atom>, Moo<Atom>),
500
501 #[compatible(Minion)]
505 #[polyquine_skip]
506 AuxDeclaration(Metadata, Reference, Moo<Expression>),
507
508 #[compatible(SAT)]
510 SATInt(Metadata, SATIntEncoding, Moo<Expression>, (i32, i32)),
511
512 #[compatible(SMT)]
515 PairwiseSum(Metadata, Moo<Expression>, Moo<Expression>),
516
517 #[compatible(SMT)]
520 PairwiseProduct(Metadata, Moo<Expression>, Moo<Expression>),
521
522 #[compatible(JsonInput)]
523 Image(Metadata, Moo<Expression>, Moo<Expression>),
524
525 #[compatible(JsonInput)]
526 ImageSet(Metadata, Moo<Expression>, Moo<Expression>),
527
528 #[compatible(JsonInput)]
529 PreImage(Metadata, Moo<Expression>, Moo<Expression>),
530
531 #[compatible(JsonInput)]
532 Inverse(Metadata, Moo<Expression>, Moo<Expression>),
533
534 #[compatible(JsonInput)]
535 Restrict(Metadata, Moo<Expression>, Moo<Expression>),
536
537 LexLt(Metadata, Moo<Expression>, Moo<Expression>),
546
547 LexLeq(Metadata, Moo<Expression>, Moo<Expression>),
549
550 LexGt(Metadata, Moo<Expression>, Moo<Expression>),
553
554 LexGeq(Metadata, Moo<Expression>, Moo<Expression>),
557
558 FlatLexLt(Metadata, Vec<Atom>, Vec<Atom>),
560
561 FlatLexLeq(Metadata, Vec<Atom>, Vec<Atom>),
563}
564
565fn bounded_i32_domain_for_matrix_literal_monotonic(
572 e: &Expression,
573 op: fn(i32, i32) -> Option<i32>,
574) -> Option<DomainPtr> {
575 let (mut exprs, _) = e.clone().unwrap_matrix_unchecked()?;
577
578 let expr = exprs.pop()?;
594 let dom = expr.domain_of()?;
595 let Some(GroundDomain::Int(ranges)) = dom.as_ground() else {
596 return None;
597 };
598
599 let (mut current_min, mut current_max) = range_vec_bounds_i32(ranges)?;
600
601 for expr in exprs {
602 let dom = expr.domain_of()?;
603 let Some(GroundDomain::Int(ranges)) = dom.as_ground() else {
604 return None;
605 };
606
607 let (min, max) = range_vec_bounds_i32(ranges)?;
608
609 let minmax = op(min, current_max)?;
611 let minmin = op(min, current_min)?;
612 let maxmin = op(max, current_min)?;
613 let maxmax = op(max, current_max)?;
614 let vals = [minmax, minmin, maxmin, maxmax];
615
616 current_min = *vals
617 .iter()
618 .min()
619 .expect("vals iterator should not be empty, and should have a minimum.");
620 current_max = *vals
621 .iter()
622 .max()
623 .expect("vals iterator should not be empty, and should have a maximum.");
624 }
625
626 if current_min == current_max {
627 Some(Domain::int(vec![Range::Single(current_min)]))
628 } else {
629 Some(Domain::int(vec![Range::Bounded(current_min, current_max)]))
630 }
631}
632
633fn range_vec_bounds_i32(ranges: &Vec<Range<i32>>) -> Option<(i32, i32)> {
635 let mut min = i32::MAX;
636 let mut max = i32::MIN;
637 for r in ranges {
638 match r {
639 Range::Single(i) => {
640 if *i < min {
641 min = *i;
642 }
643 if *i > max {
644 max = *i;
645 }
646 }
647 Range::Bounded(i, j) => {
648 if *i < min {
649 min = *i;
650 }
651 if *j > max {
652 max = *j;
653 }
654 }
655 Range::UnboundedR(_) | Range::UnboundedL(_) | Range::Unbounded => return None,
656 }
657 }
658 Some((min, max))
659}
660
661impl Expression {
662 pub fn domain_of(&self) -> Option<DomainPtr> {
664 let ret = match self {
666 Expression::Union(_, a, b) => Some(Domain::set(
667 SetAttr::<IntVal>::default(),
668 a.domain_of()?.union(&b.domain_of()?).ok()?,
669 )),
670 Expression::Intersect(_, a, b) => Some(Domain::set(
671 SetAttr::<IntVal>::default(),
672 a.domain_of()?.intersect(&b.domain_of()?).ok()?,
673 )),
674 Expression::In(_, _, _) => Some(Domain::bool()),
675 Expression::Supset(_, _, _) => Some(Domain::bool()),
676 Expression::SupsetEq(_, _, _) => Some(Domain::bool()),
677 Expression::Subset(_, _, _) => Some(Domain::bool()),
678 Expression::SubsetEq(_, _, _) => Some(Domain::bool()),
679 Expression::AbstractLiteral(_, abslit) => abslit.domain_of(),
680 Expression::DominanceRelation(_, _) => Some(Domain::bool()),
681 Expression::FromSolution(_, expr) => Some(expr.domain_of()),
682 Expression::Metavar(_, _) => None,
683 Expression::Comprehension(_, comprehension) => comprehension.domain_of(),
684 Expression::AbstractComprehension(_, comprehension) => comprehension.domain_of(),
685 Expression::UnsafeIndex(_, matrix, _) | Expression::SafeIndex(_, matrix, _) => {
686 let dom = matrix.domain_of()?;
687 if let Some((elem_domain, _)) = dom.as_matrix() {
688 return Some(elem_domain);
689 }
690
691 #[allow(clippy::redundant_pattern_matching)]
693 if let Some(_) = dom.as_tuple() {
694 return None;
696 }
697
698 #[allow(clippy::redundant_pattern_matching)]
700 if let Some(_) = dom.as_record() {
701 return None;
703 }
704
705 bug!("subject of an index operation should support indexing")
706 }
707 Expression::UnsafeSlice(_, matrix, indices)
708 | Expression::SafeSlice(_, matrix, indices) => {
709 let sliced_dimension = indices.iter().position(Option::is_none);
710
711 let dom = matrix.domain_of()?;
712 let Some((elem_domain, index_domains)) = dom.as_matrix() else {
713 bug!("subject of an index operation should be a matrix");
714 };
715
716 match sliced_dimension {
717 Some(dimension) => Some(Domain::matrix(
718 elem_domain,
719 vec![index_domains[dimension].clone()],
720 )),
721
722 None => Some(elem_domain),
724 }
725 }
726 Expression::InDomain(_, _, _) => Some(Domain::bool()),
727 Expression::Atomic(_, atom) => Some(atom.domain_of()),
728 Expression::Scope(_, _) => Some(Domain::bool()),
729 Expression::Sum(_, e) => {
730 bounded_i32_domain_for_matrix_literal_monotonic(e, |x, y| Some(x + y))
731 }
732 Expression::Product(_, e) => {
733 bounded_i32_domain_for_matrix_literal_monotonic(e, |x, y| Some(x * y))
734 }
735 Expression::Min(_, e) => bounded_i32_domain_for_matrix_literal_monotonic(e, |x, y| {
736 Some(if x < y { x } else { y })
737 }),
738 Expression::Max(_, e) => bounded_i32_domain_for_matrix_literal_monotonic(e, |x, y| {
739 Some(if x > y { x } else { y })
740 }),
741 Expression::UnsafeDiv(_, a, b) => a
742 .domain_of()?
743 .resolve()?
744 .apply_i32(
745 |x, y| {
748 if y != 0 {
749 Some((x as f32 / y as f32).floor() as i32)
750 } else {
751 None
752 }
753 },
754 b.domain_of()?.resolve()?.as_ref(),
755 )
756 .map(DomainPtr::from)
757 .ok(),
758 Expression::SafeDiv(_, a, b) => {
759 let domain = a
762 .domain_of()?
763 .resolve()?
764 .apply_i32(
765 |x, y| {
766 if y != 0 {
767 Some((x as f32 / y as f32).floor() as i32)
768 } else {
769 None
770 }
771 },
772 b.domain_of()?.resolve()?.as_ref(),
773 )
774 .unwrap_or_else(|err| bug!("Got {err} when computing domain of {self}"));
775
776 if let GroundDomain::Int(ranges) = domain {
777 let mut ranges = ranges;
778 ranges.push(Range::Single(0));
779 return Some(Domain::int(ranges));
780 } else {
781 bug!("Domain of {self} was not integer")
782 }
783 }
784 Expression::UnsafeMod(_, a, b) => a
785 .domain_of()?
786 .resolve()?
787 .apply_i32(
788 |x, y| if y != 0 { Some(x % y) } else { None },
789 b.domain_of()?.resolve()?.as_ref(),
790 )
791 .map(DomainPtr::from)
792 .ok(),
793 Expression::SafeMod(_, a, b) => {
794 let domain = a
795 .domain_of()?
796 .resolve()?
797 .apply_i32(
798 |x, y| if y != 0 { Some(x % y) } else { None },
799 b.domain_of()?.resolve()?.as_ref(),
800 )
801 .unwrap_or_else(|err| bug!("Got {err} when computing domain of {self}"));
802
803 if let GroundDomain::Int(ranges) = domain {
804 let mut ranges = ranges;
805 ranges.push(Range::Single(0));
806 return Some(Domain::int(ranges));
807 } else {
808 bug!("Domain of {self} was not integer")
809 }
810 }
811 Expression::SafePow(_, a, b) | Expression::UnsafePow(_, a, b) => a
812 .domain_of()?
813 .resolve()?
814 .apply_i32(
815 |x, y| {
816 if (x != 0 || y != 0) && y >= 0 {
817 Some(x.pow(y as u32))
818 } else {
819 None
820 }
821 },
822 b.domain_of()?.resolve()?.as_ref(),
823 )
824 .map(DomainPtr::from)
825 .ok(),
826 Expression::Root(_, _) => None,
827 Expression::Bubble(_, inner, _) => inner.domain_of(),
828 Expression::AuxDeclaration(_, _, _) => Some(Domain::bool()),
829 Expression::And(_, _) => Some(Domain::bool()),
830 Expression::Not(_, _) => Some(Domain::bool()),
831 Expression::Or(_, _) => Some(Domain::bool()),
832 Expression::Imply(_, _, _) => Some(Domain::bool()),
833 Expression::Iff(_, _, _) => Some(Domain::bool()),
834 Expression::Eq(_, _, _) => Some(Domain::bool()),
835 Expression::Neq(_, _, _) => Some(Domain::bool()),
836 Expression::Geq(_, _, _) => Some(Domain::bool()),
837 Expression::Leq(_, _, _) => Some(Domain::bool()),
838 Expression::Gt(_, _, _) => Some(Domain::bool()),
839 Expression::Lt(_, _, _) => Some(Domain::bool()),
840 Expression::FlatAbsEq(_, _, _) => Some(Domain::bool()),
841 Expression::FlatSumGeq(_, _, _) => Some(Domain::bool()),
842 Expression::FlatSumLeq(_, _, _) => Some(Domain::bool()),
843 Expression::MinionDivEqUndefZero(_, _, _, _) => Some(Domain::bool()),
844 Expression::MinionModuloEqUndefZero(_, _, _, _) => Some(Domain::bool()),
845 Expression::FlatIneq(_, _, _, _) => Some(Domain::bool()),
846 Expression::Flatten(_, n, m) => {
847 if let Some(expr) = n {
848 if expr.return_type() == ReturnType::Int {
849 return None;
851 }
852 } else {
853 let dom = m.domain_of()?.resolve()?;
855 let (val_dom, idx_doms) = match dom.as_ref() {
856 GroundDomain::Matrix(val, idx) => (val, idx),
857 _ => return None,
858 };
859 let num_elems = matrix::num_elements(idx_doms).ok()? as i32;
860
861 let new_index_domain = Domain::int(vec![Range::Bounded(1, num_elems)]);
862 return Some(Domain::matrix(
863 val_dom.clone().into(),
864 vec![new_index_domain],
865 ));
866 }
867 None
868 }
869 Expression::AllDiff(_, _) => Some(Domain::bool()),
870 Expression::FlatWatchedLiteral(_, _, _) => Some(Domain::bool()),
871 Expression::MinionReify(_, _, _) => Some(Domain::bool()),
872 Expression::MinionReifyImply(_, _, _) => Some(Domain::bool()),
873 Expression::MinionWInIntervalSet(_, _, _) => Some(Domain::bool()),
874 Expression::MinionWInSet(_, _, _) => Some(Domain::bool()),
875 Expression::MinionElementOne(_, _, _, _) => Some(Domain::bool()),
876 Expression::Neg(_, x) => {
877 let dom = x.domain_of()?;
878 let mut ranges = dom.as_int()?;
879
880 ranges = ranges
881 .into_iter()
882 .map(|r| match r {
883 Range::Single(x) => Range::Single(-x),
884 Range::Bounded(x, y) => Range::Bounded(-y, -x),
885 Range::UnboundedR(i) => Range::UnboundedL(-i),
886 Range::UnboundedL(i) => Range::UnboundedR(-i),
887 Range::Unbounded => Range::Unbounded,
888 })
889 .collect();
890
891 Some(Domain::int(ranges))
892 }
893 Expression::Minus(_, a, b) => a
894 .domain_of()?
895 .resolve()?
896 .apply_i32(|x, y| Some(x - y), b.domain_of()?.resolve()?.as_ref())
897 .map(DomainPtr::from)
898 .ok(),
899 Expression::FlatAllDiff(_, _) => Some(Domain::bool()),
900 Expression::FlatMinusEq(_, _, _) => Some(Domain::bool()),
901 Expression::FlatProductEq(_, _, _, _) => Some(Domain::bool()),
902 Expression::FlatWeightedSumLeq(_, _, _, _) => Some(Domain::bool()),
903 Expression::FlatWeightedSumGeq(_, _, _, _) => Some(Domain::bool()),
904 Expression::Abs(_, a) => a
905 .domain_of()?
906 .resolve()?
907 .apply_i32(|a, _| Some(a.abs()), a.domain_of()?.resolve()?.as_ref())
908 .map(DomainPtr::from)
909 .ok(),
910 Expression::MinionPow(_, _, _, _) => Some(Domain::bool()),
911 Expression::ToInt(_, _) => Some(Domain::int(vec![Range::Bounded(0, 1)])),
912 Expression::SATInt(_, _, _, (low, high)) => {
913 Some(Domain::int_ground(vec![Range::Bounded(*low, *high)]))
914 }
915 Expression::PairwiseSum(_, a, b) => a
916 .domain_of()?
917 .resolve()?
918 .apply_i32(|a, b| Some(a + b), b.domain_of()?.resolve()?.as_ref())
919 .map(DomainPtr::from)
920 .ok(),
921 Expression::PairwiseProduct(_, a, b) => a
922 .domain_of()?
923 .resolve()?
924 .apply_i32(|a, b| Some(a * b), b.domain_of()?.resolve()?.as_ref())
925 .map(DomainPtr::from)
926 .ok(),
927 Expression::Defined(_, function) => get_function_domain(function),
928 Expression::Range(_, function) => get_function_codomain(function),
929 Expression::Image(_, function, _) => get_function_codomain(function),
930 Expression::ImageSet(_, function, _) => get_function_codomain(function),
931 Expression::PreImage(_, function, _) => get_function_domain(function),
932 Expression::Restrict(_, function, new_domain) => {
933 let (attrs, _, codom) = function.domain_of()?.as_function()?;
934 let new_dom = new_domain.domain_of()?;
935 Some(Domain::function(attrs, new_dom, codom))
936 }
937 Expression::Inverse(..) => Some(Domain::bool()),
938 Expression::LexLt(..) => Some(Domain::bool()),
939 Expression::LexLeq(..) => Some(Domain::bool()),
940 Expression::LexGt(..) => Some(Domain::bool()),
941 Expression::LexGeq(..) => Some(Domain::bool()),
942 Expression::FlatLexLt(..) => Some(Domain::bool()),
943 Expression::FlatLexLeq(..) => Some(Domain::bool()),
944 };
945 if let Some(dom) = &ret
946 && let Some(ranges) = dom.as_int_ground()
947 && ranges.len() > 1
948 {
949 let (min, max) = range_vec_bounds_i32(ranges)?;
952 return Some(Domain::int(vec![Range::Bounded(min, max)]));
953 }
954 ret
955 }
956
957 pub fn get_meta(&self) -> Metadata {
958 let metas: VecDeque<Metadata> = self.children_bi();
959 metas[0].clone()
960 }
961
962 pub fn set_meta(&self, meta: Metadata) {
963 self.transform_bi(&|_| meta.clone());
964 }
965
966 pub fn is_safe(&self) -> bool {
973 for expr in self.universe() {
975 match expr {
976 Expression::UnsafeDiv(_, _, _)
977 | Expression::UnsafeMod(_, _, _)
978 | Expression::UnsafePow(_, _, _)
979 | Expression::UnsafeIndex(_, _, _)
980 | Expression::Bubble(_, _, _)
981 | Expression::UnsafeSlice(_, _, _) => {
982 return false;
983 }
984 _ => {}
985 }
986 }
987 true
988 }
989
990 pub fn is_clean(&self) -> bool {
991 let metadata = self.get_meta();
992 metadata.clean
993 }
994
995 pub fn set_clean(&mut self, bool_value: bool) {
996 let mut metadata = self.get_meta();
997 metadata.clean = bool_value;
998 self.set_meta(metadata);
999 }
1000
1001 pub fn is_associative_commutative_operator(&self) -> bool {
1003 TryInto::<ACOperatorKind>::try_into(self).is_ok()
1004 }
1005
1006 pub fn is_matrix_literal(&self) -> bool {
1011 matches!(
1012 self,
1013 Expression::AbstractLiteral(_, AbstractLiteral::Matrix(_, _))
1014 | Expression::Atomic(
1015 _,
1016 Atom::Literal(Literal::AbstractLiteral(AbstractLiteral::Matrix(_, _))),
1017 )
1018 )
1019 }
1020
1021 pub fn identical_atom_to(&self, other: &Expression) -> bool {
1027 let atom1: Result<&Atom, _> = self.try_into();
1028 let atom2: Result<&Atom, _> = other.try_into();
1029
1030 if let (Ok(atom1), Ok(atom2)) = (atom1, atom2) {
1031 atom2 == atom1
1032 } else {
1033 false
1034 }
1035 }
1036
1037 pub fn unwrap_list(&self) -> Option<Vec<Expression>> {
1042 match self {
1043 Expression::AbstractLiteral(_, matrix @ AbstractLiteral::Matrix(_, _)) => {
1044 matrix.unwrap_list().cloned()
1045 }
1046 Expression::Atomic(
1047 _,
1048 Atom::Literal(Literal::AbstractLiteral(matrix @ AbstractLiteral::Matrix(_, _))),
1049 ) => matrix.unwrap_list().map(|elems| {
1050 elems
1051 .clone()
1052 .into_iter()
1053 .map(|x: Literal| Expression::Atomic(Metadata::new(), Atom::Literal(x)))
1054 .collect_vec()
1055 }),
1056 _ => None,
1057 }
1058 }
1059
1060 pub fn unwrap_matrix_unchecked(self) -> Option<(Vec<Expression>, DomainPtr)> {
1068 match self {
1069 Expression::AbstractLiteral(_, AbstractLiteral::Matrix(elems, domain)) => {
1070 Some((elems, domain))
1071 }
1072 Expression::Atomic(
1073 _,
1074 Atom::Literal(Literal::AbstractLiteral(AbstractLiteral::Matrix(elems, domain))),
1075 ) => Some((
1076 elems
1077 .into_iter()
1078 .map(|x: Literal| Expression::Atomic(Metadata::new(), Atom::Literal(x)))
1079 .collect_vec(),
1080 domain.into(),
1081 )),
1082
1083 _ => None,
1084 }
1085 }
1086
1087 pub fn extend_root(self, exprs: Vec<Expression>) -> Expression {
1092 match self {
1093 Expression::Root(meta, mut children) => {
1094 children.extend(exprs);
1095 Expression::Root(meta, children)
1096 }
1097 _ => panic!("extend_root called on a non-Root expression"),
1098 }
1099 }
1100
1101 pub fn into_literal(self) -> Option<Literal> {
1103 match self {
1104 Expression::Atomic(_, Atom::Literal(lit)) => Some(lit),
1105 Expression::AbstractLiteral(_, abslit) => {
1106 Some(Literal::AbstractLiteral(abslit.into_literals()?))
1107 }
1108 Expression::Neg(_, e) => {
1109 let Literal::Int(i) = Moo::unwrap_or_clone(e).into_literal()? else {
1110 bug!("negated literal should be an int");
1111 };
1112
1113 Some(Literal::Int(-i))
1114 }
1115
1116 _ => None,
1117 }
1118 }
1119
1120 pub fn to_ac_operator_kind(&self) -> Option<ACOperatorKind> {
1122 TryFrom::try_from(self).ok()
1123 }
1124
1125 pub fn universe_categories(&self) -> HashSet<Category> {
1127 self.universe()
1128 .into_iter()
1129 .map(|x| x.category_of())
1130 .collect()
1131 }
1132}
1133
1134pub fn get_function_domain(function: &Moo<Expression>) -> Option<DomainPtr> {
1135 let function_domain = function.domain_of()?;
1136 match function_domain.resolve().as_ref() {
1137 Some(d) => {
1138 match d.as_ref() {
1139 GroundDomain::Function(_, domain, _) => Some(domain.clone().into()),
1140 _ => None,
1142 }
1143 }
1144 None => {
1145 match function_domain.as_unresolved()? {
1146 UnresolvedDomain::Function(_, domain, _) => Some(domain.clone()),
1147 _ => None,
1149 }
1150 }
1151 }
1152}
1153
1154pub fn get_function_codomain(function: &Moo<Expression>) -> Option<DomainPtr> {
1155 let function_domain = function.domain_of()?;
1156 match function_domain.resolve().as_ref() {
1157 Some(d) => {
1158 match d.as_ref() {
1159 GroundDomain::Function(_, _, codomain) => Some(codomain.clone().into()),
1160 _ => None,
1162 }
1163 }
1164 None => {
1165 match function_domain.as_unresolved()? {
1166 UnresolvedDomain::Function(_, _, codomain) => Some(codomain.clone()),
1167 _ => None,
1169 }
1170 }
1171 }
1172}
1173
1174impl TryFrom<&Expression> for i32 {
1175 type Error = ();
1176
1177 fn try_from(value: &Expression) -> Result<Self, Self::Error> {
1178 let Expression::Atomic(_, atom) = value else {
1179 return Err(());
1180 };
1181
1182 let Atom::Literal(lit) = atom else {
1183 return Err(());
1184 };
1185
1186 let Literal::Int(i) = lit else {
1187 return Err(());
1188 };
1189
1190 Ok(*i)
1191 }
1192}
1193
1194impl TryFrom<Expression> for i32 {
1195 type Error = ();
1196
1197 fn try_from(value: Expression) -> Result<Self, Self::Error> {
1198 TryFrom::<&Expression>::try_from(&value)
1199 }
1200}
1201impl From<i32> for Expression {
1202 fn from(i: i32) -> Self {
1203 Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Int(i)))
1204 }
1205}
1206
1207impl From<bool> for Expression {
1208 fn from(b: bool) -> Self {
1209 Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Bool(b)))
1210 }
1211}
1212
1213impl From<Atom> for Expression {
1214 fn from(value: Atom) -> Self {
1215 Expression::Atomic(Metadata::new(), value)
1216 }
1217}
1218
1219impl From<Literal> for Expression {
1220 fn from(value: Literal) -> Self {
1221 Expression::Atomic(Metadata::new(), value.into())
1222 }
1223}
1224
1225impl From<Moo<Expression>> for Expression {
1226 fn from(val: Moo<Expression>) -> Self {
1227 val.as_ref().clone()
1228 }
1229}
1230
1231impl CategoryOf for Expression {
1232 fn category_of(&self) -> Category {
1233 let category = self.cata(&move |x,children| {
1235
1236 if let Some(max_category) = children.iter().max() {
1237 *max_category
1240 } else {
1241 let mut max_category = Category::Bottom;
1243
1244 if !Biplate::<SubModel>::universe_bi(&x).is_empty() {
1251 return Category::Decision;
1253 }
1254
1255 if let Some(max_atom_category) = Biplate::<Atom>::universe_bi(&x).iter().map(|x| x.category_of()).max()
1257 && max_atom_category > max_category{
1259 max_category = max_atom_category;
1261 }
1262
1263 if let Some(max_declaration_category) = Biplate::<DeclarationPtr>::universe_bi(&x).iter().map(|x| x.category_of()).max()
1265 && max_declaration_category > max_category{
1267 max_category = max_declaration_category;
1269 }
1270 max_category
1271
1272 }
1273 });
1274
1275 if cfg!(debug_assertions) {
1276 trace!(
1277 category= %category,
1278 expression= %self,
1279 "Called Expression::category_of()"
1280 );
1281 };
1282 category
1283 }
1284}
1285
1286impl Display for Expression {
1287 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
1288 match &self {
1289 Expression::Union(_, box1, box2) => {
1290 write!(f, "({} union {})", box1.clone(), box2.clone())
1291 }
1292 Expression::In(_, e1, e2) => {
1293 write!(f, "{e1} in {e2}")
1294 }
1295 Expression::Intersect(_, box1, box2) => {
1296 write!(f, "({} intersect {})", box1.clone(), box2.clone())
1297 }
1298 Expression::Supset(_, box1, box2) => {
1299 write!(f, "({} supset {})", box1.clone(), box2.clone())
1300 }
1301 Expression::SupsetEq(_, box1, box2) => {
1302 write!(f, "({} supsetEq {})", box1.clone(), box2.clone())
1303 }
1304 Expression::Subset(_, box1, box2) => {
1305 write!(f, "({} subset {})", box1.clone(), box2.clone())
1306 }
1307 Expression::SubsetEq(_, box1, box2) => {
1308 write!(f, "({} subsetEq {})", box1.clone(), box2.clone())
1309 }
1310
1311 Expression::AbstractLiteral(_, l) => l.fmt(f),
1312 Expression::Comprehension(_, c) => c.fmt(f),
1313 Expression::AbstractComprehension(_, c) => c.fmt(f),
1314 Expression::UnsafeIndex(_, e1, e2) | Expression::SafeIndex(_, e1, e2) => {
1315 write!(f, "{e1}{}", pretty_vec(e2))
1316 }
1317 Expression::UnsafeSlice(_, e1, es) | Expression::SafeSlice(_, e1, es) => {
1318 let args = es
1319 .iter()
1320 .map(|x| match x {
1321 Some(x) => format!("{x}"),
1322 None => "..".into(),
1323 })
1324 .join(",");
1325
1326 write!(f, "{e1}[{args}]")
1327 }
1328 Expression::InDomain(_, e, domain) => {
1329 write!(f, "__inDomain({e},{domain})")
1330 }
1331 Expression::Root(_, exprs) => {
1332 write!(f, "{}", pretty_expressions_as_top_level(exprs))
1333 }
1334 Expression::DominanceRelation(_, expr) => write!(f, "DominanceRelation({expr})"),
1335 Expression::FromSolution(_, expr) => write!(f, "FromSolution({expr})"),
1336 Expression::Metavar(_, name) => write!(f, "&{name}"),
1337 Expression::Atomic(_, atom) => atom.fmt(f),
1338 Expression::Scope(_, submodel) => write!(f, "{{\n{submodel}\n}}"),
1339 Expression::Abs(_, a) => write!(f, "|{a}|"),
1340 Expression::Sum(_, e) => {
1341 write!(f, "sum({e})")
1342 }
1343 Expression::Product(_, e) => {
1344 write!(f, "product({e})")
1345 }
1346 Expression::Min(_, e) => {
1347 write!(f, "min({e})")
1348 }
1349 Expression::Max(_, e) => {
1350 write!(f, "max({e})")
1351 }
1352 Expression::Not(_, expr_box) => {
1353 write!(f, "!({})", expr_box.clone())
1354 }
1355 Expression::Or(_, e) => {
1356 write!(f, "or({e})")
1357 }
1358 Expression::And(_, e) => {
1359 write!(f, "and({e})")
1360 }
1361 Expression::Imply(_, box1, box2) => {
1362 write!(f, "({box1}) -> ({box2})")
1363 }
1364 Expression::Iff(_, box1, box2) => {
1365 write!(f, "({box1}) <-> ({box2})")
1366 }
1367 Expression::Eq(_, box1, box2) => {
1368 write!(f, "({} = {})", box1.clone(), box2.clone())
1369 }
1370 Expression::Neq(_, box1, box2) => {
1371 write!(f, "({} != {})", box1.clone(), box2.clone())
1372 }
1373 Expression::Geq(_, box1, box2) => {
1374 write!(f, "({} >= {})", box1.clone(), box2.clone())
1375 }
1376 Expression::Leq(_, box1, box2) => {
1377 write!(f, "({} <= {})", box1.clone(), box2.clone())
1378 }
1379 Expression::Gt(_, box1, box2) => {
1380 write!(f, "({} > {})", box1.clone(), box2.clone())
1381 }
1382 Expression::Lt(_, box1, box2) => {
1383 write!(f, "({} < {})", box1.clone(), box2.clone())
1384 }
1385 Expression::FlatSumGeq(_, box1, box2) => {
1386 write!(f, "SumGeq({}, {})", pretty_vec(box1), box2.clone())
1387 }
1388 Expression::FlatSumLeq(_, box1, box2) => {
1389 write!(f, "SumLeq({}, {})", pretty_vec(box1), box2.clone())
1390 }
1391 Expression::FlatIneq(_, box1, box2, box3) => write!(
1392 f,
1393 "Ineq({}, {}, {})",
1394 box1.clone(),
1395 box2.clone(),
1396 box3.clone()
1397 ),
1398 Expression::Flatten(_, n, m) => {
1399 if let Some(n) = n {
1400 write!(f, "flatten({n}, {m})")
1401 } else {
1402 write!(f, "flatten({m})")
1403 }
1404 }
1405 Expression::AllDiff(_, e) => {
1406 write!(f, "allDiff({e})")
1407 }
1408 Expression::Bubble(_, box1, box2) => {
1409 write!(f, "{{{} @ {}}}", box1.clone(), box2.clone())
1410 }
1411 Expression::SafeDiv(_, box1, box2) => {
1412 write!(f, "SafeDiv({}, {})", box1.clone(), box2.clone())
1413 }
1414 Expression::UnsafeDiv(_, box1, box2) => {
1415 write!(f, "UnsafeDiv({}, {})", box1.clone(), box2.clone())
1416 }
1417 Expression::UnsafePow(_, box1, box2) => {
1418 write!(f, "UnsafePow({}, {})", box1.clone(), box2.clone())
1419 }
1420 Expression::SafePow(_, box1, box2) => {
1421 write!(f, "SafePow({}, {})", box1.clone(), box2.clone())
1422 }
1423 Expression::MinionDivEqUndefZero(_, box1, box2, box3) => {
1424 write!(
1425 f,
1426 "DivEq({}, {}, {})",
1427 box1.clone(),
1428 box2.clone(),
1429 box3.clone()
1430 )
1431 }
1432 Expression::MinionModuloEqUndefZero(_, box1, box2, box3) => {
1433 write!(
1434 f,
1435 "ModEq({}, {}, {})",
1436 box1.clone(),
1437 box2.clone(),
1438 box3.clone()
1439 )
1440 }
1441 Expression::FlatWatchedLiteral(_, x, l) => {
1442 write!(f, "WatchedLiteral({x},{l})")
1443 }
1444 Expression::MinionReify(_, box1, box2) => {
1445 write!(f, "Reify({}, {})", box1.clone(), box2.clone())
1446 }
1447 Expression::MinionReifyImply(_, box1, box2) => {
1448 write!(f, "ReifyImply({}, {})", box1.clone(), box2.clone())
1449 }
1450 Expression::MinionWInIntervalSet(_, atom, intervals) => {
1451 let intervals = intervals.iter().join(",");
1452 write!(f, "__minion_w_inintervalset({atom},[{intervals}])")
1453 }
1454 Expression::MinionWInSet(_, atom, values) => {
1455 let values = values.iter().join(",");
1456 write!(f, "__minion_w_inset({atom},{values})")
1457 }
1458 Expression::AuxDeclaration(_, reference, e) => {
1459 write!(f, "{} =aux {}", reference, e.clone())
1460 }
1461 Expression::UnsafeMod(_, a, b) => {
1462 write!(f, "{} % {}", a.clone(), b.clone())
1463 }
1464 Expression::SafeMod(_, a, b) => {
1465 write!(f, "SafeMod({},{})", a.clone(), b.clone())
1466 }
1467 Expression::Neg(_, a) => {
1468 write!(f, "-({})", a.clone())
1469 }
1470 Expression::Minus(_, a, b) => {
1471 write!(f, "({} - {})", a.clone(), b.clone())
1472 }
1473 Expression::FlatAllDiff(_, es) => {
1474 write!(f, "__flat_alldiff({})", pretty_vec(es))
1475 }
1476 Expression::FlatAbsEq(_, a, b) => {
1477 write!(f, "AbsEq({},{})", a.clone(), b.clone())
1478 }
1479 Expression::FlatMinusEq(_, a, b) => {
1480 write!(f, "MinusEq({},{})", a.clone(), b.clone())
1481 }
1482 Expression::FlatProductEq(_, a, b, c) => {
1483 write!(
1484 f,
1485 "FlatProductEq({},{},{})",
1486 a.clone(),
1487 b.clone(),
1488 c.clone()
1489 )
1490 }
1491 Expression::FlatWeightedSumLeq(_, cs, vs, total) => {
1492 write!(
1493 f,
1494 "FlatWeightedSumLeq({},{},{})",
1495 pretty_vec(cs),
1496 pretty_vec(vs),
1497 total.clone()
1498 )
1499 }
1500 Expression::FlatWeightedSumGeq(_, cs, vs, total) => {
1501 write!(
1502 f,
1503 "FlatWeightedSumGeq({},{},{})",
1504 pretty_vec(cs),
1505 pretty_vec(vs),
1506 total.clone()
1507 )
1508 }
1509 Expression::MinionPow(_, atom, atom1, atom2) => {
1510 write!(f, "MinionPow({atom},{atom1},{atom2})")
1511 }
1512 Expression::MinionElementOne(_, atoms, atom, atom1) => {
1513 let atoms = atoms.iter().join(",");
1514 write!(f, "__minion_element_one([{atoms}],{atom},{atom1})")
1515 }
1516
1517 Expression::ToInt(_, expr) => {
1518 write!(f, "toInt({expr})")
1519 }
1520
1521 Expression::SATInt(_, encoding, bits, (min, max)) => {
1522 write!(f, "SATInt({encoding:?}, {bits} [{min}, {max}])")
1523 }
1524
1525 Expression::PairwiseSum(_, a, b) => write!(f, "PairwiseSum({a}, {b})"),
1526 Expression::PairwiseProduct(_, a, b) => write!(f, "PairwiseProduct({a}, {b})"),
1527
1528 Expression::Defined(_, function) => write!(f, "defined({function})"),
1529 Expression::Range(_, function) => write!(f, "range({function})"),
1530 Expression::Image(_, function, elems) => write!(f, "image({function},{elems})"),
1531 Expression::ImageSet(_, function, elems) => write!(f, "imageSet({function},{elems})"),
1532 Expression::PreImage(_, function, elems) => write!(f, "preImage({function},{elems})"),
1533 Expression::Inverse(_, a, b) => write!(f, "inverse({a},{b})"),
1534 Expression::Restrict(_, function, domain) => write!(f, "restrict({function},{domain})"),
1535
1536 Expression::LexLt(_, a, b) => write!(f, "({a} <lex {b})"),
1537 Expression::LexLeq(_, a, b) => write!(f, "({a} <=lex {b})"),
1538 Expression::LexGt(_, a, b) => write!(f, "({a} >lex {b})"),
1539 Expression::LexGeq(_, a, b) => write!(f, "({a} >=lex {b})"),
1540 Expression::FlatLexLt(_, a, b) => {
1541 write!(f, "FlatLexLt({}, {})", pretty_vec(a), pretty_vec(b))
1542 }
1543 Expression::FlatLexLeq(_, a, b) => {
1544 write!(f, "FlatLexLeq({}, {})", pretty_vec(a), pretty_vec(b))
1545 }
1546 }
1547 }
1548}
1549
1550impl Typeable for Expression {
1551 fn return_type(&self) -> ReturnType {
1552 match self {
1553 Expression::Union(_, subject, _) => ReturnType::Set(Box::new(subject.return_type())),
1554 Expression::Intersect(_, subject, _) => {
1555 ReturnType::Set(Box::new(subject.return_type()))
1556 }
1557 Expression::In(_, _, _) => ReturnType::Bool,
1558 Expression::Supset(_, _, _) => ReturnType::Bool,
1559 Expression::SupsetEq(_, _, _) => ReturnType::Bool,
1560 Expression::Subset(_, _, _) => ReturnType::Bool,
1561 Expression::SubsetEq(_, _, _) => ReturnType::Bool,
1562 Expression::AbstractLiteral(_, lit) => lit.return_type(),
1563 Expression::UnsafeIndex(_, subject, idx) | Expression::SafeIndex(_, subject, idx) => {
1564 let subject_ty = subject.return_type();
1565 match subject_ty {
1566 ReturnType::Matrix(_) => {
1567 let mut elem_typ = subject_ty;
1570 let mut idx_len = idx.len();
1571 while idx_len > 0
1572 && let ReturnType::Matrix(new_elem_typ) = &elem_typ
1573 {
1574 elem_typ = *new_elem_typ.clone();
1575 idx_len -= 1;
1576 }
1577 elem_typ
1578 }
1579 ReturnType::Record(_) | ReturnType::Tuple(_) => ReturnType::Unknown,
1581 _ => bug!(
1582 "Invalid indexing operation: expected the operand to be a collection, got {self}: {subject_ty}"
1583 ),
1584 }
1585 }
1586 Expression::UnsafeSlice(_, subject, _) | Expression::SafeSlice(_, subject, _) => {
1587 ReturnType::Matrix(Box::new(subject.return_type()))
1588 }
1589 Expression::InDomain(_, _, _) => ReturnType::Bool,
1590 Expression::Comprehension(_, comp) => comp.return_type(),
1591 Expression::AbstractComprehension(_, comp) => comp.return_type(),
1592 Expression::Root(_, _) => ReturnType::Bool,
1593 Expression::DominanceRelation(_, _) => ReturnType::Bool,
1594 Expression::FromSolution(_, expr) => expr.return_type(),
1595 Expression::Metavar(_, _) => ReturnType::Unknown,
1596 Expression::Atomic(_, atom) => atom.return_type(),
1597 Expression::Scope(_, scope) => scope.return_type(),
1598 Expression::Abs(_, _) => ReturnType::Int,
1599 Expression::Sum(_, _) => ReturnType::Int,
1600 Expression::Product(_, _) => ReturnType::Int,
1601 Expression::Min(_, _) => ReturnType::Int,
1602 Expression::Max(_, _) => ReturnType::Int,
1603 Expression::Not(_, _) => ReturnType::Bool,
1604 Expression::Or(_, _) => ReturnType::Bool,
1605 Expression::Imply(_, _, _) => ReturnType::Bool,
1606 Expression::Iff(_, _, _) => ReturnType::Bool,
1607 Expression::And(_, _) => ReturnType::Bool,
1608 Expression::Eq(_, _, _) => ReturnType::Bool,
1609 Expression::Neq(_, _, _) => ReturnType::Bool,
1610 Expression::Geq(_, _, _) => ReturnType::Bool,
1611 Expression::Leq(_, _, _) => ReturnType::Bool,
1612 Expression::Gt(_, _, _) => ReturnType::Bool,
1613 Expression::Lt(_, _, _) => ReturnType::Bool,
1614 Expression::SafeDiv(_, _, _) => ReturnType::Int,
1615 Expression::UnsafeDiv(_, _, _) => ReturnType::Int,
1616 Expression::FlatAllDiff(_, _) => ReturnType::Bool,
1617 Expression::FlatSumGeq(_, _, _) => ReturnType::Bool,
1618 Expression::FlatSumLeq(_, _, _) => ReturnType::Bool,
1619 Expression::MinionDivEqUndefZero(_, _, _, _) => ReturnType::Bool,
1620 Expression::FlatIneq(_, _, _, _) => ReturnType::Bool,
1621 Expression::Flatten(_, _, matrix) => {
1622 let matrix_type = matrix.return_type();
1623 match matrix_type {
1624 ReturnType::Matrix(_) => {
1625 let mut elem_type = matrix_type;
1627 while let ReturnType::Matrix(new_elem_type) = &elem_type {
1628 elem_type = *new_elem_type.clone();
1629 }
1630 ReturnType::Matrix(Box::new(elem_type))
1631 }
1632 _ => bug!(
1633 "Invalid indexing operation: expected the operand to be a collection, got {self}: {matrix_type}"
1634 ),
1635 }
1636 }
1637 Expression::AllDiff(_, _) => ReturnType::Bool,
1638 Expression::Bubble(_, inner, _) => inner.return_type(),
1639 Expression::FlatWatchedLiteral(_, _, _) => ReturnType::Bool,
1640 Expression::MinionReify(_, _, _) => ReturnType::Bool,
1641 Expression::MinionReifyImply(_, _, _) => ReturnType::Bool,
1642 Expression::MinionWInIntervalSet(_, _, _) => ReturnType::Bool,
1643 Expression::MinionWInSet(_, _, _) => ReturnType::Bool,
1644 Expression::MinionElementOne(_, _, _, _) => ReturnType::Bool,
1645 Expression::AuxDeclaration(_, _, _) => ReturnType::Bool,
1646 Expression::UnsafeMod(_, _, _) => ReturnType::Int,
1647 Expression::SafeMod(_, _, _) => ReturnType::Int,
1648 Expression::MinionModuloEqUndefZero(_, _, _, _) => ReturnType::Bool,
1649 Expression::Neg(_, _) => ReturnType::Int,
1650 Expression::UnsafePow(_, _, _) => ReturnType::Int,
1651 Expression::SafePow(_, _, _) => ReturnType::Int,
1652 Expression::Minus(_, _, _) => ReturnType::Int,
1653 Expression::FlatAbsEq(_, _, _) => ReturnType::Bool,
1654 Expression::FlatMinusEq(_, _, _) => ReturnType::Bool,
1655 Expression::FlatProductEq(_, _, _, _) => ReturnType::Bool,
1656 Expression::FlatWeightedSumLeq(_, _, _, _) => ReturnType::Bool,
1657 Expression::FlatWeightedSumGeq(_, _, _, _) => ReturnType::Bool,
1658 Expression::MinionPow(_, _, _, _) => ReturnType::Bool,
1659 Expression::ToInt(_, _) => ReturnType::Int,
1660 Expression::SATInt(..) => ReturnType::Int,
1661 Expression::PairwiseSum(_, _, _) => ReturnType::Int,
1662 Expression::PairwiseProduct(_, _, _) => ReturnType::Int,
1663 Expression::Defined(_, function) => {
1664 let subject = function.return_type();
1665 match subject {
1666 ReturnType::Function(domain, _) => *domain,
1667 _ => bug!(
1668 "Invalid defined operation: expected the operand to be a function, got {self}: {subject}"
1669 ),
1670 }
1671 }
1672 Expression::Range(_, function) => {
1673 let subject = function.return_type();
1674 match subject {
1675 ReturnType::Function(_, codomain) => *codomain,
1676 _ => bug!(
1677 "Invalid range operation: expected the operand to be a function, got {self}: {subject}"
1678 ),
1679 }
1680 }
1681 Expression::Image(_, function, _) => {
1682 let subject = function.return_type();
1683 match subject {
1684 ReturnType::Function(_, codomain) => *codomain,
1685 _ => bug!(
1686 "Invalid image operation: expected the operand to be a function, got {self}: {subject}"
1687 ),
1688 }
1689 }
1690 Expression::ImageSet(_, function, _) => {
1691 let subject = function.return_type();
1692 match subject {
1693 ReturnType::Function(_, codomain) => *codomain,
1694 _ => bug!(
1695 "Invalid imageSet operation: expected the operand to be a function, got {self}: {subject}"
1696 ),
1697 }
1698 }
1699 Expression::PreImage(_, function, _) => {
1700 let subject = function.return_type();
1701 match subject {
1702 ReturnType::Function(domain, _) => *domain,
1703 _ => bug!(
1704 "Invalid preImage operation: expected the operand to be a function, got {self}: {subject}"
1705 ),
1706 }
1707 }
1708 Expression::Restrict(_, function, new_domain) => {
1709 let subject = function.return_type();
1710 match subject {
1711 ReturnType::Function(_, codomain) => {
1712 ReturnType::Function(Box::new(new_domain.return_type()), codomain)
1713 }
1714 _ => bug!(
1715 "Invalid preImage operation: expected the operand to be a function, got {self}: {subject}"
1716 ),
1717 }
1718 }
1719 Expression::Inverse(..) => ReturnType::Bool,
1720 Expression::LexLt(..) => ReturnType::Bool,
1721 Expression::LexGt(..) => ReturnType::Bool,
1722 Expression::LexLeq(..) => ReturnType::Bool,
1723 Expression::LexGeq(..) => ReturnType::Bool,
1724 Expression::FlatLexLt(..) => ReturnType::Bool,
1725 Expression::FlatLexLeq(..) => ReturnType::Bool,
1726 }
1727 }
1728}
1729
1730#[cfg(test)]
1731mod tests {
1732
1733 use crate::matrix_expr;
1734
1735 use super::*;
1736
1737 #[test]
1738 fn test_domain_of_constant_sum() {
1739 let c1 = Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Int(1)));
1740 let c2 = Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Int(2)));
1741 let sum = Expression::Sum(Metadata::new(), Moo::new(matrix_expr![c1, c2]));
1742 assert_eq!(sum.domain_of(), Some(Domain::int(vec![Range::Single(3)])));
1743 }
1744
1745 #[test]
1746 fn test_domain_of_constant_invalid_type() {
1747 let c1 = Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Int(1)));
1748 let c2 = Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Bool(true)));
1749 let sum = Expression::Sum(Metadata::new(), Moo::new(matrix_expr![c1, c2]));
1750 assert_eq!(sum.domain_of(), None);
1751 }
1752
1753 #[test]
1754 fn test_domain_of_empty_sum() {
1755 let sum = Expression::Sum(Metadata::new(), Moo::new(matrix_expr![]));
1756 assert_eq!(sum.domain_of(), None);
1757 }
1758}