1use std::collections::{HashSet, VecDeque};
2use std::fmt::{Display, Formatter};
3use tracing::trace;
4
5use crate::ast::ReturnType;
6use crate::ast::SetAttr;
7use crate::ast::literals::AbstractLiteral;
8use crate::ast::literals::Literal;
9use crate::ast::pretty::{pretty_expressions_as_top_level, pretty_vec};
10use crate::ast::{Atom, DomainPtr};
11use crate::ast::{GroundDomain, Metadata, UnresolvedDomain};
12use crate::ast::{IntVal, Moo};
13use crate::ast::{Name, matrix};
14use crate::bug;
15use conjure_cp_enum_compatibility_macro::document_compatibility;
16use itertools::Itertools;
17use serde::{Deserialize, Serialize};
18use ustr::Ustr;
19
20use polyquine::Quine;
21use uniplate::{Biplate, Uniplate};
22
23use super::ac_operators::ACOperatorKind;
24use super::categories::{Category, CategoryOf};
25use super::comprehension::Comprehension;
26use super::domains::HasDomain as _;
27use super::records::RecordValue;
28use super::{DeclarationPtr, Domain, Range, Reference, SubModel, Typeable};
29
30static_assertions::assert_eq_size!([u8; 104], Expression);
53
54#[document_compatibility]
59#[derive(Clone, Debug, Hash, PartialEq, Eq, Serialize, Deserialize, Uniplate, Quine)]
60#[biplate(to=Metadata)]
61#[biplate(to=Atom)]
62#[biplate(to=DeclarationPtr)]
63#[biplate(to=Name)]
64#[biplate(to=Reference)]
65#[biplate(to=Vec<Expression>)]
66#[biplate(to=Option<Expression>)]
67#[biplate(to=SubModel)]
68#[biplate(to=Comprehension)]
69#[biplate(to=AbstractLiteral<Expression>)]
70#[biplate(to=AbstractLiteral<Literal>)]
71#[biplate(to=RecordValue<Expression>)]
72#[biplate(to=RecordValue<Literal>)]
73#[biplate(to=Literal)]
74#[biplate(to=DomainPtr)]
75#[path_prefix(conjure_cp::ast)]
76pub enum Expression {
77 AbstractLiteral(Metadata, AbstractLiteral<Expression>),
78 Root(Metadata, Vec<Expression>),
80
81 Bubble(Metadata, Moo<Expression>, Moo<Expression>),
84
85 #[polyquine_skip]
91 Comprehension(Metadata, Moo<Comprehension>),
92
93 DominanceRelation(Metadata, Moo<Expression>),
95 FromSolution(Metadata, Moo<Atom>),
97
98 #[polyquine_with(arm = (_, name) => {
99 let ident = proc_macro2::Ident::new(name.as_str(), proc_macro2::Span::call_site());
100 quote::quote! { #ident.clone().into() }
101 })]
102 Metavar(Metadata, Ustr),
103
104 Atomic(Metadata, Atom),
105
106 #[compatible(JsonInput)]
110 UnsafeIndex(Metadata, Moo<Expression>, Vec<Expression>),
111
112 #[compatible(SMT)]
116 SafeIndex(Metadata, Moo<Expression>, Vec<Expression>),
117
118 #[compatible(JsonInput)]
128 UnsafeSlice(Metadata, Moo<Expression>, Vec<Option<Expression>>),
129
130 SafeSlice(Metadata, Moo<Expression>, Vec<Option<Expression>>),
134
135 InDomain(Metadata, Moo<Expression>, DomainPtr),
141
142 #[compatible(SMT)]
148 ToInt(Metadata, Moo<Expression>),
149
150 #[polyquine_skip]
152 Scope(Metadata, Moo<SubModel>),
153
154 #[compatible(JsonInput, SMT)]
156 Abs(Metadata, Moo<Expression>),
157
158 #[compatible(JsonInput, SMT)]
160 Sum(Metadata, Moo<Expression>),
161
162 #[compatible(JsonInput, SMT)]
164 Product(Metadata, Moo<Expression>),
165
166 #[compatible(JsonInput, SMT)]
168 Min(Metadata, Moo<Expression>),
169
170 #[compatible(JsonInput, SMT)]
172 Max(Metadata, Moo<Expression>),
173
174 #[compatible(JsonInput, SAT, SMT)]
176 Not(Metadata, Moo<Expression>),
177
178 #[compatible(JsonInput, SAT, SMT)]
180 Or(Metadata, Moo<Expression>),
181
182 #[compatible(JsonInput, SAT, SMT)]
184 And(Metadata, Moo<Expression>),
185
186 #[compatible(JsonInput, SMT)]
188 Imply(Metadata, Moo<Expression>, Moo<Expression>),
189
190 #[compatible(JsonInput, SMT)]
192 Iff(Metadata, Moo<Expression>, Moo<Expression>),
193
194 #[compatible(JsonInput)]
195 Union(Metadata, Moo<Expression>, Moo<Expression>),
196
197 #[compatible(JsonInput)]
198 In(Metadata, Moo<Expression>, Moo<Expression>),
199
200 #[compatible(JsonInput)]
201 Intersect(Metadata, Moo<Expression>, Moo<Expression>),
202
203 #[compatible(JsonInput)]
204 Supset(Metadata, Moo<Expression>, Moo<Expression>),
205
206 #[compatible(JsonInput)]
207 SupsetEq(Metadata, Moo<Expression>, Moo<Expression>),
208
209 #[compatible(JsonInput)]
210 Subset(Metadata, Moo<Expression>, Moo<Expression>),
211
212 #[compatible(JsonInput)]
213 SubsetEq(Metadata, Moo<Expression>, Moo<Expression>),
214
215 #[compatible(JsonInput, SMT)]
216 Eq(Metadata, Moo<Expression>, Moo<Expression>),
217
218 #[compatible(JsonInput, SMT)]
219 Neq(Metadata, Moo<Expression>, Moo<Expression>),
220
221 #[compatible(JsonInput, SMT)]
222 Geq(Metadata, Moo<Expression>, Moo<Expression>),
223
224 #[compatible(JsonInput, SMT)]
225 Leq(Metadata, Moo<Expression>, Moo<Expression>),
226
227 #[compatible(JsonInput, SMT)]
228 Gt(Metadata, Moo<Expression>, Moo<Expression>),
229
230 #[compatible(JsonInput, SMT)]
231 Lt(Metadata, Moo<Expression>, Moo<Expression>),
232
233 #[compatible(SMT)]
235 SafeDiv(Metadata, Moo<Expression>, Moo<Expression>),
236
237 #[compatible(JsonInput)]
239 UnsafeDiv(Metadata, Moo<Expression>, Moo<Expression>),
240
241 #[compatible(SMT)]
243 SafeMod(Metadata, Moo<Expression>, Moo<Expression>),
244
245 #[compatible(JsonInput)]
247 UnsafeMod(Metadata, Moo<Expression>, Moo<Expression>),
248
249 #[compatible(JsonInput, SMT)]
251 Neg(Metadata, Moo<Expression>),
252
253 #[compatible(JsonInput)]
255 Defined(Metadata, Moo<Expression>),
256
257 #[compatible(JsonInput)]
259 Range(Metadata, Moo<Expression>),
260
261 #[compatible(JsonInput)]
265 UnsafePow(Metadata, Moo<Expression>, Moo<Expression>),
266
267 SafePow(Metadata, Moo<Expression>, Moo<Expression>),
269
270 Flatten(Metadata, Option<Moo<Expression>>, Moo<Expression>),
274
275 #[compatible(JsonInput)]
277 AllDiff(Metadata, Moo<Expression>),
278
279 #[compatible(JsonInput)]
285 Minus(Metadata, Moo<Expression>, Moo<Expression>),
286
287 #[compatible(Minion)]
295 FlatAbsEq(Metadata, Moo<Atom>, Moo<Atom>),
296
297 #[compatible(Minion)]
305 FlatAllDiff(Metadata, Vec<Atom>),
306
307 #[compatible(Minion)]
315 FlatSumGeq(Metadata, Vec<Atom>, Atom),
316
317 #[compatible(Minion)]
325 FlatSumLeq(Metadata, Vec<Atom>, Atom),
326
327 #[compatible(Minion)]
335 FlatIneq(Metadata, Moo<Atom>, Moo<Atom>, Box<Literal>),
336
337 #[compatible(Minion)]
350 #[polyquine_skip]
351 FlatWatchedLiteral(Metadata, Reference, Literal),
352
353 FlatWeightedSumLeq(Metadata, Vec<Literal>, Vec<Atom>, Moo<Atom>),
365
366 FlatWeightedSumGeq(Metadata, Vec<Literal>, Vec<Atom>, Moo<Atom>),
378
379 #[compatible(Minion)]
387 FlatMinusEq(Metadata, Moo<Atom>, Moo<Atom>),
388
389 #[compatible(Minion)]
397 FlatProductEq(Metadata, Moo<Atom>, Moo<Atom>, Moo<Atom>),
398
399 #[compatible(Minion)]
407 MinionDivEqUndefZero(Metadata, Moo<Atom>, Moo<Atom>, Moo<Atom>),
408
409 #[compatible(Minion)]
417 MinionModuloEqUndefZero(Metadata, Moo<Atom>, Moo<Atom>, Moo<Atom>),
418
419 MinionPow(Metadata, Moo<Atom>, Moo<Atom>, Moo<Atom>),
431
432 #[compatible(Minion)]
441 MinionReify(Metadata, Moo<Expression>, Atom),
442
443 #[compatible(Minion)]
452 MinionReifyImply(Metadata, Moo<Expression>, Atom),
453
454 #[compatible(Minion)]
465 MinionWInIntervalSet(Metadata, Atom, Vec<i32>),
466
467 #[compatible(Minion)]
479 MinionWInSet(Metadata, Atom, Vec<i32>),
480
481 #[compatible(Minion)]
490 MinionElementOne(Metadata, Vec<Atom>, Moo<Atom>, Moo<Atom>),
491
492 #[compatible(Minion)]
496 #[polyquine_skip]
497 AuxDeclaration(Metadata, Reference, Moo<Expression>),
498
499 #[compatible(SAT)]
501 SATInt(Metadata, Moo<Expression>),
502
503 #[compatible(SMT)]
506 PairwiseSum(Metadata, Moo<Expression>, Moo<Expression>),
507
508 #[compatible(SMT)]
511 PairwiseProduct(Metadata, Moo<Expression>, Moo<Expression>),
512
513 #[compatible(JsonInput)]
514 Image(Metadata, Moo<Expression>, Moo<Expression>),
515
516 #[compatible(JsonInput)]
517 ImageSet(Metadata, Moo<Expression>, Moo<Expression>),
518
519 #[compatible(JsonInput)]
520 PreImage(Metadata, Moo<Expression>, Moo<Expression>),
521
522 #[compatible(JsonInput)]
523 Inverse(Metadata, Moo<Expression>, Moo<Expression>),
524
525 #[compatible(JsonInput)]
526 Restrict(Metadata, Moo<Expression>, Moo<Expression>),
527
528 LexLt(Metadata, Moo<Expression>, Moo<Expression>),
537
538 LexLeq(Metadata, Moo<Expression>, Moo<Expression>),
540
541 LexGt(Metadata, Moo<Expression>, Moo<Expression>),
544
545 LexGeq(Metadata, Moo<Expression>, Moo<Expression>),
548
549 FlatLexLt(Metadata, Vec<Atom>, Vec<Atom>),
551
552 FlatLexLeq(Metadata, Vec<Atom>, Vec<Atom>),
554}
555
556fn bounded_i32_domain_for_matrix_literal_monotonic(
563 e: &Expression,
564 op: fn(i32, i32) -> Option<i32>,
565) -> Option<DomainPtr> {
566 let (mut exprs, _) = e.clone().unwrap_matrix_unchecked()?;
568
569 let expr = exprs.pop()?;
585 let dom = expr.domain_of()?;
586 let Some(GroundDomain::Int(ranges)) = dom.as_ground() else {
587 return None;
588 };
589
590 let (mut current_min, mut current_max) = range_vec_bounds_i32(ranges)?;
591
592 for expr in exprs {
593 let dom = expr.domain_of()?;
594 let Some(GroundDomain::Int(ranges)) = dom.as_ground() else {
595 return None;
596 };
597
598 let (min, max) = range_vec_bounds_i32(ranges)?;
599
600 let minmax = op(min, current_max)?;
602 let minmin = op(min, current_min)?;
603 let maxmin = op(max, current_min)?;
604 let maxmax = op(max, current_max)?;
605 let vals = [minmax, minmin, maxmin, maxmax];
606
607 current_min = *vals
608 .iter()
609 .min()
610 .expect("vals iterator should not be empty, and should have a minimum.");
611 current_max = *vals
612 .iter()
613 .max()
614 .expect("vals iterator should not be empty, and should have a maximum.");
615 }
616
617 if current_min == current_max {
618 Some(Domain::int(vec![Range::Single(current_min)]))
619 } else {
620 Some(Domain::int(vec![Range::Bounded(current_min, current_max)]))
621 }
622}
623
624fn range_vec_bounds_i32(ranges: &Vec<Range<i32>>) -> Option<(i32, i32)> {
626 let mut min = i32::MAX;
627 let mut max = i32::MIN;
628 for r in ranges {
629 match r {
630 Range::Single(i) => {
631 if *i < min {
632 min = *i;
633 }
634 if *i > max {
635 max = *i;
636 }
637 }
638 Range::Bounded(i, j) => {
639 if *i < min {
640 min = *i;
641 }
642 if *j > max {
643 max = *j;
644 }
645 }
646 Range::UnboundedR(_) | Range::UnboundedL(_) | Range::Unbounded => return None,
647 }
648 }
649 Some((min, max))
650}
651
652impl Expression {
653 pub fn domain_of(&self) -> Option<DomainPtr> {
655 let ret = match self {
657 Expression::Union(_, a, b) => Some(Domain::set(
658 SetAttr::<IntVal>::default(),
659 a.domain_of()?.union(&b.domain_of()?).ok()?,
660 )),
661 Expression::Intersect(_, a, b) => Some(Domain::set(
662 SetAttr::<IntVal>::default(),
663 a.domain_of()?.intersect(&b.domain_of()?).ok()?,
664 )),
665 Expression::In(_, _, _) => Some(Domain::bool()),
666 Expression::Supset(_, _, _) => Some(Domain::bool()),
667 Expression::SupsetEq(_, _, _) => Some(Domain::bool()),
668 Expression::Subset(_, _, _) => Some(Domain::bool()),
669 Expression::SubsetEq(_, _, _) => Some(Domain::bool()),
670 Expression::AbstractLiteral(_, abslit) => abslit.domain_of(),
671 Expression::DominanceRelation(_, _) => Some(Domain::bool()),
672 Expression::FromSolution(_, expr) => Some(expr.domain_of()),
673 Expression::Metavar(_, _) => None,
674 Expression::Comprehension(_, comprehension) => comprehension.domain_of(),
675 Expression::UnsafeIndex(_, matrix, _) | Expression::SafeIndex(_, matrix, _) => {
676 let dom = matrix.domain_of()?;
677 if let Some((elem_domain, _)) = dom.as_matrix() {
678 return Some(elem_domain);
679 }
680
681 #[allow(clippy::redundant_pattern_matching)]
683 if let Some(_) = dom.as_tuple() {
684 return None;
686 }
687
688 #[allow(clippy::redundant_pattern_matching)]
690 if let Some(_) = dom.as_record() {
691 return None;
693 }
694
695 bug!("subject of an index operation should support indexing")
696 }
697 Expression::UnsafeSlice(_, matrix, indices)
698 | Expression::SafeSlice(_, matrix, indices) => {
699 let sliced_dimension = indices.iter().position(Option::is_none);
700
701 let dom = matrix.domain_of()?;
702 let Some((elem_domain, index_domains)) = dom.as_matrix() else {
703 bug!("subject of an index operation should be a matrix");
704 };
705
706 match sliced_dimension {
707 Some(dimension) => Some(Domain::matrix(
708 elem_domain,
709 vec![index_domains[dimension].clone()],
710 )),
711
712 None => Some(elem_domain),
714 }
715 }
716 Expression::InDomain(_, _, _) => Some(Domain::bool()),
717 Expression::Atomic(_, atom) => Some(atom.domain_of()),
718 Expression::Scope(_, _) => Some(Domain::bool()),
719 Expression::Sum(_, e) => {
720 bounded_i32_domain_for_matrix_literal_monotonic(e, |x, y| Some(x + y))
721 }
722 Expression::Product(_, e) => {
723 bounded_i32_domain_for_matrix_literal_monotonic(e, |x, y| Some(x * y))
724 }
725 Expression::Min(_, e) => bounded_i32_domain_for_matrix_literal_monotonic(e, |x, y| {
726 Some(if x < y { x } else { y })
727 }),
728 Expression::Max(_, e) => bounded_i32_domain_for_matrix_literal_monotonic(e, |x, y| {
729 Some(if x > y { x } else { y })
730 }),
731 Expression::UnsafeDiv(_, a, b) => a
732 .domain_of()?
733 .resolve()?
734 .apply_i32(
735 |x, y| {
738 if y != 0 {
739 Some((x as f32 / y as f32).floor() as i32)
740 } else {
741 None
742 }
743 },
744 b.domain_of()?.resolve()?.as_ref(),
745 )
746 .map(DomainPtr::from)
747 .ok(),
748 Expression::SafeDiv(_, a, b) => {
749 let domain = a
752 .domain_of()?
753 .resolve()?
754 .apply_i32(
755 |x, y| {
756 if y != 0 {
757 Some((x as f32 / y as f32).floor() as i32)
758 } else {
759 None
760 }
761 },
762 b.domain_of()?.resolve()?.as_ref(),
763 )
764 .unwrap_or_else(|err| bug!("Got {err} when computing domain of {self}"));
765
766 if let GroundDomain::Int(ranges) = domain {
767 let mut ranges = ranges;
768 ranges.push(Range::Single(0));
769 return Some(Domain::int(ranges));
770 } else {
771 bug!("Domain of {self} was not integer")
772 }
773 }
774 Expression::UnsafeMod(_, a, b) => a
775 .domain_of()?
776 .resolve()?
777 .apply_i32(
778 |x, y| if y != 0 { Some(x % y) } else { None },
779 b.domain_of()?.resolve()?.as_ref(),
780 )
781 .map(DomainPtr::from)
782 .ok(),
783 Expression::SafeMod(_, a, b) => {
784 let domain = a
785 .domain_of()?
786 .resolve()?
787 .apply_i32(
788 |x, y| if y != 0 { Some(x % y) } else { None },
789 b.domain_of()?.resolve()?.as_ref(),
790 )
791 .unwrap_or_else(|err| bug!("Got {err} when computing domain of {self}"));
792
793 if let GroundDomain::Int(ranges) = domain {
794 let mut ranges = ranges;
795 ranges.push(Range::Single(0));
796 return Some(Domain::int(ranges));
797 } else {
798 bug!("Domain of {self} was not integer")
799 }
800 }
801 Expression::SafePow(_, a, b) | Expression::UnsafePow(_, a, b) => a
802 .domain_of()?
803 .resolve()?
804 .apply_i32(
805 |x, y| {
806 if (x != 0 || y != 0) && y >= 0 {
807 Some(x.pow(y as u32))
808 } else {
809 None
810 }
811 },
812 b.domain_of()?.resolve()?.as_ref(),
813 )
814 .map(DomainPtr::from)
815 .ok(),
816 Expression::Root(_, _) => None,
817 Expression::Bubble(_, inner, _) => inner.domain_of(),
818 Expression::AuxDeclaration(_, _, _) => Some(Domain::bool()),
819 Expression::And(_, _) => Some(Domain::bool()),
820 Expression::Not(_, _) => Some(Domain::bool()),
821 Expression::Or(_, _) => Some(Domain::bool()),
822 Expression::Imply(_, _, _) => Some(Domain::bool()),
823 Expression::Iff(_, _, _) => Some(Domain::bool()),
824 Expression::Eq(_, _, _) => Some(Domain::bool()),
825 Expression::Neq(_, _, _) => Some(Domain::bool()),
826 Expression::Geq(_, _, _) => Some(Domain::bool()),
827 Expression::Leq(_, _, _) => Some(Domain::bool()),
828 Expression::Gt(_, _, _) => Some(Domain::bool()),
829 Expression::Lt(_, _, _) => Some(Domain::bool()),
830 Expression::FlatAbsEq(_, _, _) => Some(Domain::bool()),
831 Expression::FlatSumGeq(_, _, _) => Some(Domain::bool()),
832 Expression::FlatSumLeq(_, _, _) => Some(Domain::bool()),
833 Expression::MinionDivEqUndefZero(_, _, _, _) => Some(Domain::bool()),
834 Expression::MinionModuloEqUndefZero(_, _, _, _) => Some(Domain::bool()),
835 Expression::FlatIneq(_, _, _, _) => Some(Domain::bool()),
836 Expression::Flatten(_, n, m) => {
837 if let Some(expr) = n {
838 if expr.return_type() == ReturnType::Int {
839 return None;
841 }
842 } else {
843 let dom = m.domain_of()?.resolve()?;
845 let (val_dom, idx_doms) = match dom.as_ref() {
846 GroundDomain::Matrix(val, idx) => (val, idx),
847 _ => return None,
848 };
849 let num_elems = matrix::num_elements(idx_doms).ok()? as i32;
850
851 let new_index_domain = Domain::int(vec![Range::Bounded(1, num_elems)]);
852 return Some(Domain::matrix(
853 val_dom.clone().into(),
854 vec![new_index_domain],
855 ));
856 }
857 None
858 }
859 Expression::AllDiff(_, _) => Some(Domain::bool()),
860 Expression::FlatWatchedLiteral(_, _, _) => Some(Domain::bool()),
861 Expression::MinionReify(_, _, _) => Some(Domain::bool()),
862 Expression::MinionReifyImply(_, _, _) => Some(Domain::bool()),
863 Expression::MinionWInIntervalSet(_, _, _) => Some(Domain::bool()),
864 Expression::MinionWInSet(_, _, _) => Some(Domain::bool()),
865 Expression::MinionElementOne(_, _, _, _) => Some(Domain::bool()),
866 Expression::Neg(_, x) => {
867 let dom = x.domain_of()?;
868 let mut ranges = dom.as_int()?;
869
870 ranges = ranges
871 .into_iter()
872 .map(|r| match r {
873 Range::Single(x) => Range::Single(-x),
874 Range::Bounded(x, y) => Range::Bounded(-y, -x),
875 Range::UnboundedR(i) => Range::UnboundedL(-i),
876 Range::UnboundedL(i) => Range::UnboundedR(-i),
877 Range::Unbounded => Range::Unbounded,
878 })
879 .collect();
880
881 Some(Domain::int(ranges))
882 }
883 Expression::Minus(_, a, b) => a
884 .domain_of()?
885 .resolve()?
886 .apply_i32(|x, y| Some(x - y), b.domain_of()?.resolve()?.as_ref())
887 .map(DomainPtr::from)
888 .ok(),
889 Expression::FlatAllDiff(_, _) => Some(Domain::bool()),
890 Expression::FlatMinusEq(_, _, _) => Some(Domain::bool()),
891 Expression::FlatProductEq(_, _, _, _) => Some(Domain::bool()),
892 Expression::FlatWeightedSumLeq(_, _, _, _) => Some(Domain::bool()),
893 Expression::FlatWeightedSumGeq(_, _, _, _) => Some(Domain::bool()),
894 Expression::Abs(_, a) => a
895 .domain_of()?
896 .resolve()?
897 .apply_i32(|a, _| Some(a.abs()), a.domain_of()?.resolve()?.as_ref())
898 .map(DomainPtr::from)
899 .ok(),
900 Expression::MinionPow(_, _, _, _) => Some(Domain::bool()),
901 Expression::ToInt(_, _) => Some(Domain::int(vec![Range::Bounded(0, 1)])),
902 Expression::SATInt(_, _) => {
903 Some(Domain::int_ground(vec![Range::Bounded(
904 i8::MIN.into(),
905 i8::MAX.into(),
906 )])) } Expression::PairwiseSum(_, a, b) => a
912 .domain_of()?
913 .resolve()?
914 .apply_i32(|a, b| Some(a + b), b.domain_of()?.resolve()?.as_ref())
915 .map(DomainPtr::from)
916 .ok(),
917 Expression::PairwiseProduct(_, a, b) => a
918 .domain_of()?
919 .resolve()?
920 .apply_i32(|a, b| Some(a * b), b.domain_of()?.resolve()?.as_ref())
921 .map(DomainPtr::from)
922 .ok(),
923 Expression::Defined(_, function) => get_function_domain(function),
924 Expression::Range(_, function) => get_function_codomain(function),
925 Expression::Image(_, function, _) => get_function_codomain(function),
926 Expression::ImageSet(_, function, _) => get_function_codomain(function),
927 Expression::PreImage(_, function, _) => get_function_domain(function),
928 Expression::Restrict(_, function, new_domain) => {
929 let function_domain = function.domain_of()?;
930 match function_domain.resolve().as_ref() {
931 Some(d) => {
932 match d.as_ref() {
933 GroundDomain::Function(attrs, _, codomain) => Some(Domain::function(
934 attrs.clone(),
935 new_domain.domain_of()?,
936 codomain.clone().into(),
937 )),
938 _ => None,
940 }
941 }
942 None => {
943 match function_domain.as_unresolved()? {
944 UnresolvedDomain::Function(attrs, _, codomain) => {
945 Some(Domain::function(
946 attrs.clone(),
947 new_domain.domain_of()?,
948 codomain.clone(),
949 ))
950 }
951 _ => None,
953 }
954 }
955 }
956 }
957 Expression::Inverse(..) => Some(Domain::bool()),
958 Expression::LexLt(..) => Some(Domain::bool()),
959 Expression::LexLeq(..) => Some(Domain::bool()),
960 Expression::LexGt(..) => Some(Domain::bool()),
961 Expression::LexGeq(..) => Some(Domain::bool()),
962 Expression::FlatLexLt(..) => Some(Domain::bool()),
963 Expression::FlatLexLeq(..) => Some(Domain::bool()),
964 };
965 if let Some(dom) = &ret
966 && let Some(ranges) = dom.as_int_ground()
967 && ranges.len() > 1
968 {
969 let (min, max) = range_vec_bounds_i32(ranges)?;
972 return Some(Domain::int(vec![Range::Bounded(min, max)]));
973 }
974 ret
975 }
976
977 pub fn get_meta(&self) -> Metadata {
978 let metas: VecDeque<Metadata> = self.children_bi();
979 metas[0].clone()
980 }
981
982 pub fn set_meta(&self, meta: Metadata) {
983 self.transform_bi(&|_| meta.clone());
984 }
985
986 pub fn is_safe(&self) -> bool {
993 for expr in self.universe() {
995 match expr {
996 Expression::UnsafeDiv(_, _, _)
997 | Expression::UnsafeMod(_, _, _)
998 | Expression::UnsafePow(_, _, _)
999 | Expression::UnsafeIndex(_, _, _)
1000 | Expression::Bubble(_, _, _)
1001 | Expression::UnsafeSlice(_, _, _) => {
1002 return false;
1003 }
1004 _ => {}
1005 }
1006 }
1007 true
1008 }
1009
1010 pub fn is_clean(&self) -> bool {
1011 let metadata = self.get_meta();
1012 metadata.clean
1013 }
1014
1015 pub fn set_clean(&mut self, bool_value: bool) {
1016 let mut metadata = self.get_meta();
1017 metadata.clean = bool_value;
1018 self.set_meta(metadata);
1019 }
1020
1021 pub fn is_associative_commutative_operator(&self) -> bool {
1023 TryInto::<ACOperatorKind>::try_into(self).is_ok()
1024 }
1025
1026 pub fn is_matrix_literal(&self) -> bool {
1031 matches!(
1032 self,
1033 Expression::AbstractLiteral(_, AbstractLiteral::Matrix(_, _))
1034 | Expression::Atomic(
1035 _,
1036 Atom::Literal(Literal::AbstractLiteral(AbstractLiteral::Matrix(_, _))),
1037 )
1038 )
1039 }
1040
1041 pub fn identical_atom_to(&self, other: &Expression) -> bool {
1047 let atom1: Result<&Atom, _> = self.try_into();
1048 let atom2: Result<&Atom, _> = other.try_into();
1049
1050 if let (Ok(atom1), Ok(atom2)) = (atom1, atom2) {
1051 atom2 == atom1
1052 } else {
1053 false
1054 }
1055 }
1056
1057 pub fn unwrap_list(self) -> Option<Vec<Expression>> {
1062 match self {
1063 Expression::AbstractLiteral(_, matrix @ AbstractLiteral::Matrix(_, _)) => {
1064 matrix.unwrap_list().cloned()
1065 }
1066 Expression::Atomic(
1067 _,
1068 Atom::Literal(Literal::AbstractLiteral(matrix @ AbstractLiteral::Matrix(_, _))),
1069 ) => matrix.unwrap_list().map(|elems| {
1070 elems
1071 .clone()
1072 .into_iter()
1073 .map(|x: Literal| Expression::Atomic(Metadata::new(), Atom::Literal(x)))
1074 .collect_vec()
1075 }),
1076 _ => None,
1077 }
1078 }
1079
1080 pub fn unwrap_matrix_unchecked(self) -> Option<(Vec<Expression>, DomainPtr)> {
1088 match self {
1089 Expression::AbstractLiteral(_, AbstractLiteral::Matrix(elems, domain)) => {
1090 Some((elems, domain))
1091 }
1092 Expression::Atomic(
1093 _,
1094 Atom::Literal(Literal::AbstractLiteral(AbstractLiteral::Matrix(elems, domain))),
1095 ) => Some((
1096 elems
1097 .into_iter()
1098 .map(|x: Literal| Expression::Atomic(Metadata::new(), Atom::Literal(x)))
1099 .collect_vec(),
1100 domain.into(),
1101 )),
1102
1103 _ => None,
1104 }
1105 }
1106
1107 pub fn extend_root(self, exprs: Vec<Expression>) -> Expression {
1112 match self {
1113 Expression::Root(meta, mut children) => {
1114 children.extend(exprs);
1115 Expression::Root(meta, children)
1116 }
1117 _ => panic!("extend_root called on a non-Root expression"),
1118 }
1119 }
1120
1121 pub fn into_literal(self) -> Option<Literal> {
1123 match self {
1124 Expression::Atomic(_, Atom::Literal(lit)) => Some(lit),
1125 Expression::AbstractLiteral(_, abslit) => {
1126 Some(Literal::AbstractLiteral(abslit.into_literals()?))
1127 }
1128 Expression::Neg(_, e) => {
1129 let Literal::Int(i) = Moo::unwrap_or_clone(e).into_literal()? else {
1130 bug!("negated literal should be an int");
1131 };
1132
1133 Some(Literal::Int(-i))
1134 }
1135
1136 _ => None,
1137 }
1138 }
1139
1140 pub fn to_ac_operator_kind(&self) -> Option<ACOperatorKind> {
1142 TryFrom::try_from(self).ok()
1143 }
1144
1145 pub fn universe_categories(&self) -> HashSet<Category> {
1147 self.universe()
1148 .into_iter()
1149 .map(|x| x.category_of())
1150 .collect()
1151 }
1152}
1153
1154pub fn get_function_domain(function: &Moo<Expression>) -> Option<DomainPtr> {
1155 let function_domain = function.domain_of()?;
1156 match function_domain.resolve().as_ref() {
1157 Some(d) => {
1158 match d.as_ref() {
1159 GroundDomain::Function(_, domain, _) => Some(domain.clone().into()),
1160 _ => None,
1162 }
1163 }
1164 None => {
1165 match function_domain.as_unresolved()? {
1166 UnresolvedDomain::Function(_, domain, _) => Some(domain.clone()),
1167 _ => None,
1169 }
1170 }
1171 }
1172}
1173
1174pub fn get_function_codomain(function: &Moo<Expression>) -> Option<DomainPtr> {
1175 let function_domain = function.domain_of()?;
1176 match function_domain.resolve().as_ref() {
1177 Some(d) => {
1178 match d.as_ref() {
1179 GroundDomain::Function(_, _, codomain) => Some(codomain.clone().into()),
1180 _ => None,
1182 }
1183 }
1184 None => {
1185 match function_domain.as_unresolved()? {
1186 UnresolvedDomain::Function(_, _, codomain) => Some(codomain.clone()),
1187 _ => None,
1189 }
1190 }
1191 }
1192}
1193
1194impl TryFrom<&Expression> for i32 {
1195 type Error = ();
1196
1197 fn try_from(value: &Expression) -> Result<Self, Self::Error> {
1198 let Expression::Atomic(_, atom) = value else {
1199 return Err(());
1200 };
1201
1202 let Atom::Literal(lit) = atom else {
1203 return Err(());
1204 };
1205
1206 let Literal::Int(i) = lit else {
1207 return Err(());
1208 };
1209
1210 Ok(*i)
1211 }
1212}
1213
1214impl TryFrom<Expression> for i32 {
1215 type Error = ();
1216
1217 fn try_from(value: Expression) -> Result<Self, Self::Error> {
1218 TryFrom::<&Expression>::try_from(&value)
1219 }
1220}
1221impl From<i32> for Expression {
1222 fn from(i: i32) -> Self {
1223 Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Int(i)))
1224 }
1225}
1226
1227impl From<bool> for Expression {
1228 fn from(b: bool) -> Self {
1229 Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Bool(b)))
1230 }
1231}
1232
1233impl From<Atom> for Expression {
1234 fn from(value: Atom) -> Self {
1235 Expression::Atomic(Metadata::new(), value)
1236 }
1237}
1238
1239impl From<Literal> for Expression {
1240 fn from(value: Literal) -> Self {
1241 Expression::Atomic(Metadata::new(), value.into())
1242 }
1243}
1244
1245impl From<Moo<Expression>> for Expression {
1246 fn from(val: Moo<Expression>) -> Self {
1247 val.as_ref().clone()
1248 }
1249}
1250
1251impl CategoryOf for Expression {
1252 fn category_of(&self) -> Category {
1253 let category = self.cata(&move |x,children| {
1255
1256 if let Some(max_category) = children.iter().max() {
1257 *max_category
1260 } else {
1261 let mut max_category = Category::Bottom;
1263
1264 if !Biplate::<SubModel>::universe_bi(&x).is_empty() {
1271 return Category::Decision;
1273 }
1274
1275 if let Some(max_atom_category) = Biplate::<Atom>::universe_bi(&x).iter().map(|x| x.category_of()).max()
1277 && max_atom_category > max_category{
1279 max_category = max_atom_category;
1281 }
1282
1283 if let Some(max_declaration_category) = Biplate::<DeclarationPtr>::universe_bi(&x).iter().map(|x| x.category_of()).max()
1285 && max_declaration_category > max_category{
1287 max_category = max_declaration_category;
1289 }
1290 max_category
1291
1292 }
1293 });
1294
1295 if cfg!(debug_assertions) {
1296 trace!(
1297 category= %category,
1298 expression= %self,
1299 "Called Expression::category_of()"
1300 );
1301 };
1302 category
1303 }
1304}
1305
1306impl Display for Expression {
1307 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
1308 match &self {
1309 Expression::Union(_, box1, box2) => {
1310 write!(f, "({} union {})", box1.clone(), box2.clone())
1311 }
1312 Expression::In(_, e1, e2) => {
1313 write!(f, "{e1} in {e2}")
1314 }
1315 Expression::Intersect(_, box1, box2) => {
1316 write!(f, "({} intersect {})", box1.clone(), box2.clone())
1317 }
1318 Expression::Supset(_, box1, box2) => {
1319 write!(f, "({} supset {})", box1.clone(), box2.clone())
1320 }
1321 Expression::SupsetEq(_, box1, box2) => {
1322 write!(f, "({} supsetEq {})", box1.clone(), box2.clone())
1323 }
1324 Expression::Subset(_, box1, box2) => {
1325 write!(f, "({} subset {})", box1.clone(), box2.clone())
1326 }
1327 Expression::SubsetEq(_, box1, box2) => {
1328 write!(f, "({} subsetEq {})", box1.clone(), box2.clone())
1329 }
1330
1331 Expression::AbstractLiteral(_, l) => l.fmt(f),
1332 Expression::Comprehension(_, c) => c.fmt(f),
1333 Expression::UnsafeIndex(_, e1, e2) | Expression::SafeIndex(_, e1, e2) => {
1334 write!(f, "{e1}{}", pretty_vec(e2))
1335 }
1336 Expression::UnsafeSlice(_, e1, es) | Expression::SafeSlice(_, e1, es) => {
1337 let args = es
1338 .iter()
1339 .map(|x| match x {
1340 Some(x) => format!("{x}"),
1341 None => "..".into(),
1342 })
1343 .join(",");
1344
1345 write!(f, "{e1}[{args}]")
1346 }
1347 Expression::InDomain(_, e, domain) => {
1348 write!(f, "__inDomain({e},{domain})")
1349 }
1350 Expression::Root(_, exprs) => {
1351 write!(f, "{}", pretty_expressions_as_top_level(exprs))
1352 }
1353 Expression::DominanceRelation(_, expr) => write!(f, "DominanceRelation({expr})"),
1354 Expression::FromSolution(_, expr) => write!(f, "FromSolution({expr})"),
1355 Expression::Metavar(_, name) => write!(f, "&{name}"),
1356 Expression::Atomic(_, atom) => atom.fmt(f),
1357 Expression::Scope(_, submodel) => write!(f, "{{\n{submodel}\n}}"),
1358 Expression::Abs(_, a) => write!(f, "|{a}|"),
1359 Expression::Sum(_, e) => {
1360 write!(f, "sum({e})")
1361 }
1362 Expression::Product(_, e) => {
1363 write!(f, "product({e})")
1364 }
1365 Expression::Min(_, e) => {
1366 write!(f, "min({e})")
1367 }
1368 Expression::Max(_, e) => {
1369 write!(f, "max({e})")
1370 }
1371 Expression::Not(_, expr_box) => {
1372 write!(f, "!({})", expr_box.clone())
1373 }
1374 Expression::Or(_, e) => {
1375 write!(f, "or({e})")
1376 }
1377 Expression::And(_, e) => {
1378 write!(f, "and({e})")
1379 }
1380 Expression::Imply(_, box1, box2) => {
1381 write!(f, "({box1}) -> ({box2})")
1382 }
1383 Expression::Iff(_, box1, box2) => {
1384 write!(f, "({box1}) <-> ({box2})")
1385 }
1386 Expression::Eq(_, box1, box2) => {
1387 write!(f, "({} = {})", box1.clone(), box2.clone())
1388 }
1389 Expression::Neq(_, box1, box2) => {
1390 write!(f, "({} != {})", box1.clone(), box2.clone())
1391 }
1392 Expression::Geq(_, box1, box2) => {
1393 write!(f, "({} >= {})", box1.clone(), box2.clone())
1394 }
1395 Expression::Leq(_, box1, box2) => {
1396 write!(f, "({} <= {})", box1.clone(), box2.clone())
1397 }
1398 Expression::Gt(_, box1, box2) => {
1399 write!(f, "({} > {})", box1.clone(), box2.clone())
1400 }
1401 Expression::Lt(_, box1, box2) => {
1402 write!(f, "({} < {})", box1.clone(), box2.clone())
1403 }
1404 Expression::FlatSumGeq(_, box1, box2) => {
1405 write!(f, "SumGeq({}, {})", pretty_vec(box1), box2.clone())
1406 }
1407 Expression::FlatSumLeq(_, box1, box2) => {
1408 write!(f, "SumLeq({}, {})", pretty_vec(box1), box2.clone())
1409 }
1410 Expression::FlatIneq(_, box1, box2, box3) => write!(
1411 f,
1412 "Ineq({}, {}, {})",
1413 box1.clone(),
1414 box2.clone(),
1415 box3.clone()
1416 ),
1417 Expression::Flatten(_, n, m) => {
1418 if let Some(n) = n {
1419 write!(f, "flatten({n}, {m})")
1420 } else {
1421 write!(f, "flatten({m})")
1422 }
1423 }
1424 Expression::AllDiff(_, e) => {
1425 write!(f, "allDiff({e})")
1426 }
1427 Expression::Bubble(_, box1, box2) => {
1428 write!(f, "{{{} @ {}}}", box1.clone(), box2.clone())
1429 }
1430 Expression::SafeDiv(_, box1, box2) => {
1431 write!(f, "SafeDiv({}, {})", box1.clone(), box2.clone())
1432 }
1433 Expression::UnsafeDiv(_, box1, box2) => {
1434 write!(f, "UnsafeDiv({}, {})", box1.clone(), box2.clone())
1435 }
1436 Expression::UnsafePow(_, box1, box2) => {
1437 write!(f, "UnsafePow({}, {})", box1.clone(), box2.clone())
1438 }
1439 Expression::SafePow(_, box1, box2) => {
1440 write!(f, "SafePow({}, {})", box1.clone(), box2.clone())
1441 }
1442 Expression::MinionDivEqUndefZero(_, box1, box2, box3) => {
1443 write!(
1444 f,
1445 "DivEq({}, {}, {})",
1446 box1.clone(),
1447 box2.clone(),
1448 box3.clone()
1449 )
1450 }
1451 Expression::MinionModuloEqUndefZero(_, box1, box2, box3) => {
1452 write!(
1453 f,
1454 "ModEq({}, {}, {})",
1455 box1.clone(),
1456 box2.clone(),
1457 box3.clone()
1458 )
1459 }
1460 Expression::FlatWatchedLiteral(_, x, l) => {
1461 write!(f, "WatchedLiteral({x},{l})")
1462 }
1463 Expression::MinionReify(_, box1, box2) => {
1464 write!(f, "Reify({}, {})", box1.clone(), box2.clone())
1465 }
1466 Expression::MinionReifyImply(_, box1, box2) => {
1467 write!(f, "ReifyImply({}, {})", box1.clone(), box2.clone())
1468 }
1469 Expression::MinionWInIntervalSet(_, atom, intervals) => {
1470 let intervals = intervals.iter().join(",");
1471 write!(f, "__minion_w_inintervalset({atom},[{intervals}])")
1472 }
1473 Expression::MinionWInSet(_, atom, values) => {
1474 let values = values.iter().join(",");
1475 write!(f, "__minion_w_inset({atom},{values})")
1476 }
1477 Expression::AuxDeclaration(_, reference, e) => {
1478 write!(f, "{} =aux {}", reference, e.clone())
1479 }
1480 Expression::UnsafeMod(_, a, b) => {
1481 write!(f, "{} % {}", a.clone(), b.clone())
1482 }
1483 Expression::SafeMod(_, a, b) => {
1484 write!(f, "SafeMod({},{})", a.clone(), b.clone())
1485 }
1486 Expression::Neg(_, a) => {
1487 write!(f, "-({})", a.clone())
1488 }
1489 Expression::Minus(_, a, b) => {
1490 write!(f, "({} - {})", a.clone(), b.clone())
1491 }
1492 Expression::FlatAllDiff(_, es) => {
1493 write!(f, "__flat_alldiff({})", pretty_vec(es))
1494 }
1495 Expression::FlatAbsEq(_, a, b) => {
1496 write!(f, "AbsEq({},{})", a.clone(), b.clone())
1497 }
1498 Expression::FlatMinusEq(_, a, b) => {
1499 write!(f, "MinusEq({},{})", a.clone(), b.clone())
1500 }
1501 Expression::FlatProductEq(_, a, b, c) => {
1502 write!(
1503 f,
1504 "FlatProductEq({},{},{})",
1505 a.clone(),
1506 b.clone(),
1507 c.clone()
1508 )
1509 }
1510 Expression::FlatWeightedSumLeq(_, cs, vs, total) => {
1511 write!(
1512 f,
1513 "FlatWeightedSumLeq({},{},{})",
1514 pretty_vec(cs),
1515 pretty_vec(vs),
1516 total.clone()
1517 )
1518 }
1519 Expression::FlatWeightedSumGeq(_, cs, vs, total) => {
1520 write!(
1521 f,
1522 "FlatWeightedSumGeq({},{},{})",
1523 pretty_vec(cs),
1524 pretty_vec(vs),
1525 total.clone()
1526 )
1527 }
1528 Expression::MinionPow(_, atom, atom1, atom2) => {
1529 write!(f, "MinionPow({atom},{atom1},{atom2})")
1530 }
1531 Expression::MinionElementOne(_, atoms, atom, atom1) => {
1532 let atoms = atoms.iter().join(",");
1533 write!(f, "__minion_element_one([{atoms}],{atom},{atom1})")
1534 }
1535
1536 Expression::ToInt(_, expr) => {
1537 write!(f, "toInt({expr})")
1538 }
1539
1540 Expression::SATInt(_, e) => {
1541 write!(f, "SATInt({e})")
1542 }
1543
1544 Expression::PairwiseSum(_, a, b) => write!(f, "PairwiseSum({a}, {b})"),
1545 Expression::PairwiseProduct(_, a, b) => write!(f, "PairwiseProduct({a}, {b})"),
1546
1547 Expression::Defined(_, function) => write!(f, "defined({function})"),
1548 Expression::Range(_, function) => write!(f, "range({function})"),
1549 Expression::Image(_, function, elems) => write!(f, "image({function},{elems})"),
1550 Expression::ImageSet(_, function, elems) => write!(f, "imageSet({function},{elems})"),
1551 Expression::PreImage(_, function, elems) => write!(f, "preImage({function},{elems})"),
1552 Expression::Inverse(_, a, b) => write!(f, "inverse({a},{b})"),
1553 Expression::Restrict(_, function, domain) => write!(f, "restrict({function},{domain})"),
1554
1555 Expression::LexLt(_, a, b) => write!(f, "({a} <lex {b})"),
1556 Expression::LexLeq(_, a, b) => write!(f, "({a} <=lex {b})"),
1557 Expression::LexGt(_, a, b) => write!(f, "({a} >lex {b})"),
1558 Expression::LexGeq(_, a, b) => write!(f, "({a} >=lex {b})"),
1559 Expression::FlatLexLt(_, a, b) => {
1560 write!(f, "FlatLexLt({}, {})", pretty_vec(a), pretty_vec(b))
1561 }
1562 Expression::FlatLexLeq(_, a, b) => {
1563 write!(f, "FlatLexLeq({}, {})", pretty_vec(a), pretty_vec(b))
1564 }
1565 }
1566 }
1567}
1568
1569impl Typeable for Expression {
1570 fn return_type(&self) -> ReturnType {
1571 match self {
1572 Expression::Union(_, subject, _) => ReturnType::Set(Box::new(subject.return_type())),
1573 Expression::Intersect(_, subject, _) => {
1574 ReturnType::Set(Box::new(subject.return_type()))
1575 }
1576 Expression::In(_, _, _) => ReturnType::Bool,
1577 Expression::Supset(_, _, _) => ReturnType::Bool,
1578 Expression::SupsetEq(_, _, _) => ReturnType::Bool,
1579 Expression::Subset(_, _, _) => ReturnType::Bool,
1580 Expression::SubsetEq(_, _, _) => ReturnType::Bool,
1581 Expression::AbstractLiteral(_, lit) => lit.return_type(),
1582 Expression::UnsafeIndex(_, subject, idx) | Expression::SafeIndex(_, subject, idx) => {
1583 let subject_ty = subject.return_type();
1584 match subject_ty {
1585 ReturnType::Matrix(_) => {
1586 let mut elem_typ = subject_ty;
1589 let mut idx_len = idx.len();
1590 while idx_len > 0
1591 && let ReturnType::Matrix(new_elem_typ) = &elem_typ
1592 {
1593 elem_typ = *new_elem_typ.clone();
1594 idx_len -= 1;
1595 }
1596 elem_typ
1597 }
1598 ReturnType::Record(_) | ReturnType::Tuple(_) => ReturnType::Unknown,
1600 _ => bug!(
1601 "Invalid indexing operation: expected the operand to be a collection, got {self}: {subject_ty}"
1602 ),
1603 }
1604 }
1605 Expression::UnsafeSlice(_, subject, _) | Expression::SafeSlice(_, subject, _) => {
1606 ReturnType::Matrix(Box::new(subject.return_type()))
1607 }
1608 Expression::InDomain(_, _, _) => ReturnType::Bool,
1609 Expression::Comprehension(_, comp) => comp.return_type(),
1610 Expression::Root(_, _) => ReturnType::Bool,
1611 Expression::DominanceRelation(_, _) => ReturnType::Bool,
1612 Expression::FromSolution(_, expr) => expr.return_type(),
1613 Expression::Metavar(_, _) => ReturnType::Unknown,
1614 Expression::Atomic(_, atom) => atom.return_type(),
1615 Expression::Scope(_, scope) => scope.return_type(),
1616 Expression::Abs(_, _) => ReturnType::Int,
1617 Expression::Sum(_, _) => ReturnType::Int,
1618 Expression::Product(_, _) => ReturnType::Int,
1619 Expression::Min(_, _) => ReturnType::Int,
1620 Expression::Max(_, _) => ReturnType::Int,
1621 Expression::Not(_, _) => ReturnType::Bool,
1622 Expression::Or(_, _) => ReturnType::Bool,
1623 Expression::Imply(_, _, _) => ReturnType::Bool,
1624 Expression::Iff(_, _, _) => ReturnType::Bool,
1625 Expression::And(_, _) => ReturnType::Bool,
1626 Expression::Eq(_, _, _) => ReturnType::Bool,
1627 Expression::Neq(_, _, _) => ReturnType::Bool,
1628 Expression::Geq(_, _, _) => ReturnType::Bool,
1629 Expression::Leq(_, _, _) => ReturnType::Bool,
1630 Expression::Gt(_, _, _) => ReturnType::Bool,
1631 Expression::Lt(_, _, _) => ReturnType::Bool,
1632 Expression::SafeDiv(_, _, _) => ReturnType::Int,
1633 Expression::UnsafeDiv(_, _, _) => ReturnType::Int,
1634 Expression::FlatAllDiff(_, _) => ReturnType::Bool,
1635 Expression::FlatSumGeq(_, _, _) => ReturnType::Bool,
1636 Expression::FlatSumLeq(_, _, _) => ReturnType::Bool,
1637 Expression::MinionDivEqUndefZero(_, _, _, _) => ReturnType::Bool,
1638 Expression::FlatIneq(_, _, _, _) => ReturnType::Bool,
1639 Expression::Flatten(_, _, matrix) => {
1640 let matrix_type = matrix.return_type();
1641 match matrix_type {
1642 ReturnType::Matrix(_) => {
1643 let mut elem_type = matrix_type;
1645 while let ReturnType::Matrix(new_elem_type) = &elem_type {
1646 elem_type = *new_elem_type.clone();
1647 }
1648 ReturnType::Matrix(Box::new(elem_type))
1649 }
1650 _ => bug!(
1651 "Invalid indexing operation: expected the operand to be a collection, got {self}: {matrix_type}"
1652 ),
1653 }
1654 }
1655 Expression::AllDiff(_, _) => ReturnType::Bool,
1656 Expression::Bubble(_, inner, _) => inner.return_type(),
1657 Expression::FlatWatchedLiteral(_, _, _) => ReturnType::Bool,
1658 Expression::MinionReify(_, _, _) => ReturnType::Bool,
1659 Expression::MinionReifyImply(_, _, _) => ReturnType::Bool,
1660 Expression::MinionWInIntervalSet(_, _, _) => ReturnType::Bool,
1661 Expression::MinionWInSet(_, _, _) => ReturnType::Bool,
1662 Expression::MinionElementOne(_, _, _, _) => ReturnType::Bool,
1663 Expression::AuxDeclaration(_, _, _) => ReturnType::Bool,
1664 Expression::UnsafeMod(_, _, _) => ReturnType::Int,
1665 Expression::SafeMod(_, _, _) => ReturnType::Int,
1666 Expression::MinionModuloEqUndefZero(_, _, _, _) => ReturnType::Bool,
1667 Expression::Neg(_, _) => ReturnType::Int,
1668 Expression::UnsafePow(_, _, _) => ReturnType::Int,
1669 Expression::SafePow(_, _, _) => ReturnType::Int,
1670 Expression::Minus(_, _, _) => ReturnType::Int,
1671 Expression::FlatAbsEq(_, _, _) => ReturnType::Bool,
1672 Expression::FlatMinusEq(_, _, _) => ReturnType::Bool,
1673 Expression::FlatProductEq(_, _, _, _) => ReturnType::Bool,
1674 Expression::FlatWeightedSumLeq(_, _, _, _) => ReturnType::Bool,
1675 Expression::FlatWeightedSumGeq(_, _, _, _) => ReturnType::Bool,
1676 Expression::MinionPow(_, _, _, _) => ReturnType::Bool,
1677 Expression::ToInt(_, _) => ReturnType::Int,
1678 Expression::SATInt(_, _) => ReturnType::Int,
1679 Expression::PairwiseSum(_, _, _) => ReturnType::Int,
1680 Expression::PairwiseProduct(_, _, _) => ReturnType::Int,
1681 Expression::Defined(_, function) => {
1682 let subject = function.return_type();
1683 match subject {
1684 ReturnType::Function(domain, _) => *domain,
1685 _ => bug!(
1686 "Invalid defined operation: expected the operand to be a function, got {self}: {subject}"
1687 ),
1688 }
1689 }
1690 Expression::Range(_, function) => {
1691 let subject = function.return_type();
1692 match subject {
1693 ReturnType::Function(_, codomain) => *codomain,
1694 _ => bug!(
1695 "Invalid range operation: expected the operand to be a function, got {self}: {subject}"
1696 ),
1697 }
1698 }
1699 Expression::Image(_, function, _) => {
1700 let subject = function.return_type();
1701 match subject {
1702 ReturnType::Function(_, codomain) => *codomain,
1703 _ => bug!(
1704 "Invalid image operation: expected the operand to be a function, got {self}: {subject}"
1705 ),
1706 }
1707 }
1708 Expression::ImageSet(_, function, _) => {
1709 let subject = function.return_type();
1710 match subject {
1711 ReturnType::Function(_, codomain) => *codomain,
1712 _ => bug!(
1713 "Invalid imageSet operation: expected the operand to be a function, got {self}: {subject}"
1714 ),
1715 }
1716 }
1717 Expression::PreImage(_, function, _) => {
1718 let subject = function.return_type();
1719 match subject {
1720 ReturnType::Function(domain, _) => *domain,
1721 _ => bug!(
1722 "Invalid preImage operation: expected the operand to be a function, got {self}: {subject}"
1723 ),
1724 }
1725 }
1726 Expression::Restrict(_, function, new_domain) => {
1727 let subject = function.return_type();
1728 match subject {
1729 ReturnType::Function(_, codomain) => {
1730 ReturnType::Function(Box::new(new_domain.return_type()), codomain)
1731 }
1732 _ => bug!(
1733 "Invalid preImage operation: expected the operand to be a function, got {self}: {subject}"
1734 ),
1735 }
1736 }
1737 Expression::Inverse(..) => ReturnType::Bool,
1738 Expression::LexLt(..) => ReturnType::Bool,
1739 Expression::LexGt(..) => ReturnType::Bool,
1740 Expression::LexLeq(..) => ReturnType::Bool,
1741 Expression::LexGeq(..) => ReturnType::Bool,
1742 Expression::FlatLexLt(..) => ReturnType::Bool,
1743 Expression::FlatLexLeq(..) => ReturnType::Bool,
1744 }
1745 }
1746}
1747
1748#[cfg(test)]
1749mod tests {
1750
1751 use crate::matrix_expr;
1752
1753 use super::*;
1754
1755 #[test]
1756 fn test_domain_of_constant_sum() {
1757 let c1 = Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Int(1)));
1758 let c2 = Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Int(2)));
1759 let sum = Expression::Sum(Metadata::new(), Moo::new(matrix_expr![c1, c2]));
1760 assert_eq!(sum.domain_of(), Some(Domain::int(vec![Range::Single(3)])));
1761 }
1762
1763 #[test]
1764 fn test_domain_of_constant_invalid_type() {
1765 let c1 = Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Int(1)));
1766 let c2 = Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Bool(true)));
1767 let sum = Expression::Sum(Metadata::new(), Moo::new(matrix_expr![c1, c2]));
1768 assert_eq!(sum.domain_of(), None);
1769 }
1770
1771 #[test]
1772 fn test_domain_of_empty_sum() {
1773 let sum = Expression::Sum(Metadata::new(), Moo::new(matrix_expr![]));
1774 assert_eq!(sum.domain_of(), None);
1775 }
1776}