1
//! Common utilities and types for rewriters.
2
use super::{
3
    Reduction,
4
    resolve_rules::{ResolveRulesError, RuleData},
5
    submodel_zipper::expression_ctx,
6
};
7
use crate::ast::{
8
    DeclarationPtr, Expression, Model, Name, SymbolTable,
9
    pretty::{pretty_variable_declaration, pretty_vec},
10
};
11
use crate::settings::{
12
    default_rule_trace_enabled, rule_trace_aggregates_enabled, rule_trace_enabled,
13
};
14

            
15
use itertools::Itertools;
16
use serde_json::json;
17
use std::collections::BTreeMap;
18
use std::fmt::Debug;
19
use std::sync::Arc;
20
use thiserror::Error;
21
use tracing::{info, trace};
22

            
23
#[derive(Debug, Clone)]
24
pub struct RuleResult<'a> {
25
    pub rule_data: RuleData<'a>,
26
    pub reduction: Reduction,
27
}
28

            
29
pub type VariableDeclarationSnapshot = BTreeMap<Name, String>;
30

            
31
13447268
pub fn snapshot_variable_declarations(symbols: &SymbolTable) -> VariableDeclarationSnapshot {
32
13447268
    symbols
33
13447268
        .clone()
34
13447268
        .into_iter_local()
35
1124891016
        .filter_map(|(name, _)| {
36
1124891016
            pretty_variable_declaration(symbols, &name).map(|declaration| (name, declaration))
37
1124891016
        })
38
13447268
        .collect()
39
13447268
}
40

            
41
/// Logs, to the main log, and the human readable traces used by the integration tester, that the
42
/// rule has been applied to the expression
43
606126
pub fn log_rule_application(
44
606126
    result: &RuleResult,
45
606126
    initial_expression: &Expression,
46
606126
    initial_symbols: &SymbolTable,
47
606126
    variable_declaration_snapshots: Option<(
48
606126
        &VariableDeclarationSnapshot,
49
606126
        &VariableDeclarationSnapshot,
50
606126
    )>,
51
606126
) {
52
606126
    let red = &result.reduction;
53
606126
    let rule = result.rule_data.rule;
54

            
55
    // A reduction can only modify either constraints or clauses, not both. So the the same
56
    // variable is used to hold changes in both (or empty if neither are changed).
57
606126
    let new_top_string = if !red.new_top.is_empty() {
58
52756
        pretty_vec(&red.new_top)
59
    } else {
60
553370
        pretty_vec(&red.new_clauses)
61
    };
62

            
63
606126
    info!(
64
        %new_top_string,
65
        "Applying rule: {} ({:?}), to expression: {}, resulting in: {}",
66
        rule.name,
67
        rule.rule_sets,
68
        initial_expression,
69
        red.new_expression
70
    );
71

            
72
606126
    if rule_trace_enabled() && default_rule_trace_enabled() {
73
312932
        let new_constraints_str = if !red.new_top.is_empty() {
74
26182
            let mut exprs: Vec<String> = vec![];
75
31724
            for expr in &red.new_top {
76
31724
                exprs.push(format!("  {expr}"));
77
31724
            }
78
26182
            let exprs = exprs.iter().join("\n");
79
26182
            format!("new constraints:\n{exprs}\n")
80
286750
        } else if !red.new_clauses.is_empty() {
81
83580
            let mut exprs: Vec<String> = vec![];
82
4608080
            for clause in &red.new_clauses {
83
4608080
                exprs.push(format!("  {clause}"));
84
4608080
            }
85
83580
            let exprs = exprs.iter().join("\n");
86
83580
            format!("new clauses:\n{exprs}\n")
87
        } else {
88
203170
            String::new()
89
        };
90

            
91
312932
        let (new_variables_str, updated_variables_str) =
92
312932
            if let Some((before, after)) = variable_declaration_snapshots {
93
91598
                let mut new_variables = Vec::new();
94
91598
                let mut updated_variables = Vec::new();
95

            
96
10028188
                for (name, declaration_after) in after {
97
10028188
                    match before.get(name) {
98
32568
                        None => new_variables.push(format!("  {declaration_after}")),
99
9995620
                        Some(declaration_before) if declaration_before != declaration_after => {
100
320
                            updated_variables
101
320
                                .push(format!("  {declaration_before} ~~> {declaration_after}"));
102
320
                        }
103
9995300
                        _ => {}
104
                    }
105
                }
106

            
107
91598
                let new_variables_str = if new_variables.is_empty() {
108
86996
                    String::new()
109
                } else {
110
4602
                    format!("new variables:\n{}\n", new_variables.join("\n"))
111
                };
112

            
113
91598
                let updated_variables_str = if updated_variables.is_empty() {
114
91318
                    String::new()
115
                } else {
116
280
                    format!("\nupdated variables:\n{}\n", updated_variables.join("\n"))
117
                };
118

            
119
91598
                (new_variables_str, updated_variables_str)
120
            } else {
121
                // empty if no new variables
122
221334
                let mut vars: Vec<String> = vec![];
123
1863594
                for var_name in red.added_symbols(initial_symbols) {
124
1863404
                    #[allow(clippy::unwrap_used)]
125
1863404
                    vars.push(format!(
126
1863404
                        "  {}",
127
1863404
                        pretty_variable_declaration(&red.symbols, &var_name).unwrap()
128
1863404
                    ));
129
1863404
                }
130
221334
                let new_variables_str = if vars.is_empty() {
131
148532
                    String::new()
132
                } else {
133
72802
                    format!("new variables:\n{}\n", vars.join("\n"))
134
                };
135
221334
                (new_variables_str, String::new())
136
            };
137

            
138
312932
        trace!(
139
            target: "rule_engine_rule_trace",
140
            "{}, \n   ~~> {} ({:?})\n{}\n{}{}{}\n--\n",
141
            initial_expression,
142
            rule.name,
143
            rule.rule_sets,
144
            red.new_expression,
145
            new_variables_str,
146
            updated_variables_str,
147
            new_constraints_str
148
        );
149
293194
    }
150

            
151
606126
    if rule_trace_enabled() && rule_trace_aggregates_enabled() {
152
1622
        trace!(
153
            target: "rule_engine_rule_trace_aggregates",
154
            rule_name = rule.name,
155
            "Applied rule"
156
        );
157
604504
    }
158

            
159
606126
    trace!(
160
        target: "rule_engine",
161
        "{}",
162
2234
    json!({
163
2234
        "rule_name": result.rule_data.rule.name,
164
2234
        "rule_priority": result.rule_data.priority,
165
2234
        "rule_set": {
166
2234
            "name": result.rule_data.rule_set.name,
167
        },
168
2234
        "initial_expression": serde_json::to_value(initial_expression).unwrap(),
169
2234
        "transformed_expression": serde_json::to_value(&red.new_expression).unwrap()
170
    })
171

            
172
    )
173
606126
}
174

            
175
type LettingCtxFn = Arc<dyn Fn(Expression) -> Expression>;
176
type ApplicableLettingRule<'a> = (
177
    RuleResult<'a>,
