1
use serde::{Deserialize, Serialize};
2
// use std::iter::Ste
3

            
4
650
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
5
pub enum Range<A>
6
where
7
    A: Ord,
8
{
9
    Single(A),
10
    Bounded(A, A),
11
}
12

            
13
820
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
14
pub enum Domain {
15
    BoolDomain,
16
    IntDomain(Vec<Range<i32>>),
17
}
18

            
19
impl Domain {
20
    /// Return a list of all possible i32 values in the domain if it is an IntDomain.
21
522
    pub fn values_i32(&self) -> Option<Vec<i32>> {
22
522
        match self {
23
521
            Domain::IntDomain(ranges) => Some(
24
521
                ranges
25
521
                    .iter()
26
521
                    .flat_map(|r| match r {
27
65
                        Range::Single(i) => vec![*i],
28
456
                        Range::Bounded(i, j) => (*i..=*j).collect(),
29
521
                    })
30
521
                    .collect(),
31
521
            ),
32
1
            _ => None,
33
        }
34
522
    }
35

            
36
    /// Return an unoptimised domain that is the result of applying a binary i32 operation to two domains.
37
    ///
38
    /// The given operator may return None if the operation is not defined for its arguments.
39
    /// Undefined values will not be included in the resulting domain.
40
    ///
41
    /// Returns None if the domains are not valid for i32 operations.
42
261
    pub fn apply_i32(&self, op: fn(i32, i32) -> Option<i32>, other: &Domain) -> Option<Domain> {
43
261
        if let (Some(vs1), Some(vs2)) = (self.values_i32(), other.values_i32()) {
44
            // TODO: (flm8) Optimise to use smarter, less brute-force methods
45
260
            let mut new_ranges = vec![];
46
3863
            for (v1, v2) in itertools::iproduct!(vs1, vs2) {
47
3863
                if let Some(v) = op(v1, v2) {
48
3589
                    new_ranges.push(Range::Single(v))
49
274
                }
50
            }
51
260
            return Some(Domain::IntDomain(new_ranges));
52
1
        }
53
1
        None
54
261
    }
55
}
56

            
57
#[cfg(test)]
58
mod tests {
59
    use super::*;
60

            
61
    #[test]
62
1
    fn test_negative_product() {
63
1
        let d1 = Domain::IntDomain(vec![Range::Bounded(-2, 1)]);
64
1
        let d2 = Domain::IntDomain(vec![Range::Bounded(-2, 1)]);
65
16
        let res = d1.apply_i32(|a, b| Some(a * b), &d2).unwrap();
66
1

            
67
1
        assert!(matches!(res, Domain::IntDomain(_)));
68
1
        if let Domain::IntDomain(ranges) = res {
69
1
            assert!(!ranges.contains(&Range::Single(-4)));
70
1
            assert!(!ranges.contains(&Range::Single(-3)));
71
1
            assert!(ranges.contains(&Range::Single(-2)));
72
1
            assert!(ranges.contains(&Range::Single(-1)));
73
1
            assert!(ranges.contains(&Range::Single(0)));
74
1
            assert!(ranges.contains(&Range::Single(1)));
75
1
            assert!(ranges.contains(&Range::Single(2)));
76
1
            assert!(!ranges.contains(&Range::Single(3)));
77
1
            assert!(ranges.contains(&Range::Single(4)));
78
        }
79
1
    }
80

            
81
    #[test]
82
1
    fn test_negative_div() {
83
1
        let d1 = Domain::IntDomain(vec![Range::Bounded(-2, 1)]);
84
1
        let d2 = Domain::IntDomain(vec![Range::Bounded(-2, 1)]);
85
1
        let res = d1
86
16
            .apply_i32(|a, b| if b != 0 { Some(a / b) } else { None }, &d2)
87
1
            .unwrap();
88
1

            
89
1
        assert!(matches!(res, Domain::IntDomain(_)));
90
1
        if let Domain::IntDomain(ranges) = res {
91
1
            assert!(!ranges.contains(&Range::Single(-4)));
92
1
            assert!(!ranges.contains(&Range::Single(-3)));
93
1
            assert!(ranges.contains(&Range::Single(-2)));
94
1
            assert!(ranges.contains(&Range::Single(-1)));
95
1
            assert!(ranges.contains(&Range::Single(0)));
96
1
            assert!(ranges.contains(&Range::Single(1)));
97
1
            assert!(ranges.contains(&Range::Single(2)));
98
1
            assert!(!ranges.contains(&Range::Single(3)));
99
1
            assert!(!ranges.contains(&Range::Single(4)));
100
        }
101
1
    }
102
}