1
use std::collections::{HashMap, HashSet};
2
use std::fmt::{Display, Formatter};
3
use std::hash::Hash;
4
use std::sync::OnceLock;
5

            
6
use log::warn;
7

            
8
use crate::rule_engine::{get_rule_set_by_name, get_rules, Rule};
9
use crate::solver::SolverFamily;
10

            
11
/// A structure representing a set of rules with a name, priority, and dependencies.
12
///
13
/// `RuleSet` is a way to group related rules together under a single name.
14
/// You can think of it like a list of rules that belong to the same category.
15
/// Each `RuleSet` can also have a number that tells it what order it should run in compared to other `RuleSet` instances.
16
/// Additionally, a `RuleSet` can depend on other `RuleSet` instances, meaning it needs them to run first.
17
///
18
/// To make things efficient, `RuleSet` only figures out its rules and dependencies the first time they're needed,
19
/// and then it remembers them so it doesn't have to do the work again.
20
///
21
/// # Fields
22
/// - `name`: The name of the rule set.
23
/// - `order`: A number that decides the order in which this `RuleSet` should be applied.
24
/// If two `RuleSet` instances have the same rule but with different priorities,
25
/// the one with the higher `order` number will be the one that is used.
26
/// - `rules`: A lazily initialized map of rules to their priorities.
27
/// - `dependency_rs_names`: The names of the rule sets that this rule set depends on.
28
/// - `dependencies`: A lazily initialized set of `RuleSet` dependencies.
29
/// - `solver_families`: The solver families that this rule set applies to.
30
#[derive(Clone, Debug)]
31
pub struct RuleSet<'a> {
32
    /// The name of the rule set.
33
    pub name: &'a str,
34
    /// Order of the RuleSet. Used to establish a consistent order of operations when resolving rules.
35
    /// If two RuleSets overlap (contain the same rule but with different priorities), the RuleSet with the higher order will be used as the source of truth.
36
    pub order: u16,
37
    /// A map of rules to their priorities. This will be lazily initialized at runtime.
38
    rules: OnceLock<HashMap<&'a Rule<'a>, u16>>,
39
    /// The names of the rule sets that this rule set depends on.
40
    dependency_rs_names: &'a [&'a str],
41
    dependencies: OnceLock<HashSet<&'a RuleSet<'a>>>,
42
    /// The solver families that this rule set applies to.
43
    pub solver_families: &'a [SolverFamily],
44
}
45

            
46
impl<'a> RuleSet<'a> {
47
    pub const fn new(
48
        name: &'a str,
49
        order: u16,
50
        dependencies: &'a [&'a str],
51
        solver_families: &'a [SolverFamily],
52
    ) -> Self {
53
        Self {
54
            name,
55
            order,
56
            dependency_rs_names: dependencies,
57
            solver_families,
58
            rules: OnceLock::new(),
59
            dependencies: OnceLock::new(),
60
        }
61
    }
62

            
63
    /// Get the rules of this rule set, evaluating them lazily if necessary
64
    /// Returns a `&HashMap<&Rule, u16>` where the key is the rule and the value is the priority of the rule.
65
3162
    pub fn get_rules(&self) -> &HashMap<&'a Rule<'a>, u16> {
66
3162
        match self.rules.get() {
67
            None => {
68
238
                let rules = self.resolve_rules();
69
238
                let _ = self.rules.set(rules); // Try to set the rules, but ignore if it fails.
70
238

            
71
238
                // At this point, the rules cell is guaranteed to be set, so we can unwrap safely.
72
238
                // see: https://doc.rust-lang.org/stable/std/sync/struct.OnceLock.html#method.set
73
238
                #[allow(clippy::unwrap_used)]
74
238
                self.rules.get().unwrap()
75
            }
76
2924
            Some(rules) => rules,
77
        }
78
3162
    }
79

            
80
    /// Get the dependencies of this rule set, evaluating them lazily if necessary
81
    /// Returns a `&HashSet<&RuleSet>` of the rule sets that this rule set depends on.
