1
use std::collections::HashMap;
2
use std::env;
3
use std::fmt::Display;
4

            
5
use thiserror::Error;
6

            
7
use crate::bug;
8
use crate::stats::RewriterStats;
9
use tracing::trace;
10
use uniplate::Uniplate;
11

            
12
use crate::rule_engine::{Reduction, Rule, RuleSet};
13
use crate::{
14
    ast::Expression,
15
    rule_engine::resolve_rules::{
16
        get_rule_priorities, get_rules_vec, ResolveRulesError as ResolveError,
17
    },
18
    Model,
19
};
20

            
21
#[derive(Debug)]
22
struct RuleResult<'a> {
23
    rule: &'a Rule<'a>,
24
    reduction: Reduction,
25
}
26

            
27
/// Represents errors that can occur during the model rewriting process.
28
///
29
/// This enum captures errors that occur when trying to resolve or apply rules in the model.
30
#[derive(Debug, Error)]
31
pub enum RewriteError {
32
    ResolveRulesError(ResolveError),
33
}
34

            
35
impl Display for RewriteError {
36
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
37
        match self {
38
            RewriteError::ResolveRulesError(e) => write!(f, "Error resolving rules: {}", e),
39
        }
40
    }
41
}
42

            
43
impl From<ResolveError> for RewriteError {
44
    fn from(error: ResolveError) -> Self {
45
        RewriteError::ResolveRulesError(error)
46
    }
47
}
48

            
49
/// Checks if the OPTIMIZATIONS environment variable is set to "1".
50
///
51
/// # Returns
52
/// - true if the environment variable is set to "1".
53
/// - false if the environment variable is not set or set to any other value.
54
2040
fn optimizations_enabled() -> bool {
55
2040
    match env::var("OPTIMIZATIONS") {
56
136
        Ok(val) => val == "1",
57
1904
        Err(_) => false, // Assume optimizations are disabled if the environment variable is not set
58
    }
59
2040
}
60

            
61
/// Rewrites the given model by applying a set of rules to all its constraints.
62
///
63
/// This function iteratively applies transformations to the model's constraints using the specified rule sets.
64
/// It returns a modified version of the model with all applicable rules applied, ensuring that any side-effects
65
/// such as updates to the symbol table and top-level constraints are properly reflected in the returned model.
66
///
67
/// # Parameters
68
/// - `model`: A reference to the [`Model`] to be rewritten. The function will clone this model to produce a modified version.
69
/// - `rule_sets`: A vector of references to [`RuleSet`]s that define the rules to be applied to the model's constraints.
70
///   Each `RuleSet` is expected to contain a collection of rules that can transform one or more constraints
71
///   within the model. The lifetime parameter `'a` ensures that the rules' references are valid for the
72
///   duration of the function execution.
73
///
74
/// # Returns
75
/// - `Ok(Model)`: If successful, it returns a modified copy of the [`Model`] after all applicable rules have been
76
///   applied. This new model includes any side-effects such as updates to the symbol table or modifications
77
///   to the constraints.
78
/// - `Err(RewriteError)`: If an error occurs during rule application (e.g., invalid rules or failed constraints),
79
///   it returns a [`RewriteError`] with details about the failure.
80
///
81
/// # Side-Effects
82
/// - When the model is rewritten, related data structures such as the symbol table (which tracks variable names and types)
83
///   or other top-level constraints may also be updated to reflect these changes. These updates are applied to the returned model,
84
///   ensuring that all related components stay consistent and aligned with the changes made during the rewrite.
85
/// - The function collects statistics about the rewriting process, including the number of rule applications
86
///   and the total runtime of the rewriter. These statistics are then stored in the model's context for
87
///   performance monitoring and analysis.
88
///
89
/// # Example
90
/// - Using `rewrite_model` with the Expression `a + min(x, y)`
91
///
92
///   Initial expression: a + min(x, y)
93
///   A model containing the expression is created. The variables of the model are represented by a SymbolTable and contain a,x,y.
94
///   The contraints of the initail model is the expression itself.
95
///
96
///   After getting the rules by their priorities and getting additional statistics the while loop of single interations is executed.
97
///   Details for this process can be found in [`rewrite_iteration`] documentation.
98
///
99
///   The loop is exited only when no more rules can be applied, when rewrite_iteration returns None and [`while let Some(step) = None`] occurs
100
///
101
///
102
///   Will result in side effects ((d<=x ^ d<=y) being the [`new_top`] and the model will now be a conjuction of that and (a+d)
103
///   Rewritten expression: ((a + d) ^ (d<=x ^ d<=y))
104
///
105
/// # Performance Considerations
106
/// - The function checks if optimizations are enabled before applying rules, which may affect the performance
107
///   of the rewriting process.
108
/// - Depending on the size of the model and the number of rules, the rewriting process might take a significant
109
///   amount of time. Use the statistics collected (`rewriter_run_time` and `rewriter_rule_application_attempts`)
110
///   to monitor and optimize performance.
111
///
112
/// # Panics
113
/// - This function may panic if the model's context is unavailable or if there is an issue with locking the context.
114
///
115
/// # See Also
116
/// - [`get_rule_priorities`]: Retrieves the priorities for the given rules.
117
/// - [`rewrite_iteration`]: Executes a single iteration of rewriting the model using the specified rules.
118
1020
pub fn rewrite_model<'a>(
119
1020
    model: &Model,
