1
use std::env;
2
use std::fmt::Display;
3

            
4
use thiserror::Error;
5

            
6
use crate::stats::RewriterStats;
7
use uniplate::Uniplate;
8

            
9
use crate::rule_engine::{Reduction, Rule, RuleSet};
10
use crate::{
11
    ast::Expression,
12
    rule_engine::resolve_rules::{
13
        get_rule_priorities, get_rules_vec, ResolveRulesError as ResolveError,
14
    },
15
    Model,
16
};
17

            
18
#[derive(Debug)]
19
struct RuleResult<'a> {
20
    rule: &'a Rule<'a>,
21
    reduction: Reduction,
22
}
23

            
24
#[derive(Debug, Error)]
25
pub enum RewriteError {
26
    ResolveRulesError(ResolveError),
27
}
28

            
29
impl Display for RewriteError {
30
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
31
        match self {
32
            RewriteError::ResolveRulesError(e) => write!(f, "Error resolving rules: {}", e),
33
        }
34
    }
35
}
36

            
37
impl From<ResolveError> for RewriteError {
38
    fn from(error: ResolveError) -> Self {
39
        RewriteError::ResolveRulesError(error)
40
    }
41
}
42

            
43
/// Checks if the OPTIMIZATIONS environment variable is set to "1".
44
///
45
/// # Returns
46
/// - true if the environment variable is set to "1".
47
/// - false if the environment variable is not set or set to any other value.
48
fn optimizations_enabled() -> bool {
49
    match env::var("OPTIMIZATIONS") {
50
        Ok(val) => val == "1",
51
        Err(_) => false, // Assume optimizations are disabled if the environment variable is not set
52
    }
53
}
54

            
55
/// Rewrites the model by applying the rules to all constraints.
56
///
57
/// Any side-effects such as symbol table updates and top-level constraints are applied to the returned model.
58
///
59
/// # Returns
60
/// A copy of the model after all, if any, possible rules are applied to its constraints.
61
pub fn rewrite_model<'a>(
62
    model: &Model,
63
    rule_sets: &Vec<&'a RuleSet<'a>>,
64
) -> Result<Model, RewriteError> {
65
    let rule_priorities = get_rule_priorities(rule_sets)?;
66
    let rules = get_rules_vec(&rule_priorities);
67
    let mut new_model = model.clone();
68
    let mut stats = RewriterStats {
69
        is_optimization_enabled: Some(optimizations_enabled()),
70
        rewriter_run_time: None,
71
        rewriter_rule_application_attempts: Some(0),
72
        rewriter_rule_applications: Some(0),
73
    };
74

            
75
    // Check if optimizations are enabled
76
    let apply_optimizations = optimizations_enabled();
77

            
78
    let start = std::time::Instant::now();
79

            
80
    while let Some(step) = rewrite_iteration(
81
        &new_model.constraints,
82
        &new_model,
83
        &rules,
84
        apply_optimizations,
85
        &mut stats,
86
    ) {
87
        step.apply(&mut new_model); // Apply side-effects (e.g. symbol table updates)
88
    }
89
    stats.rewriter_run_time = Some(start.elapsed());
90
    model.context.write().unwrap().stats.add_rewriter_run(stats);
91
    Ok(new_model)
92
}
93

            
94
/// # Returns
95
/// - Some(<new_expression>) after applying the first applicable rule to `expr` or a sub-expression.
96
/// - None if no rule is applicable to the expression or any sub-expression.
97
fn rewrite_iteration<'a>(
98
    expression: &'a Expression,
99
    model: &'a Model,
100
    rules: &'a Vec<&'a Rule<'a>>,
101
    apply_optimizations: bool,
102
    stats: &mut RewriterStats,
103
) -> Option<Reduction> {
104
    if apply_optimizations && expression.is_clean() {
105
        // Skip processing this expression if it's clean
106
        return None;
107
    }
108

            
109
    // Mark the expression as clean - will be marked dirty if any rule is applied
110
    let mut expression = expression.clone();
111

            
112
    let rule_results = apply_all_rules(&expression, model, rules, stats);
113
    if let Some(new) = choose_rewrite(&rule_results) {
114
        // If a rule is applied, mark the expression as dirty
115
        return Some(new);
116
    }
117

            
118
    let mut sub = expression.children();
119
    for i in 0..sub.len() {
120
        if let Some(red) = rewrite_iteration(&sub[i], model, rules, apply_optimizations, stats) {
121
            sub[i] = red.new_expression;
122
            let res = expression.with_children(sub.clone());
123
            return Some(Reduction::new(res, red.new_top, red.symbols));
124
        }
125
    }
126
    // If all children are clean, mark this expression as clean
127
    if apply_optimizations {
128
        assert!(expression.children().iter().all(|c| c.is_clean()));
129
        expression.set_clean(true);
130
        return Some(Reduction::pure(expression));
131
    }
132
    None
133
}
134

            
135
/// # Returns
136
/// - A list of RuleResults after applying all rules to `expression`.
137
/// - An empty list if no rules are applicable.
138
fn apply_all_rules<'a>(
139
    expression: &'a Expression,
140
    model: &'a Model,
141
    rules: &'a Vec<&'a Rule<'a>>,
142
    stats: &mut RewriterStats,
143
) -> Vec<RuleResult<'a>> {
144
    let mut results = Vec::new();
145
    for rule in rules {
146
        match rule.apply(expression, model) {
147
            Ok(red) => {
148
                log::trace!(target: "file", "Rule applicable: {:?}, to Expression: {:?}, resulting in: {:?}", rule, expression, red.new_expression);
149
                stats.rewriter_rule_application_attempts =
150
                    Some(stats.rewriter_rule_application_attempts.unwrap() + 1);
151
                stats.rewriter_rule_applications =
152
                    Some(stats.rewriter_rule_applications.unwrap() + 1);
153
                // Assert no clean children
154
                // assert!(!red.new_expression.children().iter().any(|c| c.is_clean()), "Rule that caused assertion to fail: {:?}", rule.name);
155
                // assert!(!red.new_expression.children().iter().any(|c| c.children().iter().any(|c| c.is_clean())));
156
                results.push(RuleResult {
157
                    rule,
158
                    reduction: red,
159
                });
160
            }
161
            Err(_) => {
162
                log::trace!(target: "file", "Rule attempted but not applied: {:?}, to Expression: {:?}", rule, expression);
163
                stats.rewriter_rule_application_attempts =
164
                    Some(stats.rewriter_rule_application_attempts.unwrap() + 1);
165
                continue;
166
            }
167
        }
168
    }
169
    results
170
}
171

            
172
/// # Returns
173
/// - Some(<reduction>) after applying the first rule in `results`.
174
/// - None if `results` is empty.
175
fn choose_rewrite(results: &[RuleResult]) -> Option<Reduction> {
176
    if results.is_empty() {
177
        return None;
178
    }
179
    // Return the first result for now
180
    Some(results[0].reduction.clone())
181
}