82
    #[allow(clippy::mutable_key_type)] // RuleSet is 'static so it's fine
83
2788
    pub fn get_dependencies(&self) -> &HashSet<&'static RuleSet> {
84
2788
        match self.dependencies.get() {
85
            None => {
86
170
                let dependencies = self.resolve_dependencies();
87
170
                let _ = self.dependencies.set(dependencies); // Try to set the dependencies, but ignore if it fails.
88
170

            
89
170
                // At this point, the dependencies cell is guaranteed to be set, so we can unwrap safely.
90
170
                // see: https://doc.rust-lang.org/stable/std/sync/struct.OnceLock.html#method.set
91
170
                #[allow(clippy::unwrap_used)]
92
170
                self.dependencies.get().unwrap()
93
            }
94
2618
            Some(dependencies) => dependencies,
95
        }
96
2788
    }
97

            
98
    /// Get the dependencies of this rule set, including itself
99
    #[allow(clippy::mutable_key_type)] // RuleSet is 'static so it's fine
100
731
    pub fn with_dependencies(&self) -> HashSet<&'static RuleSet> {
101
731
        let mut deps = self.get_dependencies().clone();
102
731
        deps.insert(self);
103
731
        deps
104
731
    }
105

            
106
    /// Resolve the rules of this rule set ("reverse the arrows")
107
238
    fn resolve_rules(&self) -> HashMap<&'a Rule<'a>, u16> {
108
238
        let mut rules = HashMap::new();
109

            
110
9282
        for rule in get_rules() {
111
9044
            let mut found = false;
112
9044
            let mut priority: u16 = 0;
113

            
114
15096
            for (name, p) in rule.rule_sets {
115
9044
                if *name == self.name {
116
2992
                    found = true;
117
2992
                    priority = *p;
118
2992
                    break;
119
6052
                }
120
            }
121

            
122
9044
            if found {
123
2992
                rules.insert(rule, priority);
124
6052
            }
125
        }
126

            
127
238
        rules
128
238
    }
129

            
130
    /// Recursively resolve the dependencies of this rule set.
131
    #[allow(clippy::mutable_key_type)] // RuleSet is 'static so it's fine
132
272
    fn resolve_dependencies(&self) -> HashSet<&'static RuleSet> {
133
272
        let mut dependencies = HashSet::new();
134

            
135
374
        for dep in self.dependency_rs_names {
136
102
            match get_rule_set_by_name(dep) {
137
                None => {
138
                    warn!(
139
                        "Rule set {} depends on non-existent rule set {}",
140
                        &self.name, dep
141
                    );
142
                }
143
102
                Some(rule_set) => {
144
102
                    if !dependencies.contains(rule_set) {
145
102
                        // Prevent cycles
146
102
                        dependencies.insert(rule_set);
147
102
                        dependencies.extend(rule_set.resolve_dependencies());
148
102
                    }
149
                }
150
            }
151
        }
152

            
153
272
        dependencies
154
272
    }
155
}
156

            
157
impl PartialEq for RuleSet<'_> {
158
1377
    fn eq(&self, other: &Self) -> bool {
159
1377
        self.name == other.name
160
1377
    }
161
}
162

            
163
impl Eq for RuleSet<'_> {}
164

            
165
impl Hash for RuleSet<'_> {
166
10387
    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
167
10387
        self.name.hash(state);
168
10387
    }
169
}
170

            
171
impl Display for RuleSet<'_> {
172
    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
173
        let n_rules = self.get_rules().len();
174
        let solver_families = self
175
            .solver_families
176
            .iter()
177
            .map(|f| f.to_string())
178
            .collect::<Vec<String>>();
179

            
180
        write!(
181
            f,
182
            "RuleSet {{\n\
183
            \tname: {}\n\
184
            \torder: {}\n\
185
            \trules: {}\n\
186
            \tsolver_families: {:?}\n\
187
        }}",
188
            self.name, self.order, n_rules, solver_families
189
        )
190
    }
191
}