1use std::fmt::Display;
2
3use conjure_core::ast::SymbolTable;
4use itertools::Itertools;
5use serde::{Deserialize, Serialize};
6
7use crate::ast::pretty::pretty_vec;
8use uniplate::{derive::Uniplate, Uniplate};
9
10use super::{records::RecordEntry, types::Typeable, AbstractLiteral, Literal, Name, ReturnType};
11
12#[derive(Clone, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
13pub enum Range<A>
14where
15 A: Ord,
16{
17 Single(A),
18 Bounded(A, A),
19
20 UnboundedR(A),
22
23 UnboundedL(A),
25}
26
27impl<A: Ord> Range<A> {
28 pub fn contains(&self, val: &A) -> bool {
29 match self {
30 Range::Single(x) => x == val,
31 Range::Bounded(x, y) => x <= val && val <= y,
32 Range::UnboundedR(x) => x <= val,
33 Range::UnboundedL(x) => x >= val,
34 }
35 }
36}
37
38impl<A: Ord + Display> Display for Range<A> {
39 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
40 match self {
41 Range::Single(i) => write!(f, "{i}"),
42 Range::Bounded(i, j) => write!(f, "{i}..{j}"),
43 Range::UnboundedR(i) => write!(f, "{i}.."),
44 Range::UnboundedL(i) => write!(f, "..{i}"),
45 }
46 }
47}
48
49#[derive(Clone, Debug, Eq, PartialEq, Hash, Serialize, Deserialize, Uniplate)]
50#[uniplate()]
51pub enum Domain {
52 BoolDomain,
53
54 IntDomain(Vec<Range<i32>>),
62 DomainReference(Name),
63 DomainSet(SetAttr, Box<Domain>),
64 DomainMatrix(Box<Domain>, Vec<Domain>),
66 DomainTuple(Vec<Domain>),
68
69 DomainRecord(Vec<RecordEntry>),
70}
71
72#[derive(Clone, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
73pub enum SetAttr {
74 None,
75 Size(i32),
76 MinSize(i32),
77 MaxSize(i32),
78 MinMaxSize(i32, i32),
79}
80impl Domain {
81 pub fn contains(&self, lit: &Literal) -> Option<bool> {
85 match (self, lit) {
88 (Domain::IntDomain(ranges), Literal::Int(x)) => {
89 if ranges.is_empty() {
91 return Some(true);
92 };
93
94 Some(ranges.iter().any(|range| range.contains(x)))
95 }
96 (Domain::IntDomain(_), _) => Some(false),
97 (Domain::BoolDomain, Literal::Bool(_)) => Some(true),
98 (Domain::BoolDomain, _) => Some(false),
99 (Domain::DomainReference(_), _) => None,
100 (
101 Domain::DomainMatrix(elem_domain, index_domains),
102 Literal::AbstractLiteral(AbstractLiteral::Matrix(elems, idx_domain)),
103 ) => {
104 let mut index_domains = index_domains.clone();
105 if index_domains
106 .pop()
107 .expect("a matrix should have atleast one index domain")
108 != *idx_domain
109 {
110 return Some(false);
111 };
112
113 let next_elem_domain = if index_domains.is_empty() {
116 elem_domain.as_ref().clone()
117 } else {
118 Domain::DomainMatrix(elem_domain.clone(), index_domains)
119 };
120
121 for elem in elems {
122 if !next_elem_domain.contains(elem)? {
123 return Some(false);
124 }
125 }
126
127 Some(true)
128 }
129 (
130 Domain::DomainTuple(elem_domains),
131 Literal::AbstractLiteral(AbstractLiteral::Tuple(literal_elems)),
132 ) => {
133 for (elem_domain, elem) in itertools::izip!(elem_domains, literal_elems) {
135 if !elem_domain.contains(elem)? {
136 return Some(false);
137 }
138 }
139
140 Some(true)
141 }
142 (
143 Domain::DomainSet(_, domain),
144 Literal::AbstractLiteral(AbstractLiteral::Set(literal_elems)),
145 ) => {
146 for elem in literal_elems {
147 if !domain.contains(elem)? {
148 return Some(false);
149 }
150 }
151 Some(true)
152 }
153 (
154 Domain::DomainRecord(entries),
155 Literal::AbstractLiteral(AbstractLiteral::Record(lit_entries)),
156 ) => {
157 for (entry, lit_entry) in itertools::izip!(entries, lit_entries) {
158 if entry.name != lit_entry.name || !(entry.domain.contains(&lit_entry.value)?) {
159 return Some(false);
160 }
161 }
162 Some(true)
163 }
164
165 (Domain::DomainRecord(_), _) => Some(false),
166
167 (Domain::DomainMatrix(_, _), _) => Some(false),
168
169 (Domain::DomainSet(_, _), _) => Some(false),
170
171 (Domain::DomainTuple(_), _) => Some(false),
172 }
173 }
174
175 pub fn values_i32(&self) -> Option<Vec<i32>> {
178 match self {
179 Domain::IntDomain(ranges) => Some(
180 ranges
181 .iter()
182 .map(|r| match r {
183 Range::Single(i) => Some(vec![*i]),
184 Range::Bounded(i, j) => Some((*i..=*j).collect()),
185 Range::UnboundedR(_) => None,
186 Range::UnboundedL(_) => None,
187 })
188 .while_some()
189 .flatten()
190 .collect_vec(),
191 ),
192 _ => None,
193 }
194 }
195
196 pub fn make_int_domain_from_values_i32(&self, vector: &[i32]) -> Option<Domain> {
200 let mut new_ranges = vec![];
201 for values in vector.iter() {
202 new_ranges.push(Range::Single(*values));
203 }
204 Some(Domain::IntDomain(new_ranges))
205 }
206
207 pub fn values(&self) -> Option<Vec<Literal>> {
210 match self {
211 Domain::BoolDomain => Some(vec![false.into(), true.into()]),
212 Domain::IntDomain(_) => self
213 .values_i32()
214 .map(|xs| xs.iter().map(|x| Literal::Int(*x)).collect_vec()),
215
216 Domain::DomainSet(_, _) => todo!(),
219 Domain::DomainMatrix(_, _) => todo!(),
220 Domain::DomainReference(_) => None,
221 Domain::DomainTuple(_) => todo!(), Domain::DomainRecord(_) => todo!(),
223 }
224 }
225
226 pub fn length(&self) -> Option<usize> {
230 self.values().map(|x| x.len())
231 }
232
233 pub fn apply_i32(&self, op: fn(i32, i32) -> Option<i32>, other: &Domain) -> Option<Domain> {
240 if let (Some(vs1), Some(vs2)) = (self.values_i32(), other.values_i32()) {
241 let mut new_ranges = vec![];
243 for (v1, v2) in itertools::iproduct!(vs1, vs2) {
244 if let Some(v) = op(v1, v2) {
245 new_ranges.push(Range::Single(v))
246 }
247 }
248 return Some(Domain::IntDomain(new_ranges));
249 }
250 None
251 }
252
253 pub fn is_finite(&self) -> Option<bool> {
257 for domain in self.universe() {
258 if let Domain::IntDomain(ranges) = domain {
259 if ranges.is_empty() {
260 return Some(false);
261 }
262
263 if ranges
264 .iter()
265 .any(|range| matches!(range, Range::UnboundedL(_) | Range::UnboundedR(_)))
266 {
267 return Some(false);
268 }
269 } else if let Domain::DomainReference(_) = domain {
270 return None;
271 }
272 }
273 Some(true)
274 }
275
276 pub fn resolve(mut self, symbols: &SymbolTable) -> Domain {
287 let mut done_something = true;
295 while done_something {
296 done_something = false;
297 for (domain, ctx) in self.clone().contexts() {
298 if let Domain::DomainReference(name) = domain {
299 self = ctx(symbols
300 .resolve_domain(&name)
301 .expect("domain reference should exist in the symbol table")
302 .resolve(symbols));
303 done_something = true;
304 }
305 }
306 }
307 self
308 }
309
310 pub fn intersect(&self, other: &Domain) -> Option<Domain> {
314 match (self, other) {
315 (Domain::DomainSet(_, x), Domain::DomainSet(_, y)) => Some(Domain::DomainSet(
316 SetAttr::None,
317 Box::new((*x).intersect(y)?),
318 )),
319 (Domain::IntDomain(_), Domain::IntDomain(_)) => {
320 let mut v: Vec<i32> = vec![];
321 if self.is_finite()? && other.is_finite()? {
322 if let (Some(v1), Some(v2)) = (self.values_i32(), other.values_i32()) {
323 for value1 in v1.iter() {
324 if v2.contains(value1) && !v.contains(value1) {
325 v.push(*value1)
326 }
327 }
328 }
329 self.make_int_domain_from_values_i32(&v)
330 } else {
331 println!("Unbounded domain");
332 None
333 }
334 }
335 _ => None,
336 }
337 }
338
339 pub fn union(&self, other: &Domain) -> Option<Domain> {
343 match (self, other) {
344 (Domain::DomainSet(_, x), Domain::DomainSet(_, y)) => {
345 Some(Domain::DomainSet(SetAttr::None, Box::new((*x).union(y)?)))
346 }
347 (Domain::IntDomain(_), Domain::IntDomain(_)) => {
348 let mut v: Vec<i32> = vec![];
349 if self.is_finite()? && other.is_finite()? {
350 if let (Some(v1), Some(v2)) = (self.values_i32(), other.values_i32()) {
351 for value1 in v1.iter() {
352 v.push(*value1);
353 }
354 for value2 in v2.iter() {
355 if !v.contains(value2) {
356 v.push(*value2);
357 }
358 }
359 }
360 self.make_int_domain_from_values_i32(&v)
361 } else {
362 println!("Unbounded Domain");
363 None
364 }
365 }
366 _ => None,
367 }
368 }
369}
370
371impl Display for Domain {
372 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
373 match self {
374 Domain::BoolDomain => {
375 write!(f, "bool")
376 }
377 Domain::IntDomain(vec) => {
378 let domain_ranges: String = vec.iter().map(|x| format!("{x}")).join(",");
379
380 if domain_ranges.is_empty() {
381 write!(f, "int")
382 } else {
383 write!(f, "int({domain_ranges})")
384 }
385 }
386 Domain::DomainReference(name) => write!(f, "{}", name),
387 Domain::DomainSet(_, domain) => {
388 write!(f, "set of ({})", domain)
389 }
390 Domain::DomainMatrix(value_domain, index_domains) => {
391 write!(
392 f,
393 "matrix indexed by [{}] of {value_domain}",
394 pretty_vec(&index_domains.iter().collect_vec())
395 )
396 }
397 Domain::DomainTuple(domains) => {
398 write!(
399 f,
400 "tuple of ({})",
401 pretty_vec(&domains.iter().collect_vec())
402 )
403 }
404 Domain::DomainRecord(entries) => {
405 write!(
406 f,
407 "record of ({})",
408 pretty_vec(
409 &entries
410 .iter()
411 .map(|entry| format!("{}: {}", entry.name, entry.domain))
412 .collect_vec()
413 )
414 )
415 }
416 }
417 }
418}
419
420impl Typeable for Domain {
421 fn return_type(&self) -> Option<ReturnType> {
422 todo!()
423 }
424}
425#[cfg(test)]
426mod tests {
427 use super::*;
428
429 #[test]
430 fn test_negative_product() {
431 let d1 = Domain::IntDomain(vec![Range::Bounded(-2, 1)]);
432 let d2 = Domain::IntDomain(vec![Range::Bounded(-2, 1)]);
433 let res = d1.apply_i32(|a, b| Some(a * b), &d2).unwrap();
434
435 assert!(matches!(res, Domain::IntDomain(_)));
436 if let Domain::IntDomain(ranges) = res {
437 assert!(!ranges.contains(&Range::Single(-4)));
438 assert!(!ranges.contains(&Range::Single(-3)));
439 assert!(ranges.contains(&Range::Single(-2)));
440 assert!(ranges.contains(&Range::Single(-1)));
441 assert!(ranges.contains(&Range::Single(0)));
442 assert!(ranges.contains(&Range::Single(1)));
443 assert!(ranges.contains(&Range::Single(2)));
444 assert!(!ranges.contains(&Range::Single(3)));
445 assert!(ranges.contains(&Range::Single(4)));
446 }
447 }
448
449 #[test]
450 fn test_negative_div() {
451 let d1 = Domain::IntDomain(vec![Range::Bounded(-2, 1)]);
452 let d2 = Domain::IntDomain(vec![Range::Bounded(-2, 1)]);
453 let res = d1
454 .apply_i32(|a, b| if b != 0 { Some(a / b) } else { None }, &d2)
455 .unwrap();
456
457 assert!(matches!(res, Domain::IntDomain(_)));
458 if let Domain::IntDomain(ranges) = res {
459 assert!(!ranges.contains(&Range::Single(-4)));
460 assert!(!ranges.contains(&Range::Single(-3)));
461 assert!(ranges.contains(&Range::Single(-2)));
462 assert!(ranges.contains(&Range::Single(-1)));
463 assert!(ranges.contains(&Range::Single(0)));
464 assert!(ranges.contains(&Range::Single(1)));
465 assert!(ranges.contains(&Range::Single(2)));
466 assert!(!ranges.contains(&Range::Single(3)));
467 assert!(!ranges.contains(&Range::Single(4)));
468 }
469 }
470}