conjure_core/ast/
domains.rs

1use std::fmt::Display;
2
3use conjure_core::ast::SymbolTable;
4use itertools::Itertools;
5use serde::{Deserialize, Serialize};
6
7use crate::ast::pretty::pretty_vec;
8use uniplate::{derive::Uniplate, Uniplate};
9
10use super::{types::Typeable, AbstractLiteral, Literal, Name, ReturnType};
11
12#[derive(Clone, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
13pub enum Range<A>
14where
15    A: Ord,
16{
17    Single(A),
18    Bounded(A, A),
19
20    /// int(i..)
21    UnboundedR(A),
22
23    /// int(..i)
24    UnboundedL(A),
25}
26
27impl<A: Ord> Range<A> {
28    pub fn contains(&self, val: &A) -> bool {
29        match self {
30            Range::Single(x) => x == val,
31            Range::Bounded(x, y) => x <= val && val <= y,
32            Range::UnboundedR(x) => x <= val,
33            Range::UnboundedL(x) => x >= val,
34        }
35    }
36}
37
38impl<A: Ord + Display> Display for Range<A> {
39    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
40        match self {
41            Range::Single(i) => write!(f, "{i}"),
42            Range::Bounded(i, j) => write!(f, "{i}..{j}"),
43            Range::UnboundedR(i) => write!(f, "{i}.."),
44            Range::UnboundedL(i) => write!(f, "..{i}"),
45        }
46    }
47}
48
49#[derive(Clone, Debug, Eq, PartialEq, Hash, Serialize, Deserialize, Uniplate)]
50#[uniplate()]
51pub enum Domain {
52    BoolDomain,
53
54    /// An integer domain.
55    ///
56    /// + If multiple ranges are inside the domain, the values in the domain are the union of these
57    ///   ranges.
58    ///
59    /// + If no ranges are given, the int domain is considered unconstrained, and can take any
60    ///   integer value.
61    IntDomain(Vec<Range<i32>>),
62    DomainReference(Name),
63    DomainSet(SetAttr, Box<Domain>),
64    /// A n-dimensional matrix with a value domain and n-index domains
65    DomainMatrix(Box<Domain>, Vec<Domain>),
66}
67
68#[derive(Clone, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
69pub enum SetAttr {
70    None,
71    Size(i32),
72    MinSize(i32),
73    MaxSize(i32),
74    MinMaxSize(i32, i32),
75}
76impl Domain {
77    // Whether the literal is a member of this domain.
78    //
79    // Returns `None` if this cannot be determined (e.g. `self` is a `DomainReference`).
80    pub fn contains(&self, lit: &Literal) -> Option<bool> {
81        // not adding a generic wildcard condition for all domains, so that this gives a compile
82        // error when a domain is added.
83        match (self, lit) {
84            (Domain::IntDomain(ranges), Literal::Int(x)) => {
85                // unconstrained int domain
86                if ranges.is_empty() {
87                    return Some(true);
88                };
89
90                Some(ranges.iter().any(|range| range.contains(x)))
91            }
92            (Domain::IntDomain(_), _) => Some(false),
93            (Domain::BoolDomain, Literal::Bool(_)) => Some(true),
94            (Domain::BoolDomain, _) => Some(false),
95            (Domain::DomainReference(_), _) => None,
96
97            (
98                Domain::DomainMatrix(elem_domain, index_domains),
99                Literal::AbstractLiteral(AbstractLiteral::Matrix(elems, idx_domain)),
100            ) => {
101                let mut index_domains = index_domains.clone();
102                if index_domains
103                    .pop()
104                    .expect("a matrix should have atleast one index domain")
105                    != *idx_domain
106                {
107                    return Some(false);
108                };
109
110                // matrix literals are represented as nested 1d matrices, so the elements of
111                // the matrix literal will be the inner dimensions of the matrix.
112                let next_elem_domain = if index_domains.is_empty() {
113                    elem_domain.as_ref().clone()
114                } else {
115                    Domain::DomainMatrix(elem_domain.clone(), index_domains)
116                };
117
118                for elem in elems {
119                    if !next_elem_domain.contains(elem)? {
120                        return Some(false);
121                    }
122                }
123
124                Some(true)
125            }
126            (Domain::DomainMatrix(_, _), _) => Some(false),
127            (Domain::DomainSet(_, _), Literal::AbstractLiteral(AbstractLiteral::Set(_))) => {
128                todo!()
129            }
130            (Domain::DomainSet(_, _), _) => Some(false),
131        }
132    }
133
134    /// Return a list of all possible i32 values in the domain if it is an IntDomain and is
135    /// bounded.
136    pub fn values_i32(&self) -> Option<Vec<i32>> {
137        match self {
138            Domain::IntDomain(ranges) => Some(
139                ranges
140                    .iter()
141                    .map(|r| match r {
142                        Range::Single(i) => Some(vec![*i]),
143                        Range::Bounded(i, j) => Some((*i..=*j).collect()),
144                        Range::UnboundedR(_) => None,
145                        Range::UnboundedL(_) => None,
146                    })
147                    .while_some()
148                    .flatten()
149                    .collect_vec(),
150            ),
151            _ => None,
152        }
153    }
154
155    /// Gets all the values inside this domain, as a [`Literal`]. Returns `None` if the domain is not
156    /// finite.
157    pub fn values(&self) -> Option<Vec<Literal>> {
158        match self {
159            Domain::BoolDomain => Some(vec![false.into(), true.into()]),
160            Domain::IntDomain(_) => self
161                .values_i32()
162                .map(|xs| xs.iter().map(|x| Literal::Int(*x)).collect_vec()),
163
164            // ~niklasdewally: don't know how to define this for collections, so leaving it for
165            // now... However, it definitely can be done, as matrices can be indexed by matrices.
166            Domain::DomainSet(_, _) => todo!(),
167            Domain::DomainMatrix(_, _) => todo!(),
168            Domain::DomainReference(_) => None,
169        }
170    }
171
172    /// Gets the length of this domain.
173    ///
174    /// Returns `None` if it is not finite.
175    pub fn length(&self) -> Option<usize> {
176        self.values().map(|x| x.len())
177    }
178
179    /// Return an unoptimised domain that is the result of applying a binary i32 operation to two domains.
180    ///
181    /// The given operator may return None if the operation is not defined for its arguments.
182    /// Undefined values will not be included in the resulting domain.
183    ///
184    /// Returns None if the domains are not valid for i32 operations.
185    pub fn apply_i32(&self, op: fn(i32, i32) -> Option<i32>, other: &Domain) -> Option<Domain> {
186        if let (Some(vs1), Some(vs2)) = (self.values_i32(), other.values_i32()) {
187            // TODO: (flm8) Optimise to use smarter, less brute-force methods
188            let mut new_ranges = vec![];
189            for (v1, v2) in itertools::iproduct!(vs1, vs2) {
190                if let Some(v) = op(v1, v2) {
191                    new_ranges.push(Range::Single(v))
192                }
193            }
194            return Some(Domain::IntDomain(new_ranges));
195        }
196        None
197    }
198
199    /// Whether this domain has a finite number of values.
200    ///
201    /// Returns `None` if this cannot be determined, e.g. if `self` is a domain reference.
202    pub fn is_finite(&self) -> Option<bool> {
203        for domain in self.universe() {
204            if let Domain::IntDomain(ranges) = domain {
205                if ranges.is_empty() {
206                    return Some(false);
207                }
208
209                if ranges
210                    .iter()
211                    .any(|range| matches!(range, Range::UnboundedL(_) | Range::UnboundedR(_)))
212                {
213                    return Some(false);
214                }
215            } else if let Domain::DomainReference(_) = domain {
216                return None;
217            }
218        }
219        Some(true)
220    }
221
222    /// Resolves this domain to a ground domain, using the symbol table provided to resolve
223    /// references.
224    ///
225    /// A domain is ground iff it is not a domain reference, nor contains any domain references.
226    ///
227    /// See also: [`SymbolTable::resolve_domain`](crate::ast::SymbolTable::resolve_domain).
228    ///
229    /// # Panics
230    ///
231    /// + If a reference domain in `self` does not exist in the given symbol table.
232    pub fn resolve(mut self, symbols: &SymbolTable) -> Domain {
233        // FIXME: cannot use Uniplate::transform here due to reference lifetime shenanigans...
234        // dont see any reason why Uniplate::transform requires a closure that only uses borrows
235        // with a 'static lifetime... ~niklasdewally
236        // ..
237        // Also, still want to make the Uniplate variant which uses FnOnce not Fn with methods that
238        // take self instead of &self -- that would come in handy here!
239
240        let mut done_something = true;
241        while done_something {
242            done_something = false;
243            for (domain, ctx) in self.clone().contexts() {
244                if let Domain::DomainReference(name) = domain {
245                    self = ctx(symbols
246                        .resolve_domain(&name)
247                        .expect("domain reference should exist in the symbol table")
248                        .resolve(symbols));
249                    done_something = true;
250                }
251            }
252        }
253        self
254    }
255}
256
257impl Display for Domain {
258    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
259        match self {
260            Domain::BoolDomain => {
261                write!(f, "bool")
262            }
263            Domain::IntDomain(vec) => {
264                let domain_ranges: String = vec.iter().map(|x| format!("{x}")).join(",");
265
266                if domain_ranges.is_empty() {
267                    write!(f, "int")
268                } else {
269                    write!(f, "int({domain_ranges})")
270                }
271            }
272            Domain::DomainReference(name) => write!(f, "{}", name),
273            Domain::DomainSet(_, domain) => {
274                write!(f, "set of ({})", domain)
275            }
276            Domain::DomainMatrix(value_domain, index_domains) => {
277                write!(
278                    f,
279                    "matrix indexed by [{}] of {value_domain}",
280                    pretty_vec(&index_domains.iter().collect_vec())
281                )
282            }
283        }
284    }
285}
286
287impl Typeable for Domain {
288    fn return_type(&self) -> Option<ReturnType> {
289        todo!()
290    }
291}
292#[cfg(test)]
293mod tests {
294    use super::*;
295
296    #[test]
297    fn test_negative_product() {
298        let d1 = Domain::IntDomain(vec![Range::Bounded(-2, 1)]);
299        let d2 = Domain::IntDomain(vec![Range::Bounded(-2, 1)]);
300        let res = d1.apply_i32(|a, b| Some(a * b), &d2).unwrap();
301
302        assert!(matches!(res, Domain::IntDomain(_)));
303        if let Domain::IntDomain(ranges) = res {
304            assert!(!ranges.contains(&Range::Single(-4)));
305            assert!(!ranges.contains(&Range::Single(-3)));
306            assert!(ranges.contains(&Range::Single(-2)));
307            assert!(ranges.contains(&Range::Single(-1)));
308            assert!(ranges.contains(&Range::Single(0)));
309            assert!(ranges.contains(&Range::Single(1)));
310            assert!(ranges.contains(&Range::Single(2)));
311            assert!(!ranges.contains(&Range::Single(3)));
312            assert!(ranges.contains(&Range::Single(4)));
313        }
314    }
315
316    #[test]
317    fn test_negative_div() {
318        let d1 = Domain::IntDomain(vec![Range::Bounded(-2, 1)]);
319        let d2 = Domain::IntDomain(vec![Range::Bounded(-2, 1)]);
320        let res = d1
321            .apply_i32(|a, b| if b != 0 { Some(a / b) } else { None }, &d2)
322            .unwrap();
323
324        assert!(matches!(res, Domain::IntDomain(_)));
325        if let Domain::IntDomain(ranges) = res {
326            assert!(!ranges.contains(&Range::Single(-4)));
327            assert!(!ranges.contains(&Range::Single(-3)));
328            assert!(ranges.contains(&Range::Single(-2)));
329            assert!(ranges.contains(&Range::Single(-1)));
330            assert!(ranges.contains(&Range::Single(0)));
331            assert!(ranges.contains(&Range::Single(1)));
332            assert!(ranges.contains(&Range::Single(2)));
333            assert!(!ranges.contains(&Range::Single(3)));
334            assert!(!ranges.contains(&Range::Single(4)));
335        }
336    }
337}