1
use std::collections::BTreeSet;
2
use std::fmt::{self, Display, Formatter};
3
use std::hash::Hash;
4

            
5
use thiserror::Error;
6

            
7
use crate::ast::{CnfClause, DeclarationPtr, Expression, Name, SubModel, SymbolTable};
8
use tree_morph::prelude::Commands;
9
use tree_morph::prelude::Rule as MorphRule;
10

            
11
#[derive(Debug, Error)]
12
pub enum ApplicationError {
13
    #[error("Rule is not applicable")]
14
    RuleNotApplicable,
15

            
16
    #[error("Could not calculate the expression domain")]
17
    DomainError,
18
}
19

            
20
/// Represents the result of applying a rule to an expression within a model.
21
///
22
/// A `Reduction` encapsulates the changes made to a model during a rule application.
23
/// It includes a new expression to replace the original one, an optional top-level constraint
24
/// to be added to the model, and any updates to the model's symbol table.
25
///
26
/// This struct allows for representing side-effects of rule applications, ensuring that
27
/// all modifications, including symbol table expansions and additional constraints, are
28
/// accounted for and can be applied to the model consistently.
29
///
30
/// # Fields
31
/// - `new_expression`: The updated [`Expression`] that replaces the original one after applying the rule.
32
/// - `new_top`: An additional top-level [`Vec<Expression>`] constraint that should be added to the model. If no top-level
33
///   constraint is needed, this field can be set to an empty vector [`Vec::new()`].
34
/// - `symbols`: A [`SymbolTable`] containing any new symbol definitions or modifications to be added to the model's
35
///   symbol table. If no symbols are modified, this field can be set to an empty symbol table.
36
///
37
/// # Usage
38
/// A `Reduction` can be created using one of the provided constructors:
39
/// - [`Reduction::new`]: Creates a reduction with a new expression, top-level constraint, and symbol modifications.
40
/// - [`Reduction::pure`]: Creates a reduction with only a new expression and no side-effects on the symbol table or constraints.
41
/// - [`Reduction::with_symbols`]: Creates a reduction with a new expression and symbol table modifications, but no top-level constraint.
42
/// - [`Reduction::with_top`]: Creates a reduction with a new expression and a top-level constraint, but no symbol table modifications.
43
/// - [`Reduction::cnf`]: Creates a reduction with a new expression, cnf clauses and symbol modifications, but no no top-level constraints.
44
///
45
/// The `apply` method allows for applying the changes represented by the `Reduction` to a [`Model`].
46
///
47
/// # Example
48
/// ```
49
/// // Need to add an example
50
/// ```
51
///
52
/// # See Also
53
/// - [`ApplicationResult`]: Represents the result of applying a rule, which may either be a `Reduction` or an `ApplicationError`.
54
/// - [`Model`]: The structure to which the `Reduction` changes are applied.
55
#[non_exhaustive]
56
#[derive(Clone, Debug)]
57
pub struct Reduction {
58
    pub new_expression: Expression,
59
    pub new_top: Vec<Expression>,
60
    pub symbols: SymbolTable,
61
    pub new_clauses: Vec<CnfClause>,
62
}
63

            
64
/// The result of applying a rule to an expression.
65
/// Contains either a set of reduction instructions or an error.
66
pub type ApplicationResult = Result<Reduction, ApplicationError>;
67

            
68
impl Reduction {
69
    pub fn new(new_expression: Expression, new_top: Vec<Expression>, symbols: SymbolTable) -> Self {
70
        Self {
71
            new_expression,
72
            new_top,
73
            symbols,
74
            new_clauses: Vec::new(),
75
        }
76
    }
77

            
78
    /// Represents a reduction with no side effects on the model.
79
    pub fn pure(new_expression: Expression) -> Self {
80
        Self {
81
            new_expression,
82
            new_top: Vec::new(),
83
            symbols: SymbolTable::new(),
84
            new_clauses: Vec::new(),
85
        }
86
    }
87

            
88
    /// Represents a reduction that also modifies the symbol table.
89
    pub fn with_symbols(new_expression: Expression, symbols: SymbolTable) -> Self {
90
        Self {
91
            new_expression,
92
            new_top: Vec::new(),
93
            symbols,
94
            new_clauses: Vec::new(),
95
        }
96
    }
97

            
98
    /// Represents a reduction that also adds a top-level constraint to the model.
99
    pub fn with_top(new_expression: Expression, new_top: Vec<Expression>) -> Self {
100
        Self {
101
            new_expression,
102
            new_top,
103
            symbols: SymbolTable::new(),
104
            new_clauses: Vec::new(),
105
        }
106
    }
107

            
108
    /// Represents a reduction that also adds clauses to the model.
109
    pub fn cnf(
110
        new_expression: Expression,
111
        new_clauses: Vec<CnfClause>,
112
        symbols: SymbolTable,
113
    ) -> Self {
114
        Self {
115
            new_expression,
116
            new_top: Vec::new(),
117
            symbols,
118
            new_clauses,
119
        }
120
    }
121

            
122
    /// Applies side-effects (e.g. symbol table updates)
123
    pub fn apply(self, model: &mut SubModel) {
124
        model.symbols_mut().extend(self.symbols); // Add new assignments to the symbol table
125
        model.add_constraints(self.new_top.clone());
126
        model.add_clauses(self.new_clauses);
127
    }
128

            
129
    /// Gets symbols added by this reduction
130
    pub fn added_symbols(&self, initial_symbols: &SymbolTable) -> BTreeSet<Name> {
131
        let initial_symbols_set: BTreeSet<Name> = initial_symbols
132
            .clone()
133
            .into_iter_local()
134
            .map(|x| x.0)
135
            .collect();
136
        let new_symbols_set: BTreeSet<Name> = self
137
            .symbols
138
            .clone()
139
            .into_iter_local()
140
            .map(|x| x.0)
141
            .collect();
142

            
143
        new_symbols_set
144
            .difference(&initial_symbols_set)
145
            .cloned()
146
            .collect()
147
    }
148

            
149
    /// Gets symbols changed by this reduction
150
    ///
151
    /// Returns a list of tuples of (name, domain before reduction, domain after reduction)
152
    pub fn changed_symbols(
153
        &self,
154
        initial_symbols: &SymbolTable,
155
    ) -> Vec<(Name, DeclarationPtr, DeclarationPtr)> {
156
        let mut changes: Vec<(Name, DeclarationPtr, DeclarationPtr)> = vec![];
157

            
158
        for (var_name, initial_value) in initial_symbols.clone().into_iter_local() {
159
            let Some(new_value) = self.symbols.lookup(&var_name) else {
160
                continue;
161
            };
162

            
163
            if new_value != initial_value {
164
                changes.push((var_name.clone(), initial_value.clone(), new_value.clone()));
165
            }
166
        }
167
        changes
168
    }
169
}
170

            
171
/// The function type used in a [`Rule`].
172
pub type RuleFn = fn(&Expression, &SymbolTable) -> ApplicationResult;
173

            
174
/**
175
 * A rule with a name, application function, and rule sets.
176
 *
177
 * # Fields
178
 * - `name` The name of the rule.
179
 * - `application` The function to apply the rule.
180
 * - `rule_sets` A list of rule set names and priorities that this rule is a part of. This is used to populate rulesets at runtime.
181
 */
