1
use std::env;
2

            
3
use itertools::Itertools;
4
use uniplate::Uniplate;
5

            
6
use crate::ast::{Expression, ReturnType};
7
use crate::bug;
8
use crate::rule_engine::{get_rules, Reduction, RuleSet};
9
use crate::stats::RewriterStats;
10
use crate::Model;
11

            
12
use super::resolve_rules::RuleData;
13
use super::rewriter_common::{log_rule_application, RewriteError, RuleResult};
14

            
15
/// Checks if the OPTIMIZATIONS environment variable is set to "1".
16
36
fn optimizations_enabled() -> bool {
17
36
    match env::var("OPTIMIZATIONS") {
18
        Ok(val) => val == "1",
19
36
        Err(_) => false, // Assume optimizations are disabled if the environment variable is not set
20
    }
21
36
}
22

            
23
/// Rewrites the given model by applying a set of rules to all its constraints, until no more rules can be applied.
24
///
25
/// Rules are applied in order of priority (from highest to lowest)
26
/// Rules can:
27
/// - Apply transformations to the constraints in the model (or their sub-expressions)
28
/// - Add new constraints to the model
29
/// - Modify the symbol table (e.g. add new variables)
30
///
31
/// # Parameters
32
/// - `model`: A reference to the [`Model`] to be rewritten.
33
/// - `rule_sets`: A vector of references to [`RuleSet`]s to be applied.
34
///
35
/// Each `RuleSet` contains a map of rules to priorities.
36
///
37
/// # Returns
38
/// - `Ok(Model)`: If successful, it returns a modified copy of the [`Model`]
39
/// - `Err(RewriteError)`: If an error occurs during rule application (e.g., invalid rules)
40
///
41
/// # Side-Effects
42
/// - Rules can apply side-effects to the model (e.g. adding new constraints or variables).
43
///   The original model is cloned and a modified copy is returned.
44
/// - Rule engine statistics (e.g. number of rule applications, run time) are collected and stored in the new model's context.
45
///
46
/// # Example
47
/// - Using `rewrite_model` with the constraint `(a + min(x, y)) = b`
48
///
49
///   Original model:
50
///   ```text
51
///   model: {
52
///     constraints: [(a + min(x, y) + 42 - 10) = b],
53
///     symbols: [a, b, x, y]
54
///   }
55
///   rule_sets: [{
56
///       name: "MyRuleSet",
57
///       rules: [
58
///         min_to_var: 10,
59
///         const_eval: 20
60
///       ]
61
///     }]
62
///   ```
63
///
64
///   Rules:
65
///   - `min_to_var`: min([a, b]) ~> c ; c <= a & c <= b & (c = a \/ c = b)
66
///   - `const_eval`: c1 + c2 ~> (c1 + c2) ; c1, c2 are constants
67
///
68
///   Result:
69
///   ```text
70
///   model: {
71
///     constraints: [
72
///       (a + aux + 32) = b,
73
///       aux <= x,
74
///       aux <= y,
75
///       aux = x \/ aux = y
76
///     ],
77
///     symbols: [a, b, x, y, aux]
78
///   }
79
///   ```
80
///
81
///   Process:
82
///   1. We traverse the expression tree until a rule can be applied.
83
///   2. If multiple rules can be applied to the same expression, the higher priority one goes first.
84
///      In this case, `const_eval` is applied before `min_to_var`.
85
///   3. The rule `min_to_var` adds a new variable `aux` and new constraints to the model.
86
///   4. When no more rules can be applied, the resulting model is returned.
87
///
88
///   Details for this process can be found in [`rewrite_iteration`] documentation.
89
///
90
/// # Performance Considerations
91
/// - We recursively traverse the tree multiple times to check if any rules can be applied.
92
/// - Expressions are cloned on each rule application
93
///
94
/// This can be expensive for large models
95
///
96
/// # Panics
97
/// - This function may panic if the model's context is unavailable or if there is an issue with locking the context.
98
///
99
/// # See Also
100
/// - [`get_rules`]: Resolves the rules from the provided rule sets and sorts them by priority.
101
/// - [`rewrite_iteration`]: Executes a single iteration of rewriting the model using the specified rules.
102
18
pub fn rewrite_model<'a>(
103
18
    model: &Model,
