1
use std::cell::RefCell;
2
use std::collections::BTreeSet;
3
use std::fmt::{self, Display, Formatter};
4
use std::hash::Hash;
5
use std::rc::Rc;
6

            
7
use thiserror::Error;
8

            
9
use crate::Model;
10
use crate::ast::{CnfClause, DeclarationPtr, Expression, Name, SymbolTable};
11
use crate::rule_engine::RuleData;
12
use crate::rule_engine::rewriter_common::{RuleResult, log_rule_application};
13
use tree_morph::prelude::Commands;
14
use tree_morph::prelude::Rule as MorphRule;
15

            
16
#[derive(Clone, Debug, Default)]
17
pub(crate) struct MorphState {
18
    pub symbols: SymbolTable,
19
    pub clauses: Vec<CnfClause>,
20
}
21

            
22
#[derive(Debug, Error)]
23
pub enum ApplicationError {
24
    #[error("Rule is not applicable")]
25
    RuleNotApplicable,
26

            
27
    #[error("Could not calculate the expression domain")]
28
    DomainError,
29
}
30

            
31
/// Represents the result of applying a rule to an expression within a model.
32
///
33
/// A `Reduction` encapsulates the changes made to a model during a rule application.
34
/// It includes a new expression to replace the original one, an optional top-level constraint
35
/// to be added to the model, and any updates to the model's symbol table.
36
///
37
/// This struct allows for representing side-effects of rule applications, ensuring that
38
/// all modifications, including symbol table expansions and additional constraints, are
39
/// accounted for and can be applied to the model consistently.
40
///
41
/// # Fields
42
/// - `new_expression`: The updated [`Expression`] that replaces the original one after applying the rule.
43
/// - `new_top`: An additional top-level [`Vec<Expression>`] constraint that should be added to the model. If no top-level
44
///   constraint is needed, this field can be set to an empty vector [`Vec::new()`].
45
/// - `symbols`: A [`SymbolTable`] containing any new symbol definitions or modifications to be added to the model's
46
///   symbol table. If no symbols are modified, this field can be set to an empty symbol table.
47
///
48
/// # Usage
49
/// A `Reduction` can be created using one of the provided constructors:
50
/// - [`Reduction::new`]: Creates a reduction with a new expression, top-level constraint, and symbol modifications.
51
/// - [`Reduction::pure`]: Creates a reduction with only a new expression and no side-effects on the symbol table or constraints.
52
/// - [`Reduction::with_symbols`]: Creates a reduction with a new expression and symbol table modifications, but no top-level constraint.
53
/// - [`Reduction::with_top`]: Creates a reduction with a new expression and a top-level constraint, but no symbol table modifications.
54
/// - [`Reduction::cnf`]: Creates a reduction with a new expression, cnf clauses and symbol modifications, but no no top-level constraints.
55
///
56
/// The `apply` method allows for applying the changes represented by the `Reduction` to a [`Model`].
57
///
58
/// # Example
59
/// ```
60
/// // Need to add an example
61
/// ```
62
///
63
/// # See Also
64
/// - [`ApplicationResult`]: Represents the result of applying a rule, which may either be a `Reduction` or an `ApplicationError`.
65
/// - [`Model`]: The structure to which the `Reduction` changes are applied.
66
#[non_exhaustive]
67
#[derive(Clone, Debug)]
68
pub struct Reduction {
69
    pub new_expression: Expression,
70
    pub new_top: Vec<Expression>,
71
    pub symbols: SymbolTable,
72
    pub new_clauses: Vec<CnfClause>,
73
}
74

            
75
/// The result of applying a rule to an expression.
76
/// Contains either a set of reduction instructions or an error.
77
pub type ApplicationResult = Result<Reduction, ApplicationError>;
78

            
79
impl Reduction {
80
57914
    pub fn new(new_expression: Expression, new_top: Vec<Expression>, symbols: SymbolTable) -> Self {
81
57914
        Self {
82
57914
            new_expression,
83
57914
            new_top,
84
57914
            symbols,
85
57914
            new_clauses: Vec::new(),
86
57914
        }
87
57914
    }
88

            
89
    /// Represents a reduction with no side effects on the model.
90
629316
    pub fn pure(new_expression: Expression) -> Self {
91
629316
        Self {
92
629316
            new_expression,
93
629316
            new_top: Vec::new(),
94
629316
            symbols: SymbolTable::new(),
95
629316
            new_clauses: Vec::new(),
96
629316
        }
97
629316
    }
98

            
99
    /// Represents a reduction that also modifies the symbol table.
100
16498
    pub fn with_symbols(new_expression: Expression, symbols: SymbolTable) -> Self {
101
16498
        Self {
102
16498
            new_expression,
103
16498
            new_top: Vec::new(),
104
16498
            symbols,
105
16498
            new_clauses: Vec::new(),
106
16498
        }
107
16498
    }
108

            
109
    /// Represents a reduction that also adds a top-level constraint to the model.
110
    pub fn with_top(new_expression: Expression, new_top: Vec<Expression>) -> Self {
111
        Self {
112
            new_expression,
113
            new_top,
114
            symbols: SymbolTable::new(),
115
            new_clauses: Vec::new(),
116
        }
117
    }
118

            
119
    /// Represents a reduction that also adds clauses to the model.
120
181560
    pub fn cnf(
121
181560
        new_expression: Expression,
122
181560
        new_clauses: Vec<CnfClause>,
123
181560
        symbols: SymbolTable,
124
181560
    ) -> Self {
125
181560
        Self {
126
181560
            new_expression,
127
181560
            new_top: Vec::new(),
128
181560
            symbols,
129
181560
            new_clauses,
130
181560
        }
131
181560
    }
132

            
133
    /// Applies side-effects (e.g. symbol table updates)
134
315194
    pub fn apply(self, model: &mut Model) {
135
315194
        model.symbols_mut().extend(self.symbols); // Add new assignments to the symbol table
136
315194
        model.add_constraints(self.new_top.clone());
137
315194
        model.add_clauses(self.new_clauses);
138
315194
    }
139

            
140
    /// Gets symbols added by this reduction
141
221334
    pub fn added_symbols(&self, initial_symbols: &SymbolTable) -> BTreeSet<Name> {
142
221334
        let initial_symbols_set: BTreeSet<Name> = initial_symbols
143
221334
            .clone()
144
221334
            .into_iter_local()
145
221334
            .map(|x| x.0)
146
221334
            .collect();
147
221334
        let new_symbols_set: BTreeSet<Name> = self
148
221334
            .symbols
149
221334
            .clone()
150
221334
            .into_iter_local()
151
221334
            .map(|x| x.0)
152
221334
            .collect();
153

            
154
221334
        new_symbols_set
155
221334
            .difference(&initial_symbols_set)
156
221334
            .cloned()
157
221334
            .collect()
158
221334
    }
159

            
160
    /// Gets symbols changed by this reduction
161
    ///
162
    /// Returns a list of tuples of (name, domain before reduction, domain after reduction)
163
    pub fn changed_symbols(
164
        &self,
165
        initial_symbols: &SymbolTable,
166
    ) -> Vec<(Name, DeclarationPtr, DeclarationPtr)> {
167
        let mut changes: Vec<(Name, DeclarationPtr, DeclarationPtr)> = vec![];
168

            
169
        for (var_name, initial_value) in initial_symbols.clone().into_iter_local() {
170
            let Some(new_value) = self.symbols.lookup(&var_name) else {
171
                continue;
172
            };
173

            
174
            if new_value != initial_value {
175
                changes.push((var_name.clone(), initial_value.clone(), new_value.clone()));
176
            }
177
        }
178
        changes
179
    }
180
}
181

            
182
/// The function type used in a [`Rule`].
183
pub type RuleFn = fn(&Expression, &SymbolTable) -> ApplicationResult;
184

            
185
/**
186
 * A rule with a name, application function, and rule sets.
187
 *
188
 * # Fields
189
 * - `name` The name of the rule.
190
 * - `application` The function to apply the rule.
191
 * - `rule_sets` A list of rule set names and priorities that this rule is a part of. This is used to populate rulesets at runtime.
192
 */
