conjure_core/ast/
domains.rs

1use std::fmt::Display;
2
3use conjure_core::ast::SymbolTable;
4use itertools::Itertools;
5use serde::{Deserialize, Serialize};
6
7use crate::ast::pretty::pretty_vec;
8use uniplate::{derive::Uniplate, Uniplate};
9
10use super::{records::RecordEntry, types::Typeable, AbstractLiteral, Literal, Name, ReturnType};
11
12#[derive(Clone, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
13pub enum Range<A>
14where
15    A: Ord,
16{
17    Single(A),
18    Bounded(A, A),
19
20    /// int(i..)
21    UnboundedR(A),
22
23    /// int(..i)
24    UnboundedL(A),
25}
26
27impl<A: Ord> Range<A> {
28    pub fn contains(&self, val: &A) -> bool {
29        match self {
30            Range::Single(x) => x == val,
31            Range::Bounded(x, y) => x <= val && val <= y,
32            Range::UnboundedR(x) => x <= val,
33            Range::UnboundedL(x) => x >= val,
34        }
35    }
36}
37
38impl<A: Ord + Display> Display for Range<A> {
39    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
40        match self {
41            Range::Single(i) => write!(f, "{i}"),
42            Range::Bounded(i, j) => write!(f, "{i}..{j}"),
43            Range::UnboundedR(i) => write!(f, "{i}.."),
44            Range::UnboundedL(i) => write!(f, "..{i}"),
45        }
46    }
47}
48
49#[derive(Clone, Debug, Eq, PartialEq, Hash, Serialize, Deserialize, Uniplate)]
50#[uniplate()]
51pub enum Domain {
52    BoolDomain,
53
54    /// An integer domain.
55    ///
56    /// + If multiple ranges are inside the domain, the values in the domain are the union of these
57    ///   ranges.
58    ///
59    /// + If no ranges are given, the int domain is considered unconstrained, and can take any
60    ///   integer value.
61    IntDomain(Vec<Range<i32>>),
62    DomainReference(Name),
63    DomainSet(SetAttr, Box<Domain>),
64    /// A n-dimensional matrix with a value domain and n-index domains
65    DomainMatrix(Box<Domain>, Vec<Domain>),
66    // A tuple of n domains (e.g. (int, bool))
67    DomainTuple(Vec<Domain>),
68
69    DomainRecord(Vec<RecordEntry>),
70}
71
72#[derive(Clone, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
73pub enum SetAttr {
74    None,
75    Size(i32),
76    MinSize(i32),
77    MaxSize(i32),
78    MinMaxSize(i32, i32),
79}
80impl Domain {
81    // Whether the literal is a member of this domain.
82    //
83    // Returns `None` if this cannot be determined (e.g. `self` is a `DomainReference`).
84    pub fn contains(&self, lit: &Literal) -> Option<bool> {
85        // not adding a generic wildcard condition for all domains, so that this gives a compile
86        // error when a domain is added.
87        match (self, lit) {
88            (Domain::IntDomain(ranges), Literal::Int(x)) => {
89                // unconstrained int domain
90                if ranges.is_empty() {
91                    return Some(true);
92                };
93
94                Some(ranges.iter().any(|range| range.contains(x)))
95            }
96            (Domain::IntDomain(_), _) => Some(false),
97            (Domain::BoolDomain, Literal::Bool(_)) => Some(true),
98            (Domain::BoolDomain, _) => Some(false),
99            (Domain::DomainReference(_), _) => None,
100            (
101                Domain::DomainMatrix(elem_domain, index_domains),
102                Literal::AbstractLiteral(AbstractLiteral::Matrix(elems, idx_domain)),
103            ) => {
104                let mut index_domains = index_domains.clone();
105                if index_domains
106                    .pop()
107                    .expect("a matrix should have atleast one index domain")
108                    != *idx_domain
109                {
110                    return Some(false);
111                };
112
113                // matrix literals are represented as nested 1d matrices, so the elements of
114                // the matrix literal will be the inner dimensions of the matrix.
115                let next_elem_domain = if index_domains.is_empty() {
116                    elem_domain.as_ref().clone()
117                } else {
118                    Domain::DomainMatrix(elem_domain.clone(), index_domains)
119                };
120
121                for elem in elems {
122                    if !next_elem_domain.contains(elem)? {
123                        return Some(false);
124                    }
125                }
126
127                Some(true)
128            }
129            (
130                Domain::DomainTuple(elem_domains),
131                Literal::AbstractLiteral(AbstractLiteral::Tuple(literal_elems)),
132            ) => {
133                // for every element in the tuple literal, check if it is in the corresponding domain
134                for (elem_domain, elem) in itertools::izip!(elem_domains, literal_elems) {
135                    if !elem_domain.contains(elem)? {
136                        return Some(false);
137                    }
138                }
139
140                Some(true)
141            }
142            (
143                Domain::DomainSet(_, domain),
144                Literal::AbstractLiteral(AbstractLiteral::Set(literal_elems)),
145            ) => {
146                for elem in literal_elems {
147                    if !domain.contains(elem)? {
148                        return Some(false);
149                    }
150                }
151                Some(true)
152            }
153            (
154                Domain::DomainRecord(entries),
155                Literal::AbstractLiteral(AbstractLiteral::Record(lit_entries)),
156            ) => {
157                for (entry, lit_entry) in itertools::izip!(entries, lit_entries) {
158                    if entry.name != lit_entry.name || !(entry.domain.contains(&lit_entry.value)?) {
159                        return Some(false);
160                    }
161                }
162                Some(true)
163            }
164
165            (Domain::DomainRecord(_), _) => Some(false),
166
167            (Domain::DomainMatrix(_, _), _) => Some(false),
168
169            (Domain::DomainSet(_, _), _) => Some(false),
170
171            (Domain::DomainTuple(_), _) => Some(false),
172        }
173    }
174
175    /// Return a list of all possible i32 values in the domain if it is an IntDomain and is
176    /// bounded.
177    pub fn values_i32(&self) -> Option<Vec<i32>> {
178        match self {
179            Domain::IntDomain(ranges) => Some(
180                ranges
181                    .iter()
182                    .map(|r| match r {
183                        Range::Single(i) => Some(vec![*i]),
184                        Range::Bounded(i, j) => Some((*i..=*j).collect()),
185                        Range::UnboundedR(_) => None,
186                        Range::UnboundedL(_) => None,
187                    })
188                    .while_some()
189                    .flatten()
190                    .collect_vec(),
191            ),
192            _ => None,
193        }
194    }
195
196    // turns vector of integers into a domain
197    // TODO: can be done more compactly in terms of the domain we produce. e.g. instead of int(1,2,3,4,5,8,9,10) produce int(1..5, 8..10)
198    // needs to be tested with domain functions intersect() and uninon() once comprehension rules are written.
199    pub fn make_int_domain_from_values_i32(&self, vector: &[i32]) -> Option<Domain> {
200        let mut new_ranges = vec![];
201        for values in vector.iter() {
202            new_ranges.push(Range::Single(*values));
203        }
204        Some(Domain::IntDomain(new_ranges))
205    }
206
207    /// Gets all the values inside this domain, as a [`Literal`]. Returns `None` if the domain is not
208    /// finite.
209    pub fn values(&self) -> Option<Vec<Literal>> {
210        match self {
211            Domain::BoolDomain => Some(vec![false.into(), true.into()]),
212            Domain::IntDomain(_) => self
213                .values_i32()
214                .map(|xs| xs.iter().map(|x| Literal::Int(*x)).collect_vec()),
215
216            // ~niklasdewally: don't know how to define this for collections, so leaving it for
217            // now... However, it definitely can be done, as matrices can be indexed by matrices.
218            Domain::DomainSet(_, _) => todo!(),
219            Domain::DomainMatrix(_, _) => todo!(),
220            Domain::DomainReference(_) => None,
221            Domain::DomainTuple(_) => todo!(), // TODO: Can this be done?
222            Domain::DomainRecord(_) => todo!(),
223        }
224    }
225
226    /// Gets the length of this domain.
227    ///
228    /// Returns `None` if it is not finite.
229    pub fn length(&self) -> Option<usize> {
230        self.values().map(|x| x.len())
231    }
232
233    /// Return an unoptimised domain that is the result of applying a binary i32 operation to two domains.
234    ///
235    /// The given operator may return None if the operation is not defined for its arguments.
236    /// Undefined values will not be included in the resulting domain.
237    ///
238    /// Returns None if the domains are not valid for i32 operations.
239    pub fn apply_i32(&self, op: fn(i32, i32) -> Option<i32>, other: &Domain) -> Option<Domain> {
240        if let (Some(vs1), Some(vs2)) = (self.values_i32(), other.values_i32()) {
241            // TODO: (flm8) Optimise to use smarter, less brute-force methods
242            let mut new_ranges = vec![];
243            for (v1, v2) in itertools::iproduct!(vs1, vs2) {
244                if let Some(v) = op(v1, v2) {
245                    new_ranges.push(Range::Single(v))
246                }
247            }
248            return Some(Domain::IntDomain(new_ranges));
249        }
250        None
251    }
252
253    /// Whether this domain has a finite number of values.
254    ///
255    /// Returns `None` if this cannot be determined, e.g. if `self` is a domain reference.
256    pub fn is_finite(&self) -> Option<bool> {
257        for domain in self.universe() {
258            if let Domain::IntDomain(ranges) = domain {
259                if ranges.is_empty() {
260                    return Some(false);
261                }
262
263                if ranges
264                    .iter()
265                    .any(|range| matches!(range, Range::UnboundedL(_) | Range::UnboundedR(_)))
266                {
267                    return Some(false);
268                }
269            } else if let Domain::DomainReference(_) = domain {
270                return None;
271            }
272        }
273        Some(true)
274    }
275
276    /// Resolves this domain to a ground domain, using the symbol table provided to resolve
277    /// references.
278    ///
279    /// A domain is ground iff it is not a domain reference, nor contains any domain references.
280    ///
281    /// See also: [`SymbolTable::resolve_domain`](crate::ast::SymbolTable::resolve_domain).
282    ///
283    /// # Panics
284    ///
285    /// + If a reference domain in `self` does not exist in the given symbol table.
286    pub fn resolve(mut self, symbols: &SymbolTable) -> Domain {
287        // FIXME: cannot use Uniplate::transform here due to reference lifetime shenanigans...
288        // dont see any reason why Uniplate::transform requires a closure that only uses borrows
289        // with a 'static lifetime... ~niklasdewally
290        // ..
291        // Also, still want to make the Uniplate variant which uses FnOnce not Fn with methods that
292        // take self instead of &self -- that would come in handy here!
293
294        let mut done_something = true;
295        while done_something {
296            done_something = false;
297            for (domain, ctx) in self.clone().contexts() {
298                if let Domain::DomainReference(name) = domain {
299                    self = ctx(symbols
300                        .resolve_domain(&name)
301                        .expect("domain reference should exist in the symbol table")
302                        .resolve(symbols));
303                    done_something = true;
304                }
305            }
306        }
307        self
308    }
309
310    // simplified domain intersection function. defined for integer domains of sets
311    // TODO: does not consider unbounded domains yet
312    // needs to be tested once comprehension rules are written
313    pub fn intersect(&self, other: &Domain) -> Option<Domain> {
314        match (self, other) {
315            (Domain::DomainSet(_, x), Domain::DomainSet(_, y)) => Some(Domain::DomainSet(
316                SetAttr::None,
317                Box::new((*x).intersect(y)?),
318            )),
319            (Domain::IntDomain(_), Domain::IntDomain(_)) => {
320                let mut v: Vec<i32> = vec![];
321                if self.is_finite()? && other.is_finite()? {
322                    if let (Some(v1), Some(v2)) = (self.values_i32(), other.values_i32()) {
323                        for value1 in v1.iter() {
324                            if v2.contains(value1) && !v.contains(value1) {
325                                v.push(*value1)
326                            }
327                        }
328                    }
329                    self.make_int_domain_from_values_i32(&v)
330                } else {
331                    println!("Unbounded domain");
332                    None
333                }
334            }
335            _ => None,
336        }
337    }
338
339    // simplified domain union function. defined for integer domains of sets
340    // TODO: does not consider unbounded domains yet
341    // needs to be tested once comprehension rules are written
342    pub fn union(&self, other: &Domain) -> Option<Domain> {
343        match (self, other) {
344            (Domain::DomainSet(_, x), Domain::DomainSet(_, y)) => {
345                Some(Domain::DomainSet(SetAttr::None, Box::new((*x).union(y)?)))
346            }
347            (Domain::IntDomain(_), Domain::IntDomain(_)) => {
348                let mut v: Vec<i32> = vec![];
349                if self.is_finite()? && other.is_finite()? {
350                    if let (Some(v1), Some(v2)) = (self.values_i32(), other.values_i32()) {
351                        for value1 in v1.iter() {
352                            v.push(*value1);
353                        }
354                        for value2 in v2.iter() {
355                            if !v.contains(value2) {
356                                v.push(*value2);
357                            }
358                        }
359                    }
360                    self.make_int_domain_from_values_i32(&v)
361                } else {
362                    println!("Unbounded Domain");
363                    None
364                }
365            }
366            _ => None,
367        }
368    }
369}
370
371impl Display for Domain {
372    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
373        match self {
374            Domain::BoolDomain => {
375                write!(f, "bool")
376            }
377            Domain::IntDomain(vec) => {
378                let domain_ranges: String = vec.iter().map(|x| format!("{x}")).join(",");
379
380                if domain_ranges.is_empty() {
381                    write!(f, "int")
382                } else {
383                    write!(f, "int({domain_ranges})")
384                }
385            }
386            Domain::DomainReference(name) => write!(f, "{}", name),
387            Domain::DomainSet(_, domain) => {
388                write!(f, "set of ({})", domain)
389            }
390            Domain::DomainMatrix(value_domain, index_domains) => {
391                write!(
392                    f,
393                    "matrix indexed by [{}] of {value_domain}",
394                    pretty_vec(&index_domains.iter().collect_vec())
395                )
396            }
397            Domain::DomainTuple(domains) => {
398                write!(
399                    f,
400                    "tuple of ({})",
401                    pretty_vec(&domains.iter().collect_vec())
402                )
403            }
404            Domain::DomainRecord(entries) => {
405                write!(
406                    f,
407                    "record of ({})",
408                    pretty_vec(
409                        &entries
410                            .iter()
411                            .map(|entry| format!("{}: {}", entry.name, entry.domain))
412                            .collect_vec()
413                    )
414                )
415            }
416        }
417    }
418}
419
420impl Typeable for Domain {
421    fn return_type(&self) -> Option<ReturnType> {
422        todo!()
423    }
424}
425#[cfg(test)]
426mod tests {
427    use super::*;
428
429    #[test]
430    fn test_negative_product() {
431        let d1 = Domain::IntDomain(vec![Range::Bounded(-2, 1)]);
432        let d2 = Domain::IntDomain(vec![Range::Bounded(-2, 1)]);
433        let res = d1.apply_i32(|a, b| Some(a * b), &d2).unwrap();
434
435        assert!(matches!(res, Domain::IntDomain(_)));
436        if let Domain::IntDomain(ranges) = res {
437            assert!(!ranges.contains(&Range::Single(-4)));
438            assert!(!ranges.contains(&Range::Single(-3)));
439            assert!(ranges.contains(&Range::Single(-2)));
440            assert!(ranges.contains(&Range::Single(-1)));
441            assert!(ranges.contains(&Range::Single(0)));
442            assert!(ranges.contains(&Range::Single(1)));
443            assert!(ranges.contains(&Range::Single(2)));
444            assert!(!ranges.contains(&Range::Single(3)));
445            assert!(ranges.contains(&Range::Single(4)));
446        }
447    }
448
449    #[test]
450    fn test_negative_div() {
451        let d1 = Domain::IntDomain(vec![Range::Bounded(-2, 1)]);
452        let d2 = Domain::IntDomain(vec![Range::Bounded(-2, 1)]);
453        let res = d1
454            .apply_i32(|a, b| if b != 0 { Some(a / b) } else { None }, &d2)
455            .unwrap();
456
457        assert!(matches!(res, Domain::IntDomain(_)));
458        if let Domain::IntDomain(ranges) = res {
459            assert!(!ranges.contains(&Range::Single(-4)));
460            assert!(!ranges.contains(&Range::Single(-3)));
461            assert!(ranges.contains(&Range::Single(-2)));
462            assert!(ranges.contains(&Range::Single(-1)));
463            assert!(ranges.contains(&Range::Single(0)));
464            assert!(ranges.contains(&Range::Single(1)));
465            assert!(ranges.contains(&Range::Single(2)));
466            assert!(!ranges.contains(&Range::Single(3)));
467            assert!(!ranges.contains(&Range::Single(4)));
468        }
469    }
470}