1
use std::env;
2

            
3
use itertools::Itertools;
4
use uniplate::Uniplate;
5

            
6
use crate::ast::{Expression, ReturnType};
7
use crate::bug;
8
use crate::rule_engine::{get_rules, Reduction, RuleSet};
9
use crate::stats::RewriterStats;
10
use crate::Model;
11

            
12
use super::resolve_rules::RuleData;
13
use super::rewriter_common::{log_rule_application, RewriteError, RuleResult};
14

            
15
/// Checks if the OPTIMIZATIONS environment variable is set to "1".
16
34
fn optimizations_enabled() -> bool {
17
34
    match env::var("OPTIMIZATIONS") {
18
        Ok(val) => val == "1",
19
34
        Err(_) => false, // Assume optimizations are disabled if the environment variable is not set
20
    }
21
34
}
22

            
23
/// Rewrites the given model by applying a set of rules to all its constraints, until no more rules can be applied.
24
///
25
/// Rules are applied in order of priority (from highest to lowest)
26
/// Rules can:
27
/// - Apply transformations to the constraints in the model (or their sub-expressions)
28
/// - Add new constraints to the model
29
/// - Modify the symbol table (e.g. add new variables)
30
///
31
/// # Parameters
32
/// - `model`: A reference to the [`Model`] to be rewritten.
33
/// - `rule_sets`: A vector of references to [`RuleSet`]s to be applied.
34
///
35
/// Each `RuleSet` contains a map of rules to priorities.
36
///
37
/// # Returns
38
/// - `Ok(Model)`: If successful, it returns a modified copy of the [`Model`]
39
/// - `Err(RewriteError)`: If an error occurs during rule application (e.g., invalid rules)
40
///
41
/// # Side-Effects
42
/// - Rules can apply side-effects to the model (e.g. adding new constraints or variables).
43
///   The original model is cloned and a modified copy is returned.
44
/// - Rule engine statistics (e.g. number of rule applications, run time) are collected and stored in the new model's context.
45
///
46
/// # Example
47
/// - Using `rewrite_model` with the constraint `(a + min(x, y)) = b`
48
///
49
///   Original model:
50
///   ```text
51
///   model: {
52
///     constraints: [(a + min(x, y) + 42 - 10) = b],
53
///     symbols: [a, b, x, y]
54
///   }
55
///   rule_sets: [{
56
///       name: "MyRuleSet",
57
///       rules: [
58
///         min_to_var: 10,
59
///         const_eval: 20
60
///       ]
61
///     }]
62
///   ```
63
///
64
///   Rules:
65
///   - `min_to_var`: min([a, b]) ~> c ; c <= a & c <= b & (c = a \/ c = b)
66
///   - `const_eval`: c1 + c2 ~> (c1 + c2) ; c1, c2 are constants
67
///
68
///   Result:
69
///   ```text
70
///   model: {
71
///     constraints: [
72
///       (a + aux + 32) = b,
73
///       aux <= x,
74
///       aux <= y,
75
///       aux = x \/ aux = y
76
///     ],
77
///     symbols: [a, b, x, y, aux]
78
///   }
79
///   ```
80
///
81
///   Process:
82
///   1. We traverse the expression tree until a rule can be applied.
83
///   2. If multiple rules can be applied to the same expression, the higher priority one goes first.
84
///      In this case, `const_eval` is applied before `min_to_var`.
85
///   3. The rule `min_to_var` adds a new variable `aux` and new constraints to the model.
86
///   4. When no more rules can be applied, the resulting model is returned.
87
///
88
///   Details for this process can be found in [`rewrite_iteration`] documentation.
89
///
90
/// # Performance Considerations
91
/// - We recursively traverse the tree multiple times to check if any rules can be applied.
92
/// - Expressions are cloned on each rule application
93
///
94
/// This can be expensive for large models
95
///
96
/// # Panics
97
/// - This function may panic if the model's context is unavailable or if there is an issue with locking the context.
98
///
99
/// # See Also
100
/// - [`get_rules`]: Resolves the rules from the provided rule sets and sorts them by priority.
101
/// - [`rewrite_iteration`]: Executes a single iteration of rewriting the model using the specified rules.
102
17
pub fn rewrite_model<'a>(
103
17
    model: &Model,