120
1020
    rule_sets: &Vec<&'a RuleSet<'a>>,
121
1020
) -> Result<Model, RewriteError> {
122
1020
    let rule_priorities = get_rule_priorities(rule_sets)?;
123
1020
    let rules = get_rules_vec(&rule_priorities);
124
1020
    let mut new_model = model.clone();
125
1020
    let mut stats = RewriterStats {
126
1020
        is_optimization_enabled: Some(optimizations_enabled()),
127
1020
        rewriter_run_time: None,
128
1020
        rewriter_rule_application_attempts: Some(0),
129
1020
        rewriter_rule_applications: Some(0),
130
1020
    };
131
1020

            
132
1020
    // Check if optimizations are enabled
133
1020
    let apply_optimizations = optimizations_enabled();
134
1020

            
135
1020
    let start = std::time::Instant::now();
136

            
137
    //the while loop is exited when None is returned implying the sub-expression is clean
138
23290
    while let Some(step) = rewrite_iteration(
139
23290
        &new_model.constraints,
140
23290
        &new_model,
141
23290
        &rules,
142
23290
        apply_optimizations,
143
23290
        &mut stats,
144
23290
    ) {
145
22270
        step.apply(&mut new_model); // Apply side-effects (e.g. symbol table updates)
146
22270
    }
147
1020
    stats.rewriter_run_time = Some(start.elapsed());
148
1020
    model.context.write().unwrap().stats.add_rewriter_run(stats);
149
1020
    Ok(new_model)
150
1020
}
151

            
152
/// Attempts to apply a set of rules to the given expression and its sub-expressions in the model.
153
///
154
/// This function recursively traverses the provided expression, applying any applicable rules from the given set.
155
/// If a rule is successfully applied to the expression or any of its sub-expressions, it returns a `Reduction`
156
/// containing the new expression, modified top-level constraints, and any changes to symbols. If no rules can be
157
/// applied at any level, it returns `None`.
158
///
159
/// # Parameters
160
/// - `expression`: A reference to the [`Expression`] to be rewritten. This is the main expression that the function
161
///   attempts to modify using the given rules.
162
/// - `model`: A reference to the [`Model`] that provides context and additional constraints for evaluating the rules.
163
/// - `rules`: A vector of references to [`Rule`]s that define the transformations to apply to the expression.
164
/// - `apply_optimizations`: A boolean flag that indicates whether optimization checks should be applied during the rewriting process.
165
///   If `true`, the function skips already "clean" (fully optimized or processed) expressions and marks them accordingly
166
///   to avoid redundant work.
167
/// - `stats`: A mutable reference to [`RewriterStats`] to collect statistics about the rule application process, such as
168
///   the number of rules applied and the time taken for each iteration.
169
///
170
/// # Returns
171
/// - `Some(<Reduction>)`: A [`Reduction`] containing the new expression and any associated modifications if a rule was applied
172
///   to `expr` or one of its sub-expressions.
173
/// - `None`: If no rule is applicable to the expression or any of its sub-expressions.
174
///
175
/// # Side-Effects
176
/// - If `apply_optimizations` is enabled, the function will skip "clean" expressions and mark successfully rewritten
177
///   expressions as "dirty". This is done to avoid unnecessary recomputation of expressions that have already been
178
///   optimized or processed.
179
///
180
/// # Example
181
/// - Recursively applying [`rewrite_iteration`]  to [`a + min(x, y)`]
182
///
183
///   Initially [`if apply_optimizations && expression.is_clean()`] is not true yet since intially our expression is dirty.
184
///
185
///   [`apply_results`] returns a null vector since no rules can be applied at the top level.
186
///   After calling function [`children`] on the expression a vector of sub-expression [`[a, min(x, y)]`] is returned.
187
///
188
///   The function iterates through the vector of the children from the top expression and calls itself.
189
///
190
///   [rewrite_iteration] on on the child [`a`] returns None, but on [`min(x, y)`] returns a [`Reduction`] object [`red`].
191
///   In this case, a rule (min simplification) can apply:
192
///   - d is added to the SymbolTable and the variables field is updated in the model. new_top is the side effects: (d<=x ^ d<=y)
193
///   - [`red = Reduction::new(new_expression = d, new_top, symbols)`];
194
///   - [`sub[1] = red.new_expression`] - Updates the second element in the vector of sub-expressions from [`min(x, y)`] to [`d`]
195
///
196
///   Since a child expression [`min(x, y)`] was rewritten to d, the parent expression [`a + min(x, y)`] is updated with the new child [`a+d`].
197
///   New [`Reduction`] is returned containing the modifications
198
///
199
///   The condition [`Some(step) = Some(new reduction)`] in the while loop in [`rewrite_model`] is met -> side effects are applied.
200
///
201
///   No more rules in our example can apply to the modified model -> mark all the children as clean and return a pure [`Reduction`].
202
///   [`return Some(Reduction::pure(expression))`]
203
///
204
///   On the last execution of rewrite_iteration condition [`apply_optimizations && expression.is_clean()`] is met, [`None`] is returned.
205
///
206
///
207
/// # Notes
208
/// - This function works recursively, meaning it traverses all sub-expressions within the given `expression` to find the
209
///   first rule that can be applied. If a rule is applied, it immediately returns the modified expression and stops
210
///   further traversal for that branch.
211
8157416
fn rewrite_iteration<'a>(
212
8157416
    expression: &'a Expression,