193
#[derive(Clone, Debug)]
194
pub struct Rule<'a> {
195
    pub name: &'a str,
196
    pub application: RuleFn,
197
    pub rule_sets: &'a [(&'a str, u16)], // (name, priority). At runtime, we add the rule to rulesets
198
    /// Discriminant ids of Expression variants this rule applies to, or None for universal rules.
199
    pub applicable_to: Option<&'static [usize]>,
200
}
201

            
202
impl<'a> Rule<'a> {
203
    pub const fn new(
204
        name: &'a str,
205
        application: RuleFn,
206
        rule_sets: &'a [(&'static str, u16)],
207
    ) -> Self {
208
        Self {
209
            name,
210
            application,
211
            rule_sets,
212
            applicable_to: None,
213
        }
214
    }
215

            
216
403238378
    pub fn apply(&self, expr: &Expression, symbols: &SymbolTable) -> ApplicationResult {
217
403238378
        (self.application)(expr, symbols)
218
403238378
    }
219
}
220

            
221
impl Display for Rule<'_> {
222
    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
223
        write!(f, "{}", self.name)
224
    }
225
}
226

            
227
impl PartialEq for Rule<'_> {
228
924
    fn eq(&self, other: &Self) -> bool {
229
924
        self.name == other.name
230
924
    }
231
}
232

            
233
impl Eq for Rule<'_> {}
234

            
235
impl Hash for Rule<'_> {
236
27950
    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
237
27950
        self.name.hash(state);
238
27950
    }
