1
use std::collections::BTreeSet;
2
use std::fmt::{self, Display, Formatter};
3
use std::hash::Hash;
4
use std::rc::Rc;
5

            
6
use thiserror::Error;
7

            
8
use crate::ast::Declaration;
9
use crate::ast::{Expression, Name, SubModel, SymbolTable};
10

            
11
#[derive(Debug, Error)]
12
pub enum ApplicationError {
13
    #[error("Rule is not applicable")]
14
    RuleNotApplicable,
15

            
16
    #[error("Could not calculate the expression domain")]
17
    DomainError,
18
}
19

            
20
/// Represents the result of applying a rule to an expression within a model.
21
///
22
/// A `Reduction` encapsulates the changes made to a model during a rule application.
23
/// It includes a new expression to replace the original one, an optional top-level constraint
24
/// to be added to the model, and any updates to the model's symbol table.
25
///
26
/// This struct allows for representing side-effects of rule applications, ensuring that
27
/// all modifications, including symbol table expansions and additional constraints, are
28
/// accounted for and can be applied to the model consistently.
29
///
30
/// # Fields
31
/// - `new_expression`: The updated [`Expression`] that replaces the original one after applying the rule.
32
/// - `new_top`: An additional top-level [`Vec<Expression>`] constraint that should be added to the model. If no top-level
33
///   constraint is needed, this field can be set to an empty vector [`Vec::new()`].
34
/// - `symbols`: A [`SymbolTable`] containing any new symbol definitions or modifications to be added to the model's
35
///   symbol table. If no symbols are modified, this field can be set to an empty symbol table.
36
///
37
/// # Usage
38
/// A `Reduction` can be created using one of the provided constructors:
39
/// - [`Reduction::new`]: Creates a reduction with a new expression, top-level constraint, and symbol modifications.
40
/// - [`Reduction::pure`]: Creates a reduction with only a new expression and no side-effects on the symbol table or constraints.
41
/// - [`Reduction::with_symbols`]: Creates a reduction with a new expression and symbol table modifications, but no top-level constraint.
42
/// - [`Reduction::with_top`]: Creates a reduction with a new expression and a top-level constraint, but no symbol table modifications.
43
///
44
/// The `apply` method allows for applying the changes represented by the `Reduction` to a [`Model`].
45
///
46
/// # Example
47
/// ```
48
1
/// // Need to add an example
49
1
/// ```
50
1
///
51
/// # See Also
52
/// - [`ApplicationResult`]: Represents the result of applying a rule, which may either be a `Reduction` or an `ApplicationError`.
53
/// - [`Model`]: The structure to which the `Reduction` changes are applied.
54
#[non_exhaustive]
55
#[derive(Clone, Debug)]
56
pub struct Reduction {
57
    pub new_expression: Expression,
58
    pub new_top: Vec<Expression>,
59
    pub symbols: SymbolTable,
60
}
61

            
62
/// The result of applying a rule to an expression.
63
/// Contains either a set of reduction instructions or an error.
64
pub type ApplicationResult = Result<Reduction, ApplicationError>;
65

            
66
impl Reduction {
67
1734
    pub fn new(new_expression: Expression, new_top: Vec<Expression>, symbols: SymbolTable) -> Self {
68
1734
        Self {
69
1734
            new_expression,
70
1734
            new_top,
71
1734
            symbols,
72
1734
        }
73
1734
    }
74

            
75
    /// Represents a reduction with no side effects on the model.
76
10795
    pub fn pure(new_expression: Expression) -> Self {
77
10795
        Self {
78
10795
            new_expression,
79
10795
            new_top: Vec::new(),
80
10795
            symbols: SymbolTable::new(),
81
10795
        }
82
10795
    }
83

            
84
    /// Represents a reduction that also modifies the symbol table.
85
17
    pub fn with_symbols(new_expression: Expression, symbols: SymbolTable) -> Self {
86
17
        Self {
87
17
            new_expression,
88
17
            new_top: Vec::new(),
89
17
            symbols,
90
17
        }
91
17
    }
92

            
93
    /// Represents a reduction that also adds a top-level constraint to the model.
94
    pub fn with_top(new_expression: Expression, new_top: Vec<Expression>) -> Self {
95
        Self {
96
            new_expression,
97
            new_top,
98
            symbols: SymbolTable::new(),
99
        }
100
    }
101

            
102
    /// Applies side-effects (e.g. symbol table updates)
103
12308
    pub fn apply(self, model: &mut SubModel) {
104
12308
        model.symbols_mut().extend(self.symbols); // Add new assignments to the symbol table
105
12308
        model.add_constraints(self.new_top.clone());
106
12308
    }
107

            
108
    /// Gets symbols added by this reduction
109
12308
    pub fn added_symbols(&self, initial_symbols: &SymbolTable) -> BTreeSet<Name> {
110
12308
        let initial_symbols_set: BTreeSet<Name> = initial_symbols
111
12308
            .clone()
112
12308
            .into_iter_local()
113
88876
            .map(|x| x.0)
114
12308
            .collect();
115
12308
        let new_symbols_set: BTreeSet<Name> = self
116
12308
            .symbols
117
12308
            .clone()
118
12308
            .into_iter_local()
119
12308
            .map(|x| x.0)
120
12308
            .collect();
121
12308

            
122
12308
        new_symbols_set
123
12308
            .difference(&initial_symbols_set)
124
12308
            .cloned()
125
12308
            .collect()
126
12308
    }
127

            
128
    /// Gets symbols changed by this reduction
129
    ///
130
    /// Returns a list of tuples of (name, domain before reduction, domain after reduction)
131
    pub fn changed_symbols(
132
        &self,
133
        initial_symbols: &SymbolTable,
134
    ) -> Vec<(Name, Rc<Declaration>, Rc<Declaration>)> {
135
        let mut changes: Vec<(Name, Rc<Declaration>, Rc<Declaration>)> = vec![];
136

            
137
        for (var_name, initial_value) in initial_symbols.clone().into_iter_local() {
138
            let Some(new_value) = self.symbols.lookup(&var_name) else {
139
                continue;
140
            };
141

            
142
            if new_value != initial_value {
143
                changes.push((var_name.clone(), initial_value.clone(), new_value.clone()));
144
            }
145
        }
146
        changes
147
    }
148
}
149

            
150
/// The function type used in a [`Rule`].
151
pub type RuleFn = fn(&Expression, &SymbolTable) -> ApplicationResult;
152

            
153
/**
154
 * A rule with a name, application function, and rule sets.
155
 *
156
 * # Fields
157
 * - `name` The name of the rule.
158
 * - `application` The function to apply the rule.
159
 * - `rule_sets` A list of rule set names and priorities that this rule is a part of. This is used to populate rulesets at runtime.
160
 */