182
#[derive(Clone, Debug)]
183
pub struct Rule<'a> {
184
    pub name: &'a str,
185
    pub application: RuleFn,
186
    pub rule_sets: &'a [(&'a str, u16)], // (name, priority). At runtime, we add the rule to rulesets
187
}
188

            
189
impl<'a> Rule<'a> {
190
    pub const fn new(
191
        name: &'a str,
192
        application: RuleFn,
193
        rule_sets: &'a [(&'static str, u16)],
194
    ) -> Self {
195
        Self {
196
            name,
197
            application,
198
            rule_sets,
199
        }
200
    }
201

            
202
    pub fn apply(&self, expr: &Expression, symbols: &SymbolTable) -> ApplicationResult {
203
        (self.application)(expr, symbols)
204
    }
205
}
206

            
207
impl Display for Rule<'_> {
208
    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
209
        write!(f, "{}", self.name)
210
    }
211
}
212

            
213
impl PartialEq for Rule<'_> {
214
    fn eq(&self, other: &Self) -> bool {
215
        self.name == other.name
216
    }
217
}
218

            
219
impl Eq for Rule<'_> {}
220

            
221
impl Hash for Rule<'_> {
222
    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
223
        self.name.hash(state);
224
    }
225
}
226

            
227
impl MorphRule<Expression, SymbolTable> for Rule<'_> {
228
    fn apply(
229
        &self,
230
        commands: &mut Commands<Expression, SymbolTable>,
231
        subtree: &Expression,
232
        meta: &SymbolTable,
233
    ) -> Option<Expression> {
234
        let reduction = self.apply(subtree, meta).ok()?;
235
        commands.mut_meta(Box::new(|m: &mut SymbolTable| m.extend(reduction.symbols)));
236
        if !reduction.new_top.is_empty() {
237
            commands.transform(Box::new(|m| m.extend_root(reduction.new_top)));
238
        }
239
        Some(reduction.new_expression)
240
    }
241
}
242

            
243
impl MorphRule<Expression, SymbolTable> for &Rule<'_> {
244
    fn apply(
245
        &self,
246
        commands: &mut Commands<Expression, SymbolTable>,
247
        subtree: &Expression,
248
        meta: &SymbolTable,
249
    ) -> Option<Expression> {
250
        let reduction = Rule::apply(self, subtree, meta).ok()?;
251
        commands.mut_meta(Box::new(|m: &mut SymbolTable| m.extend(reduction.symbols)));
252
        if !reduction.new_top.is_empty() {
253
            commands.transform(Box::new(|m| m.extend_root(reduction.new_top)));
254
        }
255
        Some(reduction.new_expression)
256
    }
257
}