213
8157416
    model: &'a Model,
214
8157416
    rules: &'a Vec<&'a Rule<'a>>,
215
8157416
    apply_optimizations: bool,
216
8157416
    stats: &mut RewriterStats,
217
8157416
) -> Option<Reduction> {
218
8157416
    if apply_optimizations && expression.is_clean() {
219
        // Skip processing this expression if it's clean
220
        return None;
221
8157416
    }
222
8157416

            
223
8157416
    // Mark the expression as clean - will be marked dirty if any rule is applied
224
8157416
    let mut expression = expression.clone();
225
8157416

            
226
8157416
    let rule_results = apply_all_rules(&expression, model, rules, stats);
227
8157416
    trace_rules(&rule_results, expression.clone());
228
8157416
    if let Some(new) = choose_rewrite(&rule_results, &expression) {
229
        // If a rule is applied, mark the expression as dirty
230
22270
        return Some(new);
231
8135146
    }
232
8135146

            
233
8135146
    let mut sub = expression.children();
234
8135146
    for i in 0..sub.len() {
235
8134126
        if let Some(red) = rewrite_iteration(&sub[i], model, rules, apply_optimizations, stats) {
236
26435
            sub[i] = red.new_expression;
237
26435
            let res = expression.with_children(sub.clone());
238
26435
            return Some(Reduction::new(res, red.new_top, red.symbols));
239
8107691
        }
240
    }
241
    // If all children are clean, mark this expression as clean
242
8108711
    if apply_optimizations {
243
        assert!(expression.children().iter().all(|c| c.is_clean()));
244
        expression.set_clean(true);
245
        return Some(Reduction::pure(expression));
246
8108711
    }
247
8108711
    None
248
8157416
}
249

            
250
/// Applies all the given rules to a specific expression within the model.
251
///
252
/// This function iterates through the provided rules and attempts to apply each rule to the given `expression`.
253
/// If a rule is successfully applied, it creates a [`RuleResult`] containing the original rule and the resulting
254
/// [`Reduction`]. The statistics (`stats`) are updated to reflect the number of rule application attempts and successful
255
/// applications.
256
///
257
/// The function does not modify the provided `expression` directly. Instead, it collects all applicable rule results
258
/// into a vector, which can then be used for further processing or selection (e.g., with [`choose_rewrite`]).
259
///
260
/// # Parameters
261
/// - `expression`: A reference to the [`Expression`] that will be evaluated against the given rules. This is the main
262
///   target for rule transformations and is expected to remain unchanged during the function execution.
263
/// - `model`: A reference to the [`Model`] that provides context for rule evaluation, such as constraints and symbols.
264
///   Rules may depend on information in the model to determine if they can be applied.
265
/// - `rules`: A vector of references to [`Rule`]s that define the transformations to be applied to the expression.
266
///   Each rule is applied independently, and all applicable rules are collected.
267
/// - `stats`: A mutable reference to [`RewriterStats`] used to track statistics about rule application, such as
268
///   the number of attempts and successful applications.
269
///
270
/// # Returns
271
/// - A `Vec<RuleResult>` containing all rule applications that were successful. Each element in the vector represents
272
///   a rule that was applied to the given `expression` along with the resulting transformation.
273
/// - An empty vector if no rules were applicable to the expression.
274
///
275
/// # Side-Effects
276
/// - The function updates the provided `stats` with the number of rule application attempts and successful applications.
277
/// - Debug or trace logging may be performed to track which rules were applicable or not for a given expression.
278
///
279
/// # Example
280
///
281
/// let applicable_rules = apply_all_rules(&expr, &model, &rules, &mut stats);
282
/// if !applicable_rules.is_empty() {
283
///     for result in applicable_rules {
284
///         println!("Rule applied: {:?}", result.rule);
285
///     }
286
/// }
287
///
288
///
289
/// # Notes
290
/// - This function does not modify the input `expression` or `model` directly. The returned `RuleResult` vector
291
///   provides information about successful transformations, allowing the caller to decide how to process them.
292
/// - The function performs independent rule applications. If rules have dependencies or should be applied in a
293
///   specific order, consider handling that logic outside of this function.
294
///
295
/// # See Also
296
/// - [`choose_rewrite`]: Chooses a single reduction from the rule results provided by `apply_all_rules`.
297
8157416
fn apply_all_rules<'a>(
298
8157416
    expression: &'a Expression,
