1
use std::fmt::Display;
2

            
3
use itertools::Itertools;
4
use serde::{Deserialize, Serialize};
5

            
6
use crate::ast::pretty::pretty_vec;
7

            
8
use super::{types::Typeable, AbstractLiteral, Literal, Name, ReturnType};
9

            
10
#[derive(Clone, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
11
pub enum Range<A>
12
where
13
    A: Ord,
14
{
15
    Single(A),
16
    Bounded(A, A),
17

            
18
    /// int(i..)
19
    UnboundedR(A),
20

            
21
    /// int(..i)
22
    UnboundedL(A),
23
}
24

            
25
impl<A: Ord> Range<A> {
26
630
    pub fn contains(&self, val: &A) -> bool {
27
630
        match self {
28
            Range::Single(x) => x == val,
29
630
            Range::Bounded(x, y) => x <= val && val <= y,
30
            Range::UnboundedR(x) => x <= val,
31
            Range::UnboundedL(x) => x >= val,
32
        }
33
630
    }
34
}
35

            
36
impl<A: Ord + Display> Display for Range<A> {
37
32490
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
38
32490
        match self {
39
864
            Range::Single(i) => write!(f, "{i}"),
40
14184
            Range::Bounded(i, j) => write!(f, "{i}..{j}"),
41
17442
            Range::UnboundedR(i) => write!(f, "{i}.."),
42
            Range::UnboundedL(i) => write!(f, "..{i}"),
43
        }
44
32490
    }
45
}
46

            
47
#[derive(Clone, Debug, Eq, PartialEq, Hash, Serialize, Deserialize)]
48
pub enum Domain {
49
    BoolDomain,
50

            
51
    /// An integer domain.
52
    ///
53
    /// + If multiple ranges are inside the domain, the values in the domain are the union of these
54
    ///   ranges.
55
    ///
56
    /// + If no ranges are given, the int domain is considered unconstrained, and can take any
57
    ///   integer value.
58
    IntDomain(Vec<Range<i32>>),
59
    DomainReference(Name),
60
    DomainSet(SetAttr, Box<Domain>),
61
    /// A n-dimensional matrix with a value domain and n-index domains
62
    DomainMatrix(Box<Domain>, Vec<Domain>),
63
}
64

            
65
#[derive(Clone, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
66
pub enum SetAttr {
67
    None,
68
    Size(i32),
69
    MinSize(i32),
70
    MaxSize(i32),
71
    MinMaxSize(i32, i32),
72
}
73
impl Domain {
74
    // Whether the literal is a member of this domain.
75
    //
76
    // Returns `None` if this cannot be determined (e.g. `self` is a `DomainReference`).
77
630
    pub fn contains(&self, lit: &Literal) -> Option<bool> {
78
630
        // not adding a generic wildcard condition for all domains, so that this gives a compile
79
630
        // error when a domain is added.
80
630
        match (self, lit) {
81
630
            (Domain::IntDomain(ranges), Literal::Int(x)) => {
82
630
                // unconstrained int domain
83
630
                if ranges.is_empty() {
84
                    return Some(true);
85
630
                };
86
630

            
87
630
                Some(ranges.iter().any(|range| range.contains(x)))
88
            }
89
            (Domain::IntDomain(_), _) => Some(false),
90
            (Domain::BoolDomain, Literal::Bool(_)) => Some(true),
91
            (Domain::BoolDomain, _) => Some(false),
92
            (Domain::DomainReference(_), _) => None,
93

            
94
            (
95
                Domain::DomainMatrix(elem_domain, index_domains),
96
                Literal::AbstractLiteral(AbstractLiteral::Matrix(elems, idx_domain)),
97
            ) => {
98
                let mut index_domains = index_domains.clone();
99
                if index_domains
100
                    .pop()
101
                    .expect("a matrix should have atleast one index domain")
102
                    != *idx_domain
103
                {
104
                    return Some(false);
105
                };
106

            
107
                // matrix literals are represented as nested 1d matrices, so the elements of
108
                // the matrix literal will be the inner dimensions of the matrix.
109
                let next_elem_domain = if index_domains.is_empty() {
110
                    elem_domain.as_ref().clone()
111
                } else {
112
                    Domain::DomainMatrix(elem_domain.clone(), index_domains)
113
                };
114

            
115
                for elem in elems {
116
                    if !next_elem_domain.contains(elem)? {
117
                        return Some(false);
118
                    }
119
                }
120

            
121
                Some(true)
122
            }
123
            (Domain::DomainMatrix(_, _), _) => Some(false),
124
            (Domain::DomainSet(_, _), Literal::AbstractLiteral(AbstractLiteral::Set(_))) => {
125
                todo!()
126
            }
127
            (Domain::DomainSet(_, _), _) => Some(false),
128
        }
129
630
    }
130
    /// Return a list of all possible i32 values in the domain if it is an IntDomain and is
131
    /// bounded.
132
4260
    pub fn values_i32(&self) -> Option<Vec<i32>> {
133
4260
        match self {
134
4259
            Domain::IntDomain(ranges) => Some(
135
4259
                ranges
136
4259
                    .iter()
137
6833
                    .map(|r| match r {
138
3173
                        Range::Single(i) => Some(vec![*i]),
139
3660
                        Range::Bounded(i, j) => Some((*i..=*j).collect()),
140
                        Range::UnboundedR(_) => None,
141
                        Range::UnboundedL(_) => None,
142
6833
                    })
143
4259
                    .while_some()
144
4259
                    .flatten()
145
4259
                    .collect_vec(),
146
4259
            ),
147
1
            _ => None,
148
        }
149
4260
    }
150

            
151
    /// Return an unoptimised domain that is the result of applying a binary i32 operation to two domains.
152
    ///
153
    /// The given operator may return None if the operation is not defined for its arguments.
154
    /// Undefined values will not be included in the resulting domain.
155
    ///
156
    /// Returns None if the domains are not valid for i32 operations.
157
2130
    pub fn apply_i32(&self, op: fn(i32, i32) -> Option<i32>, other: &Domain) -> Option<Domain> {
158
2130
        if let (Some(vs1), Some(vs2)) = (self.values_i32(), other.values_i32()) {
159
            // TODO: (flm8) Optimise to use smarter, less brute-force methods
160
2129
            let mut new_ranges = vec![];
161
93026
            for (v1, v2) in itertools::iproduct!(vs1, vs2) {
162
93026
                if let Some(v) = op(v1, v2) {
163
87802
                    new_ranges.push(Range::Single(v))
164
5224
                }
165
            }
166
2129
            return Some(Domain::IntDomain(new_ranges));
167
1
        }
168
1
        None
169
2130
    }
170
}
171

            
172
impl Display for Domain {
173
34884
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
174
34884
        match self {
175
            Domain::BoolDomain => {
176
2502
                write!(f, "bool")
177
            }
178
32058
            Domain::IntDomain(vec) => {
179
32490
                let domain_ranges: String = vec.iter().map(|x| format!("{x}")).join(",");
180
32058

            
181
32058
                if domain_ranges.is_empty() {
182
                    write!(f, "int")
183
                } else {
184
32058
                    write!(f, "int({domain_ranges})")
185
                }
186
            }
187
90
            Domain::DomainReference(name) => write!(f, "{}", name),
188
            Domain::DomainSet(_, domain) => {
189
                write!(f, "set of ({})", domain)
190
            }
191
234
            Domain::DomainMatrix(value_domain, index_domains) => {
192
234
                write!(
193
234
                    f,
194
234
                    "matrix indexed by [{}] of {value_domain}",
195
234
                    pretty_vec(&index_domains.iter().collect_vec())
196
234
                )
197
            }
198
        }
199
34884
    }
