1
use std::fmt::Display;
2

            
3
use serde::{Deserialize, Serialize};
4

            
5
use super::{types::Typeable, Name, ReturnType};
6

            
7
#[derive(Clone, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
8
pub enum Range<A>
9
where
10
    A: Ord,
11
{
12
    Single(A),
13
    Bounded(A, A),
14
}
15

            
16
#[derive(Clone, Debug, Eq, PartialEq, Hash, Serialize, Deserialize)]
17
pub enum Domain {
18
    BoolDomain,
19
    IntDomain(Vec<Range<i32>>),
20
    DomainReference(Name),
21
}
22

            
23
impl Domain {
24
    /// Return a list of all possible i32 values in the domain if it is an IntDomain.
25
3990
    pub fn values_i32(&self) -> Option<Vec<i32>> {
26
3990
        match self {
27
3989
            Domain::IntDomain(ranges) => Some(
28
3989
                ranges
29
3989
                    .iter()
30
6420
                    .flat_map(|r| match r {
31
2980
                        Range::Single(i) => vec![*i],
32
3440
                        Range::Bounded(i, j) => (*i..=*j).collect(),
33
6420
                    })
34
3989
                    .collect(),
35
3989
            ),
36
1
            _ => None,
37
        }
38
3990
    }
39

            
40
    /// Return an unoptimised domain that is the result of applying a binary i32 operation to two domains.
41
    ///
42
    /// The given operator may return None if the operation is not defined for its arguments.
43
    /// Undefined values will not be included in the resulting domain.
44
    ///
45
    /// Returns None if the domains are not valid for i32 operations.
46
1995
    pub fn apply_i32(&self, op: fn(i32, i32) -> Option<i32>, other: &Domain) -> Option<Domain> {
47
1995
        if let (Some(vs1), Some(vs2)) = (self.values_i32(), other.values_i32()) {
48
            // TODO: (flm8) Optimise to use smarter, less brute-force methods
49
1994
            let mut new_ranges = vec![];
50
87537
            for (v1, v2) in itertools::iproduct!(vs1, vs2) {
51
87537
                if let Some(v) = op(v1, v2) {
52
82620
                    new_ranges.push(Range::Single(v))
53
4917
                }
54
            }
55
1994
            return Some(Domain::IntDomain(new_ranges));
56
1
        }
57
1
        None
58
1995
    }
59
}
60

            
61
impl Display for Domain {
62
100453
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
63
100453
        match self {
64
            Domain::BoolDomain => {
65
41956
                write!(f, "bool")
66
            }
67
58395
            Domain::IntDomain(vec) => {
68
58395
                let mut domain_ranges: Vec<String> = vec![];
69
116790
                for range in vec {
70
58395
                    domain_ranges.push(match range {
71
391
                        Range::Single(a) => a.to_string(),
72
58004
                        Range::Bounded(a, b) => format!("{}..{}", a, b),
73
                    });
74
                }
75

            
76
58395
                if domain_ranges.is_empty() {
77
                    write!(f, "int")
78
                } else {
79
58395
                    write!(f, "int({})", domain_ranges.join(","))
80
                }
81
            }
82
102
            Domain::DomainReference(name) => write!(f, "{}", name),
83
        }
84
100453
    }
85
}
86

            
87
impl Typeable for Domain {
88
    fn return_type(&self) -> Option<ReturnType> {
89
        todo!()
90
    }
91
}
92
#[cfg(test)]
93
mod tests {
94
    use super::*;
95

            
96
    #[test]
97
1
    fn test_negative_product() {
98
1
        let d1 = Domain::IntDomain(vec![Range::Bounded(-2, 1)]);
99
1
        let d2 = Domain::IntDomain(vec![Range::Bounded(-2, 1)]);
100
16
        let res = d1.apply_i32(|a, b| Some(a * b), &d2).unwrap();
101
1

            
102
1
        assert!(matches!(res, Domain::IntDomain(_)));
103
1
        if let Domain::IntDomain(ranges) = res {
104
1
            assert!(!ranges.contains(&Range::Single(-4)));
105
1
            assert!(!ranges.contains(&Range::Single(-3)));
106
1
            assert!(ranges.contains(&Range::Single(-2)));
107
1
            assert!(ranges.contains(&Range::Single(-1)));
108
1
            assert!(ranges.contains(&Range::Single(0)));
109
1
            assert!(ranges.contains(&Range::Single(1)));
110
1
            assert!(ranges.contains(&Range::Single(2)));
111
1
            assert!(!ranges.contains(&Range::Single(3)));
112
1
            assert!(ranges.contains(&Range::Single(4)));
113
        }
114
1
    }
115

            
116
    #[test]
117
1
    fn test_negative_div() {
118
1
        let d1 = Domain::IntDomain(vec![Range::Bounded(-2, 1)]);
119
1
        let d2 = Domain::IntDomain(vec![Range::Bounded(-2, 1)]);
120
1
        let res = d1
121
16
            .apply_i32(|a, b| if b != 0 { Some(a / b) } else { None }, &d2)
122
1
            .unwrap();
123
1

            
124
1
        assert!(matches!(res, Domain::IntDomain(_)));
125
1
        if let Domain::IntDomain(ranges) = res {
126
1
            assert!(!ranges.contains(&Range::Single(-4)));
127
1
            assert!(!ranges.contains(&Range::Single(-3)));
128
1
            assert!(ranges.contains(&Range::Single(-2)));
129
1
            assert!(ranges.contains(&Range::Single(-1)));
130
1
            assert!(ranges.contains(&Range::Single(0)));
131
1
            assert!(ranges.contains(&Range::Single(1)));
132
1
            assert!(ranges.contains(&Range::Single(2)));
133
1
            assert!(!ranges.contains(&Range::Single(3)));
134
1
            assert!(!ranges.contains(&Range::Single(4)));
135
        }
136
1
    }
137
}