239
}
240

            
241
impl MorphRule<Expression, MorphState> for Rule<'_> {
242
    fn apply(
243
        &self,
244
        commands: &mut Commands<Expression, MorphState>,
245
        subtree: &Expression,
246
        meta: &MorphState,
247
    ) -> Option<Expression> {
248
        let reduction = self.apply(subtree, &meta.symbols).ok()?;
249
        let new_expression = reduction.new_expression;
250
        let new_top = reduction.new_top;
251
        let added_symbols = reduction.symbols;
252
        let added_clauses = reduction.new_clauses;
253
        commands.mut_meta(Box::new(move |m: &mut MorphState| {
254
            m.symbols.extend(added_symbols);
255
            m.clauses.extend(added_clauses);
256
        }));
257
        if !new_top.is_empty() {
258
            commands.transform(Box::new(move |m| m.extend_root(new_top)));
259
        }
260
        Some(new_expression)
261
    }
262

            
263
    fn name(&self) -> &str {
264
        self.name
265
    }
266

            
267
    fn applicable_to(&self) -> Option<Vec<usize>> {
268
        self.applicable_to.map(|s| s.to_vec())
269
    }
270
}
271

            
272
impl MorphRule<Expression, MorphState> for &Rule<'_> {
273
    fn apply(
274
        &self,
275
        commands: &mut Commands<Expression, MorphState>,
276
        subtree: &Expression,
277
        meta: &MorphState,
278
    ) -> Option<Expression> {
279
        let reduction = Rule::apply(self, subtree, &meta.symbols).ok()?;
280
        let new_expression = reduction.new_expression;
281
        let new_top = reduction.new_top;
282
        let added_symbols = reduction.symbols;
283
        let added_clauses = reduction.new_clauses;
284
        commands.mut_meta(Box::new(move |m: &mut MorphState| {
285
            m.symbols.extend(added_symbols);
286
            m.clauses.extend(added_clauses);
287
        }));
288
        if !new_top.is_empty() {
289
            commands.transform(Box::new(move |m| m.extend_root(new_top)));
290
        }
291
        Some(new_expression)
292
    }
293

            
294
    fn name(&self) -> &str {
295
        self.name
296
    }
297
}
298

            
299
impl MorphRule<Expression, Rc<RefCell<MorphState>>> for Rule<'_> {
300
    fn apply(
301
        &self,
302
        commands: &mut Commands<Expression, Rc<RefCell<MorphState>>>,
303
        subtree: &Expression,
304
        meta: &Rc<RefCell<MorphState>>,
305
    ) -> Option<Expression> {
306
        let state = meta.borrow();
307
        let reduction = self.apply(subtree, &state.symbols).ok()?;
308
        let new_expression = reduction.new_expression;
309
        let new_top = reduction.new_top;
310
        let added_symbols = reduction.symbols;
311
        let added_clauses = reduction.new_clauses;
312
        commands.mut_meta(Box::new(move |m| {
313
            let mut state = m.borrow_mut();
314
            state.symbols.extend(added_symbols);
315
            state.clauses.extend(added_clauses);
316
        }));
317

            
318
        if !new_top.is_empty() {
319
            commands.transform(Box::new(move |m| m.extend_root(new_top)));
320
        }
321

            
322
        Some(new_expression)
323
    }
324

            
325
    fn name(&self) -> &str {
326
        self.name
327
    }
328
}
329

            
330
impl MorphRule<Expression, MorphState> for RuleData<'_> {
331
403229658
    fn apply(
332
403229658
        &self,
333
403229658
        commands: &mut Commands<Expression, MorphState>,
334
403229658
        subtree: &Expression,
335
403229658
        meta: &MorphState,
336
403229658
    ) -> Option<Expression> {
337
403229658
        let reduction = self.rule.apply(subtree, &meta.symbols).ok()?;
338
290932
        let result = RuleResult {
339
290932
            rule_data: self.clone(),
340
290932
            reduction: reduction.clone(),
341
290932
        };
342

            
343
290932
        log_rule_application(&result, subtree, &meta.symbols, None);
344

            
345
290932
        let new_expression = reduction.new_expression;
346
290932
        let new_top = reduction.new_top;
347
290932
        let added_symbols = reduction.symbols;
348
290932
        let added_clauses = reduction.new_clauses;
349
290932
        commands.mut_meta(Box::new(move |m: &mut MorphState| {
350
290852
            m.symbols.extend(added_symbols);
351
290852
            m.clauses.extend(added_clauses);
352
290852
        }));
353

            
354
290932
        if !new_top.is_empty() {
355
26420
            commands.transform(Box::new(move |m| m.extend_root(new_top)));
356
264512
        }
357
290932
        Some(new_expression)
358
403229658
    }
359

            
360
98
    fn name(&self) -> &str {
361
98
        self.rule.name
362
98
    }
363

            
364
    fn applicable_to(&self) -> Option<Vec<usize>> {
365
        self.rule.applicable_to.map(|s| s.to_vec())
366
    }
367
}