299
8157416
    model: &'a Model,
300
8157416
    rules: &'a Vec<&'a Rule<'a>>,
301
8157416
    stats: &mut RewriterStats,
302
8157416
) -> Vec<RuleResult<'a>> {
303
8157416
    let mut results = Vec::new();
304
301936796
    for rule in rules {
305
293779380
        match rule.apply(expression, model) {
306
22848
            Ok(red) => {
307
22848
                stats.rewriter_rule_application_attempts =
308
22848
                    Some(stats.rewriter_rule_application_attempts.unwrap() + 1);
309
22848
                stats.rewriter_rule_applications =
310
22848
                    Some(stats.rewriter_rule_applications.unwrap() + 1);
311
22848
                // Assert no clean children
312
22848
                // assert!(!red.new_expression.children().iter().any(|c| c.is_clean()), "Rule that caused assertion to fail: {:?}", rule.name);
313
22848
                // assert!(!red.new_expression.children().iter().any(|c| c.children().iter().any(|c| c.is_clean())));
314
22848
                results.push(RuleResult {
315
22848
                    rule,
316
22848
                    reduction: red,
317
22848
                });
318
22848
            }
319
            Err(_) => {
320
293756532
                log::trace!(
321
                    "Rule attempted but not applied: {} ({:?}), to expression: {}",
322
                    rule.name,
323
                    rule.rule_sets,
324
                    expression
325
                );
326
293756532
                stats.rewriter_rule_application_attempts =
327
293756532
                    Some(stats.rewriter_rule_application_attempts.unwrap() + 1);
328
293756532
                continue;
329
            }
330
        }
331
    }
332
8157416
    results
