1
use serde::{Deserialize, Serialize};
2

            
3
#[derive(Clone, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
4
pub enum Range<A>
5
where
6
    A: Ord,
7
{
8
    Single(A),
9
    Bounded(A, A),
10
}
11

            
12
#[derive(Clone, Debug, Eq, PartialEq, Hash, Serialize, Deserialize)]
13
pub enum Domain {
14
    BoolDomain,
15
    IntDomain(Vec<Range<i32>>),
16
}
17

            
18
impl Domain {
19
    /// Return a list of all possible i32 values in the domain if it is an IntDomain.
20
3990
    pub fn values_i32(&self) -> Option<Vec<i32>> {
21
3990
        match self {
22
3989
            Domain::IntDomain(ranges) => Some(
23
3989
                ranges
24
3989
                    .iter()
25
6420
                    .flat_map(|r| match r {
26
2980
                        Range::Single(i) => vec![*i],
27
3440
                        Range::Bounded(i, j) => (*i..=*j).collect(),
28
6420
                    })
29
3989
                    .collect(),
30
3989
            ),
31
1
            _ => None,
32
        }
33
3990
    }
34

            
35
    /// Return an unoptimised domain that is the result of applying a binary i32 operation to two domains.
36
    ///
37
    /// The given operator may return None if the operation is not defined for its arguments.
38
    /// Undefined values will not be included in the resulting domain.
39
    ///
40
    /// Returns None if the domains are not valid for i32 operations.
41
1995
    pub fn apply_i32(&self, op: fn(i32, i32) -> Option<i32>, other: &Domain) -> Option<Domain> {
42
1995
        if let (Some(vs1), Some(vs2)) = (self.values_i32(), other.values_i32()) {
43
            // TODO: (flm8) Optimise to use smarter, less brute-force methods
44
1994
            let mut new_ranges = vec![];
45
87537
            for (v1, v2) in itertools::iproduct!(vs1, vs2) {
46
87537
                if let Some(v) = op(v1, v2) {
47
82620
                    new_ranges.push(Range::Single(v))
48
4917
                }
49
            }
50
1994
            return Some(Domain::IntDomain(new_ranges));
51
1
        }
52
1
        None
53
1995
    }
54
}
55

            
56
#[cfg(test)]
57
mod tests {
58
    use super::*;
59

            
60
    #[test]
61
1
    fn test_negative_product() {
62
1
        let d1 = Domain::IntDomain(vec![Range::Bounded(-2, 1)]);
63
1
        let d2 = Domain::IntDomain(vec![Range::Bounded(-2, 1)]);
64
16
        let res = d1.apply_i32(|a, b| Some(a * b), &d2).unwrap();
65
1

            
66
1
        assert!(matches!(res, Domain::IntDomain(_)));
67
1
        if let Domain::IntDomain(ranges) = res {
68
1
            assert!(!ranges.contains(&Range::Single(-4)));
69
1
            assert!(!ranges.contains(&Range::Single(-3)));
70
1
            assert!(ranges.contains(&Range::Single(-2)));
71
1
            assert!(ranges.contains(&Range::Single(-1)));
72
1
            assert!(ranges.contains(&Range::Single(0)));
73
1
            assert!(ranges.contains(&Range::Single(1)));
74
1
            assert!(ranges.contains(&Range::Single(2)));
75
1
            assert!(!ranges.contains(&Range::Single(3)));
76
1
            assert!(ranges.contains(&Range::Single(4)));
77
        }
78
1
    }
79

            
80
    #[test]
81
1
    fn test_negative_div() {
82
1
        let d1 = Domain::IntDomain(vec![Range::Bounded(-2, 1)]);
83
1
        let d2 = Domain::IntDomain(vec![Range::Bounded(-2, 1)]);
84
1
        let res = d1
85
16
            .apply_i32(|a, b| if b != 0 { Some(a / b) } else { None }, &d2)
86
1
            .unwrap();
87
1

            
88
1
        assert!(matches!(res, Domain::IntDomain(_)));
89
1
        if let Domain::IntDomain(ranges) = res {
90
1
            assert!(!ranges.contains(&Range::Single(-4)));
91
1
            assert!(!ranges.contains(&Range::Single(-3)));
92
1
            assert!(ranges.contains(&Range::Single(-2)));
93
1
            assert!(ranges.contains(&Range::Single(-1)));
94
1
            assert!(ranges.contains(&Range::Single(0)));
95
1
            assert!(ranges.contains(&Range::Single(1)));
96
1
            assert!(ranges.contains(&Range::Single(2)));
97
1
            assert!(!ranges.contains(&Range::Single(3)));
98
1
            assert!(!ranges.contains(&Range::Single(4)));
99
        }
100
1
    }
101
}