1use std::collections::{HashSet, VecDeque};
2use std::fmt::{Display, Formatter};
3use tracing::trace;
4
5use conjure_cp_enum_compatibility_macro::document_compatibility;
6use itertools::Itertools;
7use serde::{Deserialize, Serialize};
8use ustr::Ustr;
9
10use polyquine::Quine;
11use uniplate::{Biplate, Uniplate};
12
13use crate::bug;
14
15use super::abstract_comprehension::AbstractComprehension;
16use super::ac_operators::ACOperatorKind;
17use super::categories::{Category, CategoryOf};
18use super::comprehension::Comprehension;
19use super::domains::HasDomain as _;
20use super::pretty::{pretty_expressions_as_top_level, pretty_vec};
21use super::records::RecordValue;
22use super::sat_encoding::SATIntEncoding;
23use super::{
24 AbstractLiteral, Atom, DeclarationPtr, Domain, DomainPtr, GroundDomain, IntVal, Literal,
25 Metadata, Moo, Name, Range, Reference, ReturnType, SetAttr, SubModel, SymbolTable,
26 SymbolTablePtr, Typeable, UnresolvedDomain, matrix,
27};
28
29static_assertions::assert_eq_size!([u8; 104], Expression);
52
53#[document_compatibility]
58#[derive(Clone, Debug, Hash, PartialEq, Eq, Serialize, Deserialize, Uniplate, Quine)]
59#[biplate(to=AbstractComprehension)]
60#[biplate(to=AbstractLiteral<Expression>)]
61#[biplate(to=AbstractLiteral<Literal>)]
62#[biplate(to=Atom)]
63#[biplate(to=Comprehension)]
64#[biplate(to=DeclarationPtr)]
65#[biplate(to=DomainPtr)]
66#[biplate(to=Literal)]
67#[biplate(to=Metadata)]
68#[biplate(to=Name)]
69#[biplate(to=Option<Expression>)]
70#[biplate(to=RecordValue<Expression>)]
71#[biplate(to=RecordValue<Literal>)]
72#[biplate(to=Reference)]
73#[biplate(to=SubModel)]
74#[biplate(to=SymbolTable)]
75#[biplate(to=SymbolTablePtr)]
76#[biplate(to=Vec<Expression>)]
77#[path_prefix(conjure_cp::ast)]
78pub enum Expression {
79 AbstractLiteral(Metadata, AbstractLiteral<Expression>),
80 Root(Metadata, Vec<Expression>),
82
83 Bubble(Metadata, Moo<Expression>, Moo<Expression>),
86
87 #[polyquine_skip]
93 Comprehension(Metadata, Moo<Comprehension>),
94
95 #[polyquine_skip] AbstractComprehension(Metadata, Moo<AbstractComprehension>),
98
99 DominanceRelation(Metadata, Moo<Expression>),
101 FromSolution(Metadata, Moo<Atom>),
103
104 #[polyquine_with(arm = (_, name) => {
105 let ident = proc_macro2::Ident::new(name.as_str(), proc_macro2::Span::call_site());
106 quote::quote! { #ident.clone().into() }
107 })]
108 Metavar(Metadata, Ustr),
109
110 Atomic(Metadata, Atom),
111
112 #[compatible(JsonInput)]
116 UnsafeIndex(Metadata, Moo<Expression>, Vec<Expression>),
117
118 #[compatible(SMT)]
122 SafeIndex(Metadata, Moo<Expression>, Vec<Expression>),
123
124 #[compatible(JsonInput)]
134 UnsafeSlice(Metadata, Moo<Expression>, Vec<Option<Expression>>),
135
136 SafeSlice(Metadata, Moo<Expression>, Vec<Option<Expression>>),
140
141 InDomain(Metadata, Moo<Expression>, DomainPtr),
147
148 #[compatible(SMT)]
154 ToInt(Metadata, Moo<Expression>),
155
156 #[polyquine_skip]
158 Scope(Metadata, Moo<SubModel>),
159
160 #[compatible(JsonInput, SMT)]
162 Abs(Metadata, Moo<Expression>),
163
164 #[compatible(JsonInput, SMT)]
166 Sum(Metadata, Moo<Expression>),
167
168 #[compatible(JsonInput, SMT)]
170 Product(Metadata, Moo<Expression>),
171
172 #[compatible(JsonInput, SMT)]
174 Min(Metadata, Moo<Expression>),
175
176 #[compatible(JsonInput, SMT)]
178 Max(Metadata, Moo<Expression>),
179
180 #[compatible(JsonInput, SAT, SMT)]
182 Not(Metadata, Moo<Expression>),
183
184 #[compatible(JsonInput, SAT, SMT)]
186 Or(Metadata, Moo<Expression>),
187
188 #[compatible(JsonInput, SAT, SMT)]
190 And(Metadata, Moo<Expression>),
191
192 #[compatible(JsonInput, SMT)]
194 Imply(Metadata, Moo<Expression>, Moo<Expression>),
195
196 #[compatible(JsonInput, SMT)]
198 Iff(Metadata, Moo<Expression>, Moo<Expression>),
199
200 #[compatible(JsonInput)]
201 Union(Metadata, Moo<Expression>, Moo<Expression>),
202
203 #[compatible(JsonInput)]
204 In(Metadata, Moo<Expression>, Moo<Expression>),
205
206 #[compatible(JsonInput)]
207 Intersect(Metadata, Moo<Expression>, Moo<Expression>),
208
209 #[compatible(JsonInput)]
210 Supset(Metadata, Moo<Expression>, Moo<Expression>),
211
212 #[compatible(JsonInput)]
213 SupsetEq(Metadata, Moo<Expression>, Moo<Expression>),
214
215 #[compatible(JsonInput)]
216 Subset(Metadata, Moo<Expression>, Moo<Expression>),
217
218 #[compatible(JsonInput)]
219 SubsetEq(Metadata, Moo<Expression>, Moo<Expression>),
220
221 #[compatible(JsonInput, SMT)]
222 Eq(Metadata, Moo<Expression>, Moo<Expression>),
223
224 #[compatible(JsonInput, SMT)]
225 Neq(Metadata, Moo<Expression>, Moo<Expression>),
226
227 #[compatible(JsonInput, SMT)]
228 Geq(Metadata, Moo<Expression>, Moo<Expression>),
229
230 #[compatible(JsonInput, SMT)]
231 Leq(Metadata, Moo<Expression>, Moo<Expression>),
232
233 #[compatible(JsonInput, SMT)]
234 Gt(Metadata, Moo<Expression>, Moo<Expression>),
235
236 #[compatible(JsonInput, SMT)]
237 Lt(Metadata, Moo<Expression>, Moo<Expression>),
238
239 #[compatible(SMT)]
241 SafeDiv(Metadata, Moo<Expression>, Moo<Expression>),
242
243 #[compatible(JsonInput)]
245 UnsafeDiv(Metadata, Moo<Expression>, Moo<Expression>),
246
247 #[compatible(SMT)]
249 SafeMod(Metadata, Moo<Expression>, Moo<Expression>),
250
251 #[compatible(JsonInput)]
253 UnsafeMod(Metadata, Moo<Expression>, Moo<Expression>),
254
255 #[compatible(JsonInput, SMT)]
257 Neg(Metadata, Moo<Expression>),
258
259 #[compatible(JsonInput)]
261 Defined(Metadata, Moo<Expression>),
262
263 #[compatible(JsonInput)]
265 Range(Metadata, Moo<Expression>),
266
267 #[compatible(JsonInput)]
271 UnsafePow(Metadata, Moo<Expression>, Moo<Expression>),
272
273 SafePow(Metadata, Moo<Expression>, Moo<Expression>),
275
276 Flatten(Metadata, Option<Moo<Expression>>, Moo<Expression>),
280
281 #[compatible(JsonInput)]
283 AllDiff(Metadata, Moo<Expression>),
284
285 #[compatible(JsonInput)]
291 Minus(Metadata, Moo<Expression>, Moo<Expression>),
292
293 #[compatible(Minion)]
301 FlatAbsEq(Metadata, Moo<Atom>, Moo<Atom>),
302
303 #[compatible(Minion)]
311 FlatAllDiff(Metadata, Vec<Atom>),
312
313 #[compatible(Minion)]
321 FlatSumGeq(Metadata, Vec<Atom>, Atom),
322
323 #[compatible(Minion)]
331 FlatSumLeq(Metadata, Vec<Atom>, Atom),
332
333 #[compatible(Minion)]
341 FlatIneq(Metadata, Moo<Atom>, Moo<Atom>, Box<Literal>),
342
343 #[compatible(Minion)]
356 #[polyquine_skip]
357 FlatWatchedLiteral(Metadata, Reference, Literal),
358
359 FlatWeightedSumLeq(Metadata, Vec<Literal>, Vec<Atom>, Moo<Atom>),
371
372 FlatWeightedSumGeq(Metadata, Vec<Literal>, Vec<Atom>, Moo<Atom>),
384
385 #[compatible(Minion)]
393 FlatMinusEq(Metadata, Moo<Atom>, Moo<Atom>),
394
395 #[compatible(Minion)]
403 FlatProductEq(Metadata, Moo<Atom>, Moo<Atom>, Moo<Atom>),
404
405 #[compatible(Minion)]
413 MinionDivEqUndefZero(Metadata, Moo<Atom>, Moo<Atom>, Moo<Atom>),
414
415 #[compatible(Minion)]
423 MinionModuloEqUndefZero(Metadata, Moo<Atom>, Moo<Atom>, Moo<Atom>),
424
425 MinionPow(Metadata, Moo<Atom>, Moo<Atom>, Moo<Atom>),
437
438 #[compatible(Minion)]
447 MinionReify(Metadata, Moo<Expression>, Atom),
448
449 #[compatible(Minion)]
458 MinionReifyImply(Metadata, Moo<Expression>, Atom),
459
460 #[compatible(Minion)]
471 MinionWInIntervalSet(Metadata, Atom, Vec<i32>),
472
473 #[compatible(Minion)]
485 MinionWInSet(Metadata, Atom, Vec<i32>),
486
487 #[compatible(Minion)]
496 MinionElementOne(Metadata, Vec<Atom>, Moo<Atom>, Moo<Atom>),
497
498 #[compatible(Minion)]
502 #[polyquine_skip]
503 AuxDeclaration(Metadata, Reference, Moo<Expression>),
504
505 #[compatible(SAT)]
507 SATInt(Metadata, SATIntEncoding, Moo<Expression>, (i32, i32)),
508
509 #[compatible(SMT)]
512 PairwiseSum(Metadata, Moo<Expression>, Moo<Expression>),
513
514 #[compatible(SMT)]
517 PairwiseProduct(Metadata, Moo<Expression>, Moo<Expression>),
518
519 #[compatible(JsonInput)]
520 Image(Metadata, Moo<Expression>, Moo<Expression>),
521
522 #[compatible(JsonInput)]
523 ImageSet(Metadata, Moo<Expression>, Moo<Expression>),
524
525 #[compatible(JsonInput)]
526 PreImage(Metadata, Moo<Expression>, Moo<Expression>),
527
528 #[compatible(JsonInput)]
529 Inverse(Metadata, Moo<Expression>, Moo<Expression>),
530
531 #[compatible(JsonInput)]
532 Restrict(Metadata, Moo<Expression>, Moo<Expression>),
533
534 LexLt(Metadata, Moo<Expression>, Moo<Expression>),
543
544 LexLeq(Metadata, Moo<Expression>, Moo<Expression>),
546
547 LexGt(Metadata, Moo<Expression>, Moo<Expression>),
550
551 LexGeq(Metadata, Moo<Expression>, Moo<Expression>),
554
555 FlatLexLt(Metadata, Vec<Atom>, Vec<Atom>),
557
558 FlatLexLeq(Metadata, Vec<Atom>, Vec<Atom>),
560}
561
562fn bounded_i32_domain_for_matrix_literal_monotonic(
569 e: &Expression,
570 op: fn(i32, i32) -> Option<i32>,
571) -> Option<DomainPtr> {
572 let (mut exprs, _) = e.clone().unwrap_matrix_unchecked()?;
574
575 let expr = exprs.pop()?;
591 let dom = expr.domain_of()?;
592 let Some(GroundDomain::Int(ranges)) = dom.as_ground() else {
593 return None;
594 };
595
596 let (mut current_min, mut current_max) = range_vec_bounds_i32(ranges)?;
597
598 for expr in exprs {
599 let dom = expr.domain_of()?;
600 let Some(GroundDomain::Int(ranges)) = dom.as_ground() else {
601 return None;
602 };
603
604 let (min, max) = range_vec_bounds_i32(ranges)?;
605
606 let minmax = op(min, current_max)?;
608 let minmin = op(min, current_min)?;
609 let maxmin = op(max, current_min)?;
610 let maxmax = op(max, current_max)?;
611 let vals = [minmax, minmin, maxmin, maxmax];
612
613 current_min = *vals
614 .iter()
615 .min()
616 .expect("vals iterator should not be empty, and should have a minimum.");
617 current_max = *vals
618 .iter()
619 .max()
620 .expect("vals iterator should not be empty, and should have a maximum.");
621 }
622
623 if current_min == current_max {
624 Some(Domain::int(vec![Range::Single(current_min)]))
625 } else {
626 Some(Domain::int(vec![Range::Bounded(current_min, current_max)]))
627 }
628}
629
630fn range_vec_bounds_i32(ranges: &Vec<Range<i32>>) -> Option<(i32, i32)> {
632 let mut min = i32::MAX;
633 let mut max = i32::MIN;
634 for r in ranges {
635 match r {
636 Range::Single(i) => {
637 if *i < min {
638 min = *i;
639 }
640 if *i > max {
641 max = *i;
642 }
643 }
644 Range::Bounded(i, j) => {
645 if *i < min {
646 min = *i;
647 }
648 if *j > max {
649 max = *j;
650 }
651 }
652 Range::UnboundedR(_) | Range::UnboundedL(_) | Range::Unbounded => return None,
653 }
654 }
655 Some((min, max))
656}
657
658impl Expression {
659 pub fn domain_of(&self) -> Option<DomainPtr> {
661 match self {
662 Expression::Union(_, a, b) => Some(Domain::set(
663 SetAttr::<IntVal>::default(),
664 a.domain_of()?.union(&b.domain_of()?).ok()?,
665 )),
666 Expression::Intersect(_, a, b) => Some(Domain::set(
667 SetAttr::<IntVal>::default(),
668 a.domain_of()?.intersect(&b.domain_of()?).ok()?,
669 )),
670 Expression::In(_, _, _) => Some(Domain::bool()),
671 Expression::Supset(_, _, _) => Some(Domain::bool()),
672 Expression::SupsetEq(_, _, _) => Some(Domain::bool()),
673 Expression::Subset(_, _, _) => Some(Domain::bool()),
674 Expression::SubsetEq(_, _, _) => Some(Domain::bool()),
675 Expression::AbstractLiteral(_, abslit) => abslit.domain_of(),
676 Expression::DominanceRelation(_, _) => Some(Domain::bool()),
677 Expression::FromSolution(_, expr) => Some(expr.domain_of()),
678 Expression::Metavar(_, _) => None,
679 Expression::Comprehension(_, comprehension) => comprehension.domain_of(),
680 Expression::AbstractComprehension(_, comprehension) => comprehension.domain_of(),
681 Expression::UnsafeIndex(_, matrix, _) | Expression::SafeIndex(_, matrix, _) => {
682 let dom = matrix.domain_of()?;
683 if let Some((elem_domain, _)) = dom.as_matrix() {
684 return Some(elem_domain);
685 }
686
687 #[allow(clippy::redundant_pattern_matching)]
689 if let Some(_) = dom.as_tuple() {
690 return None;
692 }
693
694 #[allow(clippy::redundant_pattern_matching)]
696 if let Some(_) = dom.as_record() {
697 return None;
699 }
700
701 bug!("subject of an index operation should support indexing")
702 }
703 Expression::UnsafeSlice(_, matrix, indices)
704 | Expression::SafeSlice(_, matrix, indices) => {
705 let sliced_dimension = indices.iter().position(Option::is_none);
706
707 let dom = matrix.domain_of()?;
708 let Some((elem_domain, index_domains)) = dom.as_matrix() else {
709 bug!("subject of an index operation should be a matrix");
710 };
711
712 match sliced_dimension {
713 Some(dimension) => Some(Domain::matrix(
714 elem_domain,
715 vec![index_domains[dimension].clone()],
716 )),
717
718 None => Some(elem_domain),
720 }
721 }
722 Expression::InDomain(_, _, _) => Some(Domain::bool()),
723 Expression::Atomic(_, atom) => Some(atom.domain_of()),
724 Expression::Scope(_, _) => Some(Domain::bool()),
725 Expression::Sum(_, e) => {
726 bounded_i32_domain_for_matrix_literal_monotonic(e, |x, y| Some(x + y))
727 }
728 Expression::Product(_, e) => {
729 bounded_i32_domain_for_matrix_literal_monotonic(e, |x, y| Some(x * y))
730 }
731 Expression::Min(_, e) => bounded_i32_domain_for_matrix_literal_monotonic(e, |x, y| {
732 Some(if x < y { x } else { y })
733 }),
734 Expression::Max(_, e) => bounded_i32_domain_for_matrix_literal_monotonic(e, |x, y| {
735 Some(if x > y { x } else { y })
736 }),
737 Expression::UnsafeDiv(_, a, b) => a
738 .domain_of()?
739 .resolve()?
740 .apply_i32(
741 |x, y| {
744 if y != 0 {
745 Some((x as f32 / y as f32).floor() as i32)
746 } else {
747 None
748 }
749 },
750 b.domain_of()?.resolve()?.as_ref(),
751 )
752 .map(DomainPtr::from)
753 .ok(),
754 Expression::SafeDiv(_, a, b) => {
755 let domain = a
758 .domain_of()?
759 .resolve()?
760 .apply_i32(
761 |x, y| {
762 if y != 0 {
763 Some((x as f32 / y as f32).floor() as i32)
764 } else {
765 None
766 }
767 },
768 b.domain_of()?.resolve()?.as_ref(),
769 )
770 .unwrap_or_else(|err| bug!("Got {err} when computing domain of {self}"));
771
772 if let GroundDomain::Int(ranges) = domain {
773 let mut ranges = ranges;
774 ranges.push(Range::Single(0));
775 Some(Domain::int(ranges))
776 } else {
777 bug!("Domain of {self} was not integer")
778 }
779 }
780 Expression::UnsafeMod(_, a, b) => a
781 .domain_of()?
782 .resolve()?
783 .apply_i32(
784 |x, y| if y != 0 { Some(x % y) } else { None },
785 b.domain_of()?.resolve()?.as_ref(),
786 )
787 .map(DomainPtr::from)
788 .ok(),
789 Expression::SafeMod(_, a, b) => {
790 let domain = a
791 .domain_of()?
792 .resolve()?
793 .apply_i32(
794 |x, y| if y != 0 { Some(x % y) } else { None },
795 b.domain_of()?.resolve()?.as_ref(),
796 )
797 .unwrap_or_else(|err| bug!("Got {err} when computing domain of {self}"));
798
799 if let GroundDomain::Int(ranges) = domain {
800 let mut ranges = ranges;
801 ranges.push(Range::Single(0));
802 Some(Domain::int(ranges))
803 } else {
804 bug!("Domain of {self} was not integer")
805 }
806 }
807 Expression::SafePow(_, a, b) | Expression::UnsafePow(_, a, b) => a
808 .domain_of()?
809 .resolve()?
810 .apply_i32(
811 |x, y| {
812 if (x != 0 || y != 0) && y >= 0 {
813 Some(x.pow(y as u32))
814 } else {
815 None
816 }
817 },
818 b.domain_of()?.resolve()?.as_ref(),
819 )
820 .map(DomainPtr::from)
821 .ok(),
822 Expression::Root(_, _) => None,
823 Expression::Bubble(_, inner, _) => inner.domain_of(),
824 Expression::AuxDeclaration(_, _, _) => Some(Domain::bool()),
825 Expression::And(_, _) => Some(Domain::bool()),
826 Expression::Not(_, _) => Some(Domain::bool()),
827 Expression::Or(_, _) => Some(Domain::bool()),
828 Expression::Imply(_, _, _) => Some(Domain::bool()),
829 Expression::Iff(_, _, _) => Some(Domain::bool()),
830 Expression::Eq(_, _, _) => Some(Domain::bool()),
831 Expression::Neq(_, _, _) => Some(Domain::bool()),
832 Expression::Geq(_, _, _) => Some(Domain::bool()),
833 Expression::Leq(_, _, _) => Some(Domain::bool()),
834 Expression::Gt(_, _, _) => Some(Domain::bool()),
835 Expression::Lt(_, _, _) => Some(Domain::bool()),
836 Expression::FlatAbsEq(_, _, _) => Some(Domain::bool()),
837 Expression::FlatSumGeq(_, _, _) => Some(Domain::bool()),
838 Expression::FlatSumLeq(_, _, _) => Some(Domain::bool()),
839 Expression::MinionDivEqUndefZero(_, _, _, _) => Some(Domain::bool()),
840 Expression::MinionModuloEqUndefZero(_, _, _, _) => Some(Domain::bool()),
841 Expression::FlatIneq(_, _, _, _) => Some(Domain::bool()),
842 Expression::Flatten(_, n, m) => {
843 if let Some(expr) = n {
844 if expr.return_type() == ReturnType::Int {
845 return None;
847 }
848 } else {
849 let dom = m.domain_of()?.resolve()?;
851 let (val_dom, idx_doms) = match dom.as_ref() {
852 GroundDomain::Matrix(val, idx) => (val, idx),
853 _ => return None,
854 };
855 let num_elems = matrix::num_elements(idx_doms).ok()? as i32;
856
857 let new_index_domain = Domain::int(vec![Range::Bounded(1, num_elems)]);
858 return Some(Domain::matrix(
859 val_dom.clone().into(),
860 vec![new_index_domain],
861 ));
862 }
863 None
864 }
865 Expression::AllDiff(_, _) => Some(Domain::bool()),
866 Expression::FlatWatchedLiteral(_, _, _) => Some(Domain::bool()),
867 Expression::MinionReify(_, _, _) => Some(Domain::bool()),
868 Expression::MinionReifyImply(_, _, _) => Some(Domain::bool()),
869 Expression::MinionWInIntervalSet(_, _, _) => Some(Domain::bool()),
870 Expression::MinionWInSet(_, _, _) => Some(Domain::bool()),
871 Expression::MinionElementOne(_, _, _, _) => Some(Domain::bool()),
872 Expression::Neg(_, x) => {
873 let dom = x.domain_of()?;
874 let mut ranges = dom.as_int()?;
875
876 ranges = ranges
877 .into_iter()
878 .map(|r| match r {
879 Range::Single(x) => Range::Single(-x),
880 Range::Bounded(x, y) => Range::Bounded(-y, -x),
881 Range::UnboundedR(i) => Range::UnboundedL(-i),
882 Range::UnboundedL(i) => Range::UnboundedR(-i),
883 Range::Unbounded => Range::Unbounded,
884 })
885 .collect();
886
887 Some(Domain::int(ranges))
888 }
889 Expression::Minus(_, a, b) => a
890 .domain_of()?
891 .resolve()?
892 .apply_i32(|x, y| Some(x - y), b.domain_of()?.resolve()?.as_ref())
893 .map(DomainPtr::from)
894 .ok(),
895 Expression::FlatAllDiff(_, _) => Some(Domain::bool()),
896 Expression::FlatMinusEq(_, _, _) => Some(Domain::bool()),
897 Expression::FlatProductEq(_, _, _, _) => Some(Domain::bool()),
898 Expression::FlatWeightedSumLeq(_, _, _, _) => Some(Domain::bool()),
899 Expression::FlatWeightedSumGeq(_, _, _, _) => Some(Domain::bool()),
900 Expression::Abs(_, a) => a
901 .domain_of()?
902 .resolve()?
903 .apply_i32(|a, _| Some(a.abs()), a.domain_of()?.resolve()?.as_ref())
904 .map(DomainPtr::from)
905 .ok(),
906 Expression::MinionPow(_, _, _, _) => Some(Domain::bool()),
907 Expression::ToInt(_, _) => Some(Domain::int(vec![Range::Bounded(0, 1)])),
908 Expression::SATInt(_, _, _, (low, high)) => {
909 Some(Domain::int_ground(vec![Range::Bounded(*low, *high)]))
910 }
911 Expression::PairwiseSum(_, a, b) => a
912 .domain_of()?
913 .resolve()?
914 .apply_i32(|a, b| Some(a + b), b.domain_of()?.resolve()?.as_ref())
915 .map(DomainPtr::from)
916 .ok(),
917 Expression::PairwiseProduct(_, a, b) => a
918 .domain_of()?
919 .resolve()?
920 .apply_i32(|a, b| Some(a * b), b.domain_of()?.resolve()?.as_ref())
921 .map(DomainPtr::from)
922 .ok(),
923 Expression::Defined(_, function) => get_function_domain(function),
924 Expression::Range(_, function) => get_function_codomain(function),
925 Expression::Image(_, function, _) => get_function_codomain(function),
926 Expression::ImageSet(_, function, _) => get_function_codomain(function),
927 Expression::PreImage(_, function, _) => get_function_domain(function),
928 Expression::Restrict(_, function, new_domain) => {
929 let (attrs, _, codom) = function.domain_of()?.as_function()?;
930 let new_dom = new_domain.domain_of()?;
931 Some(Domain::function(attrs, new_dom, codom))
932 }
933 Expression::Inverse(..) => Some(Domain::bool()),
934 Expression::LexLt(..) => Some(Domain::bool()),
935 Expression::LexLeq(..) => Some(Domain::bool()),
936 Expression::LexGt(..) => Some(Domain::bool()),
937 Expression::LexGeq(..) => Some(Domain::bool()),
938 Expression::FlatLexLt(..) => Some(Domain::bool()),
939 Expression::FlatLexLeq(..) => Some(Domain::bool()),
940 }
941 }
942
943 pub fn get_meta(&self) -> Metadata {
944 let metas: VecDeque<Metadata> = self.children_bi();
945 metas[0].clone()
946 }
947
948 pub fn set_meta(&self, meta: Metadata) {
949 self.transform_bi(&|_| meta.clone());
950 }
951
952 pub fn is_safe(&self) -> bool {
959 for expr in self.universe() {
961 match expr {
962 Expression::UnsafeDiv(_, _, _)
963 | Expression::UnsafeMod(_, _, _)
964 | Expression::UnsafePow(_, _, _)
965 | Expression::UnsafeIndex(_, _, _)
966 | Expression::Bubble(_, _, _)
967 | Expression::UnsafeSlice(_, _, _) => {
968 return false;
969 }
970 _ => {}
971 }
972 }
973 true
974 }
975
976 pub fn is_clean(&self) -> bool {
977 let metadata = self.get_meta();
978 metadata.clean
979 }
980
981 pub fn set_clean(&mut self, bool_value: bool) {
982 let mut metadata = self.get_meta();
983 metadata.clean = bool_value;
984 self.set_meta(metadata);
985 }
986
987 pub fn is_associative_commutative_operator(&self) -> bool {
989 TryInto::<ACOperatorKind>::try_into(self).is_ok()
990 }
991
992 pub fn is_matrix_literal(&self) -> bool {
997 matches!(
998 self,
999 Expression::AbstractLiteral(_, AbstractLiteral::Matrix(_, _))
1000 | Expression::Atomic(
1001 _,
1002 Atom::Literal(Literal::AbstractLiteral(AbstractLiteral::Matrix(_, _))),
1003 )
1004 )
1005 }
1006
1007 pub fn identical_atom_to(&self, other: &Expression) -> bool {
1013 let atom1: Result<&Atom, _> = self.try_into();
1014 let atom2: Result<&Atom, _> = other.try_into();
1015
1016 if let (Ok(atom1), Ok(atom2)) = (atom1, atom2) {
1017 atom2 == atom1
1018 } else {
1019 false
1020 }
1021 }
1022
1023 pub fn unwrap_list(&self) -> Option<Vec<Expression>> {
1028 match self {
1029 Expression::AbstractLiteral(_, matrix @ AbstractLiteral::Matrix(_, _)) => {
1030 matrix.unwrap_list().cloned()
1031 }
1032 Expression::Atomic(
1033 _,
1034 Atom::Literal(Literal::AbstractLiteral(matrix @ AbstractLiteral::Matrix(_, _))),
1035 ) => matrix.unwrap_list().map(|elems| {
1036 elems
1037 .clone()
1038 .into_iter()
1039 .map(|x: Literal| Expression::Atomic(Metadata::new(), Atom::Literal(x)))
1040 .collect_vec()
1041 }),
1042 _ => None,
1043 }
1044 }
1045
1046 pub fn unwrap_matrix_unchecked(self) -> Option<(Vec<Expression>, DomainPtr)> {
1054 match self {
1055 Expression::AbstractLiteral(_, AbstractLiteral::Matrix(elems, domain)) => {
1056 Some((elems, domain))
1057 }
1058 Expression::Atomic(
1059 _,
1060 Atom::Literal(Literal::AbstractLiteral(AbstractLiteral::Matrix(elems, domain))),
1061 ) => Some((
1062 elems
1063 .into_iter()
1064 .map(|x: Literal| Expression::Atomic(Metadata::new(), Atom::Literal(x)))
1065 .collect_vec(),
1066 domain.into(),
1067 )),
1068
1069 _ => None,
1070 }
1071 }
1072
1073 pub fn extend_root(self, exprs: Vec<Expression>) -> Expression {
1078 match self {
1079 Expression::Root(meta, mut children) => {
1080 children.extend(exprs);
1081 Expression::Root(meta, children)
1082 }
1083 _ => panic!("extend_root called on a non-Root expression"),
1084 }
1085 }
1086
1087 pub fn into_literal(self) -> Option<Literal> {
1089 match self {
1090 Expression::Atomic(_, Atom::Literal(lit)) => Some(lit),
1091 Expression::AbstractLiteral(_, abslit) => {
1092 Some(Literal::AbstractLiteral(abslit.into_literals()?))
1093 }
1094 Expression::Neg(_, e) => {
1095 let Literal::Int(i) = Moo::unwrap_or_clone(e).into_literal()? else {
1096 bug!("negated literal should be an int");
1097 };
1098
1099 Some(Literal::Int(-i))
1100 }
1101
1102 _ => None,
1103 }
1104 }
1105
1106 pub fn to_ac_operator_kind(&self) -> Option<ACOperatorKind> {
1108 TryFrom::try_from(self).ok()
1109 }
1110
1111 pub fn universe_categories(&self) -> HashSet<Category> {
1113 self.universe()
1114 .into_iter()
1115 .map(|x| x.category_of())
1116 .collect()
1117 }
1118}
1119
1120pub fn get_function_domain(function: &Moo<Expression>) -> Option<DomainPtr> {
1121 let function_domain = function.domain_of()?;
1122 match function_domain.resolve().as_ref() {
1123 Some(d) => {
1124 match d.as_ref() {
1125 GroundDomain::Function(_, domain, _) => Some(domain.clone().into()),
1126 _ => None,
1128 }
1129 }
1130 None => {
1131 match function_domain.as_unresolved()? {
1132 UnresolvedDomain::Function(_, domain, _) => Some(domain.clone()),
1133 _ => None,
1135 }
1136 }
1137 }
1138}
1139
1140pub fn get_function_codomain(function: &Moo<Expression>) -> Option<DomainPtr> {
1141 let function_domain = function.domain_of()?;
1142 match function_domain.resolve().as_ref() {
1143 Some(d) => {
1144 match d.as_ref() {
1145 GroundDomain::Function(_, _, codomain) => Some(codomain.clone().into()),
1146 _ => None,
1148 }
1149 }
1150 None => {
1151 match function_domain.as_unresolved()? {
1152 UnresolvedDomain::Function(_, _, codomain) => Some(codomain.clone()),
1153 _ => None,
1155 }
1156 }
1157 }
1158}
1159
1160impl TryFrom<&Expression> for i32 {
1161 type Error = ();
1162
1163 fn try_from(value: &Expression) -> Result<Self, Self::Error> {
1164 let Expression::Atomic(_, atom) = value else {
1165 return Err(());
1166 };
1167
1168 let Atom::Literal(lit) = atom else {
1169 return Err(());
1170 };
1171
1172 let Literal::Int(i) = lit else {
1173 return Err(());
1174 };
1175
1176 Ok(*i)
1177 }
1178}
1179
1180impl TryFrom<Expression> for i32 {
1181 type Error = ();
1182
1183 fn try_from(value: Expression) -> Result<Self, Self::Error> {
1184 TryFrom::<&Expression>::try_from(&value)
1185 }
1186}
1187impl From<i32> for Expression {
1188 fn from(i: i32) -> Self {
1189 Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Int(i)))
1190 }
1191}
1192
1193impl From<bool> for Expression {
1194 fn from(b: bool) -> Self {
1195 Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Bool(b)))
1196 }
1197}
1198
1199impl From<Atom> for Expression {
1200 fn from(value: Atom) -> Self {
1201 Expression::Atomic(Metadata::new(), value)
1202 }
1203}
1204
1205impl From<Literal> for Expression {
1206 fn from(value: Literal) -> Self {
1207 Expression::Atomic(Metadata::new(), value.into())
1208 }
1209}
1210
1211impl From<Moo<Expression>> for Expression {
1212 fn from(val: Moo<Expression>) -> Self {
1213 val.as_ref().clone()
1214 }
1215}
1216
1217impl CategoryOf for Expression {
1218 fn category_of(&self) -> Category {
1219 let category = self.cata(&move |x,children| {
1221
1222 if let Some(max_category) = children.iter().max() {
1223 *max_category
1226 } else {
1227 let mut max_category = Category::Bottom;
1229
1230 if !Biplate::<SubModel>::universe_bi(&x).is_empty() {
1237 return Category::Decision;
1239 }
1240
1241 if let Some(max_atom_category) = Biplate::<Atom>::universe_bi(&x).iter().map(|x| x.category_of()).max()
1243 && max_atom_category > max_category{
1245 max_category = max_atom_category;
1247 }
1248
1249 if let Some(max_declaration_category) = Biplate::<DeclarationPtr>::universe_bi(&x).iter().map(|x| x.category_of()).max()
1251 && max_declaration_category > max_category{
1253 max_category = max_declaration_category;
1255 }
1256 max_category
1257
1258 }
1259 });
1260
1261 if cfg!(debug_assertions) {
1262 trace!(
1263 category= %category,
1264 expression= %self,
1265 "Called Expression::category_of()"
1266 );
1267 };
1268 category
1269 }
1270}
1271
1272impl Display for Expression {
1273 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
1274 match &self {
1275 Expression::Union(_, box1, box2) => {
1276 write!(f, "({} union {})", box1.clone(), box2.clone())
1277 }
1278 Expression::In(_, e1, e2) => {
1279 write!(f, "{e1} in {e2}")
1280 }
1281 Expression::Intersect(_, box1, box2) => {
1282 write!(f, "({} intersect {})", box1.clone(), box2.clone())
1283 }
1284 Expression::Supset(_, box1, box2) => {
1285 write!(f, "({} supset {})", box1.clone(), box2.clone())
1286 }
1287 Expression::SupsetEq(_, box1, box2) => {
1288 write!(f, "({} supsetEq {})", box1.clone(), box2.clone())
1289 }
1290 Expression::Subset(_, box1, box2) => {
1291 write!(f, "({} subset {})", box1.clone(), box2.clone())
1292 }
1293 Expression::SubsetEq(_, box1, box2) => {
1294 write!(f, "({} subsetEq {})", box1.clone(), box2.clone())
1295 }
1296
1297 Expression::AbstractLiteral(_, l) => l.fmt(f),
1298 Expression::Comprehension(_, c) => c.fmt(f),
1299 Expression::AbstractComprehension(_, c) => c.fmt(f),
1300 Expression::UnsafeIndex(_, e1, e2) | Expression::SafeIndex(_, e1, e2) => {
1301 write!(f, "{e1}{}", pretty_vec(e2))
1302 }
1303 Expression::UnsafeSlice(_, e1, es) | Expression::SafeSlice(_, e1, es) => {
1304 let args = es
1305 .iter()
1306 .map(|x| match x {
1307 Some(x) => format!("{x}"),
1308 None => "..".into(),
1309 })
1310 .join(",");
1311
1312 write!(f, "{e1}[{args}]")
1313 }
1314 Expression::InDomain(_, e, domain) => {
1315 write!(f, "__inDomain({e},{domain})")
1316 }
1317 Expression::Root(_, exprs) => {
1318 write!(f, "{}", pretty_expressions_as_top_level(exprs))
1319 }
1320 Expression::DominanceRelation(_, expr) => write!(f, "DominanceRelation({expr})"),
1321 Expression::FromSolution(_, expr) => write!(f, "FromSolution({expr})"),
1322 Expression::Metavar(_, name) => write!(f, "&{name}"),
1323 Expression::Atomic(_, atom) => atom.fmt(f),
1324 Expression::Scope(_, submodel) => write!(f, "{{\n{submodel}\n}}"),
1325 Expression::Abs(_, a) => write!(f, "|{a}|"),
1326 Expression::Sum(_, e) => {
1327 write!(f, "sum({e})")
1328 }
1329 Expression::Product(_, e) => {
1330 write!(f, "product({e})")
1331 }
1332 Expression::Min(_, e) => {
1333 write!(f, "min({e})")
1334 }
1335 Expression::Max(_, e) => {
1336 write!(f, "max({e})")
1337 }
1338 Expression::Not(_, expr_box) => {
1339 write!(f, "!({})", expr_box.clone())
1340 }
1341 Expression::Or(_, e) => {
1342 write!(f, "or({e})")
1343 }
1344 Expression::And(_, e) => {
1345 write!(f, "and({e})")
1346 }
1347 Expression::Imply(_, box1, box2) => {
1348 write!(f, "({box1}) -> ({box2})")
1349 }
1350 Expression::Iff(_, box1, box2) => {
1351 write!(f, "({box1}) <-> ({box2})")
1352 }
1353 Expression::Eq(_, box1, box2) => {
1354 write!(f, "({} = {})", box1.clone(), box2.clone())
1355 }
1356 Expression::Neq(_, box1, box2) => {
1357 write!(f, "({} != {})", box1.clone(), box2.clone())
1358 }
1359 Expression::Geq(_, box1, box2) => {
1360 write!(f, "({} >= {})", box1.clone(), box2.clone())
1361 }
1362 Expression::Leq(_, box1, box2) => {
1363 write!(f, "({} <= {})", box1.clone(), box2.clone())
1364 }
1365 Expression::Gt(_, box1, box2) => {
1366 write!(f, "({} > {})", box1.clone(), box2.clone())
1367 }
1368 Expression::Lt(_, box1, box2) => {
1369 write!(f, "({} < {})", box1.clone(), box2.clone())
1370 }
1371 Expression::FlatSumGeq(_, box1, box2) => {
1372 write!(f, "SumGeq({}, {})", pretty_vec(box1), box2.clone())
1373 }
1374 Expression::FlatSumLeq(_, box1, box2) => {
1375 write!(f, "SumLeq({}, {})", pretty_vec(box1), box2.clone())
1376 }
1377 Expression::FlatIneq(_, box1, box2, box3) => write!(
1378 f,
1379 "Ineq({}, {}, {})",
1380 box1.clone(),
1381 box2.clone(),
1382 box3.clone()
1383 ),
1384 Expression::Flatten(_, n, m) => {
1385 if let Some(n) = n {
1386 write!(f, "flatten({n}, {m})")
1387 } else {
1388 write!(f, "flatten({m})")
1389 }
1390 }
1391 Expression::AllDiff(_, e) => {
1392 write!(f, "allDiff({e})")
1393 }
1394 Expression::Bubble(_, box1, box2) => {
1395 write!(f, "{{{} @ {}}}", box1.clone(), box2.clone())
1396 }
1397 Expression::SafeDiv(_, box1, box2) => {
1398 write!(f, "SafeDiv({}, {})", box1.clone(), box2.clone())
1399 }
1400 Expression::UnsafeDiv(_, box1, box2) => {
1401 write!(f, "UnsafeDiv({}, {})", box1.clone(), box2.clone())
1402 }
1403 Expression::UnsafePow(_, box1, box2) => {
1404 write!(f, "UnsafePow({}, {})", box1.clone(), box2.clone())
1405 }
1406 Expression::SafePow(_, box1, box2) => {
1407 write!(f, "SafePow({}, {})", box1.clone(), box2.clone())
1408 }
1409 Expression::MinionDivEqUndefZero(_, box1, box2, box3) => {
1410 write!(
1411 f,
1412 "DivEq({}, {}, {})",
1413 box1.clone(),
1414 box2.clone(),
1415 box3.clone()
1416 )
1417 }
1418 Expression::MinionModuloEqUndefZero(_, box1, box2, box3) => {
1419 write!(
1420 f,
1421 "ModEq({}, {}, {})",
1422 box1.clone(),
1423 box2.clone(),
1424 box3.clone()
1425 )
1426 }
1427 Expression::FlatWatchedLiteral(_, x, l) => {
1428 write!(f, "WatchedLiteral({x},{l})")
1429 }
1430 Expression::MinionReify(_, box1, box2) => {
1431 write!(f, "Reify({}, {})", box1.clone(), box2.clone())
1432 }
1433 Expression::MinionReifyImply(_, box1, box2) => {
1434 write!(f, "ReifyImply({}, {})", box1.clone(), box2.clone())
1435 }
1436 Expression::MinionWInIntervalSet(_, atom, intervals) => {
1437 let intervals = intervals.iter().join(",");
1438 write!(f, "__minion_w_inintervalset({atom},[{intervals}])")
1439 }
1440 Expression::MinionWInSet(_, atom, values) => {
1441 let values = values.iter().join(",");
1442 write!(f, "__minion_w_inset({atom},[{values}])")
1443 }
1444 Expression::AuxDeclaration(_, reference, e) => {
1445 write!(f, "{} =aux {}", reference, e.clone())
1446 }
1447 Expression::UnsafeMod(_, a, b) => {
1448 write!(f, "{} % {}", a.clone(), b.clone())
1449 }
1450 Expression::SafeMod(_, a, b) => {
1451 write!(f, "SafeMod({},{})", a.clone(), b.clone())
1452 }
1453 Expression::Neg(_, a) => {
1454 write!(f, "-({})", a.clone())
1455 }
1456 Expression::Minus(_, a, b) => {
1457 write!(f, "({} - {})", a.clone(), b.clone())
1458 }
1459 Expression::FlatAllDiff(_, es) => {
1460 write!(f, "__flat_alldiff({})", pretty_vec(es))
1461 }
1462 Expression::FlatAbsEq(_, a, b) => {
1463 write!(f, "AbsEq({},{})", a.clone(), b.clone())
1464 }
1465 Expression::FlatMinusEq(_, a, b) => {
1466 write!(f, "MinusEq({},{})", a.clone(), b.clone())
1467 }
1468 Expression::FlatProductEq(_, a, b, c) => {
1469 write!(
1470 f,
1471 "FlatProductEq({},{},{})",
1472 a.clone(),
1473 b.clone(),
1474 c.clone()
1475 )
1476 }
1477 Expression::FlatWeightedSumLeq(_, cs, vs, total) => {
1478 write!(
1479 f,
1480 "FlatWeightedSumLeq({},{},{})",
1481 pretty_vec(cs),
1482 pretty_vec(vs),
1483 total.clone()
1484 )
1485 }
1486 Expression::FlatWeightedSumGeq(_, cs, vs, total) => {
1487 write!(
1488 f,
1489 "FlatWeightedSumGeq({},{},{})",
1490 pretty_vec(cs),
1491 pretty_vec(vs),
1492 total.clone()
1493 )
1494 }
1495 Expression::MinionPow(_, atom, atom1, atom2) => {
1496 write!(f, "MinionPow({atom},{atom1},{atom2})")
1497 }
1498 Expression::MinionElementOne(_, atoms, atom, atom1) => {
1499 let atoms = atoms.iter().join(",");
1500 write!(f, "__minion_element_one([{atoms}],{atom},{atom1})")
1501 }
1502
1503 Expression::ToInt(_, expr) => {
1504 write!(f, "toInt({expr})")
1505 }
1506
1507 Expression::SATInt(_, encoding, bits, (min, max)) => {
1508 write!(f, "SATInt({encoding:?}, {bits} [{min}, {max}])")
1509 }
1510
1511 Expression::PairwiseSum(_, a, b) => write!(f, "PairwiseSum({a}, {b})"),
1512 Expression::PairwiseProduct(_, a, b) => write!(f, "PairwiseProduct({a}, {b})"),
1513
1514 Expression::Defined(_, function) => write!(f, "defined({function})"),
1515 Expression::Range(_, function) => write!(f, "range({function})"),
1516 Expression::Image(_, function, elems) => write!(f, "image({function},{elems})"),
1517 Expression::ImageSet(_, function, elems) => write!(f, "imageSet({function},{elems})"),
1518 Expression::PreImage(_, function, elems) => write!(f, "preImage({function},{elems})"),
1519 Expression::Inverse(_, a, b) => write!(f, "inverse({a},{b})"),
1520 Expression::Restrict(_, function, domain) => write!(f, "restrict({function},{domain})"),
1521
1522 Expression::LexLt(_, a, b) => write!(f, "({a} <lex {b})"),
1523 Expression::LexLeq(_, a, b) => write!(f, "({a} <=lex {b})"),
1524 Expression::LexGt(_, a, b) => write!(f, "({a} >lex {b})"),
1525 Expression::LexGeq(_, a, b) => write!(f, "({a} >=lex {b})"),
1526 Expression::FlatLexLt(_, a, b) => {
1527 write!(f, "FlatLexLt({}, {})", pretty_vec(a), pretty_vec(b))
1528 }
1529 Expression::FlatLexLeq(_, a, b) => {
1530 write!(f, "FlatLexLeq({}, {})", pretty_vec(a), pretty_vec(b))
1531 }
1532 }
1533 }
1534}
1535
1536impl Typeable for Expression {
1537 fn return_type(&self) -> ReturnType {
1538 match self {
1539 Expression::Union(_, subject, _) => ReturnType::Set(Box::new(subject.return_type())),
1540 Expression::Intersect(_, subject, _) => {
1541 ReturnType::Set(Box::new(subject.return_type()))
1542 }
1543 Expression::In(_, _, _) => ReturnType::Bool,
1544 Expression::Supset(_, _, _) => ReturnType::Bool,
1545 Expression::SupsetEq(_, _, _) => ReturnType::Bool,
1546 Expression::Subset(_, _, _) => ReturnType::Bool,
1547 Expression::SubsetEq(_, _, _) => ReturnType::Bool,
1548 Expression::AbstractLiteral(_, lit) => lit.return_type(),
1549 Expression::UnsafeIndex(_, subject, idx) | Expression::SafeIndex(_, subject, idx) => {
1550 let subject_ty = subject.return_type();
1551 match subject_ty {
1552 ReturnType::Matrix(_) => {
1553 let mut elem_typ = subject_ty;
1556 let mut idx_len = idx.len();
1557 while idx_len > 0
1558 && let ReturnType::Matrix(new_elem_typ) = &elem_typ
1559 {
1560 elem_typ = *new_elem_typ.clone();
1561 idx_len -= 1;
1562 }
1563 elem_typ
1564 }
1565 ReturnType::Record(_) | ReturnType::Tuple(_) => ReturnType::Unknown,
1567 _ => bug!(
1568 "Invalid indexing operation: expected the operand to be a collection, got {self}: {subject_ty}"
1569 ),
1570 }
1571 }
1572 Expression::UnsafeSlice(_, subject, _) | Expression::SafeSlice(_, subject, _) => {
1573 ReturnType::Matrix(Box::new(subject.return_type()))
1574 }
1575 Expression::InDomain(_, _, _) => ReturnType::Bool,
1576 Expression::Comprehension(_, comp) => comp.return_type(),
1577 Expression::AbstractComprehension(_, comp) => comp.return_type(),
1578 Expression::Root(_, _) => ReturnType::Bool,
1579 Expression::DominanceRelation(_, _) => ReturnType::Bool,
1580 Expression::FromSolution(_, expr) => expr.return_type(),
1581 Expression::Metavar(_, _) => ReturnType::Unknown,
1582 Expression::Atomic(_, atom) => atom.return_type(),
1583 Expression::Scope(_, scope) => scope.return_type(),
1584 Expression::Abs(_, _) => ReturnType::Int,
1585 Expression::Sum(_, _) => ReturnType::Int,
1586 Expression::Product(_, _) => ReturnType::Int,
1587 Expression::Min(_, _) => ReturnType::Int,
1588 Expression::Max(_, _) => ReturnType::Int,
1589 Expression::Not(_, _) => ReturnType::Bool,
1590 Expression::Or(_, _) => ReturnType::Bool,
1591 Expression::Imply(_, _, _) => ReturnType::Bool,
1592 Expression::Iff(_, _, _) => ReturnType::Bool,
1593 Expression::And(_, _) => ReturnType::Bool,
1594 Expression::Eq(_, _, _) => ReturnType::Bool,
1595 Expression::Neq(_, _, _) => ReturnType::Bool,
1596 Expression::Geq(_, _, _) => ReturnType::Bool,
1597 Expression::Leq(_, _, _) => ReturnType::Bool,
1598 Expression::Gt(_, _, _) => ReturnType::Bool,
1599 Expression::Lt(_, _, _) => ReturnType::Bool,
1600 Expression::SafeDiv(_, _, _) => ReturnType::Int,
1601 Expression::UnsafeDiv(_, _, _) => ReturnType::Int,
1602 Expression::FlatAllDiff(_, _) => ReturnType::Bool,
1603 Expression::FlatSumGeq(_, _, _) => ReturnType::Bool,
1604 Expression::FlatSumLeq(_, _, _) => ReturnType::Bool,
1605 Expression::MinionDivEqUndefZero(_, _, _, _) => ReturnType::Bool,
1606 Expression::FlatIneq(_, _, _, _) => ReturnType::Bool,
1607 Expression::Flatten(_, _, matrix) => {
1608 let matrix_type = matrix.return_type();
1609 match matrix_type {
1610 ReturnType::Matrix(_) => {
1611 let mut elem_type = matrix_type;
1613 while let ReturnType::Matrix(new_elem_type) = &elem_type {
1614 elem_type = *new_elem_type.clone();
1615 }
1616 ReturnType::Matrix(Box::new(elem_type))
1617 }
1618 _ => bug!(
1619 "Invalid indexing operation: expected the operand to be a collection, got {self}: {matrix_type}"
1620 ),
1621 }
1622 }
1623 Expression::AllDiff(_, _) => ReturnType::Bool,
1624 Expression::Bubble(_, inner, _) => inner.return_type(),
1625 Expression::FlatWatchedLiteral(_, _, _) => ReturnType::Bool,
1626 Expression::MinionReify(_, _, _) => ReturnType::Bool,
1627 Expression::MinionReifyImply(_, _, _) => ReturnType::Bool,
1628 Expression::MinionWInIntervalSet(_, _, _) => ReturnType::Bool,
1629 Expression::MinionWInSet(_, _, _) => ReturnType::Bool,
1630 Expression::MinionElementOne(_, _, _, _) => ReturnType::Bool,
1631 Expression::AuxDeclaration(_, _, _) => ReturnType::Bool,
1632 Expression::UnsafeMod(_, _, _) => ReturnType::Int,
1633 Expression::SafeMod(_, _, _) => ReturnType::Int,
1634 Expression::MinionModuloEqUndefZero(_, _, _, _) => ReturnType::Bool,
1635 Expression::Neg(_, _) => ReturnType::Int,
1636 Expression::UnsafePow(_, _, _) => ReturnType::Int,
1637 Expression::SafePow(_, _, _) => ReturnType::Int,
1638 Expression::Minus(_, _, _) => ReturnType::Int,
1639 Expression::FlatAbsEq(_, _, _) => ReturnType::Bool,
1640 Expression::FlatMinusEq(_, _, _) => ReturnType::Bool,
1641 Expression::FlatProductEq(_, _, _, _) => ReturnType::Bool,
1642 Expression::FlatWeightedSumLeq(_, _, _, _) => ReturnType::Bool,
1643 Expression::FlatWeightedSumGeq(_, _, _, _) => ReturnType::Bool,
1644 Expression::MinionPow(_, _, _, _) => ReturnType::Bool,
1645 Expression::ToInt(_, _) => ReturnType::Int,
1646 Expression::SATInt(..) => ReturnType::Int,
1647 Expression::PairwiseSum(_, _, _) => ReturnType::Int,
1648 Expression::PairwiseProduct(_, _, _) => ReturnType::Int,
1649 Expression::Defined(_, function) => {
1650 let subject = function.return_type();
1651 match subject {
1652 ReturnType::Function(domain, _) => *domain,
1653 _ => bug!(
1654 "Invalid defined operation: expected the operand to be a function, got {self}: {subject}"
1655 ),
1656 }
1657 }
1658 Expression::Range(_, function) => {
1659 let subject = function.return_type();
1660 match subject {
1661 ReturnType::Function(_, codomain) => *codomain,
1662 _ => bug!(
1663 "Invalid range operation: expected the operand to be a function, got {self}: {subject}"
1664 ),
1665 }
1666 }
1667 Expression::Image(_, function, _) => {
1668 let subject = function.return_type();
1669 match subject {
1670 ReturnType::Function(_, codomain) => *codomain,
1671 _ => bug!(
1672 "Invalid image operation: expected the operand to be a function, got {self}: {subject}"
1673 ),
1674 }
1675 }
1676 Expression::ImageSet(_, function, _) => {
1677 let subject = function.return_type();
1678 match subject {
1679 ReturnType::Function(_, codomain) => *codomain,
1680 _ => bug!(
1681 "Invalid imageSet operation: expected the operand to be a function, got {self}: {subject}"
1682 ),
1683 }
1684 }
1685 Expression::PreImage(_, function, _) => {
1686 let subject = function.return_type();
1687 match subject {
1688 ReturnType::Function(domain, _) => *domain,
1689 _ => bug!(
1690 "Invalid preImage operation: expected the operand to be a function, got {self}: {subject}"
1691 ),
1692 }
1693 }
1694 Expression::Restrict(_, function, new_domain) => {
1695 let subject = function.return_type();
1696 match subject {
1697 ReturnType::Function(_, codomain) => {
1698 ReturnType::Function(Box::new(new_domain.return_type()), codomain)
1699 }
1700 _ => bug!(
1701 "Invalid preImage operation: expected the operand to be a function, got {self}: {subject}"
1702 ),
1703 }
1704 }
1705 Expression::Inverse(..) => ReturnType::Bool,
1706 Expression::LexLt(..) => ReturnType::Bool,
1707 Expression::LexGt(..) => ReturnType::Bool,
1708 Expression::LexLeq(..) => ReturnType::Bool,
1709 Expression::LexGeq(..) => ReturnType::Bool,
1710 Expression::FlatLexLt(..) => ReturnType::Bool,
1711 Expression::FlatLexLeq(..) => ReturnType::Bool,
1712 }
1713 }
1714}
1715
1716#[cfg(test)]
1717mod tests {
1718 use crate::matrix_expr;
1719
1720 use super::*;
1721
1722 #[test]
1723 fn test_domain_of_constant_sum() {
1724 let c1 = Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Int(1)));
1725 let c2 = Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Int(2)));
1726 let sum = Expression::Sum(Metadata::new(), Moo::new(matrix_expr![c1, c2]));
1727 assert_eq!(sum.domain_of(), Some(Domain::int(vec![Range::Single(3)])));
1728 }
1729
1730 #[test]
1731 fn test_domain_of_constant_invalid_type() {
1732 let c1 = Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Int(1)));
1733 let c2 = Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Bool(true)));
1734 let sum = Expression::Sum(Metadata::new(), Moo::new(matrix_expr![c1, c2]));
1735 assert_eq!(sum.domain_of(), None);
1736 }
1737
1738 #[test]
1739 fn test_domain_of_empty_sum() {
1740 let sum = Expression::Sum(Metadata::new(), Moo::new(matrix_expr![]));
1741 assert_eq!(sum.domain_of(), None);
1742 }
1743}