1
use conjure_core::ast::{Expression, ReturnType};
2
use conjure_core::metadata::Metadata;
3
use conjure_core::rule_engine::{
4
    register_rule, register_rule_set, ApplicationError, ApplicationError::*, ApplicationResult,
5
    Reduction,
6
};
7
use uniplate::Uniplate;
8

            
9
use crate::ast::{Atom, Literal, SymbolTable};
10
use crate::{into_matrix_expr, matrix_expr};
11

            
12
use super::utils::is_all_constant;
13

            
14
register_rule_set!("Bubble", ("Base"));
15

            
16
// Bubble reduction rules
17

            
18
/*
19
    Reduce bubbles with a boolean expression to a conjunction with their condition.
20

            
21
    e.g. (a / b = c) @ (b != 0) => (a / b = c) & (b != 0)
22
*/
23
#[register_rule(("Bubble", 8900))]
24
510120
fn expand_bubble(expr: &Expression, _: &SymbolTable) -> ApplicationResult {
25
1476
    match expr {
26
1476
        Expression::Bubble(_, a, b) if a.return_type() == Some(ReturnType::Bool) => {
27
1476
            Ok(Reduction::pure(Expression::And(
28
1476
                Metadata::new(),
29
1476
                Box::new(matrix_expr![*a.clone(), *b.clone()]),
30
1476
            )))
31
        }
32
508644
        _ => Err(ApplicationError::RuleNotApplicable),
33
    }
34
510120
}
35

            
36
/*
37
    Bring bubbles with a non-boolean expression higher up the tree.
38

            
39
    E.g. ((a / b) @ (b != 0)) = c => (a / b = c) @ (b != 0)
40
*/
41
#[register_rule(("Bubble", 8900))]
42
510120
fn bubble_up(expr: &Expression, _: &SymbolTable) -> ApplicationResult {
43
510120
    let mut sub = expr.children();
44
510120
    let mut bubbled_conditions = vec![];
45
510120
    for e in sub.iter_mut() {
46
508212
        if let Expression::Bubble(_, a, b) = e {
47
3294
            if a.return_type() != Some(ReturnType::Bool) {
48
1818
                bubbled_conditions.push(*b.clone());
49
1818
                *e = *a.clone();
50
1818
            }
51
504918
        }
52
    }
53
510120
    if bubbled_conditions.is_empty() {
54
508302
        return Err(ApplicationError::RuleNotApplicable);
55
1818
    }
56
1818
    Ok(Reduction::pure(Expression::Bubble(
57
1818
        Metadata::new(),
58
1818
        Box::new(expr.with_children(sub)),
59
1818
        Box::new(Expression::And(
60
1818
            Metadata::new(),
61
1818
            Box::new(into_matrix_expr![bubbled_conditions]),
62
1818
        )),
63
1818
    )))
64
510120
}
65

            
66
// Bubble applications
67

            
68
/// Converts an unsafe division to a safe division with a bubble condition.
69
///
70
/// ```text
71
///     a / b => (a / b) @ (b != 0)
72
/// ```
73
///
74
/// Division by zero is undefined and therefore not allowed, so we add a condition to check for it.
75
/// This condition is brought up the tree and expanded into a conjunction with the first
76
/// boolean-type expression it is paired with.
77

            
78
#[register_rule(("Bubble", 6000))]
79
316098
fn div_to_bubble(expr: &Expression, _: &SymbolTable) -> ApplicationResult {
80
316098
    if is_all_constant(expr) {
81
51696
        return Err(RuleNotApplicable);
82
264402
    }
83
264402
    if let Expression::UnsafeDiv(_, a, b) = expr {
84
        // bubble bottom up
85
576
        if !a.is_safe() || !b.is_safe() {
86
72
            return Err(RuleNotApplicable);
87
504
        }
88
504

            
89
504
        return Ok(Reduction::pure(Expression::Bubble(
90
504
            Metadata::new(),
91
504
            Box::new(Expression::SafeDiv(Metadata::new(), a.clone(), b.clone())),
92
504
            Box::new(Expression::Neq(
93
504
                Metadata::new(),
94
504
                b.clone(),
95
504
                Box::new(Expression::from(0)),
96
504
            )),
97
504
        )));
98
263826
    }
99
263826
    Err(ApplicationError::RuleNotApplicable)
100
316098
}
101

            
102
/// Converts an unsafe mod to a safe mod with a bubble condition.
103
///
104
/// ```text
105
/// a % b => (a % b) @ (b != 0)
106
/// ```
107
///
108
/// Mod zero is undefined and therefore not allowed, so we add a condition to check for it.
109
/// This condition is brought up the tree and expanded into a conjunction with the first
110
/// boolean-type expression it is paired with.
111
///
112
#[register_rule(("Bubble", 6000))]
113
316098
fn mod_to_bubble(expr: &Expression, _: &SymbolTable) -> ApplicationResult {
114
316098
    if is_all_constant(expr) {
115
51696
        return Err(RuleNotApplicable);
116
264402
    }
117
264402
    if let Expression::UnsafeMod(_, a, b) = expr {
118
        // bubble bottom up
119
414
        if !a.is_safe() || !b.is_safe() {
120
54
            return Err(RuleNotApplicable);
121
360
        }
122
360

            
123
360
        return Ok(Reduction::pure(Expression::Bubble(
124
360
            Metadata::new(),
125
360
            Box::new(Expression::SafeMod(Metadata::new(), a.clone(), b.clone())),
126
360
            Box::new(Expression::Neq(
127
360
                Metadata::new(),
128
360
                b.clone(),
129
360
                Box::new(Expression::from(0)),
130
360
            )),
131
360
        )));
132
263988
    }
133
263988
    Err(ApplicationError::RuleNotApplicable)
134
316098
}
135

            
136
/// Converts an unsafe pow to a safe pow with a bubble condition.
137
///
138
/// ```text
139
/// a**b => (a ** b) @ ((a!=0 \/ b!=0) /\ b>=0
140
/// ```
141
///
142
/// Pow is only defined when `(a!=0 \/ b!=0) /\ b>=0`, so we add a condition to check for it.
143
/// This condition is brought up the tree and expanded into a conjunction with the first
144
/// boolean-type expression it is paired with.
145
///
146
#[register_rule(("Bubble", 6000))]
147
316098
fn pow_to_bubble(expr: &Expression, _: &SymbolTable) -> ApplicationResult {
148
316098
    if is_all_constant(expr) {
149
51696
        return Err(RuleNotApplicable);
150
264402
    }
151
264402
    if let Expression::UnsafePow(_, a, b) = expr.clone() {
152
        // bubble bottom up
153
108
        if !a.is_safe() || !b.is_safe() {
154
18
            return Err(RuleNotApplicable);
155
90
        }
156
90

            
157
90
        return Ok(Reduction::pure(Expression::Bubble(
158
90
            Metadata::new(),
159
90
            Box::new(Expression::SafePow(Metadata::new(), a.clone(), b.clone())),
160
90
            Box::new(Expression::And(
161
90
                Metadata::new(),
162
90
                Box::new(matrix_expr![
163
90
                    Expression::Or(
164
90
                        Metadata::new(),
165
90
                        Box::new(matrix_expr![
166
90
                            Expression::Neq(
167
90
                                Metadata::new(),
168
90
                                a,
169
90
                                Box::new(Atom::Literal(Literal::Int(0)).into()),
170
90
                            ),
171
90
                            Expression::Neq(
172
90
                                Metadata::new(),
173
90
                                b.clone(),
174
90
                                Box::new(Atom::Literal(Literal::Int(0)).into()),
175
90
                            ),
176
90
                        ]),
177
90
                    ),
178
90
                    Expression::Geq(
179
90
                        Metadata::new(),
180
90
                        b,
181
90
                        Box::new(Atom::Literal(Literal::Int(0)).into()),
182
90
                    ),
183
90
                ]),
184
90
            )),
185
90
        )));
186
264294
    }
187
264294
    Err(ApplicationError::RuleNotApplicable)
188
316098
}