1
use conjure_cp::ast::{Atom, DomainPtr, GroundDomain, Metadata, Range, eval_constant};
2
use conjure_cp::ast::{Expression, Moo, SymbolTable};
3
use conjure_cp::rule_engine::{
4
    ApplicationError, ApplicationError::RuleNotApplicable, ApplicationResult, Reduction,
5
    register_rule,
6
};
7
use conjure_cp::{bug, into_matrix_expr};
8
use itertools::{Itertools as _, izip};
9

            
10
/// Whether an index expression is provably inside or outside a target domain.
11
#[derive(Debug, PartialEq, Eq)]
12
enum MembershipProof {
13
    AlwaysIn,
14
    AlwaysOut,
15
    Unknown,
16
}
17

            
18
/// Tries to prove whether `index` is always in, always out, or unknown for `domain`.
19
2466
fn expression_membership_proof(
20
2466
    domain: &GroundDomain,
21
2466
    index: &Expression,
22
2466
) -> Result<MembershipProof, ApplicationError> {
23
2466
    let Some(index_domain) = index.domain_of().and_then(|domain| domain.resolve()) else {
24
6
        return Ok(MembershipProof::Unknown);
25
    };
26

            
27
2460
    let Ok(intersection) = index_domain.as_ref().intersect(domain) else {
28
26
        return Ok(MembershipProof::Unknown);
29
    };
30

            
31
2434
    if normalise_int_domain(&intersection) == normalise_int_domain(index_domain.as_ref()) {
32
2374
        return Ok(MembershipProof::AlwaysIn);
33
60
    }
34

            
35
60
    if let Ok(values) = intersection.values_i32()
36
60
        && values.is_empty()
37
    {
38
        return Ok(MembershipProof::AlwaysOut);
39
60
    }
40

            
41
60
    Ok(MembershipProof::Unknown)
42
2466
}
43

            
44
/// Normalises integer domains so equivalent range partitions compare equal.
45
4868
fn normalise_int_domain(domain: &GroundDomain) -> GroundDomain {
46
4868
    match domain {
47
4868
        GroundDomain::Int(ranges) => GroundDomain::Int(Range::squeeze(
48
4868
            &ranges
49
4868
                .iter()
50
4868
                .map(|range| Range::new(range.low().copied(), range.high().copied()))
51
4868
                .collect_vec(),
52
        )),
53
        _ => domain.clone(),
54
    }
55
4868
}
56

            
57
/// Builds the bubble condition needed to make an unsafe index operation safe.
58
5010
fn index_bubble_condition(
59
5010
    index_domains: &[Moo<GroundDomain>],
60
5010
    indices: &[Expression],
61
5010
) -> Result<Option<Expression>, ApplicationError> {
62
5010
    let mut bubble_constraints = vec![];
63

            
64
6896
    for (domain, index) in izip!(index_domains, indices) {
65
6896
        match eval_constant(index) {
66
4430
            Some(lit) => match domain
67
4430
                .contains(&lit)
68
4430
                .map_err(|_| ApplicationError::DomainError)?
69
            {
70
4346
                true => {}
71
                false => {
72
84
                    return Ok(Some(Expression::Atomic(Metadata::new(), Atom::from(false))));
73
                }
74
            },
75
2466
            None => match expression_membership_proof(domain.as_ref(), index)? {
76
2374
                MembershipProof::AlwaysIn => {}
77
                MembershipProof::AlwaysOut => {
78
                    return Ok(Some(Expression::Atomic(Metadata::new(), Atom::from(false))));
79
                }
80
92
                MembershipProof::Unknown => bubble_constraints.push(Expression::InDomain(
81
92
                    Metadata::new(),
82
92
                    Moo::new(index.clone()),
83
92
                    DomainPtr::from(domain.clone()),
84
92
                )),
85
            },
86
        }
87
    }
88

            
89
4926
    match bubble_constraints.len() {
90
4842
        0 => Ok(None),
91
76
        1 => Ok(Some(
92
76
            bubble_constraints.pop().expect("length checked above"),
93
76
        )),
94
8
        _ => Ok(Some(Expression::And(
95
8
            Metadata::new(),
96
8
            Moo::new(into_matrix_expr![bubble_constraints]),
97
8
        ))),
98
    }
99
5010
}
100

            
101
/// Converts an unsafe index to a safe index using a bubble expression.
102
#[register_rule("Bubble", 6000, [UnsafeIndex])]
103
1277528
fn index_to_bubble(expr: &Expression, _: &SymbolTable) -> ApplicationResult {
104
1277528
    let Expression::UnsafeIndex(_, subject, indices) = expr else {
105
1273070
        return Err(RuleNotApplicable);
106
    };
107

            
108
4458
    let domain = subject
109
4458
        .domain_of()
110
4458
        .ok_or(ApplicationError::DomainError)?
111
4458
        .resolve()
112
4458
        .ok_or(RuleNotApplicable)?;
113

            
114
    // TODO: tuple, this is a hack right now just to avoid the rule being applied to tuples, but could we safely modify the rule to
115
    // handle tuples as well?
116
4458
    if matches!(domain.as_ref(), GroundDomain::Tuple(_))
117
4458
        || matches!(domain.as_ref(), GroundDomain::Record(_))
118
    {
119
        return Err(RuleNotApplicable);
120
4458
    }
121

            
122
4458
    let GroundDomain::Matrix(_, index_domains) = domain.as_ref() else {
123
        bug!(
124
            "subject of an index expression should have a matrix domain. subject: {:?}, with domain: {:?}",
125
            subject,
126
            domain.as_ref()
127
        );
128
    };
129

            
130
4458
    assert_eq!(
131
4458
        index_domains.len(),
132
4458
        indices.len(),
133
        "in an index expression, there should be the same number of indices as the subject has index domains"
134
    );
135

            
136
4458
    let new_expr = Moo::new(Expression::SafeIndex(
137
4458
        Metadata::new(),
138
4458
        subject.clone(),
139
4458
        indices.clone(),
140
4458
    ));
141

            
142
4458
    match index_bubble_condition(index_domains, indices)? {
143
4302
        None => Ok(Reduction::pure(Moo::unwrap_or_clone(new_expr))),
144
156
        Some(condition) => Ok(Reduction::pure(Expression::Bubble(
145
156
            Metadata::new(),
146
156
            new_expr,
147
156
            Moo::new(condition),
148
156
        ))),
149
    }
150
1277528
}
151

            
152
/// Converts an unsafe slice to a safe slice using a bubble expression.
153
#[register_rule("Bubble", 6000, [UnsafeSlice])]
154
1277528
fn slice_to_bubble(expr: &Expression, _: &SymbolTable) -> ApplicationResult {
155
1277528
    let Expression::UnsafeSlice(_, subject, indices) = expr else {
156
1276976
        return Err(RuleNotApplicable);
157
    };
158

            
159
552
    let domain = subject
160
552
        .domain_of()
161
552
        .ok_or(ApplicationError::DomainError)?
162
552
        .resolve()
163
552
        .ok_or(RuleNotApplicable)?;
164

            
165
552
    let GroundDomain::Matrix(_, index_domains) = domain.as_ref() else {
166
        bug!(
167
            "subject of a slice expression should have a matrix domain. subject: {:?}, with domain: {:?}",
168
            subject,
169
            domain
170
        );
171
    };
172

            
173
552
    assert_eq!(
174
552
        index_domains.len(),
175
552
        indices.len(),
176
        "in a slice expression, there should be the same number of indices as the subject has index domains"
177
    );
178

            
179
552
    let constrained_index_domains = izip!(index_domains, indices)
180
1080
        .filter_map(|(domain, index)| index.clone().map(|index| (domain.clone(), index)))
181
552
        .collect_vec();
182
552
    let (filtered_index_domains, filtered_indices): (Vec<_>, Vec<_>) =
183
552
        constrained_index_domains.into_iter().unzip();
184

            
185
552
    let new_expr = Moo::new(Expression::SafeSlice(
186
552
        Metadata::new(),
187
552
        subject.clone(),
188
552
        indices.clone(),
189
552
    ));
190

            
191
552
    match index_bubble_condition(&filtered_index_domains, &filtered_indices)? {
192
540
        None => Ok(Reduction::pure(Moo::unwrap_or_clone(new_expr))),
193
12
        Some(condition) => Ok(Reduction::pure(Expression::Bubble(
194
12
            Metadata::new(),
195
12
            new_expr,
196
12
            Moo::new(condition),
197
12
        ))),
198
    }
199
1277528
}