1use crate::ast::declaration::serde::DeclarationPtrAsId;
2use serde_with::serde_as;
3use std::collections::{HashSet, VecDeque};
4use std::fmt::{Display, Formatter};
5use tracing::trace;
6
7use crate::ast::Atom;
8use crate::ast::Moo;
9use crate::ast::Name;
10use crate::ast::ReturnType;
11use crate::ast::SetAttr;
12use crate::ast::literals::AbstractLiteral;
13use crate::ast::literals::Literal;
14use crate::ast::pretty::{pretty_expressions_as_top_level, pretty_vec};
15use crate::bug;
16use crate::metadata::Metadata;
17use enum_compatability_macro::document_compatibility;
18use itertools::Itertools;
19use serde::{Deserialize, Serialize};
20
21use uniplate::{Biplate, Uniplate};
22
23use super::ac_operators::ACOperatorKind;
24use super::categories::{Category, CategoryOf};
25use super::comprehension::Comprehension;
26use super::domains::HasDomain as _;
27use super::records::RecordValue;
28use super::{DeclarationPtr, Domain, Range, SubModel, Typeable};
29
30static_assertions::assert_eq_size!([u8; 104], Expression);
53
54#[document_compatibility]
59#[serde_as]
60#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize, Uniplate)]
61#[biplate(to=Metadata)]
62#[biplate(to=Atom)]
63#[biplate(to=DeclarationPtr)]
64#[biplate(to=Name)]
65#[biplate(to=Vec<Expression>)]
66#[biplate(to=Option<Expression>)]
67#[biplate(to=SubModel)]
68#[biplate(to=Comprehension)]
69#[biplate(to=AbstractLiteral<Expression>)]
70#[biplate(to=AbstractLiteral<Literal>)]
71#[biplate(to=RecordValue<Expression>)]
72#[biplate(to=RecordValue<Literal>)]
73#[biplate(to=Literal)]
74pub enum Expression {
75 AbstractLiteral(Metadata, AbstractLiteral<Expression>),
76 Root(Metadata, Vec<Expression>),
78
79 Bubble(Metadata, Moo<Expression>, Moo<Expression>),
82
83 Comprehension(Metadata, Moo<Comprehension>),
87
88 DominanceRelation(Metadata, Moo<Expression>),
90 FromSolution(Metadata, Moo<Expression>),
92
93 Atomic(Metadata, Atom),
94
95 #[compatible(JsonInput)]
99 UnsafeIndex(Metadata, Moo<Expression>, Vec<Expression>),
100
101 SafeIndex(Metadata, Moo<Expression>, Vec<Expression>),
105
106 #[compatible(JsonInput)]
116 UnsafeSlice(Metadata, Moo<Expression>, Vec<Option<Expression>>),
117
118 SafeSlice(Metadata, Moo<Expression>, Vec<Option<Expression>>),
122
123 InDomain(Metadata, Moo<Expression>, Domain),
129
130 ToInt(Metadata, Moo<Expression>),
136
137 Scope(Metadata, Moo<SubModel>),
138
139 #[compatible(JsonInput)]
141 Abs(Metadata, Moo<Expression>),
142
143 #[compatible(JsonInput)]
145 Sum(Metadata, Moo<Expression>),
146
147 #[compatible(JsonInput)]
149 Product(Metadata, Moo<Expression>),
150
151 #[compatible(JsonInput)]
153 Min(Metadata, Moo<Expression>),
154
155 #[compatible(JsonInput)]
157 Max(Metadata, Moo<Expression>),
158
159 #[compatible(JsonInput, SAT)]
161 Not(Metadata, Moo<Expression>),
162
163 #[compatible(JsonInput, SAT)]
165 Or(Metadata, Moo<Expression>),
166
167 #[compatible(JsonInput, SAT)]
169 And(Metadata, Moo<Expression>),
170
171 #[compatible(JsonInput)]
173 Imply(Metadata, Moo<Expression>, Moo<Expression>),
174
175 #[compatible(JsonInput)]
177 Iff(Metadata, Moo<Expression>, Moo<Expression>),
178
179 #[compatible(JsonInput)]
180 Union(Metadata, Moo<Expression>, Moo<Expression>),
181
182 #[compatible(JsonInput)]
183 In(Metadata, Moo<Expression>, Moo<Expression>),
184
185 #[compatible(JsonInput)]
186 Intersect(Metadata, Moo<Expression>, Moo<Expression>),
187
188 #[compatible(JsonInput)]
189 Supset(Metadata, Moo<Expression>, Moo<Expression>),
190
191 #[compatible(JsonInput)]
192 SupsetEq(Metadata, Moo<Expression>, Moo<Expression>),
193
194 #[compatible(JsonInput)]
195 Subset(Metadata, Moo<Expression>, Moo<Expression>),
196
197 #[compatible(JsonInput)]
198 SubsetEq(Metadata, Moo<Expression>, Moo<Expression>),
199
200 #[compatible(JsonInput)]
201 Eq(Metadata, Moo<Expression>, Moo<Expression>),
202
203 #[compatible(JsonInput)]
204 Neq(Metadata, Moo<Expression>, Moo<Expression>),
205
206 #[compatible(JsonInput)]
207 Geq(Metadata, Moo<Expression>, Moo<Expression>),
208
209 #[compatible(JsonInput)]
210 Leq(Metadata, Moo<Expression>, Moo<Expression>),
211
212 #[compatible(JsonInput)]
213 Gt(Metadata, Moo<Expression>, Moo<Expression>),
214
215 #[compatible(JsonInput)]
216 Lt(Metadata, Moo<Expression>, Moo<Expression>),
217
218 SafeDiv(Metadata, Moo<Expression>, Moo<Expression>),
220
221 #[compatible(JsonInput)]
223 UnsafeDiv(Metadata, Moo<Expression>, Moo<Expression>),
224
225 SafeMod(Metadata, Moo<Expression>, Moo<Expression>),
227
228 #[compatible(JsonInput)]
230 UnsafeMod(Metadata, Moo<Expression>, Moo<Expression>),
231
232 #[compatible(JsonInput)]
234 Neg(Metadata, Moo<Expression>),
235
236 #[compatible(JsonInput)]
240 UnsafePow(Metadata, Moo<Expression>, Moo<Expression>),
241
242 SafePow(Metadata, Moo<Expression>, Moo<Expression>),
244
245 #[compatible(JsonInput)]
247 AllDiff(Metadata, Moo<Expression>),
248
249 #[compatible(JsonInput)]
255 Minus(Metadata, Moo<Expression>, Moo<Expression>),
256
257 #[compatible(Minion)]
265 FlatAbsEq(Metadata, Moo<Atom>, Moo<Atom>),
266
267 #[compatible(Minion)]
275 FlatAllDiff(Metadata, Vec<Atom>),
276
277 #[compatible(Minion)]
285 FlatSumGeq(Metadata, Vec<Atom>, Atom),
286
287 #[compatible(Minion)]
295 FlatSumLeq(Metadata, Vec<Atom>, Atom),
296
297 #[compatible(Minion)]
305 FlatIneq(Metadata, Moo<Atom>, Moo<Atom>, Box<Literal>),
306
307 #[compatible(Minion)]
320 FlatWatchedLiteral(
321 Metadata,
322 #[serde_as(as = "DeclarationPtrAsId")] DeclarationPtr,
323 Literal,
324 ),
325
326 FlatWeightedSumLeq(Metadata, Vec<Literal>, Vec<Atom>, Moo<Atom>),
338
339 FlatWeightedSumGeq(Metadata, Vec<Literal>, Vec<Atom>, Moo<Atom>),
351
352 #[compatible(Minion)]
360 FlatMinusEq(Metadata, Moo<Atom>, Moo<Atom>),
361
362 #[compatible(Minion)]
370 FlatProductEq(Metadata, Moo<Atom>, Moo<Atom>, Moo<Atom>),
371
372 #[compatible(Minion)]
380 MinionDivEqUndefZero(Metadata, Moo<Atom>, Moo<Atom>, Moo<Atom>),
381
382 #[compatible(Minion)]
390 MinionModuloEqUndefZero(Metadata, Moo<Atom>, Moo<Atom>, Moo<Atom>),
391
392 MinionPow(Metadata, Moo<Atom>, Moo<Atom>, Moo<Atom>),
404
405 #[compatible(Minion)]
414 MinionReify(Metadata, Moo<Expression>, Atom),
415
416 #[compatible(Minion)]
425 MinionReifyImply(Metadata, Moo<Expression>, Atom),
426
427 #[compatible(Minion)]
438 MinionWInIntervalSet(Metadata, Atom, Vec<i32>),
439
440 #[compatible(Minion)]
452 MinionWInSet(Metadata, Atom, Vec<i32>),
453
454 #[compatible(Minion)]
463 MinionElementOne(Metadata, Vec<Atom>, Moo<Atom>, Moo<Atom>),
464
465 #[compatible(Minion)]
469 AuxDeclaration(
470 Metadata,
471 #[serde_as(as = "DeclarationPtrAsId")] DeclarationPtr,
472 Moo<Expression>,
473 ),
474}
475
476fn bounded_i32_domain_for_matrix_literal_monotonic(
483 e: &Expression,
484 op: fn(i32, i32) -> Option<i32>,
485) -> Option<Domain> {
486 let (mut exprs, _) = e.clone().unwrap_matrix_unchecked()?;
488
489 let expr = exprs.pop()?;
505 let Some(Domain::Int(ranges)) = expr.domain_of() else {
506 return None;
507 };
508
509 let (mut current_min, mut current_max) = range_vec_bounds_i32(&ranges)?;
510
511 for expr in exprs {
512 let Some(Domain::Int(ranges)) = expr.domain_of() else {
513 return None;
514 };
515
516 let (min, max) = range_vec_bounds_i32(&ranges)?;
517
518 let minmax = op(min, current_max)?;
520 let minmin = op(min, current_min)?;
521 let maxmin = op(max, current_min)?;
522 let maxmax = op(max, current_max)?;
523 let vals = [minmax, minmin, maxmin, maxmax];
524
525 current_min = *vals
526 .iter()
527 .min()
528 .expect("vals iterator should not be empty, and should have a minimum.");
529 current_max = *vals
530 .iter()
531 .max()
532 .expect("vals iterator should not be empty, and should have a maximum.");
533 }
534
535 if current_min == current_max {
536 Some(Domain::Int(vec![Range::Single(current_min)]))
537 } else {
538 Some(Domain::Int(vec![Range::Bounded(current_min, current_max)]))
539 }
540}
541
542fn range_vec_bounds_i32(ranges: &Vec<Range<i32>>) -> Option<(i32, i32)> {
544 let mut min = i32::MAX;
545 let mut max = i32::MIN;
546 for r in ranges {
547 match r {
548 Range::Single(i) => {
549 if *i < min {
550 min = *i;
551 }
552 if *i > max {
553 max = *i;
554 }
555 }
556 Range::Bounded(i, j) => {
557 if *i < min {
558 min = *i;
559 }
560 if *j > max {
561 max = *j;
562 }
563 }
564 Range::UnboundedR(_) | Range::UnboundedL(_) => return None,
565 }
566 }
567 Some((min, max))
568}
569
570impl Expression {
571 pub fn domain_of(&self) -> Option<Domain> {
573 let ret = match self {
574 Expression::Union(_, a, b) => Some(Domain::Set(
575 SetAttr::None,
576 Box::new(a.domain_of()?.union(&b.domain_of()?).ok()?),
577 )),
578 Expression::Intersect(_, a, b) => Some(Domain::Set(
579 SetAttr::None,
580 Box::new(a.domain_of()?.intersect(&b.domain_of()?).ok()?),
581 )),
582 Expression::In(_, _, _) => Some(Domain::Bool),
583 Expression::Supset(_, _, _) => Some(Domain::Bool),
584 Expression::SupsetEq(_, _, _) => Some(Domain::Bool),
585 Expression::Subset(_, _, _) => Some(Domain::Bool),
586 Expression::SubsetEq(_, _, _) => Some(Domain::Bool),
587 Expression::AbstractLiteral(_, abslit) => abslit.domain_of(),
588 Expression::DominanceRelation(_, _) => Some(Domain::Bool),
589 Expression::FromSolution(_, expr) => expr.domain_of(),
590 Expression::Comprehension(_, comprehension) => comprehension.domain_of(),
591 Expression::UnsafeIndex(_, matrix, _) | Expression::SafeIndex(_, matrix, _) => {
592 match matrix.domain_of()? {
593 Domain::Matrix(elem_domain, _) => Some(*elem_domain),
594 Domain::Tuple(_) => None,
595 Domain::Record(_) => None,
596 _ => {
597 bug!("subject of an index operation should support indexing")
598 }
599 }
600 }
601 Expression::UnsafeSlice(_, matrix, indices)
602 | Expression::SafeSlice(_, matrix, indices) => {
603 let sliced_dimension = indices.iter().position(Option::is_none);
604
605 let Domain::Matrix(elem_domain, index_domains) = matrix.domain_of()? else {
606 bug!("subject of an index operation should be a matrix");
607 };
608
609 match sliced_dimension {
610 Some(dimension) => Some(Domain::Matrix(
611 elem_domain,
612 vec![index_domains[dimension].clone()],
613 )),
614
615 None => Some(*elem_domain),
617 }
618 }
619 Expression::InDomain(_, _, _) => Some(Domain::Bool),
620 Expression::Atomic(_, Atom::Reference(ptr)) => ptr.domain(),
621 Expression::Atomic(_, atom) => Some(atom.domain_of()),
622 Expression::Scope(_, _) => Some(Domain::Bool),
623 Expression::Sum(_, e) => {
624 bounded_i32_domain_for_matrix_literal_monotonic(e, |x, y| Some(x + y))
625 }
626 Expression::Product(_, e) => {
627 bounded_i32_domain_for_matrix_literal_monotonic(e, |x, y| Some(x * y))
628 }
629 Expression::Min(_, e) => bounded_i32_domain_for_matrix_literal_monotonic(e, |x, y| {
630 Some(if x < y { x } else { y })
631 }),
632 Expression::Max(_, e) => bounded_i32_domain_for_matrix_literal_monotonic(e, |x, y| {
633 Some(if x > y { x } else { y })
634 }),
635 Expression::UnsafeDiv(_, a, b) => a
636 .domain_of()?
637 .apply_i32(
638 |x, y| {
641 if y != 0 {
642 Some((x as f32 / y as f32).floor() as i32)
643 } else {
644 None
645 }
646 },
647 &b.domain_of()?,
648 )
649 .ok(),
650 Expression::SafeDiv(_, a, b) => {
651 let domain = a.domain_of()?.apply_i32(
654 |x, y| {
655 if y != 0 {
656 Some((x as f32 / y as f32).floor() as i32)
657 } else {
658 None
659 }
660 },
661 &b.domain_of()?,
662 );
663
664 match domain {
665 Ok(Domain::Int(ranges)) => {
666 let mut ranges = ranges;
667 ranges.push(Range::Single(0));
668 Some(Domain::Int(ranges))
669 }
670 Err(_) => todo!(),
671 _ => unreachable!(),
672 }
673 }
674 Expression::UnsafeMod(_, a, b) => a
675 .domain_of()?
676 .apply_i32(
677 |x, y| if y != 0 { Some(x % y) } else { None },
678 &b.domain_of()?,
679 )
680 .ok(),
681 Expression::SafeMod(_, a, b) => {
682 let domain = a.domain_of()?.apply_i32(
683 |x, y| if y != 0 { Some(x % y) } else { None },
684 &b.domain_of()?,
685 );
686
687 match domain {
688 Ok(Domain::Int(ranges)) => {
689 let mut ranges = ranges;
690 ranges.push(Range::Single(0));
691 Some(Domain::Int(ranges))
692 }
693 Err(_) => todo!(),
694 _ => unreachable!(),
695 }
696 }
697 Expression::SafePow(_, a, b) | Expression::UnsafePow(_, a, b) => a
698 .domain_of()?
699 .apply_i32(
700 |x, y| {
701 if (x != 0 || y != 0) && y >= 0 {
702 Some(x.pow(y as u32))
703 } else {
704 None
705 }
706 },
707 &b.domain_of()?,
708 )
709 .ok(),
710 Expression::Root(_, _) => None,
711 Expression::Bubble(_, inner, _) => inner.domain_of(),
712 Expression::AuxDeclaration(_, _, _) => Some(Domain::Bool),
713 Expression::And(_, _) => Some(Domain::Bool),
714 Expression::Not(_, _) => Some(Domain::Bool),
715 Expression::Or(_, _) => Some(Domain::Bool),
716 Expression::Imply(_, _, _) => Some(Domain::Bool),
717 Expression::Iff(_, _, _) => Some(Domain::Bool),
718 Expression::Eq(_, _, _) => Some(Domain::Bool),
719 Expression::Neq(_, _, _) => Some(Domain::Bool),
720 Expression::Geq(_, _, _) => Some(Domain::Bool),
721 Expression::Leq(_, _, _) => Some(Domain::Bool),
722 Expression::Gt(_, _, _) => Some(Domain::Bool),
723 Expression::Lt(_, _, _) => Some(Domain::Bool),
724 Expression::FlatAbsEq(_, _, _) => Some(Domain::Bool),
725 Expression::FlatSumGeq(_, _, _) => Some(Domain::Bool),
726 Expression::FlatSumLeq(_, _, _) => Some(Domain::Bool),
727 Expression::MinionDivEqUndefZero(_, _, _, _) => Some(Domain::Bool),
728 Expression::MinionModuloEqUndefZero(_, _, _, _) => Some(Domain::Bool),
729 Expression::FlatIneq(_, _, _, _) => Some(Domain::Bool),
730 Expression::AllDiff(_, _) => Some(Domain::Bool),
731 Expression::FlatWatchedLiteral(_, _, _) => Some(Domain::Bool),
732 Expression::MinionReify(_, _, _) => Some(Domain::Bool),
733 Expression::MinionReifyImply(_, _, _) => Some(Domain::Bool),
734 Expression::MinionWInIntervalSet(_, _, _) => Some(Domain::Bool),
735 Expression::MinionWInSet(_, _, _) => Some(Domain::Bool),
736 Expression::MinionElementOne(_, _, _, _) => Some(Domain::Bool),
737 Expression::Neg(_, x) => {
738 let Some(Domain::Int(mut ranges)) = x.domain_of() else {
739 return None;
740 };
741
742 for range in ranges.iter_mut() {
743 *range = match range {
744 Range::Single(x) => Range::Single(-*x),
745 Range::Bounded(x, y) => Range::Bounded(-*y, -*x),
746 Range::UnboundedR(i) => Range::UnboundedL(-*i),
747 Range::UnboundedL(i) => Range::UnboundedR(-*i),
748 };
749 }
750
751 Some(Domain::Int(ranges))
752 }
753 Expression::Minus(_, a, b) => a
754 .domain_of()?
755 .apply_i32(|x, y| Some(x - y), &b.domain_of()?)
756 .ok(),
757 Expression::FlatAllDiff(_, _) => Some(Domain::Bool),
758 Expression::FlatMinusEq(_, _, _) => Some(Domain::Bool),
759 Expression::FlatProductEq(_, _, _, _) => Some(Domain::Bool),
760 Expression::FlatWeightedSumLeq(_, _, _, _) => Some(Domain::Bool),
761 Expression::FlatWeightedSumGeq(_, _, _, _) => Some(Domain::Bool),
762 Expression::Abs(_, a) => a
763 .domain_of()?
764 .apply_i32(|a, _| Some(a.abs()), &a.domain_of()?)
765 .ok(),
766 Expression::MinionPow(_, _, _, _) => Some(Domain::Bool),
767 Expression::ToInt(_, _) => Some(Domain::Int(vec![Range::Bounded(0, 1)])),
768 };
769 match ret {
770 Some(Domain::Int(ranges)) if ranges.len() > 1 => {
773 let (min, max) = range_vec_bounds_i32(&ranges)?;
774 Some(Domain::Int(vec![Range::Bounded(min, max)]))
775 }
776 _ => ret,
777 }
778 }
779
780 pub fn get_meta(&self) -> Metadata {
781 let metas: VecDeque<Metadata> = self.children_bi();
782 metas[0].clone()
783 }
784
785 pub fn set_meta(&self, meta: Metadata) {
786 self.transform_bi(&|_| meta.clone());
787 }
788
789 pub fn is_safe(&self) -> bool {
796 for expr in self.universe() {
798 match expr {
799 Expression::UnsafeDiv(_, _, _)
800 | Expression::UnsafeMod(_, _, _)
801 | Expression::UnsafePow(_, _, _)
802 | Expression::UnsafeIndex(_, _, _)
803 | Expression::Bubble(_, _, _)
804 | Expression::UnsafeSlice(_, _, _) => {
805 return false;
806 }
807 _ => {}
808 }
809 }
810 true
811 }
812
813 pub fn is_clean(&self) -> bool {
814 let metadata = self.get_meta();
815 metadata.clean
816 }
817
818 pub fn set_clean(&mut self, bool_value: bool) {
819 let mut metadata = self.get_meta();
820 metadata.clean = bool_value;
821 self.set_meta(metadata);
822 }
823
824 pub fn is_associative_commutative_operator(&self) -> bool {
826 TryInto::<ACOperatorKind>::try_into(self).is_ok()
827 }
828
829 pub fn is_matrix_literal(&self) -> bool {
834 matches!(
835 self,
836 Expression::AbstractLiteral(_, AbstractLiteral::Matrix(_, _))
837 | Expression::Atomic(
838 _,
839 Atom::Literal(Literal::AbstractLiteral(AbstractLiteral::Matrix(_, _))),
840 )
841 )
842 }
843
844 pub fn identical_atom_to(&self, other: &Expression) -> bool {
850 let atom1: Result<&Atom, _> = self.try_into();
851 let atom2: Result<&Atom, _> = other.try_into();
852
853 if let (Ok(atom1), Ok(atom2)) = (atom1, atom2) {
854 atom2 == atom1
855 } else {
856 false
857 }
858 }
859
860 pub fn unwrap_list(self) -> Option<Vec<Expression>> {
865 match self {
866 Expression::AbstractLiteral(_, matrix @ AbstractLiteral::Matrix(_, _)) => {
867 matrix.unwrap_list().cloned()
868 }
869 Expression::Atomic(
870 _,
871 Atom::Literal(Literal::AbstractLiteral(matrix @ AbstractLiteral::Matrix(_, _))),
872 ) => matrix.unwrap_list().map(|elems| {
873 elems
874 .clone()
875 .into_iter()
876 .map(|x: Literal| Expression::Atomic(Metadata::new(), Atom::Literal(x)))
877 .collect_vec()
878 }),
879 _ => None,
880 }
881 }
882
883 pub fn unwrap_matrix_unchecked(self) -> Option<(Vec<Expression>, Domain)> {
891 match self {
892 Expression::AbstractLiteral(_, AbstractLiteral::Matrix(elems, domain)) => {
893 Some((elems.clone(), *domain))
894 }
895 Expression::Atomic(
896 _,
897 Atom::Literal(Literal::AbstractLiteral(AbstractLiteral::Matrix(elems, domain))),
898 ) => Some((
899 elems
900 .clone()
901 .into_iter()
902 .map(|x: Literal| Expression::Atomic(Metadata::new(), Atom::Literal(x)))
903 .collect_vec(),
904 *domain,
905 )),
906
907 _ => None,
908 }
909 }
910
911 pub fn extend_root(self, exprs: Vec<Expression>) -> Expression {
916 match self {
917 Expression::Root(meta, mut children) => {
918 children.extend(exprs);
919 Expression::Root(meta, children)
920 }
921 _ => panic!("extend_root called on a non-Root expression"),
922 }
923 }
924
925 pub fn into_literal(self) -> Option<Literal> {
927 match self {
928 Expression::Atomic(_, Atom::Literal(lit)) => Some(lit),
929 Expression::AbstractLiteral(_, abslit) => {
930 Some(Literal::AbstractLiteral(abslit.clone().into_literals()?))
931 }
932 Expression::Neg(_, e) => {
933 let Literal::Int(i) = Moo::unwrap_or_clone(e).into_literal()? else {
934 bug!("negated literal should be an int");
935 };
936
937 Some(Literal::Int(-i))
938 }
939
940 _ => None,
941 }
942 }
943
944 pub fn to_ac_operator_kind(&self) -> Option<ACOperatorKind> {
946 TryFrom::try_from(self).ok()
947 }
948
949 pub fn universe_categories(&self) -> HashSet<Category> {
951 self.universe()
952 .into_iter()
953 .map(|x| x.category_of())
954 .collect()
955 }
956}
957
958impl TryFrom<&Expression> for i32 {
959 type Error = ();
960
961 fn try_from(value: &Expression) -> Result<Self, Self::Error> {
962 let Expression::Atomic(_, atom) = value else {
963 return Err(());
964 };
965
966 let Atom::Literal(lit) = atom else {
967 return Err(());
968 };
969
970 let Literal::Int(i) = lit else {
971 return Err(());
972 };
973
974 Ok(*i)
975 }
976}
977
978impl TryFrom<Expression> for i32 {
979 type Error = ();
980
981 fn try_from(value: Expression) -> Result<Self, Self::Error> {
982 TryFrom::<&Expression>::try_from(&value)
983 }
984}
985impl From<i32> for Expression {
986 fn from(i: i32) -> Self {
987 Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Int(i)))
988 }
989}
990
991impl From<bool> for Expression {
992 fn from(b: bool) -> Self {
993 Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Bool(b)))
994 }
995}
996
997impl From<Atom> for Expression {
998 fn from(value: Atom) -> Self {
999 Expression::Atomic(Metadata::new(), value)
1000 }
1001}
1002
1003impl From<Moo<Expression>> for Expression {
1004 fn from(val: Moo<Expression>) -> Self {
1005 val.as_ref().clone()
1006 }
1007}
1008
1009impl CategoryOf for Expression {
1010 fn category_of(&self) -> Category {
1011 let category = self.cata(&move |x,children| {
1013
1014 if let Some(max_category) = children.iter().max() {
1015 *max_category
1018 } else {
1019 let mut max_category = Category::Bottom;
1021
1022 if !Biplate::<SubModel>::universe_bi(&x).is_empty() {
1029 return Category::Decision;
1031 }
1032
1033 if let Some(max_atom_category) = Biplate::<Atom>::universe_bi(&x).iter().map(|x| x.category_of()).max()
1035 && max_atom_category > max_category{
1037 max_category = max_atom_category;
1039 }
1040
1041 if let Some(max_declaration_category) = Biplate::<DeclarationPtr>::universe_bi(&x).iter().map(|x| x.category_of()).max()
1043 && max_declaration_category > max_category{
1045 max_category = max_declaration_category;
1047 }
1048 max_category
1049
1050 }
1051 });
1052
1053 if cfg!(debug_assertions) {
1054 trace!(
1055 category= %category,
1056 expression= %self,
1057 "Called Expression::category_of()"
1058 );
1059 };
1060 category
1061 }
1062}
1063
1064impl Display for Expression {
1065 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
1066 match &self {
1067 Expression::Union(_, box1, box2) => {
1068 write!(f, "({} union {})", box1.clone(), box2.clone())
1069 }
1070 Expression::In(_, e1, e2) => {
1071 write!(f, "{e1} in {e2}")
1072 }
1073 Expression::Intersect(_, box1, box2) => {
1074 write!(f, "({} intersect {})", box1.clone(), box2.clone())
1075 }
1076 Expression::Supset(_, box1, box2) => {
1077 write!(f, "({} supset {})", box1.clone(), box2.clone())
1078 }
1079 Expression::SupsetEq(_, box1, box2) => {
1080 write!(f, "({} supsetEq {})", box1.clone(), box2.clone())
1081 }
1082 Expression::Subset(_, box1, box2) => {
1083 write!(f, "({} subset {})", box1.clone(), box2.clone())
1084 }
1085 Expression::SubsetEq(_, box1, box2) => {
1086 write!(f, "({} subsetEq {})", box1.clone(), box2.clone())
1087 }
1088
1089 Expression::AbstractLiteral(_, l) => l.fmt(f),
1090 Expression::Comprehension(_, c) => c.fmt(f),
1091 Expression::UnsafeIndex(_, e1, e2) | Expression::SafeIndex(_, e1, e2) => {
1092 write!(f, "{e1}{}", pretty_vec(e2))
1093 }
1094 Expression::UnsafeSlice(_, e1, es) | Expression::SafeSlice(_, e1, es) => {
1095 let args = es
1096 .iter()
1097 .map(|x| match x {
1098 Some(x) => format!("{x}"),
1099 None => "..".into(),
1100 })
1101 .join(",");
1102
1103 write!(f, "{e1}[{args}]")
1104 }
1105 Expression::InDomain(_, e, domain) => {
1106 write!(f, "__inDomain({e},{domain})")
1107 }
1108 Expression::Root(_, exprs) => {
1109 write!(f, "{}", pretty_expressions_as_top_level(exprs))
1110 }
1111 Expression::DominanceRelation(_, expr) => write!(f, "DominanceRelation({expr})"),
1112 Expression::FromSolution(_, expr) => write!(f, "FromSolution({expr})"),
1113 Expression::Atomic(_, atom) => atom.fmt(f),
1114 Expression::Scope(_, submodel) => write!(f, "{{\n{submodel}\n}}"),
1115 Expression::Abs(_, a) => write!(f, "|{a}|"),
1116 Expression::Sum(_, e) => {
1117 write!(f, "sum({e})")
1118 }
1119 Expression::Product(_, e) => {
1120 write!(f, "product({e})")
1121 }
1122 Expression::Min(_, e) => {
1123 write!(f, "min({e})")
1124 }
1125 Expression::Max(_, e) => {
1126 write!(f, "max({e})")
1127 }
1128 Expression::Not(_, expr_box) => {
1129 write!(f, "!({})", expr_box.clone())
1130 }
1131 Expression::Or(_, e) => {
1132 write!(f, "or({e})")
1133 }
1134 Expression::And(_, e) => {
1135 write!(f, "and({e})")
1136 }
1137 Expression::Imply(_, box1, box2) => {
1138 write!(f, "({box1}) -> ({box2})")
1139 }
1140 Expression::Iff(_, box1, box2) => {
1141 write!(f, "({box1}) <-> ({box2})")
1142 }
1143 Expression::Eq(_, box1, box2) => {
1144 write!(f, "({} = {})", box1.clone(), box2.clone())
1145 }
1146 Expression::Neq(_, box1, box2) => {
1147 write!(f, "({} != {})", box1.clone(), box2.clone())
1148 }
1149 Expression::Geq(_, box1, box2) => {
1150 write!(f, "({} >= {})", box1.clone(), box2.clone())
1151 }
1152 Expression::Leq(_, box1, box2) => {
1153 write!(f, "({} <= {})", box1.clone(), box2.clone())
1154 }
1155 Expression::Gt(_, box1, box2) => {
1156 write!(f, "({} > {})", box1.clone(), box2.clone())
1157 }
1158 Expression::Lt(_, box1, box2) => {
1159 write!(f, "({} < {})", box1.clone(), box2.clone())
1160 }
1161 Expression::FlatSumGeq(_, box1, box2) => {
1162 write!(f, "SumGeq({}, {})", pretty_vec(box1), box2.clone())
1163 }
1164 Expression::FlatSumLeq(_, box1, box2) => {
1165 write!(f, "SumLeq({}, {})", pretty_vec(box1), box2.clone())
1166 }
1167 Expression::FlatIneq(_, box1, box2, box3) => write!(
1168 f,
1169 "Ineq({}, {}, {})",
1170 box1.clone(),
1171 box2.clone(),
1172 box3.clone()
1173 ),
1174 Expression::AllDiff(_, e) => {
1175 write!(f, "allDiff({e})")
1176 }
1177 Expression::Bubble(_, box1, box2) => {
1178 write!(f, "{{{} @ {}}}", box1.clone(), box2.clone())
1179 }
1180 Expression::SafeDiv(_, box1, box2) => {
1181 write!(f, "SafeDiv({}, {})", box1.clone(), box2.clone())
1182 }
1183 Expression::UnsafeDiv(_, box1, box2) => {
1184 write!(f, "UnsafeDiv({}, {})", box1.clone(), box2.clone())
1185 }
1186 Expression::UnsafePow(_, box1, box2) => {
1187 write!(f, "UnsafePow({}, {})", box1.clone(), box2.clone())
1188 }
1189 Expression::SafePow(_, box1, box2) => {
1190 write!(f, "SafePow({}, {})", box1.clone(), box2.clone())
1191 }
1192 Expression::MinionDivEqUndefZero(_, box1, box2, box3) => {
1193 write!(
1194 f,
1195 "DivEq({}, {}, {})",
1196 box1.clone(),
1197 box2.clone(),
1198 box3.clone()
1199 )
1200 }
1201 Expression::MinionModuloEqUndefZero(_, box1, box2, box3) => {
1202 write!(
1203 f,
1204 "ModEq({}, {}, {})",
1205 box1.clone(),
1206 box2.clone(),
1207 box3.clone()
1208 )
1209 }
1210 Expression::FlatWatchedLiteral(_, x, l) => {
1211 write!(f, "WatchedLiteral({x},{l})", x = &x.name() as &Name)
1212 }
1213 Expression::MinionReify(_, box1, box2) => {
1214 write!(f, "Reify({}, {})", box1.clone(), box2.clone())
1215 }
1216 Expression::MinionReifyImply(_, box1, box2) => {
1217 write!(f, "ReifyImply({}, {})", box1.clone(), box2.clone())
1218 }
1219 Expression::MinionWInIntervalSet(_, atom, intervals) => {
1220 let intervals = intervals.iter().join(",");
1221 write!(f, "__minion_w_inintervalset({atom},[{intervals}])")
1222 }
1223 Expression::MinionWInSet(_, atom, values) => {
1224 let values = values.iter().join(",");
1225 write!(f, "__minion_w_inset({atom},{values})")
1226 }
1227 Expression::AuxDeclaration(_, decl, e) => {
1228 write!(f, "{} =aux {}", &decl.name() as &Name, e.clone())
1229 }
1230 Expression::UnsafeMod(_, a, b) => {
1231 write!(f, "{} % {}", a.clone(), b.clone())
1232 }
1233 Expression::SafeMod(_, a, b) => {
1234 write!(f, "SafeMod({},{})", a.clone(), b.clone())
1235 }
1236 Expression::Neg(_, a) => {
1237 write!(f, "-({})", a.clone())
1238 }
1239 Expression::Minus(_, a, b) => {
1240 write!(f, "({} - {})", a.clone(), b.clone())
1241 }
1242 Expression::FlatAllDiff(_, es) => {
1243 write!(f, "__flat_alldiff({})", pretty_vec(es))
1244 }
1245 Expression::FlatAbsEq(_, a, b) => {
1246 write!(f, "AbsEq({},{})", a.clone(), b.clone())
1247 }
1248 Expression::FlatMinusEq(_, a, b) => {
1249 write!(f, "MinusEq({},{})", a.clone(), b.clone())
1250 }
1251 Expression::FlatProductEq(_, a, b, c) => {
1252 write!(
1253 f,
1254 "FlatProductEq({},{},{})",
1255 a.clone(),
1256 b.clone(),
1257 c.clone()
1258 )
1259 }
1260 Expression::FlatWeightedSumLeq(_, cs, vs, total) => {
1261 write!(
1262 f,
1263 "FlatWeightedSumLeq({},{},{})",
1264 pretty_vec(cs),
1265 pretty_vec(vs),
1266 total.clone()
1267 )
1268 }
1269 Expression::FlatWeightedSumGeq(_, cs, vs, total) => {
1270 write!(
1271 f,
1272 "FlatWeightedSumGeq({},{},{})",
1273 pretty_vec(cs),
1274 pretty_vec(vs),
1275 total.clone()
1276 )
1277 }
1278 Expression::MinionPow(_, atom, atom1, atom2) => {
1279 write!(f, "MinionPow({atom},{atom1},{atom2})")
1280 }
1281 Expression::MinionElementOne(_, atoms, atom, atom1) => {
1282 let atoms = atoms.iter().join(",");
1283 write!(f, "__minion_element_one([{atoms}],{atom},{atom1})")
1284 }
1285
1286 Expression::ToInt(_, expr) => {
1287 write!(f, "toInt({expr})")
1288 }
1289 }
1290 }
1291}
1292
1293impl Typeable for Expression {
1294 fn return_type(&self) -> Option<ReturnType> {
1295 match self {
1296 Expression::Union(_, subject, _) => {
1297 Some(ReturnType::Set(Box::new(subject.return_type()?)))
1298 }
1299 Expression::Intersect(_, subject, _) => {
1300 Some(ReturnType::Set(Box::new(subject.return_type()?)))
1301 }
1302 Expression::In(_, _, _) => Some(ReturnType::Bool),
1303 Expression::Supset(_, _, _) => Some(ReturnType::Bool),
1304 Expression::SupsetEq(_, _, _) => Some(ReturnType::Bool),
1305 Expression::Subset(_, _, _) => Some(ReturnType::Bool),
1306 Expression::SubsetEq(_, _, _) => Some(ReturnType::Bool),
1307 Expression::AbstractLiteral(_, lit) => lit.return_type(),
1308 Expression::UnsafeIndex(_, subject, _) | Expression::SafeIndex(_, subject, _) => {
1309 let mut elem_typ = subject.return_type()?;
1310 let ReturnType::Matrix(_) = elem_typ else {
1311 return None;
1312 };
1313
1314 while let ReturnType::Matrix(new_elem_typ) = elem_typ {
1316 elem_typ = *new_elem_typ;
1317 }
1318
1319 Some(elem_typ)
1320 }
1321 Expression::UnsafeSlice(_, subject, _) | Expression::SafeSlice(_, subject, _) => {
1322 Some(ReturnType::Matrix(Box::new(subject.return_type()?)))
1323 }
1324 Expression::InDomain(_, _, _) => Some(ReturnType::Bool),
1325 Expression::Comprehension(_, _) => None,
1326 Expression::Root(_, _) => Some(ReturnType::Bool),
1327 Expression::DominanceRelation(_, _) => Some(ReturnType::Bool),
1328 Expression::FromSolution(_, expr) => expr.return_type(),
1329 Expression::Atomic(_, atom) => atom.return_type(),
1330 Expression::Scope(_, scope) => scope.return_type(),
1331 Expression::Abs(_, _) => Some(ReturnType::Int),
1332 Expression::Sum(_, _) => Some(ReturnType::Int),
1333 Expression::Product(_, _) => Some(ReturnType::Int),
1334 Expression::Min(_, _) => Some(ReturnType::Int),
1335 Expression::Max(_, _) => Some(ReturnType::Int),
1336 Expression::Not(_, _) => Some(ReturnType::Bool),
1337 Expression::Or(_, _) => Some(ReturnType::Bool),
1338 Expression::Imply(_, _, _) => Some(ReturnType::Bool),
1339 Expression::Iff(_, _, _) => Some(ReturnType::Bool),
1340 Expression::And(_, _) => Some(ReturnType::Bool),
1341 Expression::Eq(_, _, _) => Some(ReturnType::Bool),
1342 Expression::Neq(_, _, _) => Some(ReturnType::Bool),
1343 Expression::Geq(_, _, _) => Some(ReturnType::Bool),
1344 Expression::Leq(_, _, _) => Some(ReturnType::Bool),
1345 Expression::Gt(_, _, _) => Some(ReturnType::Bool),
1346 Expression::Lt(_, _, _) => Some(ReturnType::Bool),
1347 Expression::SafeDiv(_, _, _) => Some(ReturnType::Int),
1348 Expression::UnsafeDiv(_, _, _) => Some(ReturnType::Int),
1349 Expression::FlatAllDiff(_, _) => Some(ReturnType::Bool),
1350 Expression::FlatSumGeq(_, _, _) => Some(ReturnType::Bool),
1351 Expression::FlatSumLeq(_, _, _) => Some(ReturnType::Bool),
1352 Expression::MinionDivEqUndefZero(_, _, _, _) => Some(ReturnType::Bool),
1353 Expression::FlatIneq(_, _, _, _) => Some(ReturnType::Bool),
1354 Expression::AllDiff(_, _) => Some(ReturnType::Bool),
1355 Expression::Bubble(_, inner, _) => inner.return_type(),
1356 Expression::FlatWatchedLiteral(_, _, _) => Some(ReturnType::Bool),
1357 Expression::MinionReify(_, _, _) => Some(ReturnType::Bool),
1358 Expression::MinionReifyImply(_, _, _) => Some(ReturnType::Bool),
1359 Expression::MinionWInIntervalSet(_, _, _) => Some(ReturnType::Bool),
1360 Expression::MinionWInSet(_, _, _) => Some(ReturnType::Bool),
1361 Expression::MinionElementOne(_, _, _, _) => Some(ReturnType::Bool),
1362 Expression::AuxDeclaration(_, _, _) => Some(ReturnType::Bool),
1363 Expression::UnsafeMod(_, _, _) => Some(ReturnType::Int),
1364 Expression::SafeMod(_, _, _) => Some(ReturnType::Int),
1365 Expression::MinionModuloEqUndefZero(_, _, _, _) => Some(ReturnType::Bool),
1366 Expression::Neg(_, _) => Some(ReturnType::Int),
1367 Expression::UnsafePow(_, _, _) => Some(ReturnType::Int),
1368 Expression::SafePow(_, _, _) => Some(ReturnType::Int),
1369 Expression::Minus(_, _, _) => Some(ReturnType::Int),
1370 Expression::FlatAbsEq(_, _, _) => Some(ReturnType::Bool),
1371 Expression::FlatMinusEq(_, _, _) => Some(ReturnType::Bool),
1372 Expression::FlatProductEq(_, _, _, _) => Some(ReturnType::Bool),
1373 Expression::FlatWeightedSumLeq(_, _, _, _) => Some(ReturnType::Bool),
1374 Expression::FlatWeightedSumGeq(_, _, _, _) => Some(ReturnType::Bool),
1375 Expression::MinionPow(_, _, _, _) => Some(ReturnType::Bool),
1376 Expression::ToInt(_, _) => Some(ReturnType::Int),
1377 }
1378 }
1379}
1380
1381#[cfg(test)]
1382mod tests {
1383
1384 use crate::matrix_expr;
1385
1386 use super::*;
1387
1388 #[test]
1389 fn test_domain_of_constant_sum() {
1390 let c1 = Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Int(1)));
1391 let c2 = Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Int(2)));
1392 let sum = Expression::Sum(
1393 Metadata::new(),
1394 Moo::new(matrix_expr![c1.clone(), c2.clone()]),
1395 );
1396 assert_eq!(sum.domain_of(), Some(Domain::Int(vec![Range::Single(3)])));
1397 }
1398
1399 #[test]
1400 fn test_domain_of_constant_invalid_type() {
1401 let c1 = Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Int(1)));
1402 let c2 = Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Bool(true)));
1403 let sum = Expression::Sum(
1404 Metadata::new(),
1405 Moo::new(matrix_expr![c1.clone(), c2.clone()]),
1406 );
1407 assert_eq!(sum.domain_of(), None);
1408 }
1409
1410 #[test]
1411 fn test_domain_of_empty_sum() {
1412 let sum = Expression::Sum(Metadata::new(), Moo::new(matrix_expr![]));
1413 assert_eq!(sum.domain_of(), None);
1414 }
1415}