200
}
201

            
202
impl Typeable for Domain {
203
    fn return_type(&self) -> Option<ReturnType> {
204
        todo!()
205
    }
206
}
207
#[cfg(test)]
208
mod tests {
209
    use super::*;
210

            
211
    #[test]
212
1
    fn test_negative_product() {
213
1
        let d1 = Domain::IntDomain(vec![Range::Bounded(-2, 1)]);
214
1
        let d2 = Domain::IntDomain(vec![Range::Bounded(-2, 1)]);
215
16
        let res = d1.apply_i32(|a, b| Some(a * b), &d2).unwrap();
216
1

            
217
1
        assert!(matches!(res, Domain::IntDomain(_)));
218
1
        if let Domain::IntDomain(ranges) = res {
219
1
            assert!(!ranges.contains(&Range::Single(-4)));
220
1
            assert!(!ranges.contains(&Range::Single(-3)));
221
1
            assert!(ranges.contains(&Range::Single(-2)));
222
1
            assert!(ranges.contains(&Range::Single(-1)));
223
1
            assert!(ranges.contains(&Range::Single(0)));
224
1
            assert!(ranges.contains(&Range::Single(1)));
225
1
            assert!(ranges.contains(&Range::Single(2)));
226
1
            assert!(!ranges.contains(&Range::Single(3)));
227
1
            assert!(ranges.contains(&Range::Single(4)));
228
        }
229
1
    }
230

            
231
    #[test]
232
1
    fn test_negative_div() {
233
1
        let d1 = Domain::IntDomain(vec![Range::Bounded(-2, 1)]);
234
1
        let d2 = Domain::IntDomain(vec![Range::Bounded(-2, 1)]);
235
1
        let res = d1
236
16
            .apply_i32(|a, b| if b != 0 { Some(a / b) } else { None }, &d2)
237
1
            .unwrap();
238
1

            
239
1
        assert!(matches!(res, Domain::IntDomain(_)));
240
1
        if let Domain::IntDomain(ranges) = res {
241
1
            assert!(!ranges.contains(&Range::Single(-4)));
242
1
            assert!(!ranges.contains(&Range::Single(-3)));
243
1
            assert!(ranges.contains(&Range::Single(-2)));
244
1
            assert!(ranges.contains(&Range::Single(-1)));
245
1
            assert!(ranges.contains(&Range::Single(0)));
246
1
            assert!(ranges.contains(&Range::Single(1)));
247
1
            assert!(ranges.contains(&Range::Single(2)));
248
1
            assert!(!ranges.contains(&Range::Single(3)));
249
1
            assert!(!ranges.contains(&Range::Single(4)));
250
        }
251
1
    }
252
}