178
    u16,
179
    Expression,
180
    DeclarationPtr,
181
    LettingCtxFn,
182
);
183

            
184
434749
pub(crate) fn try_rewrite_value_letting_once(
185
434749
    model: &mut Model,
186
434749
    rules_grouped: &Vec<(u16, Vec<RuleData<'_>>)>,
187
434749
    prop_multiple_equally_applicable: bool,
188
434749
) -> Option<()> {
189
434749
    let symbols = model.symbols().clone();
190
434749
    let mut results: Vec<ApplicableLettingRule<'_>> = vec![];
191

            
192
7947589
    'top: for (priority, rules) in rules_grouped.iter() {
193
620128972
        for (_, decl) in symbols.clone().into_iter_local() {
194
620128972
            let Some(letting_expr) = decl.as_value_letting().map(|expr| expr.clone()) else {
195
618513978
                continue;
196
            };
197

            
198
2554880
            for (expr, ctx) in expression_ctx(letting_expr) {
199
2554880
                let expr = expr.clone();
200
2554880
                let ctx = ctx.clone();
201

            
202
10717988
                for rd in rules {
203
10717988
                    let Ok(reduction) = (rd.rule.application)(&expr, &symbols) else {
204
10717150
                        continue;
205
                    };
206

            
207
838
                    results.push((
208
838
                        RuleResult {
209
838
                            rule_data: rd.clone(),
210
838
                            reduction,
211
838
                        },
212
838
                        *priority,
213
838
                        expr.clone(),
214
838
                        decl.clone(),
215
838
                        ctx.clone(),
216
838
                    ));
217
                }
218

            
219
2554880
                if !results.is_empty() {
220
838
                    break 'top;
221
2554042
                }
222
            }
223
        }
224
    }
225

            
226
434749
    let (result, _, expr, decl, ctx) = match results.as_slice() {
227
434749
        [] => return None,
228
838
        [single, ..] => single,
229
    };
230

            
231
838
    if prop_multiple_equally_applicable && results.len() > 1 {
232
        let names: Vec<_> = results
233
            .iter()
234
            .map(|(result, _, _, _, _)| result.rule_data.rule.name)
235
            .collect();
236
        panic!("Multiple equally applicable rules for value letting expression {expr}: {names:?}");
237
838
    }
238

            
239
838
    log_rule_application(result, expr, &symbols, None);
240

            
241
838
    let rewritten_expr = ctx(result.reduction.new_expression.clone());
242
838
    result.reduction.clone().apply(model);
243

            
244
838
    let mut decl = decl.clone();
245
838
    *decl
246
838
        .as_value_letting_mut()
247
838
        .expect("declaration should still be a value letting") = rewritten_expr;
248

            
249
838
    Some(())
250
434749
}
251

            
252
/// Represents errors that can occur during the model rewriting process.
253
#[derive(Debug, Error)]
254
pub enum RewriteError {
255
    #[error("Error resolving rules {0}")]
256
    ResolveRulesError(ResolveRulesError),
257
}
258

            
259
impl From<ResolveRulesError> for RewriteError {
260
    fn from(error: ResolveRulesError) -> Self {
261
        RewriteError::ResolveRulesError(error)
262
    }
263
}