104
18
    rule_sets: &Vec<&'a RuleSet<'a>>,
105
18
) -> Result<Model, RewriteError> {
106
18
    let rules = get_rules(rule_sets)?.into_iter().collect();
107
18
    let mut new_model = model.clone();
108
18
    let mut stats = RewriterStats {
109
18
        is_optimization_enabled: Some(optimizations_enabled()),
110
18
        rewriter_run_time: None,
111
18
        rewriter_rule_application_attempts: Some(0),
112
18
        rewriter_rule_applications: Some(0),
113
18
    };
114
18

            
115
18
    // Check if optimizations are enabled
116
18
    let apply_optimizations = optimizations_enabled();
117
18

            
118
18
    let start = std::time::Instant::now();
119
18

            
120
18
    //the while loop is exited when None is returned implying the sub-expression is clean
121
18
    let mut i: usize = 0;
122
36
    while i < new_model.as_submodel().constraints().len() {
123
18
        while let Some(step) = rewrite_iteration(
124
18
            &new_model.as_submodel().constraints()[i],
125
18
            &new_model,
126
18
            &rules,
127
18
            apply_optimizations,
128
18
            &mut stats,
129
18
        ) {
130
            debug_assert!(is_vec_bool(&step.new_top)); // All new_top expressions should be boolean
131
            new_model.as_submodel_mut().constraints_mut()[i] = step.new_expression.clone();
132
            step.apply(new_model.as_submodel_mut()); // Apply side-effects (e.g., symbol table updates)
133
        }
134

            
135
        // If new constraints are added, continue processing them in the next iterations.
136
18
        i += 1;
137
    }
138

            
139
18
    stats.rewriter_run_time = Some(start.elapsed());
140
18
    model.context.write().unwrap().stats.add_rewriter_run(stats);
141
18
    Ok(new_model)
142
18
}
143

            
144
/// Checks if all expressions in `Vec<Expr>` are booleans.
145
/// All top-level constraints in a model should be boolean expressions.
146
fn is_vec_bool(exprs: &[Expression]) -> bool {
147
    exprs
148
        .iter()
149
        .all(|expr| expr.return_type() == Some(ReturnType::Bool))
150
}
151

            
152
/// Attempts to apply a set of rules to the given expression and its sub-expressions in the model.
153
///
154
/// 1. Checks if the expression is "clean" (all possible rules have been applied).
155
/// 2. Tries to apply rules to the top-level expression, in oprder of priority.
156
/// 3. If no rules can be applied to the top-level expression, recurses into its sub-expressions.
157
///
158
/// When a successful rule application is found, immediately returns a `Reduction` and stops.
159
/// The `Reduction` contains the new expression and any side-effects (e.g., new constraints, variables).
160
/// If no rule applications are possible in this expression tree, returns `None`.
161
///
162
/// # Parameters
163
/// - `expression`: The [`Expression`] to be rewritten.
164
/// - `model`: The root [`Model`] for access to the context and symbol table.
165
/// - `rules`: A max-heap of [`RuleData`] containing rules, priorities, and metadata. Ordered by rule priority.
166
/// - `apply_optimizations`: If `true`, skip already "clean" expressions to avoid redundant work.
167
/// - `stats`: A mutable reference to [`RewriterStats`] to collect statistics
168
///
169
/// # Returns
170
/// - `Some(<Reduction>)`: If a rule is successfully applied to the expression or any of its sub-expressions.
171
///                        Contains the new expression and any side-effects to apply to the model.
172
/// - `None`: If no rule is applicable to the expression or any of its sub-expressions.
173
///
174
/// # Example
175
///
176
/// - Rewriting the expression `a + min(x, y)`:
177
///
178
///   Input:
179
///   ```text
180
///   expression: a + min(x, y)
181
///   rules: [min_to_var]
182
///   model: {
183
///     constraints: [(a + min(x, y)) = b],
184
///     symbols: [a, b, x, y]
185
///   }
186
///   apply_optimizations: true
187
///   ```
188
///
189
///   Process:
190
///   1. Initially, the expression is dirty, so we proceed with the rewrite.
191
///   2. No rules can be applied to the top-level expression `a + min(x, y)`.
192
///      Try its children: `a` and `min(x, y)`.
193
///   3. No rules can be applied to `a`. Mark it as clean and return None.
194
///   4. The rule `min_to_var` can be applied to `min(x, y)`. Return the `Reduction`.
195
///      ```text
196
///      Reduction {
197
///        new_expression: aux,
198
///        new_top: [aux <= x, aux <= y, aux = x \/ aux = y],
199
///        symbols: [a, b, x, y, aux]
200
///      }
201
///      ```
202
///   5. Update the parent expression `a + min(x, y)` with the new child `a + aux`.
203
///      Add new constraints and variables to the model.
204
///   6. No more rules can be applied to this expression. Mark it as clean and return a pure `Reduction`.
205
54
fn rewrite_iteration(
206
54
    expression: &Expression,
207
54
    model: &Model,
208
54
    rules: &Vec<RuleData<'_>>,
209
54
    apply_optimizations: bool,
210
54
    stats: &mut RewriterStats,
211
54
) -> Option<Reduction> {
212
54
    if apply_optimizations && expression.is_clean() {
213
        // Skip processing this expression if it's clean
214
        return None;
215
54
    }
216
54

            
217
54
    // Mark the expression as clean - will be marked dirty if any rule is applied
218
54
    let mut expression = expression.clone();
219
54

            
220
54
    let rule_results = apply_all_rules(&expression, model, rules, stats);
221
54
    if let Some(result) = choose_rewrite(&rule_results, &expression) {
222
        // If a rule is applied, mark the expression as dirty
223
        log_rule_application(&result, &expression, model.as_submodel());
224
        return Some(result.reduction);
225
54
    }
226
54

            
227
54
    let mut sub = expression.children();
228
54
    for i in 0..sub.len() {
229
36
        if let Some(red) = rewrite_iteration(&sub[i], model, rules, apply_optimizations, stats) {
230
            sub[i] = red.new_expression;
231
            let res = expression.with_children(sub.clone());
232
            return Some(Reduction::new(res, red.new_top, red.symbols));
233
36
        }
234
    }
235
    // If all children are clean, mark this expression as clean
236
54
    if apply_optimizations {
237
        assert!(expression.children().iter().all(|c| c.is_clean()));
238
        expression.set_clean(true);
239
        return Some(Reduction::pure(expression));
240
54
    }
241
54
    None
242
54
}
243

            
244
/// Tries to apply rules to an expression and returns a list of successful applications.
245
///
246
/// The expression or model is NOT modified directly.
247
/// We create a list of `RuleResult`s containing the reductions and pass it to `choose_rewrite` to select one to apply.
248
///
249
/// # Parameters
250
/// - `expression`: A reference to the [`Expression`] to evaluate.
251
/// - `model`: A reference to the [`Model`] for access to the symbol table and context.
252
/// - `rules`: A vector of references to [`Rule`]s to try.
253
/// - `stats`: A mutable reference to [`RewriterStats`] used to track the number of rule applications and other statistics.
254
///
255
/// # Returns
256
/// - A `Vec<RuleResult>` containing all successful rule applications to the expression.
257
///   Each `RuleResult` contains the rule that was applied and the resulting `Reduction`.
258
///
259
/// # Side-Effects
260
/// - The function updates the provided `stats` with the number of rule application attempts and successful applications.
261
/// - Debug or trace logging may be performed to track which rules were applicable or not for a given expression.
262
///
263
/// # Example
264
/// let applicable_rules = apply_all_rules(&expr, &model, &rules, &mut stats);
265
/// if !applicable_rules.is_empty() {
266
///     for result in applicable_rules {
267
///         println!("Rule applied: {:?}", result.rule_data.rule);
268
///     }
269
/// }
270
///
271
/// ## Note
272
/// - Rules are applied only to the given expression, not its children.
273
///
274
/// # See Also
275
/// - [`choose_rewrite`]: Chooses a single reduction from the rule results provided by `apply_all_rules`.
276
54
fn apply_all_rules<'a>(
277
54
    expression: &Expression,