104
17
    rule_sets: &Vec<&'a RuleSet<'a>>,
105
17
) -> Result<Model, RewriteError> {
106
17
    let rules = get_rules(rule_sets)?.into_iter().collect();
107
17
    let mut new_model = model.clone();
108
17
    let mut stats = RewriterStats {
109
17
        is_optimization_enabled: Some(optimizations_enabled()),
110
17
        rewriter_run_time: None,
111
17
        rewriter_rule_application_attempts: Some(0),
112
17
        rewriter_rule_applications: Some(0),
113
17
    };
114
17

            
115
17
    // Check if optimizations are enabled
116
17
    let apply_optimizations = optimizations_enabled();
117
17

            
118
17
    let start = std::time::Instant::now();
119
17

            
120
17
    //the while loop is exited when None is returned implying the sub-expression is clean
121
17
    let mut i: usize = 0;
122
34
    while i < new_model.get_constraints_vec().len() {
123
17
        while let Some(step) = rewrite_iteration(
124
17
            &new_model.get_constraints_vec()[i],
125
17
            &new_model,
126
17
            &rules,
127
17
            apply_optimizations,
128
17
            &mut stats,
129
17
        ) {
130
            debug_assert!(is_vec_bool(&step.new_top)); // All new_top expressions should be boolean
131
            new_model.get_constraints_vec()[i] = step.new_expression.clone();
132
            step.apply(&mut new_model); // Apply side-effects (e.g., symbol table updates)
133
        }
134

            
135
        // If new constraints are added, continue processing them in the next iterations.
136
17
        i += 1;
137
    }
138

            
139
17
    stats.rewriter_run_time = Some(start.elapsed());
140
17
    model.context.write().unwrap().stats.add_rewriter_run(stats);
141
17
    Ok(new_model)
142
17
}
143

            
144
/// Checks if all expressions in `Vec<Expr>` are booleans.
145
/// All top-level constraints in a model should be boolean expressions.
146
fn is_vec_bool(exprs: &[Expression]) -> bool {
147
    exprs
148
        .iter()
149
        .all(|expr| expr.return_type() == Some(ReturnType::Bool))
150
}
151

            
152
/// Attempts to apply a set of rules to the given expression and its sub-expressions in the model.
153
///
154
/// 1. Checks if the expression is "clean" (all possible rules have been applied).
155
/// 2. Tries to apply rules to the top-level expression, in oprder of priority.
156
/// 3. If no rules can be applied to the top-level expression, recurses into its sub-expressions.
157
///
158
/// When a successful rule application is found, immediately returns a `Reduction` and stops.
159
/// The `Reduction` contains the new expression and any side-effects (e.g., new constraints, variables).
160
/// If no rule applications are possible in this expression tree, returns `None`.
161
///
162
/// # Parameters
163
/// - `expression`: The [`Expression`] to be rewritten.
164
/// - `model`: The root [`Model`] for access to the context and symbol table.
165
/// - `rules`: A max-heap of [`RuleData`] containing rules, priorities, and metadata. Ordered by rule priority.
166
/// - `apply_optimizations`: If `true`, skip already "clean" expressions to avoid redundant work.
167
/// - `stats`: A mutable reference to [`RewriterStats`] to collect statistics
168
///
169
/// # Returns
170
/// - `Some(<Reduction>)`: If a rule is successfully applied to the expression or any of its sub-expressions.
171
///                        Contains the new expression and any side-effects to apply to the model.
172
/// - `None`: If no rule is applicable to the expression or any of its sub-expressions.
173
///
174
/// # Example
175
///
176
/// - Rewriting the expression `a + min(x, y)`:
177
///
178
///   Input:
179
///   ```text
180
///   expression: a + min(x, y)
181
///   rules: [min_to_var]
182
///   model: {
183
///     constraints: [(a + min(x, y)) = b],
184
///     symbols: [a, b, x, y]
185
///   }
186
///   apply_optimizations: true
187
///   ```
188
///
189
///   Process:
190
///   1. Initially, the expression is dirty, so we proceed with the rewrite.
191
///   2. No rules can be applied to the top-level expression `a + min(x, y)`.
192
///      Try its children: `a` and `min(x, y)`.
193
///   3. No rules can be applied to `a`. Mark it as clean and return None.
194
///   4. The rule `min_to_var` can be applied to `min(x, y)`. Return the `Reduction`.
195
///      ```text
196
///      Reduction {
197
///        new_expression: aux,
198
///        new_top: [aux <= x, aux <= y, aux = x \/ aux = y],
199
///        symbols: [a, b, x, y, aux]
200
///      }
201
///      ```
202
///   5. Update the parent expression `a + min(x, y)` with the new child `a + aux`.
203
///      Add new constraints and variables to the model.
204
///   6. No more rules can be applied to this expression. Mark it as clean and return a pure `Reduction`.
205
51
fn rewrite_iteration(
206
51
    expression: &Expression,
207
51
    model: &Model,
208
51
    rules: &Vec<RuleData<'_>>,
209
51
    apply_optimizations: bool,
210
51
    stats: &mut RewriterStats,
211
51
) -> Option<Reduction> {
212
51
    if apply_optimizations && expression.is_clean() {
213
        // Skip processing this expression if it's clean
214
        return None;
215
51
    }
216
51

            
217
51
    // Mark the expression as clean - will be marked dirty if any rule is applied
218
51
    let mut expression = expression.clone();
219
51

            
220
51
    let rule_results = apply_all_rules(&expression, model, rules, stats);
221
51
    if let Some(result) = choose_rewrite(&rule_results, &expression) {
222
        // If a rule is applied, mark the expression as dirty
223
        log_rule_application(&result, &expression, model);
224
        return Some(result.reduction);
225
51
    }
226
51

            
227
51
    let mut sub = expression.children();
228
51
    for i in 0..sub.len() {
229
34
        if let Some(red) = rewrite_iteration(&sub[i], model, rules, apply_optimizations, stats) {
230
            sub[i] = red.new_expression;
231
            let res = expression.with_children(sub.clone());
232
            return Some(Reduction::new(res, red.new_top, red.symbols));
233
34
        }
234
    }
235
    // If all children are clean, mark this expression as clean
236
51
    if apply_optimizations {
237
        assert!(expression.children().iter().all(|c| c.is_clean()));
238
        expression.set_clean(true);
239
        return Some(Reduction::pure(expression));
240
51
    }
241
51
    None
242
51
}
243

            
244
/// Tries to apply rules to an expression and returns a list of successful applications.
245
///
246
/// The expression or model is NOT modified directly.
247
/// We create a list of `RuleResult`s containing the reductions and pass it to `choose_rewrite` to select one to apply.
248
///
249
/// # Parameters
250
/// - `expression`: A reference to the [`Expression`] to evaluate.
251
/// - `model`: A reference to the [`Model`] for access to the symbol table and context.
252
/// - `rules`: A vector of references to [`Rule`]s to try.
253
/// - `stats`: A mutable reference to [`RewriterStats`] used to track the number of rule applications and other statistics.
254
///
255
/// # Returns
256
/// - A `Vec<RuleResult>` containing all successful rule applications to the expression.
257
///   Each `RuleResult` contains the rule that was applied and the resulting `Reduction`.
258
///
259
/// # Side-Effects
260
/// - The function updates the provided `stats` with the number of rule application attempts and successful applications.
261
/// - Debug or trace logging may be performed to track which rules were applicable or not for a given expression.
262
///
263
/// # Example
264
/// let applicable_rules = apply_all_rules(&expr, &model, &rules, &mut stats);
265
/// if !applicable_rules.is_empty() {
266
///     for result in applicable_rules {
267
///         println!("Rule applied: {:?}", result.rule_data.rule);
268
///     }
269
/// }
270
///
271
/// ## Note
272
/// - Rules are applied only to the given expression, not its children.
273
///
274
/// # See Also
275
/// - [`choose_rewrite`]: Chooses a single reduction from the rule results provided by `apply_all_rules`.
276
51
fn apply_all_rules<'a>(
277
51
    expression: &Expression,
