1use std::collections::{HashSet, VecDeque};
2use std::fmt::{Display, Formatter};
3use tracing::trace;
4
5use conjure_cp_enum_compatibility_macro::document_compatibility;
6use itertools::Itertools;
7use serde::{Deserialize, Serialize};
8use ustr::Ustr;
9
10use polyquine::Quine;
11use uniplate::{Biplate, Uniplate};
12
13use crate::bug;
14
15use super::abstract_comprehension::AbstractComprehension;
16use super::ac_operators::ACOperatorKind;
17use super::categories::{Category, CategoryOf};
18use super::comprehension::Comprehension;
19use super::domains::HasDomain as _;
20use super::pretty::{pretty_expressions_as_top_level, pretty_vec};
21use super::records::RecordValue;
22use super::sat_encoding::SATIntEncoding;
23use super::{
24 AbstractLiteral, Atom, DeclarationPtr, Domain, DomainPtr, GroundDomain, IntVal, Literal,
25 Metadata, Model, Moo, Name, Range, Reference, ReturnType, SetAttr, SymbolTable, SymbolTablePtr,
26 Typeable, UnresolvedDomain, matrix,
27};
28
29static_assertions::assert_eq_size!([u8; 112], Expression);
52
53#[document_compatibility]
58#[derive(Clone, Debug, Hash, PartialEq, Eq, Serialize, Deserialize, Uniplate, Quine)]
59#[biplate(to=AbstractComprehension)]
60#[biplate(to=AbstractLiteral<Expression>)]
61#[biplate(to=AbstractLiteral<Literal>)]
62#[biplate(to=Atom)]
63#[biplate(to=Comprehension)]
64#[biplate(to=DeclarationPtr)]
65#[biplate(to=DomainPtr)]
66#[biplate(to=Literal)]
67#[biplate(to=Metadata)]
68#[biplate(to=Name)]
69#[biplate(to=Option<Expression>)]
70#[biplate(to=RecordValue<Expression>)]
71#[biplate(to=RecordValue<Literal>)]
72#[biplate(to=Reference)]
73#[biplate(to=Model)]
74#[biplate(to=SymbolTable)]
75#[biplate(to=SymbolTablePtr)]
76#[biplate(to=Vec<Expression>)]
77#[path_prefix(conjure_cp::ast)]
78pub enum Expression {
79 AbstractLiteral(Metadata, AbstractLiteral<Expression>),
80 Root(Metadata, Vec<Expression>),
82
83 Bubble(Metadata, Moo<Expression>, Moo<Expression>),
86
87 #[polyquine_skip]
93 Comprehension(Metadata, Moo<Comprehension>),
94
95 #[polyquine_skip] AbstractComprehension(Metadata, Moo<AbstractComprehension>),
98
99 DominanceRelation(Metadata, Moo<Expression>),
101 FromSolution(Metadata, Moo<Atom>),
103
104 #[polyquine_with(arm = (_, name) => {
105 let ident = proc_macro2::Ident::new(name.as_str(), proc_macro2::Span::call_site());
106 quote::quote! { #ident.clone().into() }
107 })]
108 Metavar(Metadata, Ustr),
109
110 Atomic(Metadata, Atom),
111
112 #[compatible(JsonInput)]
116 UnsafeIndex(Metadata, Moo<Expression>, Vec<Expression>),
117
118 #[compatible(SMT)]
122 SafeIndex(Metadata, Moo<Expression>, Vec<Expression>),
123
124 #[compatible(JsonInput)]
134 UnsafeSlice(Metadata, Moo<Expression>, Vec<Option<Expression>>),
135
136 SafeSlice(Metadata, Moo<Expression>, Vec<Option<Expression>>),
140
141 InDomain(Metadata, Moo<Expression>, DomainPtr),
147
148 #[compatible(SMT)]
154 ToInt(Metadata, Moo<Expression>),
155
156 #[compatible(JsonInput, SMT)]
158 Abs(Metadata, Moo<Expression>),
159
160 #[compatible(JsonInput, SMT)]
162 Sum(Metadata, Moo<Expression>),
163
164 #[compatible(JsonInput, SMT)]
166 Product(Metadata, Moo<Expression>),
167
168 #[compatible(JsonInput, SMT)]
170 Min(Metadata, Moo<Expression>),
171
172 #[compatible(JsonInput, SMT)]
174 Max(Metadata, Moo<Expression>),
175
176 #[compatible(JsonInput, SAT, SMT)]
178 Not(Metadata, Moo<Expression>),
179
180 #[compatible(JsonInput, SAT, SMT)]
182 Or(Metadata, Moo<Expression>),
183
184 #[compatible(JsonInput, SAT, SMT)]
186 And(Metadata, Moo<Expression>),
187
188 #[compatible(JsonInput, SMT)]
190 Imply(Metadata, Moo<Expression>, Moo<Expression>),
191
192 #[compatible(JsonInput, SMT)]
194 Iff(Metadata, Moo<Expression>, Moo<Expression>),
195
196 #[compatible(JsonInput)]
197 Union(Metadata, Moo<Expression>, Moo<Expression>),
198
199 #[compatible(JsonInput)]
200 In(Metadata, Moo<Expression>, Moo<Expression>),
201
202 #[compatible(JsonInput)]
203 Intersect(Metadata, Moo<Expression>, Moo<Expression>),
204
205 #[compatible(JsonInput)]
206 Supset(Metadata, Moo<Expression>, Moo<Expression>),
207
208 #[compatible(JsonInput)]
209 SupsetEq(Metadata, Moo<Expression>, Moo<Expression>),
210
211 #[compatible(JsonInput)]
212 Subset(Metadata, Moo<Expression>, Moo<Expression>),
213
214 #[compatible(JsonInput)]
215 SubsetEq(Metadata, Moo<Expression>, Moo<Expression>),
216
217 #[compatible(JsonInput, SMT)]
218 Eq(Metadata, Moo<Expression>, Moo<Expression>),
219
220 #[compatible(JsonInput, SMT)]
221 Neq(Metadata, Moo<Expression>, Moo<Expression>),
222
223 #[compatible(JsonInput, SMT)]
224 Geq(Metadata, Moo<Expression>, Moo<Expression>),
225
226 #[compatible(JsonInput, SMT)]
227 Leq(Metadata, Moo<Expression>, Moo<Expression>),
228
229 #[compatible(JsonInput, SMT)]
230 Gt(Metadata, Moo<Expression>, Moo<Expression>),
231
232 #[compatible(JsonInput, SMT)]
233 Lt(Metadata, Moo<Expression>, Moo<Expression>),
234
235 #[compatible(SMT)]
237 SafeDiv(Metadata, Moo<Expression>, Moo<Expression>),
238
239 #[compatible(JsonInput)]
241 UnsafeDiv(Metadata, Moo<Expression>, Moo<Expression>),
242
243 #[compatible(SMT)]
245 SafeMod(Metadata, Moo<Expression>, Moo<Expression>),
246
247 #[compatible(JsonInput)]
249 UnsafeMod(Metadata, Moo<Expression>, Moo<Expression>),
250
251 #[compatible(JsonInput, SMT)]
253 Neg(Metadata, Moo<Expression>),
254
255 #[compatible(JsonInput)]
257 Defined(Metadata, Moo<Expression>),
258
259 #[compatible(JsonInput)]
261 Range(Metadata, Moo<Expression>),
262
263 #[compatible(JsonInput)]
267 UnsafePow(Metadata, Moo<Expression>, Moo<Expression>),
268
269 SafePow(Metadata, Moo<Expression>, Moo<Expression>),
271
272 Flatten(Metadata, Option<Moo<Expression>>, Moo<Expression>),
276
277 #[compatible(JsonInput)]
279 AllDiff(Metadata, Moo<Expression>),
280
281 #[compatible(JsonInput)]
286 Table(Metadata, Moo<Expression>, Moo<Expression>),
287
288 #[compatible(JsonInput)]
293 NegativeTable(Metadata, Moo<Expression>, Moo<Expression>),
294 #[compatible(JsonInput)]
300 Minus(Metadata, Moo<Expression>, Moo<Expression>),
301
302 #[compatible(Minion)]
310 FlatAbsEq(Metadata, Moo<Atom>, Moo<Atom>),
311
312 #[compatible(Minion)]
320 FlatAllDiff(Metadata, Vec<Atom>),
321
322 #[compatible(Minion)]
330 FlatSumGeq(Metadata, Vec<Atom>, Atom),
331
332 #[compatible(Minion)]
340 FlatSumLeq(Metadata, Vec<Atom>, Atom),
341
342 #[compatible(Minion)]
350 FlatIneq(Metadata, Moo<Atom>, Moo<Atom>, Box<Literal>),
351
352 #[compatible(Minion)]
365 #[polyquine_skip]
366 FlatWatchedLiteral(Metadata, Reference, Literal),
367
368 FlatWeightedSumLeq(Metadata, Vec<Literal>, Vec<Atom>, Moo<Atom>),
380
381 FlatWeightedSumGeq(Metadata, Vec<Literal>, Vec<Atom>, Moo<Atom>),
393
394 #[compatible(Minion)]
402 FlatMinusEq(Metadata, Moo<Atom>, Moo<Atom>),
403
404 #[compatible(Minion)]
412 FlatProductEq(Metadata, Moo<Atom>, Moo<Atom>, Moo<Atom>),
413
414 #[compatible(Minion)]
422 MinionDivEqUndefZero(Metadata, Moo<Atom>, Moo<Atom>, Moo<Atom>),
423
424 #[compatible(Minion)]
432 MinionModuloEqUndefZero(Metadata, Moo<Atom>, Moo<Atom>, Moo<Atom>),
433
434 MinionPow(Metadata, Moo<Atom>, Moo<Atom>, Moo<Atom>),
446
447 #[compatible(Minion)]
456 MinionReify(Metadata, Moo<Expression>, Atom),
457
458 #[compatible(Minion)]
467 MinionReifyImply(Metadata, Moo<Expression>, Atom),
468
469 #[compatible(Minion)]
480 MinionWInIntervalSet(Metadata, Atom, Vec<i32>),
481
482 #[compatible(Minion)]
494 MinionWInSet(Metadata, Atom, Vec<i32>),
495
496 #[compatible(Minion)]
505 MinionElementOne(Metadata, Vec<Atom>, Moo<Atom>, Moo<Atom>),
506
507 #[compatible(Minion)]
511 #[polyquine_skip]
512 AuxDeclaration(Metadata, Reference, Moo<Expression>),
513
514 #[compatible(SAT)]
516 SATInt(Metadata, SATIntEncoding, Moo<Expression>, (i32, i32)),
517
518 #[compatible(SMT)]
521 PairwiseSum(Metadata, Moo<Expression>, Moo<Expression>),
522
523 #[compatible(SMT)]
526 PairwiseProduct(Metadata, Moo<Expression>, Moo<Expression>),
527
528 #[compatible(JsonInput)]
529 Image(Metadata, Moo<Expression>, Moo<Expression>),
530
531 #[compatible(JsonInput)]
532 ImageSet(Metadata, Moo<Expression>, Moo<Expression>),
533
534 #[compatible(JsonInput)]
535 PreImage(Metadata, Moo<Expression>, Moo<Expression>),
536
537 #[compatible(JsonInput)]
538 Inverse(Metadata, Moo<Expression>, Moo<Expression>),
539
540 #[compatible(JsonInput)]
541 Restrict(Metadata, Moo<Expression>, Moo<Expression>),
542
543 LexLt(Metadata, Moo<Expression>, Moo<Expression>),
552
553 LexLeq(Metadata, Moo<Expression>, Moo<Expression>),
555
556 LexGt(Metadata, Moo<Expression>, Moo<Expression>),
559
560 LexGeq(Metadata, Moo<Expression>, Moo<Expression>),
563
564 FlatLexLt(Metadata, Vec<Atom>, Vec<Atom>),
566
567 FlatLexLeq(Metadata, Vec<Atom>, Vec<Atom>),
569}
570
571fn bounded_i32_domain_for_matrix_literal_monotonic(
578 e: &Expression,
579 op: fn(i32, i32) -> Option<i32>,
580) -> Option<DomainPtr> {
581 let (mut exprs, _) = e.clone().unwrap_matrix_unchecked()?;
583
584 let expr = exprs.pop()?;
600 let dom = expr.domain_of()?;
601 let resolved = dom.resolve()?;
602 let GroundDomain::Int(ranges) = resolved.as_ref() else {
603 return None;
604 };
605
606 let (mut current_min, mut current_max) = range_vec_bounds_i32(ranges)?;
607
608 for expr in exprs {
609 let dom = expr.domain_of()?;
610 let resolved = dom.resolve()?;
611 let GroundDomain::Int(ranges) = resolved.as_ref() else {
612 return None;
613 };
614
615 let (min, max) = range_vec_bounds_i32(ranges)?;
616
617 let minmax = op(min, current_max)?;
619 let minmin = op(min, current_min)?;
620 let maxmin = op(max, current_min)?;
621 let maxmax = op(max, current_max)?;
622 let vals = [minmax, minmin, maxmin, maxmax];
623
624 current_min = *vals
625 .iter()
626 .min()
627 .expect("vals iterator should not be empty, and should have a minimum.");
628 current_max = *vals
629 .iter()
630 .max()
631 .expect("vals iterator should not be empty, and should have a maximum.");
632 }
633
634 if current_min == current_max {
635 Some(Domain::int(vec![Range::Single(current_min)]))
636 } else {
637 Some(Domain::int(vec![Range::Bounded(current_min, current_max)]))
638 }
639}
640
641fn range_vec_bounds_i32(ranges: &Vec<Range<i32>>) -> Option<(i32, i32)> {
643 let mut min = i32::MAX;
644 let mut max = i32::MIN;
645 for r in ranges {
646 match r {
647 Range::Single(i) => {
648 if *i < min {
649 min = *i;
650 }
651 if *i > max {
652 max = *i;
653 }
654 }
655 Range::Bounded(i, j) => {
656 if *i < min {
657 min = *i;
658 }
659 if *j > max {
660 max = *j;
661 }
662 }
663 Range::UnboundedR(_) | Range::UnboundedL(_) | Range::Unbounded => return None,
664 }
665 }
666 Some((min, max))
667}
668
669impl Expression {
670 pub fn domain_of(&self) -> Option<DomainPtr> {
672 match self {
673 Expression::Union(_, a, b) => Some(Domain::set(
674 SetAttr::<IntVal>::default(),
675 a.domain_of()?.union(&b.domain_of()?).ok()?,
676 )),
677 Expression::Intersect(_, a, b) => Some(Domain::set(
678 SetAttr::<IntVal>::default(),
679 a.domain_of()?.intersect(&b.domain_of()?).ok()?,
680 )),
681 Expression::In(_, _, _) => Some(Domain::bool()),
682 Expression::Supset(_, _, _) => Some(Domain::bool()),
683 Expression::SupsetEq(_, _, _) => Some(Domain::bool()),
684 Expression::Subset(_, _, _) => Some(Domain::bool()),
685 Expression::SubsetEq(_, _, _) => Some(Domain::bool()),
686 Expression::AbstractLiteral(_, abslit) => abslit.domain_of(),
687 Expression::DominanceRelation(_, _) => Some(Domain::bool()),
688 Expression::FromSolution(_, expr) => Some(expr.domain_of()),
689 Expression::Metavar(_, _) => None,
690 Expression::Comprehension(_, comprehension) => comprehension.domain_of(),
691 Expression::AbstractComprehension(_, comprehension) => comprehension.domain_of(),
692 Expression::UnsafeIndex(_, matrix, _) | Expression::SafeIndex(_, matrix, _) => {
693 let dom = matrix.domain_of()?;
694 if let Some((elem_domain, _)) = dom.as_matrix() {
695 return Some(elem_domain);
696 }
697
698 #[allow(clippy::redundant_pattern_matching)]
700 if let Some(_) = dom.as_tuple() {
701 return None;
703 }
704
705 #[allow(clippy::redundant_pattern_matching)]
707 if let Some(_) = dom.as_record() {
708 return None;
710 }
711
712 bug!("subject of an index operation should support indexing")
713 }
714 Expression::UnsafeSlice(_, matrix, indices)
715 | Expression::SafeSlice(_, matrix, indices) => {
716 let sliced_dimension = indices.iter().position(Option::is_none);
717
718 let dom = matrix.domain_of()?;
719 let Some((elem_domain, index_domains)) = dom.as_matrix() else {
720 bug!("subject of an index operation should be a matrix");
721 };
722
723 match sliced_dimension {
724 Some(dimension) => Some(Domain::matrix(
725 elem_domain,
726 vec![index_domains[dimension].clone()],
727 )),
728
729 None => Some(elem_domain),
731 }
732 }
733 Expression::InDomain(_, _, _) => Some(Domain::bool()),
734 Expression::Atomic(_, atom) => Some(atom.domain_of()),
735 Expression::Sum(_, e) => {
736 bounded_i32_domain_for_matrix_literal_monotonic(e, |x, y| Some(x + y))
737 }
738 Expression::Product(_, e) => {
739 bounded_i32_domain_for_matrix_literal_monotonic(e, |x, y| Some(x * y))
740 }
741 Expression::Min(_, e) => bounded_i32_domain_for_matrix_literal_monotonic(e, |x, y| {
742 Some(if x < y { x } else { y })
743 }),
744 Expression::Max(_, e) => bounded_i32_domain_for_matrix_literal_monotonic(e, |x, y| {
745 Some(if x > y { x } else { y })
746 }),
747 Expression::UnsafeDiv(_, a, b) => a
748 .domain_of()?
749 .resolve()?
750 .apply_i32(
751 |x, y| {
754 if y != 0 {
755 Some((x as f32 / y as f32).floor() as i32)
756 } else {
757 None
758 }
759 },
760 b.domain_of()?.resolve()?.as_ref(),
761 )
762 .map(DomainPtr::from)
763 .ok(),
764 Expression::SafeDiv(_, a, b) => {
765 let domain = a
768 .domain_of()?
769 .resolve()?
770 .apply_i32(
771 |x, y| {
772 if y != 0 {
773 Some((x as f32 / y as f32).floor() as i32)
774 } else {
775 None
776 }
777 },
778 b.domain_of()?.resolve()?.as_ref(),
779 )
780 .unwrap_or_else(|err| bug!("Got {err} when computing domain of {self}"));
781
782 if let GroundDomain::Int(ranges) = domain {
783 let mut ranges = ranges;
784 ranges.push(Range::Single(0));
785 Some(Domain::int(ranges))
786 } else {
787 bug!("Domain of {self} was not integer")
788 }
789 }
790 Expression::UnsafeMod(_, a, b) => a
791 .domain_of()?
792 .resolve()?
793 .apply_i32(
794 |x, y| if y != 0 { Some(x % y) } else { None },
795 b.domain_of()?.resolve()?.as_ref(),
796 )
797 .map(DomainPtr::from)
798 .ok(),
799 Expression::SafeMod(_, a, b) => {
800 let domain = a
801 .domain_of()?
802 .resolve()?
803 .apply_i32(
804 |x, y| if y != 0 { Some(x % y) } else { None },
805 b.domain_of()?.resolve()?.as_ref(),
806 )
807 .unwrap_or_else(|err| bug!("Got {err} when computing domain of {self}"));
808
809 if let GroundDomain::Int(ranges) = domain {
810 let mut ranges = ranges;
811 ranges.push(Range::Single(0));
812 Some(Domain::int(ranges))
813 } else {
814 bug!("Domain of {self} was not integer")
815 }
816 }
817 Expression::SafePow(_, a, b) | Expression::UnsafePow(_, a, b) => a
818 .domain_of()?
819 .resolve()?
820 .apply_i32(
821 |x, y| {
822 if (x != 0 || y != 0) && y >= 0 {
823 Some(x.pow(y as u32))
824 } else {
825 None
826 }
827 },
828 b.domain_of()?.resolve()?.as_ref(),
829 )
830 .map(DomainPtr::from)
831 .ok(),
832 Expression::Root(_, _) => None,
833 Expression::Bubble(_, inner, _) => inner.domain_of(),
834 Expression::AuxDeclaration(_, _, _) => Some(Domain::bool()),
835 Expression::And(_, _) => Some(Domain::bool()),
836 Expression::Not(_, _) => Some(Domain::bool()),
837 Expression::Or(_, _) => Some(Domain::bool()),
838 Expression::Imply(_, _, _) => Some(Domain::bool()),
839 Expression::Iff(_, _, _) => Some(Domain::bool()),
840 Expression::Eq(_, _, _) => Some(Domain::bool()),
841 Expression::Neq(_, _, _) => Some(Domain::bool()),
842 Expression::Geq(_, _, _) => Some(Domain::bool()),
843 Expression::Leq(_, _, _) => Some(Domain::bool()),
844 Expression::Gt(_, _, _) => Some(Domain::bool()),
845 Expression::Lt(_, _, _) => Some(Domain::bool()),
846 Expression::FlatAbsEq(_, _, _) => Some(Domain::bool()),
847 Expression::FlatSumGeq(_, _, _) => Some(Domain::bool()),
848 Expression::FlatSumLeq(_, _, _) => Some(Domain::bool()),
849 Expression::MinionDivEqUndefZero(_, _, _, _) => Some(Domain::bool()),
850 Expression::MinionModuloEqUndefZero(_, _, _, _) => Some(Domain::bool()),
851 Expression::FlatIneq(_, _, _, _) => Some(Domain::bool()),
852 Expression::Flatten(_, n, m) => {
853 if let Some(expr) = n {
854 if expr.return_type() == ReturnType::Int {
855 return None;
857 }
858 } else {
859 let dom = m.domain_of()?.resolve()?;
861 let (val_dom, idx_doms) = match dom.as_ref() {
862 GroundDomain::Matrix(val, idx) => (val, idx),
863 _ => return None,
864 };
865 let num_elems = matrix::num_elements(idx_doms).ok()? as i32;
866
867 let new_index_domain = Domain::int(vec![Range::Bounded(1, num_elems)]);
868 return Some(Domain::matrix(
869 val_dom.clone().into(),
870 vec![new_index_domain],
871 ));
872 }
873 None
874 }
875 Expression::AllDiff(_, _) => Some(Domain::bool()),
876 Expression::Table(_, _, _) => Some(Domain::bool()),
877 Expression::NegativeTable(_, _, _) => Some(Domain::bool()),
878 Expression::FlatWatchedLiteral(_, _, _) => Some(Domain::bool()),
879 Expression::MinionReify(_, _, _) => Some(Domain::bool()),
880 Expression::MinionReifyImply(_, _, _) => Some(Domain::bool()),
881 Expression::MinionWInIntervalSet(_, _, _) => Some(Domain::bool()),
882 Expression::MinionWInSet(_, _, _) => Some(Domain::bool()),
883 Expression::MinionElementOne(_, _, _, _) => Some(Domain::bool()),
884 Expression::Neg(_, x) => {
885 let dom = x.domain_of()?;
886 let mut ranges = dom.as_int()?;
887
888 ranges = ranges
889 .into_iter()
890 .map(|r| match r {
891 Range::Single(x) => Range::Single(-x),
892 Range::Bounded(x, y) => Range::Bounded(-y, -x),
893 Range::UnboundedR(i) => Range::UnboundedL(-i),
894 Range::UnboundedL(i) => Range::UnboundedR(-i),
895 Range::Unbounded => Range::Unbounded,
896 })
897 .collect();
898
899 Some(Domain::int(ranges))
900 }
901 Expression::Minus(_, a, b) => a
902 .domain_of()?
903 .resolve()?
904 .apply_i32(|x, y| Some(x - y), b.domain_of()?.resolve()?.as_ref())
905 .map(DomainPtr::from)
906 .ok(),
907 Expression::FlatAllDiff(_, _) => Some(Domain::bool()),
908 Expression::FlatMinusEq(_, _, _) => Some(Domain::bool()),
909 Expression::FlatProductEq(_, _, _, _) => Some(Domain::bool()),
910 Expression::FlatWeightedSumLeq(_, _, _, _) => Some(Domain::bool()),
911 Expression::FlatWeightedSumGeq(_, _, _, _) => Some(Domain::bool()),
912 Expression::Abs(_, a) => a
913 .domain_of()?
914 .resolve()?
915 .apply_i32(|a, _| Some(a.abs()), a.domain_of()?.resolve()?.as_ref())
916 .map(DomainPtr::from)
917 .ok(),
918 Expression::MinionPow(_, _, _, _) => Some(Domain::bool()),
919 Expression::ToInt(_, _) => Some(Domain::int(vec![Range::Bounded(0, 1)])),
920 Expression::SATInt(_, _, _, (low, high)) => {
921 Some(Domain::int_ground(vec![Range::Bounded(*low, *high)]))
922 }
923 Expression::PairwiseSum(_, a, b) => a
924 .domain_of()?
925 .resolve()?
926 .apply_i32(|a, b| Some(a + b), b.domain_of()?.resolve()?.as_ref())
927 .map(DomainPtr::from)
928 .ok(),
929 Expression::PairwiseProduct(_, a, b) => a
930 .domain_of()?
931 .resolve()?
932 .apply_i32(|a, b| Some(a * b), b.domain_of()?.resolve()?.as_ref())
933 .map(DomainPtr::from)
934 .ok(),
935 Expression::Defined(_, function) => get_function_domain(function),
936 Expression::Range(_, function) => get_function_codomain(function),
937 Expression::Image(_, function, _) => get_function_codomain(function),
938 Expression::ImageSet(_, function, _) => get_function_codomain(function),
939 Expression::PreImage(_, function, _) => get_function_domain(function),
940 Expression::Restrict(_, function, new_domain) => {
941 let (attrs, _, codom) = function.domain_of()?.as_function()?;
942 let new_dom = new_domain.domain_of()?;
943 Some(Domain::function(attrs, new_dom, codom))
944 }
945 Expression::Inverse(..) => Some(Domain::bool()),
946 Expression::LexLt(..) => Some(Domain::bool()),
947 Expression::LexLeq(..) => Some(Domain::bool()),
948 Expression::LexGt(..) => Some(Domain::bool()),
949 Expression::LexGeq(..) => Some(Domain::bool()),
950 Expression::FlatLexLt(..) => Some(Domain::bool()),
951 Expression::FlatLexLeq(..) => Some(Domain::bool()),
952 }
953 }
954
955 pub fn get_meta(&self) -> Metadata {
956 let metas: VecDeque<Metadata> = self.children_bi();
957 metas[0].clone()
958 }
959
960 pub fn set_meta(&self, meta: Metadata) {
961 self.transform_bi(&|_| meta.clone());
962 }
963
964 pub fn is_safe(&self) -> bool {
971 for expr in self.universe() {
973 match expr {
974 Expression::UnsafeDiv(_, _, _)
975 | Expression::UnsafeMod(_, _, _)
976 | Expression::UnsafePow(_, _, _)
977 | Expression::UnsafeIndex(_, _, _)
978 | Expression::Bubble(_, _, _)
979 | Expression::UnsafeSlice(_, _, _) => {
980 return false;
981 }
982 _ => {}
983 }
984 }
985 true
986 }
987
988 pub fn is_clean(&self) -> bool {
989 let metadata = self.get_meta();
990 metadata.clean
991 }
992
993 pub fn set_clean(&mut self, bool_value: bool) {
994 let mut metadata = self.get_meta();
995 metadata.clean = bool_value;
996 self.set_meta(metadata);
997 }
998
999 pub fn is_associative_commutative_operator(&self) -> bool {
1001 TryInto::<ACOperatorKind>::try_into(self).is_ok()
1002 }
1003
1004 pub fn is_matrix_literal(&self) -> bool {
1009 matches!(
1010 self,
1011 Expression::AbstractLiteral(_, AbstractLiteral::Matrix(_, _))
1012 | Expression::Atomic(
1013 _,
1014 Atom::Literal(Literal::AbstractLiteral(AbstractLiteral::Matrix(_, _))),
1015 )
1016 )
1017 }
1018
1019 pub fn identical_atom_to(&self, other: &Expression) -> bool {
1025 let atom1: Result<&Atom, _> = self.try_into();
1026 let atom2: Result<&Atom, _> = other.try_into();
1027
1028 if let (Ok(atom1), Ok(atom2)) = (atom1, atom2) {
1029 atom2 == atom1
1030 } else {
1031 false
1032 }
1033 }
1034
1035 pub fn unwrap_list(&self) -> Option<Vec<Expression>> {
1040 match self {
1041 Expression::AbstractLiteral(_, matrix @ AbstractLiteral::Matrix(_, _)) => {
1042 matrix.unwrap_list().cloned()
1043 }
1044 Expression::Atomic(
1045 _,
1046 Atom::Literal(Literal::AbstractLiteral(matrix @ AbstractLiteral::Matrix(_, _))),
1047 ) => matrix.unwrap_list().map(|elems| {
1048 elems
1049 .clone()
1050 .into_iter()
1051 .map(|x: Literal| Expression::Atomic(Metadata::new(), Atom::Literal(x)))
1052 .collect_vec()
1053 }),
1054 _ => None,
1055 }
1056 }
1057
1058 pub fn unwrap_matrix_unchecked(self) -> Option<(Vec<Expression>, DomainPtr)> {
1066 match self {
1067 Expression::AbstractLiteral(_, AbstractLiteral::Matrix(elems, domain)) => {
1068 Some((elems, domain))
1069 }
1070 Expression::Atomic(
1071 _,
1072 Atom::Literal(Literal::AbstractLiteral(AbstractLiteral::Matrix(elems, domain))),
1073 ) => Some((
1074 elems
1075 .into_iter()
1076 .map(|x: Literal| Expression::Atomic(Metadata::new(), Atom::Literal(x)))
1077 .collect_vec(),
1078 domain.into(),
1079 )),
1080
1081 _ => None,
1082 }
1083 }
1084
1085 pub fn extend_root(self, exprs: Vec<Expression>) -> Expression {
1090 match self {
1091 Expression::Root(meta, mut children) => {
1092 children.extend(exprs);
1093 Expression::Root(meta, children)
1094 }
1095 _ => panic!("extend_root called on a non-Root expression"),
1096 }
1097 }
1098
1099 pub fn into_literal(self) -> Option<Literal> {
1101 match self {
1102 Expression::Atomic(_, Atom::Literal(lit)) => Some(lit),
1103 Expression::AbstractLiteral(_, abslit) => {
1104 Some(Literal::AbstractLiteral(abslit.into_literals()?))
1105 }
1106 Expression::Neg(_, e) => {
1107 let Literal::Int(i) = Moo::unwrap_or_clone(e).into_literal()? else {
1108 bug!("negated literal should be an int");
1109 };
1110
1111 Some(Literal::Int(-i))
1112 }
1113
1114 _ => None,
1115 }
1116 }
1117
1118 pub fn to_ac_operator_kind(&self) -> Option<ACOperatorKind> {
1120 TryFrom::try_from(self).ok()
1121 }
1122
1123 pub fn universe_categories(&self) -> HashSet<Category> {
1125 self.universe()
1126 .into_iter()
1127 .map(|x| x.category_of())
1128 .collect()
1129 }
1130}
1131
1132pub fn get_function_domain(function: &Moo<Expression>) -> Option<DomainPtr> {
1133 let function_domain = function.domain_of()?;
1134 match function_domain.resolve().as_ref() {
1135 Some(d) => {
1136 match d.as_ref() {
1137 GroundDomain::Function(_, domain, _) => Some(domain.clone().into()),
1138 _ => None,
1140 }
1141 }
1142 None => {
1143 match function_domain.as_unresolved()? {
1144 UnresolvedDomain::Function(_, domain, _) => Some(domain.clone()),
1145 _ => None,
1147 }
1148 }
1149 }
1150}
1151
1152pub fn get_function_codomain(function: &Moo<Expression>) -> Option<DomainPtr> {
1153 let function_domain = function.domain_of()?;
1154 match function_domain.resolve().as_ref() {
1155 Some(d) => {
1156 match d.as_ref() {
1157 GroundDomain::Function(_, _, codomain) => Some(codomain.clone().into()),
1158 _ => None,
1160 }
1161 }
1162 None => {
1163 match function_domain.as_unresolved()? {
1164 UnresolvedDomain::Function(_, _, codomain) => Some(codomain.clone()),
1165 _ => None,
1167 }
1168 }
1169 }
1170}
1171
1172impl TryFrom<&Expression> for i32 {
1173 type Error = ();
1174
1175 fn try_from(value: &Expression) -> Result<Self, Self::Error> {
1176 let Expression::Atomic(_, atom) = value else {
1177 return Err(());
1178 };
1179
1180 let Atom::Literal(lit) = atom else {
1181 return Err(());
1182 };
1183
1184 let Literal::Int(i) = lit else {
1185 return Err(());
1186 };
1187
1188 Ok(*i)
1189 }
1190}
1191
1192impl TryFrom<Expression> for i32 {
1193 type Error = ();
1194
1195 fn try_from(value: Expression) -> Result<Self, Self::Error> {
1196 TryFrom::<&Expression>::try_from(&value)
1197 }
1198}
1199impl From<i32> for Expression {
1200 fn from(i: i32) -> Self {
1201 Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Int(i)))
1202 }
1203}
1204
1205impl From<bool> for Expression {
1206 fn from(b: bool) -> Self {
1207 Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Bool(b)))
1208 }
1209}
1210
1211impl From<Atom> for Expression {
1212 fn from(value: Atom) -> Self {
1213 Expression::Atomic(Metadata::new(), value)
1214 }
1215}
1216
1217impl From<Literal> for Expression {
1218 fn from(value: Literal) -> Self {
1219 Expression::Atomic(Metadata::new(), value.into())
1220 }
1221}
1222
1223impl From<Moo<Expression>> for Expression {
1224 fn from(val: Moo<Expression>) -> Self {
1225 val.as_ref().clone()
1226 }
1227}
1228
1229impl CategoryOf for Expression {
1230 fn category_of(&self) -> Category {
1231 let category = self.cata(&move |x,children| {
1233
1234 if let Some(max_category) = children.iter().max() {
1235 *max_category
1238 } else {
1239 let mut max_category = Category::Bottom;
1241
1242 if !Biplate::<Model>::universe_bi(&x).is_empty() {
1249 return Category::Decision;
1251 }
1252
1253 if let Some(max_atom_category) = Biplate::<Atom>::universe_bi(&x).iter().map(|x| x.category_of()).max()
1255 && max_atom_category > max_category{
1257 max_category = max_atom_category;
1259 }
1260
1261 if let Some(max_declaration_category) = Biplate::<DeclarationPtr>::universe_bi(&x).iter().map(|x| x.category_of()).max()
1263 && max_declaration_category > max_category{
1265 max_category = max_declaration_category;
1267 }
1268 max_category
1269
1270 }
1271 });
1272
1273 if cfg!(debug_assertions) {
1274 trace!(
1275 category= %category,
1276 expression= %self,
1277 "Called Expression::category_of()"
1278 );
1279 };
1280 category
1281 }
1282}
1283
1284impl Display for Expression {
1285 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
1286 match &self {
1287 Expression::Union(_, box1, box2) => {
1288 write!(f, "({} union {})", box1.clone(), box2.clone())
1289 }
1290 Expression::In(_, e1, e2) => {
1291 write!(f, "{e1} in {e2}")
1292 }
1293 Expression::Intersect(_, box1, box2) => {
1294 write!(f, "({} intersect {})", box1.clone(), box2.clone())
1295 }
1296 Expression::Supset(_, box1, box2) => {
1297 write!(f, "({} supset {})", box1.clone(), box2.clone())
1298 }
1299 Expression::SupsetEq(_, box1, box2) => {
1300 write!(f, "({} supsetEq {})", box1.clone(), box2.clone())
1301 }
1302 Expression::Subset(_, box1, box2) => {
1303 write!(f, "({} subset {})", box1.clone(), box2.clone())
1304 }
1305 Expression::SubsetEq(_, box1, box2) => {
1306 write!(f, "({} subsetEq {})", box1.clone(), box2.clone())
1307 }
1308
1309 Expression::AbstractLiteral(_, l) => l.fmt(f),
1310 Expression::Comprehension(_, c) => c.fmt(f),
1311 Expression::AbstractComprehension(_, c) => c.fmt(f),
1312 Expression::UnsafeIndex(_, e1, e2) | Expression::SafeIndex(_, e1, e2) => {
1313 write!(f, "{e1}{}", pretty_vec(e2))
1314 }
1315 Expression::UnsafeSlice(_, e1, es) | Expression::SafeSlice(_, e1, es) => {
1316 let args = es
1317 .iter()
1318 .map(|x| match x {
1319 Some(x) => format!("{x}"),
1320 None => "..".into(),
1321 })
1322 .join(",");
1323
1324 write!(f, "{e1}[{args}]")
1325 }
1326 Expression::InDomain(_, e, domain) => {
1327 write!(f, "__inDomain({e},{domain})")
1328 }
1329 Expression::Root(_, exprs) => {
1330 write!(f, "{}", pretty_expressions_as_top_level(exprs))
1331 }
1332 Expression::DominanceRelation(_, expr) => write!(f, "DominanceRelation({expr})"),
1333 Expression::FromSolution(_, expr) => write!(f, "FromSolution({expr})"),
1334 Expression::Metavar(_, name) => write!(f, "&{name}"),
1335 Expression::Atomic(_, atom) => atom.fmt(f),
1336 Expression::Abs(_, a) => write!(f, "|{a}|"),
1337 Expression::Sum(_, e) => {
1338 write!(f, "sum({e})")
1339 }
1340 Expression::Product(_, e) => {
1341 write!(f, "product({e})")
1342 }
1343 Expression::Min(_, e) => {
1344 write!(f, "min({e})")
1345 }
1346 Expression::Max(_, e) => {
1347 write!(f, "max({e})")
1348 }
1349 Expression::Not(_, expr_box) => {
1350 write!(f, "!({})", expr_box.clone())
1351 }
1352 Expression::Or(_, e) => {
1353 write!(f, "or({e})")
1354 }
1355 Expression::And(_, e) => {
1356 write!(f, "and({e})")
1357 }
1358 Expression::Imply(_, box1, box2) => {
1359 write!(f, "({box1}) -> ({box2})")
1360 }
1361 Expression::Iff(_, box1, box2) => {
1362 write!(f, "({box1}) <-> ({box2})")
1363 }
1364 Expression::Eq(_, box1, box2) => {
1365 write!(f, "({} = {})", box1.clone(), box2.clone())
1366 }
1367 Expression::Neq(_, box1, box2) => {
1368 write!(f, "({} != {})", box1.clone(), box2.clone())
1369 }
1370 Expression::Geq(_, box1, box2) => {
1371 write!(f, "({} >= {})", box1.clone(), box2.clone())
1372 }
1373 Expression::Leq(_, box1, box2) => {
1374 write!(f, "({} <= {})", box1.clone(), box2.clone())
1375 }
1376 Expression::Gt(_, box1, box2) => {
1377 write!(f, "({} > {})", box1.clone(), box2.clone())
1378 }
1379 Expression::Lt(_, box1, box2) => {
1380 write!(f, "({} < {})", box1.clone(), box2.clone())
1381 }
1382 Expression::FlatSumGeq(_, box1, box2) => {
1383 write!(f, "SumGeq({}, {})", pretty_vec(box1), box2.clone())
1384 }
1385 Expression::FlatSumLeq(_, box1, box2) => {
1386 write!(f, "SumLeq({}, {})", pretty_vec(box1), box2.clone())
1387 }
1388 Expression::FlatIneq(_, box1, box2, box3) => write!(
1389 f,
1390 "Ineq({}, {}, {})",
1391 box1.clone(),
1392 box2.clone(),
1393 box3.clone()
1394 ),
1395 Expression::Flatten(_, n, m) => {
1396 if let Some(n) = n {
1397 write!(f, "flatten({n}, {m})")
1398 } else {
1399 write!(f, "flatten({m})")
1400 }
1401 }
1402 Expression::AllDiff(_, e) => {
1403 write!(f, "allDiff({e})")
1404 }
1405 Expression::Table(_, tuple_expr, rows_expr) => {
1406 write!(f, "table({tuple_expr}, {rows_expr})")
1407 }
1408 Expression::NegativeTable(_, tuple_expr, rows_expr) => {
1409 write!(f, "negativeTable({tuple_expr}, {rows_expr})")
1410 }
1411 Expression::Bubble(_, box1, box2) => {
1412 write!(f, "{{{} @ {}}}", box1.clone(), box2.clone())
1413 }
1414 Expression::SafeDiv(_, box1, box2) => {
1415 write!(f, "SafeDiv({}, {})", box1.clone(), box2.clone())
1416 }
1417 Expression::UnsafeDiv(_, box1, box2) => {
1418 write!(f, "UnsafeDiv({}, {})", box1.clone(), box2.clone())
1419 }
1420 Expression::UnsafePow(_, box1, box2) => {
1421 write!(f, "UnsafePow({}, {})", box1.clone(), box2.clone())
1422 }
1423 Expression::SafePow(_, box1, box2) => {
1424 write!(f, "SafePow({}, {})", box1.clone(), box2.clone())
1425 }
1426 Expression::MinionDivEqUndefZero(_, box1, box2, box3) => {
1427 write!(
1428 f,
1429 "DivEq({}, {}, {})",
1430 box1.clone(),
1431 box2.clone(),
1432 box3.clone()
1433 )
1434 }
1435 Expression::MinionModuloEqUndefZero(_, box1, box2, box3) => {
1436 write!(
1437 f,
1438 "ModEq({}, {}, {})",
1439 box1.clone(),
1440 box2.clone(),
1441 box3.clone()
1442 )
1443 }
1444 Expression::FlatWatchedLiteral(_, x, l) => {
1445 write!(f, "WatchedLiteral({x},{l})")
1446 }
1447 Expression::MinionReify(_, box1, box2) => {
1448 write!(f, "Reify({}, {})", box1.clone(), box2.clone())
1449 }
1450 Expression::MinionReifyImply(_, box1, box2) => {
1451 write!(f, "ReifyImply({}, {})", box1.clone(), box2.clone())
1452 }
1453 Expression::MinionWInIntervalSet(_, atom, intervals) => {
1454 let intervals = intervals.iter().join(",");
1455 write!(f, "__minion_w_inintervalset({atom},[{intervals}])")
1456 }
1457 Expression::MinionWInSet(_, atom, values) => {
1458 let values = values.iter().join(",");
1459 write!(f, "__minion_w_inset({atom},[{values}])")
1460 }
1461 Expression::AuxDeclaration(_, reference, e) => {
1462 write!(f, "{} =aux {}", reference, e.clone())
1463 }
1464 Expression::UnsafeMod(_, a, b) => {
1465 write!(f, "{} % {}", a.clone(), b.clone())
1466 }
1467 Expression::SafeMod(_, a, b) => {
1468 write!(f, "SafeMod({},{})", a.clone(), b.clone())
1469 }
1470 Expression::Neg(_, a) => {
1471 write!(f, "-({})", a.clone())
1472 }
1473 Expression::Minus(_, a, b) => {
1474 write!(f, "({} - {})", a.clone(), b.clone())
1475 }
1476 Expression::FlatAllDiff(_, es) => {
1477 write!(f, "__flat_alldiff({})", pretty_vec(es))
1478 }
1479 Expression::FlatAbsEq(_, a, b) => {
1480 write!(f, "AbsEq({},{})", a.clone(), b.clone())
1481 }
1482 Expression::FlatMinusEq(_, a, b) => {
1483 write!(f, "MinusEq({},{})", a.clone(), b.clone())
1484 }
1485 Expression::FlatProductEq(_, a, b, c) => {
1486 write!(
1487 f,
1488 "FlatProductEq({},{},{})",
1489 a.clone(),
1490 b.clone(),
1491 c.clone()
1492 )
1493 }
1494 Expression::FlatWeightedSumLeq(_, cs, vs, total) => {
1495 write!(
1496 f,
1497 "FlatWeightedSumLeq({},{},{})",
1498 pretty_vec(cs),
1499 pretty_vec(vs),
1500 total.clone()
1501 )
1502 }
1503 Expression::FlatWeightedSumGeq(_, cs, vs, total) => {
1504 write!(
1505 f,
1506 "FlatWeightedSumGeq({},{},{})",
1507 pretty_vec(cs),
1508 pretty_vec(vs),
1509 total.clone()
1510 )
1511 }
1512 Expression::MinionPow(_, atom, atom1, atom2) => {
1513 write!(f, "MinionPow({atom},{atom1},{atom2})")
1514 }
1515 Expression::MinionElementOne(_, atoms, atom, atom1) => {
1516 let atoms = atoms.iter().join(",");
1517 write!(f, "__minion_element_one([{atoms}],{atom},{atom1})")
1518 }
1519
1520 Expression::ToInt(_, expr) => {
1521 write!(f, "toInt({expr})")
1522 }
1523
1524 Expression::SATInt(_, encoding, bits, (min, max)) => {
1525 write!(f, "SATInt({encoding:?}, {bits} [{min}, {max}])")
1526 }
1527
1528 Expression::PairwiseSum(_, a, b) => write!(f, "PairwiseSum({a}, {b})"),
1529 Expression::PairwiseProduct(_, a, b) => write!(f, "PairwiseProduct({a}, {b})"),
1530
1531 Expression::Defined(_, function) => write!(f, "defined({function})"),
1532 Expression::Range(_, function) => write!(f, "range({function})"),
1533 Expression::Image(_, function, elems) => write!(f, "image({function},{elems})"),
1534 Expression::ImageSet(_, function, elems) => write!(f, "imageSet({function},{elems})"),
1535 Expression::PreImage(_, function, elems) => write!(f, "preImage({function},{elems})"),
1536 Expression::Inverse(_, a, b) => write!(f, "inverse({a},{b})"),
1537 Expression::Restrict(_, function, domain) => write!(f, "restrict({function},{domain})"),
1538
1539 Expression::LexLt(_, a, b) => write!(f, "({a} <lex {b})"),
1540 Expression::LexLeq(_, a, b) => write!(f, "({a} <=lex {b})"),
1541 Expression::LexGt(_, a, b) => write!(f, "({a} >lex {b})"),
1542 Expression::LexGeq(_, a, b) => write!(f, "({a} >=lex {b})"),
1543 Expression::FlatLexLt(_, a, b) => {
1544 write!(f, "FlatLexLt({}, {})", pretty_vec(a), pretty_vec(b))
1545 }
1546 Expression::FlatLexLeq(_, a, b) => {
1547 write!(f, "FlatLexLeq({}, {})", pretty_vec(a), pretty_vec(b))
1548 }
1549 }
1550 }
1551}
1552
1553impl Typeable for Expression {
1554 fn return_type(&self) -> ReturnType {
1555 match self {
1556 Expression::Union(_, subject, _) => ReturnType::Set(Box::new(subject.return_type())),
1557 Expression::Intersect(_, subject, _) => {
1558 ReturnType::Set(Box::new(subject.return_type()))
1559 }
1560 Expression::In(_, _, _) => ReturnType::Bool,
1561 Expression::Supset(_, _, _) => ReturnType::Bool,
1562 Expression::SupsetEq(_, _, _) => ReturnType::Bool,
1563 Expression::Subset(_, _, _) => ReturnType::Bool,
1564 Expression::SubsetEq(_, _, _) => ReturnType::Bool,
1565 Expression::AbstractLiteral(_, lit) => lit.return_type(),
1566 Expression::UnsafeIndex(_, subject, idx) | Expression::SafeIndex(_, subject, idx) => {
1567 let subject_ty = subject.return_type();
1568 match subject_ty {
1569 ReturnType::Matrix(_) => {
1570 let mut elem_typ = subject_ty;
1573 let mut idx_len = idx.len();
1574 while idx_len > 0
1575 && let ReturnType::Matrix(new_elem_typ) = &elem_typ
1576 {
1577 elem_typ = *new_elem_typ.clone();
1578 idx_len -= 1;
1579 }
1580 elem_typ
1581 }
1582 ReturnType::Record(_) | ReturnType::Tuple(_) => ReturnType::Unknown,
1584 _ => bug!(
1585 "Invalid indexing operation: expected the operand to be a collection, got {self}: {subject_ty}"
1586 ),
1587 }
1588 }
1589 Expression::UnsafeSlice(_, subject, _) | Expression::SafeSlice(_, subject, _) => {
1590 ReturnType::Matrix(Box::new(subject.return_type()))
1591 }
1592 Expression::InDomain(_, _, _) => ReturnType::Bool,
1593 Expression::Comprehension(_, comp) => comp.return_type(),
1594 Expression::AbstractComprehension(_, comp) => comp.return_type(),
1595 Expression::Root(_, _) => ReturnType::Bool,
1596 Expression::DominanceRelation(_, _) => ReturnType::Bool,
1597 Expression::FromSolution(_, expr) => expr.return_type(),
1598 Expression::Metavar(_, _) => ReturnType::Unknown,
1599 Expression::Atomic(_, atom) => atom.return_type(),
1600 Expression::Abs(_, _) => ReturnType::Int,
1601 Expression::Sum(_, _) => ReturnType::Int,
1602 Expression::Product(_, _) => ReturnType::Int,
1603 Expression::Min(_, _) => ReturnType::Int,
1604 Expression::Max(_, _) => ReturnType::Int,
1605 Expression::Not(_, _) => ReturnType::Bool,
1606 Expression::Or(_, _) => ReturnType::Bool,
1607 Expression::Imply(_, _, _) => ReturnType::Bool,
1608 Expression::Iff(_, _, _) => ReturnType::Bool,
1609 Expression::And(_, _) => ReturnType::Bool,
1610 Expression::Eq(_, _, _) => ReturnType::Bool,
1611 Expression::Neq(_, _, _) => ReturnType::Bool,
1612 Expression::Geq(_, _, _) => ReturnType::Bool,
1613 Expression::Leq(_, _, _) => ReturnType::Bool,
1614 Expression::Gt(_, _, _) => ReturnType::Bool,
1615 Expression::Lt(_, _, _) => ReturnType::Bool,
1616 Expression::SafeDiv(_, _, _) => ReturnType::Int,
1617 Expression::UnsafeDiv(_, _, _) => ReturnType::Int,
1618 Expression::FlatAllDiff(_, _) => ReturnType::Bool,
1619 Expression::FlatSumGeq(_, _, _) => ReturnType::Bool,
1620 Expression::FlatSumLeq(_, _, _) => ReturnType::Bool,
1621 Expression::MinionDivEqUndefZero(_, _, _, _) => ReturnType::Bool,
1622 Expression::FlatIneq(_, _, _, _) => ReturnType::Bool,
1623 Expression::Flatten(_, _, matrix) => {
1624 let matrix_type = matrix.return_type();
1625 match matrix_type {
1626 ReturnType::Matrix(_) => {
1627 let mut elem_type = matrix_type;
1629 while let ReturnType::Matrix(new_elem_type) = &elem_type {
1630 elem_type = *new_elem_type.clone();
1631 }
1632 ReturnType::Matrix(Box::new(elem_type))
1633 }
1634 _ => bug!(
1635 "Invalid indexing operation: expected the operand to be a collection, got {self}: {matrix_type}"
1636 ),
1637 }
1638 }
1639 Expression::AllDiff(_, _) => ReturnType::Bool,
1640 Expression::Table(_, _, _) => ReturnType::Bool,
1641 Expression::NegativeTable(_, _, _) => ReturnType::Bool,
1642 Expression::Bubble(_, inner, _) => inner.return_type(),
1643 Expression::FlatWatchedLiteral(_, _, _) => ReturnType::Bool,
1644 Expression::MinionReify(_, _, _) => ReturnType::Bool,
1645 Expression::MinionReifyImply(_, _, _) => ReturnType::Bool,
1646 Expression::MinionWInIntervalSet(_, _, _) => ReturnType::Bool,
1647 Expression::MinionWInSet(_, _, _) => ReturnType::Bool,
1648 Expression::MinionElementOne(_, _, _, _) => ReturnType::Bool,
1649 Expression::AuxDeclaration(_, _, _) => ReturnType::Bool,
1650 Expression::UnsafeMod(_, _, _) => ReturnType::Int,
1651 Expression::SafeMod(_, _, _) => ReturnType::Int,
1652 Expression::MinionModuloEqUndefZero(_, _, _, _) => ReturnType::Bool,
1653 Expression::Neg(_, _) => ReturnType::Int,
1654 Expression::UnsafePow(_, _, _) => ReturnType::Int,
1655 Expression::SafePow(_, _, _) => ReturnType::Int,
1656 Expression::Minus(_, _, _) => ReturnType::Int,
1657 Expression::FlatAbsEq(_, _, _) => ReturnType::Bool,
1658 Expression::FlatMinusEq(_, _, _) => ReturnType::Bool,
1659 Expression::FlatProductEq(_, _, _, _) => ReturnType::Bool,
1660 Expression::FlatWeightedSumLeq(_, _, _, _) => ReturnType::Bool,
1661 Expression::FlatWeightedSumGeq(_, _, _, _) => ReturnType::Bool,
1662 Expression::MinionPow(_, _, _, _) => ReturnType::Bool,
1663 Expression::ToInt(_, _) => ReturnType::Int,
1664 Expression::SATInt(..) => ReturnType::Int,
1665 Expression::PairwiseSum(_, _, _) => ReturnType::Int,
1666 Expression::PairwiseProduct(_, _, _) => ReturnType::Int,
1667 Expression::Defined(_, function) => {
1668 let subject = function.return_type();
1669 match subject {
1670 ReturnType::Function(domain, _) => *domain,
1671 _ => bug!(
1672 "Invalid defined operation: expected the operand to be a function, got {self}: {subject}"
1673 ),
1674 }
1675 }
1676 Expression::Range(_, function) => {
1677 let subject = function.return_type();
1678 match subject {
1679 ReturnType::Function(_, codomain) => *codomain,
1680 _ => bug!(
1681 "Invalid range operation: expected the operand to be a function, got {self}: {subject}"
1682 ),
1683 }
1684 }
1685 Expression::Image(_, function, _) => {
1686 let subject = function.return_type();
1687 match subject {
1688 ReturnType::Function(_, codomain) => *codomain,
1689 _ => bug!(
1690 "Invalid image operation: expected the operand to be a function, got {self}: {subject}"
1691 ),
1692 }
1693 }
1694 Expression::ImageSet(_, function, _) => {
1695 let subject = function.return_type();
1696 match subject {
1697 ReturnType::Function(_, codomain) => *codomain,
1698 _ => bug!(
1699 "Invalid imageSet operation: expected the operand to be a function, got {self}: {subject}"
1700 ),
1701 }
1702 }
1703 Expression::PreImage(_, function, _) => {
1704 let subject = function.return_type();
1705 match subject {
1706 ReturnType::Function(domain, _) => *domain,
1707 _ => bug!(
1708 "Invalid preImage operation: expected the operand to be a function, got {self}: {subject}"
1709 ),
1710 }
1711 }
1712 Expression::Restrict(_, function, new_domain) => {
1713 let subject = function.return_type();
1714 match subject {
1715 ReturnType::Function(_, codomain) => {
1716 ReturnType::Function(Box::new(new_domain.return_type()), codomain)
1717 }
1718 _ => bug!(
1719 "Invalid preImage operation: expected the operand to be a function, got {self}: {subject}"
1720 ),
1721 }
1722 }
1723 Expression::Inverse(..) => ReturnType::Bool,
1724 Expression::LexLt(..) => ReturnType::Bool,
1725 Expression::LexGt(..) => ReturnType::Bool,
1726 Expression::LexLeq(..) => ReturnType::Bool,
1727 Expression::LexGeq(..) => ReturnType::Bool,
1728 Expression::FlatLexLt(..) => ReturnType::Bool,
1729 Expression::FlatLexLeq(..) => ReturnType::Bool,
1730 }
1731 }
1732}
1733
1734#[cfg(test)]
1735mod tests {
1736 use crate::matrix_expr;
1737
1738 use super::*;
1739
1740 #[test]
1741 fn test_domain_of_constant_sum() {
1742 let c1 = Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Int(1)));
1743 let c2 = Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Int(2)));
1744 let sum = Expression::Sum(Metadata::new(), Moo::new(matrix_expr![c1, c2]));
1745 assert_eq!(sum.domain_of(), Some(Domain::int(vec![Range::Single(3)])));
1746 }
1747
1748 #[test]
1749 fn test_domain_of_constant_invalid_type() {
1750 let c1 = Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Int(1)));
1751 let c2 = Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Bool(true)));
1752 let sum = Expression::Sum(Metadata::new(), Moo::new(matrix_expr![c1, c2]));
1753 assert_eq!(sum.domain_of(), None);
1754 }
1755
1756 #[test]
1757 fn test_domain_of_empty_sum() {
1758 let sum = Expression::Sum(Metadata::new(), Moo::new(matrix_expr![]));
1759 assert_eq!(sum.domain_of(), None);
1760 }
1761}