278
54
    model: &Model,
279
54
    rules: &Vec<RuleData<'a>>,
280
54
    stats: &mut RewriterStats,
281
54
) -> Vec<RuleResult<'a>> {
282
54
    let mut results = Vec::new();
283
2862
    for rule_data in rules {
284
2808
        match rule_data
285
2808
            .rule
286
2808
            .apply(expression, &model.as_submodel().symbols())
287
        {
288
            Ok(red) => {
289
                stats.rewriter_rule_application_attempts =
290
                    Some(stats.rewriter_rule_application_attempts.unwrap() + 1);
291
                stats.rewriter_rule_applications =
292
                    Some(stats.rewriter_rule_applications.unwrap() + 1);
293
                // Assert no clean children
294
                // assert!(!red.new_expression.children().iter().any(|c| c.is_clean()), "Rule that caused assertion to fail: {:?}", rule.name);
295
                // assert!(!red.new_expression.children().iter().any(|c| c.children().iter().any(|c| c.is_clean())));
296
                results.push(RuleResult {
297
                    rule_data: rule_data.clone(),
298
                    reduction: red,
299
                });
300
            }
301
            Err(_) => {
302
2808
                log::trace!(
303
                    "Rule attempted but not applied: {}, to expression: {} ({:?})",
304
                    rule_data.rule,
305
                    expression,
306
                    rule_data
307
                );
308
2808
                stats.rewriter_rule_application_attempts =
309
2808
                    Some(stats.rewriter_rule_application_attempts.unwrap() + 1);
310
2808
                continue;
311
            }
312
        }
313
    }