278
51
    model: &Model,
279
51
    rules: &Vec<RuleData<'a>>,
280
51
    stats: &mut RewriterStats,
281
51
) -> Vec<RuleResult<'a>> {
282
51
    let mut results = Vec::new();
283
2601
    for rule_data in rules {
284
2550
        match rule_data.rule.apply(expression, &model.symbols()) {
285
            Ok(red) => {
286
                stats.rewriter_rule_application_attempts =
287
                    Some(stats.rewriter_rule_application_attempts.unwrap() + 1);
288
                stats.rewriter_rule_applications =
289
                    Some(stats.rewriter_rule_applications.unwrap() + 1);
290
                // Assert no clean children
291
                // assert!(!red.new_expression.children().iter().any(|c| c.is_clean()), "Rule that caused assertion to fail: {:?}", rule.name);
292
                // assert!(!red.new_expression.children().iter().any(|c| c.children().iter().any(|c| c.is_clean())));
293
                results.push(RuleResult {
294
                    rule_data: rule_data.clone(),
295
                    reduction: red,
296
                });
297
            }
298
            Err(_) => {
299
2550
                log::trace!(
300
                    "Rule attempted but not applied: {}, to expression: {} ({:?})",
301
                    rule_data.rule,
302
                    expression,
303
                    rule_data
304
                );
305
2550
                stats.rewriter_rule_application_attempts =
306
2550
                    Some(stats.rewriter_rule_application_attempts.unwrap() + 1);
307
2550
                continue;
308
            }
309
        }
310
    }