161
#[derive(Clone, Debug)]
162
pub struct Rule<'a> {
163
    pub name: &'a str,
164
    pub application: RuleFn,
165
    pub rule_sets: &'a [(&'a str, u16)], // (name, priority). At runtime, we add the rule to rulesets
166
}
167

            
168
impl<'a> Rule<'a> {
169
    pub const fn new(
170
        name: &'a str,
171
        application: RuleFn,
172
        rule_sets: &'a [(&'static str, u16)],
173
    ) -> Self {
174
        Self {
175
            name,
176
            application,
177
            rule_sets,
178
        }
179
    }
180

            
181
6562
    pub fn apply(&self, expr: &Expression, symbols: &SymbolTable) -> ApplicationResult {
182
6562
        (self.application)(expr, symbols)
183
6562
    }
184
}
185

            
186
impl Display for Rule<'_> {
187
    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
188
        write!(f, "{}", self.name)
189
    }
190
}
191

            
192
impl PartialEq for Rule<'_> {
193
153
    fn eq(&self, other: &Self) -> bool {
194
153
        self.name == other.name
195
153
    }
196
}
197

            
198
impl Eq for Rule<'_> {}
199

            
200
impl Hash for Rule<'_> {
201
5899
    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
202
5899
        self.name.hash(state);
203
5899
    }
204
}