conjure_core/ast/
domains.rs

1use std::fmt::Display;
2
3use itertools::Itertools;
4use serde::{Deserialize, Serialize};
5
6use crate::ast::pretty::pretty_vec;
7
8use super::{types::Typeable, AbstractLiteral, Literal, Name, ReturnType};
9
10#[derive(Clone, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
11pub enum Range<A>
12where
13    A: Ord,
14{
15    Single(A),
16    Bounded(A, A),
17
18    /// int(i..)
19    UnboundedR(A),
20
21    /// int(..i)
22    UnboundedL(A),
23}
24
25impl<A: Ord> Range<A> {
26    pub fn contains(&self, val: &A) -> bool {
27        match self {
28            Range::Single(x) => x == val,
29            Range::Bounded(x, y) => x <= val && val <= y,
30            Range::UnboundedR(x) => x <= val,
31            Range::UnboundedL(x) => x >= val,
32        }
33    }
34}
35
36impl<A: Ord + Display> Display for Range<A> {
37    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
38        match self {
39            Range::Single(i) => write!(f, "{i}"),
40            Range::Bounded(i, j) => write!(f, "{i}..{j}"),
41            Range::UnboundedR(i) => write!(f, "{i}.."),
42            Range::UnboundedL(i) => write!(f, "..{i}"),
43        }
44    }
45}
46
47#[derive(Clone, Debug, Eq, PartialEq, Hash, Serialize, Deserialize)]
48pub enum Domain {
49    BoolDomain,
50
51    /// An integer domain.
52    ///
53    /// + If multiple ranges are inside the domain, the values in the domain are the union of these
54    ///   ranges.
55    ///
56    /// + If no ranges are given, the int domain is considered unconstrained, and can take any
57    ///   integer value.
58    IntDomain(Vec<Range<i32>>),
59    DomainReference(Name),
60    DomainSet(SetAttr, Box<Domain>),
61    /// A n-dimensional matrix with a value domain and n-index domains
62    DomainMatrix(Box<Domain>, Vec<Domain>),
63}
64
65#[derive(Clone, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
66pub enum SetAttr {
67    None,
68    Size(i32),
69    MinSize(i32),
70    MaxSize(i32),
71    MinMaxSize(i32, i32),
72}
73impl Domain {
74    // Whether the literal is a member of this domain.
75    //
76    // Returns `None` if this cannot be determined (e.g. `self` is a `DomainReference`).
77    pub fn contains(&self, lit: &Literal) -> Option<bool> {
78        // not adding a generic wildcard condition for all domains, so that this gives a compile
79        // error when a domain is added.
80        match (self, lit) {
81            (Domain::IntDomain(ranges), Literal::Int(x)) => {
82                // unconstrained int domain
83                if ranges.is_empty() {
84                    return Some(true);
85                };
86
87                Some(ranges.iter().any(|range| range.contains(x)))
88            }
89            (Domain::IntDomain(_), _) => Some(false),
90            (Domain::BoolDomain, Literal::Bool(_)) => Some(true),
91            (Domain::BoolDomain, _) => Some(false),
92            (Domain::DomainReference(_), _) => None,
93
94            (
95                Domain::DomainMatrix(elem_domain, index_domains),
96                Literal::AbstractLiteral(AbstractLiteral::Matrix(elems, idx_domain)),
97            ) => {
98                let mut index_domains = index_domains.clone();
99                if index_domains
100                    .pop()
101                    .expect("a matrix should have atleast one index domain")
102                    != *idx_domain
103                {
104                    return Some(false);
105                };
106
107                // matrix literals are represented as nested 1d matrices, so the elements of
108                // the matrix literal will be the inner dimensions of the matrix.
109                let next_elem_domain = if index_domains.is_empty() {
110                    elem_domain.as_ref().clone()
111                } else {
112                    Domain::DomainMatrix(elem_domain.clone(), index_domains)
113                };
114
115                for elem in elems {
116                    if !next_elem_domain.contains(elem)? {
117                        return Some(false);
118                    }
119                }
120
121                Some(true)
122            }
123            (Domain::DomainMatrix(_, _), _) => Some(false),
124            (Domain::DomainSet(_, _), Literal::AbstractLiteral(AbstractLiteral::Set(_))) => {
125                todo!()
126            }
127            (Domain::DomainSet(_, _), _) => Some(false),
128        }
129    }
130    /// Return a list of all possible i32 values in the domain if it is an IntDomain and is
131    /// bounded.
132    pub fn values_i32(&self) -> Option<Vec<i32>> {
133        match self {
134            Domain::IntDomain(ranges) => Some(
135                ranges
136                    .iter()
137                    .map(|r| match r {
138                        Range::Single(i) => Some(vec![*i]),
139                        Range::Bounded(i, j) => Some((*i..=*j).collect()),
140                        Range::UnboundedR(_) => None,
141                        Range::UnboundedL(_) => None,
142                    })
143                    .while_some()
144                    .flatten()
145                    .collect_vec(),
146            ),
147            _ => None,
148        }
149    }
150
151    /// Return an unoptimised domain that is the result of applying a binary i32 operation to two domains.
152    ///
153    /// The given operator may return None if the operation is not defined for its arguments.
154    /// Undefined values will not be included in the resulting domain.
155    ///
156    /// Returns None if the domains are not valid for i32 operations.
157    pub fn apply_i32(&self, op: fn(i32, i32) -> Option<i32>, other: &Domain) -> Option<Domain> {
158        if let (Some(vs1), Some(vs2)) = (self.values_i32(), other.values_i32()) {
159            // TODO: (flm8) Optimise to use smarter, less brute-force methods
160            let mut new_ranges = vec![];
161            for (v1, v2) in itertools::iproduct!(vs1, vs2) {
162                if let Some(v) = op(v1, v2) {
163                    new_ranges.push(Range::Single(v))
164                }
165            }
166            return Some(Domain::IntDomain(new_ranges));
167        }
168        None
169    }
170}
171
172impl Display for Domain {
173    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
174        match self {
175            Domain::BoolDomain => {
176                write!(f, "bool")
177            }
178            Domain::IntDomain(vec) => {
179                let domain_ranges: String = vec.iter().map(|x| format!("{x}")).join(",");
180
181                if domain_ranges.is_empty() {
182                    write!(f, "int")
183                } else {
184                    write!(f, "int({domain_ranges})")
185                }
186            }
187            Domain::DomainReference(name) => write!(f, "{}", name),
188            Domain::DomainSet(_, domain) => {
189                write!(f, "set of ({})", domain)
190            }
191            Domain::DomainMatrix(value_domain, index_domains) => {
192                write!(
193                    f,
194                    "matrix indexed by [{}] of {value_domain}",
195                    pretty_vec(&index_domains.iter().collect_vec())
196                )
197            }
198        }
199    }
200}
201
202impl Typeable for Domain {
203    fn return_type(&self) -> Option<ReturnType> {
204        todo!()
205    }
206}
207#[cfg(test)]
208mod tests {
209    use super::*;
210
211    #[test]
212    fn test_negative_product() {
213        let d1 = Domain::IntDomain(vec![Range::Bounded(-2, 1)]);
214        let d2 = Domain::IntDomain(vec![Range::Bounded(-2, 1)]);
215        let res = d1.apply_i32(|a, b| Some(a * b), &d2).unwrap();
216
217        assert!(matches!(res, Domain::IntDomain(_)));
218        if let Domain::IntDomain(ranges) = res {
219            assert!(!ranges.contains(&Range::Single(-4)));
220            assert!(!ranges.contains(&Range::Single(-3)));
221            assert!(ranges.contains(&Range::Single(-2)));
222            assert!(ranges.contains(&Range::Single(-1)));
223            assert!(ranges.contains(&Range::Single(0)));
224            assert!(ranges.contains(&Range::Single(1)));
225            assert!(ranges.contains(&Range::Single(2)));
226            assert!(!ranges.contains(&Range::Single(3)));
227            assert!(ranges.contains(&Range::Single(4)));
228        }
229    }
230
231    #[test]
232    fn test_negative_div() {
233        let d1 = Domain::IntDomain(vec![Range::Bounded(-2, 1)]);
234        let d2 = Domain::IntDomain(vec![Range::Bounded(-2, 1)]);
235        let res = d1
236            .apply_i32(|a, b| if b != 0 { Some(a / b) } else { None }, &d2)
237            .unwrap();
238
239        assert!(matches!(res, Domain::IntDomain(_)));
240        if let Domain::IntDomain(ranges) = res {
241            assert!(!ranges.contains(&Range::Single(-4)));
242            assert!(!ranges.contains(&Range::Single(-3)));
243            assert!(ranges.contains(&Range::Single(-2)));
244            assert!(ranges.contains(&Range::Single(-1)));
245            assert!(ranges.contains(&Range::Single(0)));
246            assert!(ranges.contains(&Range::Single(1)));
247            assert!(ranges.contains(&Range::Single(2)));
248            assert!(!ranges.contains(&Range::Single(3)));
249            assert!(!ranges.contains(&Range::Single(4)));
250        }
251    }
252}