311
51
    results
312
51
}
313

            
314
/// Chooses the first applicable rule result from a list of rule applications.
315
///
316
/// Currently, applies the rule with the highest priority.
317
/// If multiple rules have the same priority, logs an error message and panics.
318
///
319
/// # Parameters
320
/// - `results`: A slice of [`RuleResult`]s to consider.
321
/// -  `initial_expression`: [`Expression`] before the rule application.
322
///
323
/// # Returns
324
/// - `Some(<Reduction>)`: If there is at least one successful rule application, returns a [`Reduction`] to apply.
325
/// - `None`: If there are no successful rule applications (i.e. `results` is empty).
326
///
327
/// # Example
328
///
329
/// let rule_results = vec![rule1_result, rule2_result];
330
/// if let Some(reduction) = choose_rewrite(&rule_results) {
331
///   // Process the chosen reduction
332
/// }
333
///
334
51
fn choose_rewrite<'a>(
335
51
    results: &[RuleResult<'a>],
336
51
    initial_expression: &Expression,
337
51
) -> Option<RuleResult<'a>> {
338
51
    //in the case where multiple rules are applicable
339
51
    if !results.is_empty() {
340
        let mut rewrite_options: Vec<RuleResult> = Vec::new();
341
        for (priority, group) in &results.iter().chunk_by(|result| result.rule_data.priority) {
342
            let options: Vec<&RuleResult> = group.collect();
343
            if options.len() > 1 {
344
                // Multiple rules with the same priority
345
                let mut message = format!(
346
                    "Found multiple rules of the same priority {} applicable to expression: {}\n",
347
                    priority, initial_expression
348
                );
349
                for option in options {
350
                    message.push_str(&format!(
351
                        "- Rule: {} (from {})\n",
352
                        option.rule_data.rule.name, option.rule_data.rule_set.name
353
                    ));
354
                }
355
                bug!("{}", message);
356
            } else {
357
                // Only one rule with this priority, add it to the list
358
                rewrite_options.push(options[0].clone());
359
            }
360
        }
361

            
362
        if rewrite_options.len() > 1 {
363
            // Keep old behaviour: log a message and apply the highest priority rule
364
            let mut message = format!(
365
                "Found multiple rules of different priorities applicable to expression: {}\n",
366
                initial_expression
367
            );
368
            for option in &rewrite_options {
369
                message.push_str(&format!(
370
                    "- Rule: {} (priority {}, from {})\n",
371
                    option.rule_data.rule.name,
372
                    option.rule_data.priority,
373
                    option.rule_data.rule_set.name
374
                ));
375
            }
376
            log::warn!("{}", message);
377
        }
378

            
379
        return Some(rewrite_options[0].clone());
380
51
    }
381
51

            
382
51
    None
383
51
}