1use std::fmt::Display;
2
3use itertools::Itertools;
4use serde::{Deserialize, Serialize};
5
6use crate::ast::pretty::pretty_vec;
7
8use super::{types::Typeable, AbstractLiteral, Literal, Name, ReturnType};
9
10#[derive(Clone, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
11pub enum Range<A>
12where
13 A: Ord,
14{
15 Single(A),
16 Bounded(A, A),
17
18 UnboundedR(A),
20
21 UnboundedL(A),
23}
24
25impl<A: Ord> Range<A> {
26 pub fn contains(&self, val: &A) -> bool {
27 match self {
28 Range::Single(x) => x == val,
29 Range::Bounded(x, y) => x <= val && val <= y,
30 Range::UnboundedR(x) => x <= val,
31 Range::UnboundedL(x) => x >= val,
32 }
33 }
34}
35
36impl<A: Ord + Display> Display for Range<A> {
37 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
38 match self {
39 Range::Single(i) => write!(f, "{i}"),
40 Range::Bounded(i, j) => write!(f, "{i}..{j}"),
41 Range::UnboundedR(i) => write!(f, "{i}.."),
42 Range::UnboundedL(i) => write!(f, "..{i}"),
43 }
44 }
45}
46
47#[derive(Clone, Debug, Eq, PartialEq, Hash, Serialize, Deserialize)]
48pub enum Domain {
49 BoolDomain,
50
51 IntDomain(Vec<Range<i32>>),
59 DomainReference(Name),
60 DomainSet(SetAttr, Box<Domain>),
61 DomainMatrix(Box<Domain>, Vec<Domain>),
63}
64
65#[derive(Clone, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
66pub enum SetAttr {
67 None,
68 Size(i32),
69 MinSize(i32),
70 MaxSize(i32),
71 MinMaxSize(i32, i32),
72}
73impl Domain {
74 pub fn contains(&self, lit: &Literal) -> Option<bool> {
78 match (self, lit) {
81 (Domain::IntDomain(ranges), Literal::Int(x)) => {
82 if ranges.is_empty() {
84 return Some(true);
85 };
86
87 Some(ranges.iter().any(|range| range.contains(x)))
88 }
89 (Domain::IntDomain(_), _) => Some(false),
90 (Domain::BoolDomain, Literal::Bool(_)) => Some(true),
91 (Domain::BoolDomain, _) => Some(false),
92 (Domain::DomainReference(_), _) => None,
93
94 (
95 Domain::DomainMatrix(elem_domain, index_domains),
96 Literal::AbstractLiteral(AbstractLiteral::Matrix(elems, idx_domain)),
97 ) => {
98 let mut index_domains = index_domains.clone();
99 if index_domains
100 .pop()
101 .expect("a matrix should have atleast one index domain")
102 != *idx_domain
103 {
104 return Some(false);
105 };
106
107 let next_elem_domain = if index_domains.is_empty() {
110 elem_domain.as_ref().clone()
111 } else {
112 Domain::DomainMatrix(elem_domain.clone(), index_domains)
113 };
114
115 for elem in elems {
116 if !next_elem_domain.contains(elem)? {
117 return Some(false);
118 }
119 }
120
121 Some(true)
122 }
123 (Domain::DomainMatrix(_, _), _) => Some(false),
124 (Domain::DomainSet(_, _), Literal::AbstractLiteral(AbstractLiteral::Set(_))) => {
125 todo!()
126 }
127 (Domain::DomainSet(_, _), _) => Some(false),
128 }
129 }
130 pub fn values_i32(&self) -> Option<Vec<i32>> {
133 match self {
134 Domain::IntDomain(ranges) => Some(
135 ranges
136 .iter()
137 .map(|r| match r {
138 Range::Single(i) => Some(vec![*i]),
139 Range::Bounded(i, j) => Some((*i..=*j).collect()),
140 Range::UnboundedR(_) => None,
141 Range::UnboundedL(_) => None,
142 })
143 .while_some()
144 .flatten()
145 .collect_vec(),
146 ),
147 _ => None,
148 }
149 }
150
151 pub fn apply_i32(&self, op: fn(i32, i32) -> Option<i32>, other: &Domain) -> Option<Domain> {
158 if let (Some(vs1), Some(vs2)) = (self.values_i32(), other.values_i32()) {
159 let mut new_ranges = vec![];
161 for (v1, v2) in itertools::iproduct!(vs1, vs2) {
162 if let Some(v) = op(v1, v2) {
163 new_ranges.push(Range::Single(v))
164 }
165 }
166 return Some(Domain::IntDomain(new_ranges));
167 }
168 None
169 }
170}
171
172impl Display for Domain {
173 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
174 match self {
175 Domain::BoolDomain => {
176 write!(f, "bool")
177 }
178 Domain::IntDomain(vec) => {
179 let domain_ranges: String = vec.iter().map(|x| format!("{x}")).join(",");
180
181 if domain_ranges.is_empty() {
182 write!(f, "int")
183 } else {
184 write!(f, "int({domain_ranges})")
185 }
186 }
187 Domain::DomainReference(name) => write!(f, "{}", name),
188 Domain::DomainSet(_, domain) => {
189 write!(f, "set of ({})", domain)
190 }
191 Domain::DomainMatrix(value_domain, index_domains) => {
192 write!(
193 f,
194 "matrix indexed by [{}] of {value_domain}",
195 pretty_vec(&index_domains.iter().collect_vec())
196 )
197 }
198 }
199 }
200}
201
202impl Typeable for Domain {
203 fn return_type(&self) -> Option<ReturnType> {
204 todo!()
205 }
206}
207#[cfg(test)]
208mod tests {
209 use super::*;
210
211 #[test]
212 fn test_negative_product() {
213 let d1 = Domain::IntDomain(vec![Range::Bounded(-2, 1)]);
214 let d2 = Domain::IntDomain(vec![Range::Bounded(-2, 1)]);
215 let res = d1.apply_i32(|a, b| Some(a * b), &d2).unwrap();
216
217 assert!(matches!(res, Domain::IntDomain(_)));
218 if let Domain::IntDomain(ranges) = res {
219 assert!(!ranges.contains(&Range::Single(-4)));
220 assert!(!ranges.contains(&Range::Single(-3)));
221 assert!(ranges.contains(&Range::Single(-2)));
222 assert!(ranges.contains(&Range::Single(-1)));
223 assert!(ranges.contains(&Range::Single(0)));
224 assert!(ranges.contains(&Range::Single(1)));
225 assert!(ranges.contains(&Range::Single(2)));
226 assert!(!ranges.contains(&Range::Single(3)));
227 assert!(ranges.contains(&Range::Single(4)));
228 }
229 }
230
231 #[test]
232 fn test_negative_div() {
233 let d1 = Domain::IntDomain(vec![Range::Bounded(-2, 1)]);
234 let d2 = Domain::IntDomain(vec![Range::Bounded(-2, 1)]);
235 let res = d1
236 .apply_i32(|a, b| if b != 0 { Some(a / b) } else { None }, &d2)
237 .unwrap();
238
239 assert!(matches!(res, Domain::IntDomain(_)));
240 if let Domain::IntDomain(ranges) = res {
241 assert!(!ranges.contains(&Range::Single(-4)));
242 assert!(!ranges.contains(&Range::Single(-3)));
243 assert!(ranges.contains(&Range::Single(-2)));
244 assert!(ranges.contains(&Range::Single(-1)));
245 assert!(ranges.contains(&Range::Single(0)));
246 assert!(ranges.contains(&Range::Single(1)));
247 assert!(ranges.contains(&Range::Single(2)));
248 assert!(!ranges.contains(&Range::Single(3)));
249 assert!(!ranges.contains(&Range::Single(4)));
250 }
251 }
252}