1use std::fmt::Display;
2
3use conjure_core::ast::SymbolTable;
4use itertools::Itertools;
5use serde::{Deserialize, Serialize};
6
7use crate::ast::pretty::pretty_vec;
8use uniplate::{derive::Uniplate, Uniplate};
9
10use super::{types::Typeable, AbstractLiteral, Literal, Name, ReturnType};
11
12#[derive(Clone, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
13pub enum Range<A>
14where
15 A: Ord,
16{
17 Single(A),
18 Bounded(A, A),
19
20 UnboundedR(A),
22
23 UnboundedL(A),
25}
26
27impl<A: Ord> Range<A> {
28 pub fn contains(&self, val: &A) -> bool {
29 match self {
30 Range::Single(x) => x == val,
31 Range::Bounded(x, y) => x <= val && val <= y,
32 Range::UnboundedR(x) => x <= val,
33 Range::UnboundedL(x) => x >= val,
34 }
35 }
36}
37
38impl<A: Ord + Display> Display for Range<A> {
39 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
40 match self {
41 Range::Single(i) => write!(f, "{i}"),
42 Range::Bounded(i, j) => write!(f, "{i}..{j}"),
43 Range::UnboundedR(i) => write!(f, "{i}.."),
44 Range::UnboundedL(i) => write!(f, "..{i}"),
45 }
46 }
47}
48
49#[derive(Clone, Debug, Eq, PartialEq, Hash, Serialize, Deserialize, Uniplate)]
50#[uniplate()]
51pub enum Domain {
52 BoolDomain,
53
54 IntDomain(Vec<Range<i32>>),
62 DomainReference(Name),
63 DomainSet(SetAttr, Box<Domain>),
64 DomainMatrix(Box<Domain>, Vec<Domain>),
66}
67
68#[derive(Clone, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
69pub enum SetAttr {
70 None,
71 Size(i32),
72 MinSize(i32),
73 MaxSize(i32),
74 MinMaxSize(i32, i32),
75}
76impl Domain {
77 pub fn contains(&self, lit: &Literal) -> Option<bool> {
81 match (self, lit) {
84 (Domain::IntDomain(ranges), Literal::Int(x)) => {
85 if ranges.is_empty() {
87 return Some(true);
88 };
89
90 Some(ranges.iter().any(|range| range.contains(x)))
91 }
92 (Domain::IntDomain(_), _) => Some(false),
93 (Domain::BoolDomain, Literal::Bool(_)) => Some(true),
94 (Domain::BoolDomain, _) => Some(false),
95 (Domain::DomainReference(_), _) => None,
96
97 (
98 Domain::DomainMatrix(elem_domain, index_domains),
99 Literal::AbstractLiteral(AbstractLiteral::Matrix(elems, idx_domain)),
100 ) => {
101 let mut index_domains = index_domains.clone();
102 if index_domains
103 .pop()
104 .expect("a matrix should have atleast one index domain")
105 != *idx_domain
106 {
107 return Some(false);
108 };
109
110 let next_elem_domain = if index_domains.is_empty() {
113 elem_domain.as_ref().clone()
114 } else {
115 Domain::DomainMatrix(elem_domain.clone(), index_domains)
116 };
117
118 for elem in elems {
119 if !next_elem_domain.contains(elem)? {
120 return Some(false);
121 }
122 }
123
124 Some(true)
125 }
126 (Domain::DomainMatrix(_, _), _) => Some(false),
127 (Domain::DomainSet(_, _), Literal::AbstractLiteral(AbstractLiteral::Set(_))) => {
128 todo!()
129 }
130 (Domain::DomainSet(_, _), _) => Some(false),
131 }
132 }
133
134 pub fn values_i32(&self) -> Option<Vec<i32>> {
137 match self {
138 Domain::IntDomain(ranges) => Some(
139 ranges
140 .iter()
141 .map(|r| match r {
142 Range::Single(i) => Some(vec![*i]),
143 Range::Bounded(i, j) => Some((*i..=*j).collect()),
144 Range::UnboundedR(_) => None,
145 Range::UnboundedL(_) => None,
146 })
147 .while_some()
148 .flatten()
149 .collect_vec(),
150 ),
151 _ => None,
152 }
153 }
154
155 pub fn values(&self) -> Option<Vec<Literal>> {
158 match self {
159 Domain::BoolDomain => Some(vec![false.into(), true.into()]),
160 Domain::IntDomain(_) => self
161 .values_i32()
162 .map(|xs| xs.iter().map(|x| Literal::Int(*x)).collect_vec()),
163
164 Domain::DomainSet(_, _) => todo!(),
167 Domain::DomainMatrix(_, _) => todo!(),
168 Domain::DomainReference(_) => None,
169 }
170 }
171
172 pub fn length(&self) -> Option<usize> {
176 self.values().map(|x| x.len())
177 }
178
179 pub fn apply_i32(&self, op: fn(i32, i32) -> Option<i32>, other: &Domain) -> Option<Domain> {
186 if let (Some(vs1), Some(vs2)) = (self.values_i32(), other.values_i32()) {
187 let mut new_ranges = vec![];
189 for (v1, v2) in itertools::iproduct!(vs1, vs2) {
190 if let Some(v) = op(v1, v2) {
191 new_ranges.push(Range::Single(v))
192 }
193 }
194 return Some(Domain::IntDomain(new_ranges));
195 }
196 None
197 }
198
199 pub fn is_finite(&self) -> Option<bool> {
203 for domain in self.universe() {
204 if let Domain::IntDomain(ranges) = domain {
205 if ranges.is_empty() {
206 return Some(false);
207 }
208
209 if ranges
210 .iter()
211 .any(|range| matches!(range, Range::UnboundedL(_) | Range::UnboundedR(_)))
212 {
213 return Some(false);
214 }
215 } else if let Domain::DomainReference(_) = domain {
216 return None;
217 }
218 }
219 Some(true)
220 }
221
222 pub fn resolve(mut self, symbols: &SymbolTable) -> Domain {
233 let mut done_something = true;
241 while done_something {
242 done_something = false;
243 for (domain, ctx) in self.clone().contexts() {
244 if let Domain::DomainReference(name) = domain {
245 self = ctx(symbols
246 .resolve_domain(&name)
247 .expect("domain reference should exist in the symbol table")
248 .resolve(symbols));
249 done_something = true;
250 }
251 }
252 }
253 self
254 }
255}
256
257impl Display for Domain {
258 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
259 match self {
260 Domain::BoolDomain => {
261 write!(f, "bool")
262 }
263 Domain::IntDomain(vec) => {
264 let domain_ranges: String = vec.iter().map(|x| format!("{x}")).join(",");
265
266 if domain_ranges.is_empty() {
267 write!(f, "int")
268 } else {
269 write!(f, "int({domain_ranges})")
270 }
271 }
272 Domain::DomainReference(name) => write!(f, "{}", name),
273 Domain::DomainSet(_, domain) => {
274 write!(f, "set of ({})", domain)
275 }
276 Domain::DomainMatrix(value_domain, index_domains) => {
277 write!(
278 f,
279 "matrix indexed by [{}] of {value_domain}",
280 pretty_vec(&index_domains.iter().collect_vec())
281 )
282 }
283 }
284 }
285}
286
287impl Typeable for Domain {
288 fn return_type(&self) -> Option<ReturnType> {
289 todo!()
290 }
291}
292#[cfg(test)]
293mod tests {
294 use super::*;
295
296 #[test]
297 fn test_negative_product() {
298 let d1 = Domain::IntDomain(vec![Range::Bounded(-2, 1)]);
299 let d2 = Domain::IntDomain(vec![Range::Bounded(-2, 1)]);
300 let res = d1.apply_i32(|a, b| Some(a * b), &d2).unwrap();
301
302 assert!(matches!(res, Domain::IntDomain(_)));
303 if let Domain::IntDomain(ranges) = res {
304 assert!(!ranges.contains(&Range::Single(-4)));
305 assert!(!ranges.contains(&Range::Single(-3)));
306 assert!(ranges.contains(&Range::Single(-2)));
307 assert!(ranges.contains(&Range::Single(-1)));
308 assert!(ranges.contains(&Range::Single(0)));
309 assert!(ranges.contains(&Range::Single(1)));
310 assert!(ranges.contains(&Range::Single(2)));
311 assert!(!ranges.contains(&Range::Single(3)));
312 assert!(ranges.contains(&Range::Single(4)));
313 }
314 }
315
316 #[test]
317 fn test_negative_div() {
318 let d1 = Domain::IntDomain(vec![Range::Bounded(-2, 1)]);
319 let d2 = Domain::IntDomain(vec![Range::Bounded(-2, 1)]);
320 let res = d1
321 .apply_i32(|a, b| if b != 0 { Some(a / b) } else { None }, &d2)
322 .unwrap();
323
324 assert!(matches!(res, Domain::IntDomain(_)));
325 if let Domain::IntDomain(ranges) = res {
326 assert!(!ranges.contains(&Range::Single(-4)));
327 assert!(!ranges.contains(&Range::Single(-3)));
328 assert!(ranges.contains(&Range::Single(-2)));
329 assert!(ranges.contains(&Range::Single(-1)));
330 assert!(ranges.contains(&Range::Single(0)));
331 assert!(ranges.contains(&Range::Single(1)));
332 assert!(ranges.contains(&Range::Single(2)));
333 assert!(!ranges.contains(&Range::Single(3)));
334 assert!(!ranges.contains(&Range::Single(4)));
335 }
336 }
337}