conjure_core/ast/
domains.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
use serde::{Deserialize, Serialize};
// use std::iter::Ste

#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
pub enum Range<A>
where
    A: Ord,
{
    Single(A),
    Bounded(A, A),
}

#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
pub enum Domain {
    BoolDomain,
    IntDomain(Vec<Range<i32>>),
}

impl Domain {
    /// Return a list of all possible i32 values in the domain if it is an IntDomain.
    pub fn values_i32(&self) -> Option<Vec<i32>> {
        match self {
            Domain::IntDomain(ranges) => Some(
                ranges
                    .iter()
                    .flat_map(|r| match r {
                        Range::Single(i) => vec![*i],
                        Range::Bounded(i, j) => (*i..=*j).collect(),
                    })
                    .collect(),
            ),
            _ => None,
        }
    }

    /// Return an unoptimised domain that is the result of applying a binary i32 operation to two domains.
    ///
    /// The given operator may return None if the operation is not defined for its arguments.
    /// Undefined values will not be included in the resulting domain.
    ///
    /// Returns None if the domains are not valid for i32 operations.
    pub fn apply_i32(&self, op: fn(i32, i32) -> Option<i32>, other: &Domain) -> Option<Domain> {
        if let (Some(vs1), Some(vs2)) = (self.values_i32(), other.values_i32()) {
            // TODO: (flm8) Optimise to use smarter, less brute-force methods
            let mut new_ranges = vec![];
            for (v1, v2) in itertools::iproduct!(vs1, vs2) {
                if let Some(v) = op(v1, v2) {
                    new_ranges.push(Range::Single(v))
                }
            }
            return Some(Domain::IntDomain(new_ranges));
        }
        None
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_negative_product() {
        let d1 = Domain::IntDomain(vec![Range::Bounded(-2, 1)]);
        let d2 = Domain::IntDomain(vec![Range::Bounded(-2, 1)]);
        let res = d1.apply_i32(|a, b| Some(a * b), &d2).unwrap();

        assert!(matches!(res, Domain::IntDomain(_)));
        if let Domain::IntDomain(ranges) = res {
            assert!(!ranges.contains(&Range::Single(-4)));
            assert!(!ranges.contains(&Range::Single(-3)));
            assert!(ranges.contains(&Range::Single(-2)));
            assert!(ranges.contains(&Range::Single(-1)));
            assert!(ranges.contains(&Range::Single(0)));
            assert!(ranges.contains(&Range::Single(1)));
            assert!(ranges.contains(&Range::Single(2)));
            assert!(!ranges.contains(&Range::Single(3)));
            assert!(ranges.contains(&Range::Single(4)));
        }
    }

    #[test]
    fn test_negative_div() {
        let d1 = Domain::IntDomain(vec![Range::Bounded(-2, 1)]);
        let d2 = Domain::IntDomain(vec![Range::Bounded(-2, 1)]);
        let res = d1
            .apply_i32(|a, b| if b != 0 { Some(a / b) } else { None }, &d2)
            .unwrap();

        assert!(matches!(res, Domain::IntDomain(_)));
        if let Domain::IntDomain(ranges) = res {
            assert!(!ranges.contains(&Range::Single(-4)));
            assert!(!ranges.contains(&Range::Single(-3)));
            assert!(ranges.contains(&Range::Single(-2)));
            assert!(ranges.contains(&Range::Single(-1)));
            assert!(ranges.contains(&Range::Single(0)));
            assert!(ranges.contains(&Range::Single(1)));
            assert!(ranges.contains(&Range::Single(2)));
            assert!(!ranges.contains(&Range::Single(3)));
            assert!(!ranges.contains(&Range::Single(4)));
        }
    }
}