314
54
    results
315
54
}
316

            
317
/// Chooses the first applicable rule result from a list of rule applications.
318
///
319
/// Currently, applies the rule with the highest priority.
320
/// If multiple rules have the same priority, logs an error message and panics.
321
///
322
/// # Parameters
323
/// - `results`: A slice of [`RuleResult`]s to consider.
324
/// -  `initial_expression`: [`Expression`] before the rule application.
325
///
326
/// # Returns
327
/// - `Some(<Reduction>)`: If there is at least one successful rule application, returns a [`Reduction`] to apply.
328
/// - `None`: If there are no successful rule applications (i.e. `results` is empty).
329
///
330
/// # Example
331
///
332
/// let rule_results = vec![rule1_result, rule2_result];
333
/// if let Some(reduction) = choose_rewrite(&rule_results) {
334
///   // Process the chosen reduction
335
/// }
336
///
337
54
fn choose_rewrite<'a>(
338
54
    results: &[RuleResult<'a>],
339
54
    initial_expression: &Expression,
340
54
) -> Option<RuleResult<'a>> {
341
54
    //in the case where multiple rules are applicable
342
54
    if !results.is_empty() {
343
        let mut rewrite_options: Vec<RuleResult> = Vec::new();
344
        for (priority, group) in &results.iter().chunk_by(|result| result.rule_data.priority) {
345
            let options: Vec<&RuleResult> = group.collect();
346
            if options.len() > 1 {
347
                // Multiple rules with the same priority
348
                let mut message = format!(
349
                    "Found multiple rules of the same priority {} applicable to expression: {}\n",
350
                    priority, initial_expression
351
                );
352
                for option in options {
353
                    message.push_str(&format!(
354
                        "- Rule: {} (from {})\n",
355
                        option.rule_data.rule.name, option.rule_data.rule_set.name
356
                    ));
357
                }
358
                bug!("{}", message);
359
            } else {
360
                // Only one rule with this priority, add it to the list
361
                rewrite_options.push(options[0].clone());
362
            }
363
        }
364

            
365
        if rewrite_options.len() > 1 {
366
            // Keep old behaviour: log a message and apply the highest priority rule
367
            let mut message = format!(
368
                "Found multiple rules of different priorities applicable to expression: {}\n",
369
                initial_expression
370
            );
371
            for option in &rewrite_options {
372
                message.push_str(&format!(
373
                    "- Rule: {} (priority {}, from {})\n",
374
                    option.rule_data.rule.name,
375
                    option.rule_data.priority,
376
                    option.rule_data.rule_set.name
377
                ));
378
            }
379
            log::warn!("{}", message);
380
        }
381

            
382
        return Some(rewrite_options[0].clone());
383
54
    }
384
54

            
385
54
    None
386
54
}