1
use super::{RewriteError, RuleSet};
2
use crate::{
3
    ast::Expression as Expr,
4
    bug,
5
    rule_engine::{
6
        get_rule_priorities,
7
        rewriter_common::{log_rule_application, RuleResult},
8
        Rule,
9
    },
10
    Model,
11
};
12
use std::collections::{BTreeMap, HashSet};
13
use std::sync::Arc;
14
use uniplate::Biplate;
15

            
16
/// A naive, exhaustive rewriter for development purposes. Applies rules in priority order,
17
/// favouring expressions found earlier during preorder traversal of the tree.
18
1428
pub fn rewrite_naive<'a>(
19
1428
    model: &Model,
20
1428
    rule_sets: &Vec<&'a RuleSet<'a>>,
21
1428
    prop_multiple_equally_applicable: bool,
22
1428
) -> Result<Model, RewriteError> {
23
1428
    let priorities =
24
        get_rule_priorities(rule_sets).unwrap_or_else(|_| bug!("get_rule_priorities() failed!"));
25
1428

            
26
1428
    // Group rules by priority in descending order.
27
1428
    let mut grouped: BTreeMap<u16, HashSet<&'a Rule<'a>>> = BTreeMap::new();
28
75684
    for (rule, priority) in priorities {
29
74256
        grouped.entry(priority).or_default().insert(rule);
30
74256
    }
31
1428
    let rules_by_priority: Vec<(u16, HashSet<&'a Rule<'a>>)> = grouped.into_iter().collect();
32

            
33
    type CtxFn = Arc<dyn Fn(Expr) -> Vec<Expr>>;
34
1428
    let mut model = model.clone();
35

            
36
    loop {
37
13175
        let mut results: Vec<(RuleResult<'_>, u16, Expr, CtxFn)> = vec![];
38

            
39
        // Iterate over rules by priority in descending order.
40
89216
        'top: for (priority, rule_set) in rules_by_priority.iter().rev() {
41
2855745
            for (expr, ctx) in <_ as Biplate<Expr>>::contexts_bi(&model.get_constraints_vec()) {
42
                // Clone expr and ctx so they can be reused
43
2855745
                let expr = expr.clone();
44
2855745
                let ctx = ctx.clone();
45
14006300
                for rule in rule_set {
46
11150555
                    match (rule.application)(&expr, &model) {
47
11747
                        Ok(red) => {
48
11747
                            // Collect applicable rules
49
11747
                            results.push((
50
11747
                                RuleResult {
51
11747
                                    rule,
52
11747
                                    reduction: red,
53
11747
                                },
54
11747
                                *priority,
55
11747
                                expr.clone(),
56
11747
                                ctx.clone(),
57
11747
                            ));
58
11747
                        }
59
                        Err(_) => {
60
                            // when called a lot, this becomes very expensive!
61
                            #[cfg(debug_assertions)]
62
11138808
                            tracing::trace!(
63
17
                                "Rule attempted but not applied: {} (priority {}), to expression: {}",
64
                                rule.name,
65
                                priority,
66
                                expr
67
                            );
68
                        }
69
                    }
70
                }
71
                // This expression has the highest rule priority so far, so this is what we want to
72
                // rewrite.
73
2855745
                if !results.is_empty() {
74
11747
                    break 'top;
75
2843998
                }
76
            }
77
        }
78

            
79
13175
        match results.as_slice() {
80
13175
            [] => break, // Exit if no rules are applicable.
81
11747
            [(result, _priority, expr, ctx), ..] => {
82
11747
                // Extract the single applicable rule and apply it
83
11747

            
84
11747
                log_rule_application(result, expr, &model);
85
11747

            
86
11747
                // Replace expr with new_expression
87
11747
                model.set_constraints(ctx(result.reduction.new_expression.clone()));
88
11747

            
89
11747
                // Apply new symbols and top level
90
11747
                result.reduction.clone().apply(&mut model);
91
11747

            
92
11747
                if results.len() > 1 && prop_multiple_equally_applicable {
93
                    let names: Vec<_> = results
94
                        .iter()
95
                        .map(|(result, _, _, _)| result.rule.name)
96
                        .collect();
97

            
98
                    // Extract the expression from the first result
99
                    let expr = results[0].2.clone();
100

            
101
                    // Construct a single string to display the names of the rules grouped by priority
102
                    let mut rules_by_priority_string = String::new();
103
                    rules_by_priority_string.push_str("Rules grouped by priority:\n");
104
                    for (priority, rule_set) in rules_by_priority.iter().rev() {
105
                        rules_by_priority_string.push_str(&format!("Priority {}:\n", priority));
106
                        for rule in rule_set {
107
                            rules_by_priority_string.push_str(&format!("  - {}\n", rule.name));
108
                        }
109
                    }
110
                    bug!("Multiple equally applicable rules for {expr}: {names:#?}\n\n{rules_by_priority_string}");
111
11747
                }
112
            }
113
        }
114
    }
115

            
116
1428
    Ok(model)
117
1428
}