333
8157416
}
334

            
335
/// Chooses the first applicable rule result from a list of rule applications.
336
///
337
/// This function selects a reduction from the provided `RuleResult` list, prioritizing the first rule
338
/// that successfully transforms the expression. This strategy can be modified in the future to incorporate
339
/// more complex selection criteria, such as prioritizing rules based on cost, complexity, or other heuristic metrics.
340
///
341
/// The function also checks the priorities of all the applicable rules and detects if there are multiple rules of the same proirity
342
///
343
/// # Parameters
344
/// - `results`: A slice of [`RuleResult`] containing potential rule applications to be considered. Each element
345
///   represents a rule that was successfully applied to the expression, along with the resulting transformation.
346
/// -  `initial_expression`: [`Expression`] before the rule tranformation.
347
///
348
/// # Returns
349
/// - `Some(<Reduction>)`: Returns a [`Reduction`] representing the first rule's application if there is at least one
350
///   rule that produced a successful transformation.
351
/// - `None`: If no rule applications are available in the `results` slice (i.e., it is empty), it returns `None`.
352
///
353
/// # Example
354
///
355
/// let rule_results = vec![rule1_result, rule2_result];
356
/// if let Some(reduction) = choose_rewrite(&rule_results) {
357
/// Process the chosen reduction
358
/// }
359
///
360
8157416
fn choose_rewrite(results: &[RuleResult], initial_expression: &Expression) -> Option<Reduction> {
361
8157416
    //in the case where multiple rules are applicable
362
8157416
    if results.len() > 1 {
363
493
        let expr = results[0].reduction.new_expression.clone();
364
1071
        let rules: Vec<_> = results.iter().map(|result| &result.rule).collect();
365
493

            
366
493
        check_priority(rules.clone(), initial_expression, &expr);
367
8156923
    }
368

            
369
8157416
    if results.is_empty() {
370
8135146
        return None;
371
22270
    }
372
22270
    let red = results[0].reduction.clone();
373
22270
    let rule = results[0].rule;
374
22270
    tracing::info!(
375
        new_top=%red.new_top,
376
        "Rule applicable: {} ({:?}), to expression: {}, resulting in: {}",
377
        rule.name,
378
        rule.rule_sets,
379
        initial_expression,
380
        red.new_expression
381
    );
382
    // Return the first result for now
383
22270
    Some(red)
384
8157416
}
385

            
386
/// Function filters all the applicable rules based on their priority.
387
/// In the case where there are multiple rules of the same prioriy, a bug! is thrown listing all those duplicates.
388
/// Otherwise, if there are multiple rules applicable but they all have different priorities, a warning message is dispalyed.
389
///
390
/// # Parameters
391
/// - `rules`: a vector of [`Rule`] containing all the applicable rules and their metadata for a specific expression.
392
/// - `initial_expression`: [`Expression`] before rule the tranformation.
393
/// - `new_expr`: [`Expression`] after the rule transformation.
394
///
395
493
fn check_priority<'a>(
396
493
    rules: Vec<&&Rule<'_>>,
397
493
    initial_expr: &'a Expression,
398
493
    new_expr: &'a Expression,
399
493
) {
400
493
    //getting the rule sets from the applicable rules
401
1071
    let rule_sets: Vec<_> = rules.iter().map(|rule| &rule.rule_sets).collect();
402
493

            
403
493
    //a map with keys being rule priorities and their values neing all the rules of that priority found in the rule_sets
404
493
    let mut rules_by_priorities: HashMap<u16, Vec<&str>> = HashMap::new();
405

            
406
    //iterates over each rule_set and groups by the rule priority
407
1564
    for rule_set in &rule_sets {
408
1071
        if let Some((name, priority)) = rule_set.first() {
409
1071
            rules_by_priorities
410
1071
                .entry(*priority)
411
1071
                .or_default()
412
1071
                .push(*name);
413
1071
        }
414
    }
415

            
416
    //filters the map, retaining only entries where there is more than 1 rule of the same priority
417
493
    let duplicate_rules: HashMap<u16, Vec<&str>> = rules_by_priorities
418
493
        .into_iter()
419
1071
        .filter(|(_, group)| group.len() > 1)
420
493
        .collect();
421
493

            
422
493
    if !duplicate_rules.is_empty() {
423
        //accumulates all duplicates into a formatted message
424
        let mut message = format!("Found multiple rules of the same priority applicable to to expression: {:?} \n resulting in expression: {:?}", initial_expr, new_expr);
425
        for (priority, rules) in &duplicate_rules {
426
            message.push_str(&format!("Priority {:?} \n Rules: {:?}", priority, rules));
427
        }
428
        bug!("{}", message);
429

            
430
    //no duplicate rules of the same priorities were found in the set of applicable rules
431
    } else {
432
493
        log::warn!("Multiple rules of different priorities are applicable to expression {:?} \n resulting in expression: {:?}
433
        \n Rules{:?}", initial_expr, new_expr, rules)
434
    }
435
493
}
436

            
437
8157416
fn trace_rules(results: &[RuleResult], expression: Expression) {
438
8157416
    if !results.is_empty() {
439
22270
        let rule = results[0].rule;
440
22270
        let new_expression = results[0].reduction.new_expression.clone();
441
22270

            
442
22270
        trace!(
443
            target: "rule_engine",
444
5304
            "Rule applicable: {} ({:?}), to expression: {}, resulting in: {}",
445
            rule.name,
446
            rule.rule_sets,
447
            expression,
448
            new_expression,
449
        );
450
8135146
    }
451
8157416
}