1
use std::collections::{HashMap, HashSet};
2
use std::fmt::{Display, Formatter};
3
use std::hash::Hash;
4
use std::sync::OnceLock;
5

            
6
use log::warn;
7

            
8
use crate::rule_engine::{get_all_rules, get_rule_set_by_name, Rule};
9
use crate::solver::SolverFamily;
10

            
11
/// A structure representing a set of rules with a name, priority, and dependencies.
12
///
13
/// `RuleSet` is a way to group related rules together under a single name.
14
/// You can think of it like a list of rules that belong to the same category.
15
/// Each `RuleSet` can also have a number that tells it what order it should run in compared to other `RuleSet` instances.
16
/// Additionally, a `RuleSet` can depend on other `RuleSet` instances, meaning it needs them to run first.
17
///
18
/// To make things efficient, `RuleSet` only figures out its rules and dependencies the first time they're needed,
19
/// and then it remembers them so it doesn't have to do the work again.
20
///
21
/// # Fields
22
/// - `name`: The name of the rule set.
23
/// - `order`: A number that decides the order in which this `RuleSet` should be applied.
24
///     If two `RuleSet` instances have the same rule but with different priorities,
25
///     the one with the higher `order` number will be the one that is used.
26
/// - `rules`: A lazily initialized map of rules to their priorities.
27
/// - `dependency_rs_names`: The names of the rule sets that this rule set depends on.
28
/// - `dependencies`: A lazily initialized set of `RuleSet` dependencies.
29
/// - `solver_families`: The solver families that this rule set applies to.
30
#[derive(Clone, Debug)]
31
pub struct RuleSet<'a> {
32
    /// The name of the rule set.
33
    pub name: &'a str,
34
    /// A map of rules to their priorities. This will be lazily initialized at runtime.
35
    rules: OnceLock<HashMap<&'a Rule<'a>, u16>>,
36
    /// The names of the rule sets that this rule set depends on.
37
    dependency_rs_names: &'a [&'a str],
38
    dependencies: OnceLock<HashSet<&'a RuleSet<'a>>>,
39
    /// The solver families that this rule set applies to.
40
    pub solver_families: &'a [SolverFamily],
41
}
42

            
43
impl<'a> RuleSet<'a> {
44
    pub const fn new(
45
        name: &'a str,
46
        dependencies: &'a [&'a str],
47
        solver_families: &'a [SolverFamily],
48
    ) -> Self {
49
        Self {
50
            name,
51
            dependency_rs_names: dependencies,
52
            solver_families,
53
            rules: OnceLock::new(),
54
            dependencies: OnceLock::new(),
55
        }
56
    }
57

            
58
    /// Get the rules of this rule set, evaluating them lazily if necessary
59
    /// Returns a `&HashMap<&Rule, u16>` where the key is the rule and the value is the priority of the rule.
60
6154
    pub fn get_rules(&self) -> &HashMap<&'a Rule<'a>, u16> {
61
6154
        match self.rules.get() {
62
            None => {
63
170
                let rules = self.resolve_rules();
64
170
                let _ = self.rules.set(rules); // Try to set the rules, but ignore if it fails.
65
170

            
66
170
                // At this point, the rules cell is guaranteed to be set, so we can unwrap safely.
67
170
                // see: https://doc.rust-lang.org/stable/std/sync/struct.OnceLock.html#method.set
68
170
                #[allow(clippy::unwrap_used)]
69
170
                self.rules.get().unwrap()
70
            }
71
5984
            Some(rules) => rules,
72
        }
73
6154
    }
74

            
75
    /// Get the dependencies of this rule set, evaluating them lazily if necessary
76
    /// Returns a `&HashSet<&RuleSet>` of the rule sets that this rule set depends on.
77
    #[allow(clippy::mutable_key_type)] // RuleSet is 'static so it's fine
78
6154
    pub fn get_dependencies(&self) -> &HashSet<&'static RuleSet> {
79
6154
        match self.dependencies.get() {
80
            None => {
81
136
                let dependencies = self.resolve_dependencies();
82
136
                let _ = self.dependencies.set(dependencies); // Try to set the dependencies, but ignore if it fails.
83
136

            
84
136
                // At this point, the dependencies cell is guaranteed to be set, so we can unwrap safely.
85
136
                // see: https://doc.rust-lang.org/stable/std/sync/struct.OnceLock.html#method.set
86
136
                #[allow(clippy::unwrap_used)]
87
136
                self.dependencies.get().unwrap()
88
            }
89
6018
            Some(dependencies) => dependencies,
90
        }
91
6154
    }
92

            
93
    /// Get the dependencies of this rule set, including itself
94
    #[allow(clippy::mutable_key_type)] // RuleSet is 'static so it's fine
95
1564
    pub fn with_dependencies(&self) -> HashSet<&'static RuleSet> {
96
1564
        let mut deps = self.get_dependencies().clone();
97
1564
        deps.insert(self);
98
1564
        deps
99
1564
    }
100

            
101
    /// Resolve the rules of this rule set ("reverse the arrows")
102
170
    fn resolve_rules(&self) -> HashMap<&'a Rule<'a>, u16> {
103
170
        let mut rules = HashMap::new();
104

            
105
9520
        for rule in get_all_rules() {
106
9350
            let mut found = false;
107
9350
            let mut priority: u16 = 0;
108

            
109
16065
            for (name, p) in rule.rule_sets {
110
9350
                if *name == self.name {
111
2635
                    found = true;
112
2635
                    priority = *p;
113
2635
                    break;
114
6715
                }
115
            }
116

            
117
9350
            if found {
118
2635
                rules.insert(rule, priority);
119
6715
            }
120
        }
121

            
122
170
        rules
123
170
    }
124

            
125
    /// Recursively resolve the dependencies of this rule set.
126
    #[allow(clippy::mutable_key_type)] // RuleSet is 'static so it's fine
127
204
    fn resolve_dependencies(&self) -> HashSet<&'static RuleSet> {
128
204
        let mut dependencies = HashSet::new();
129

            
130
272
        for dep in self.dependency_rs_names {
131
68
            match get_rule_set_by_name(dep) {
132
                None => {
133
                    warn!(
134
                        "Rule set {} depends on non-existent rule set {}",
135
                        &self.name, dep
136
                    );
137
                }
138
68
                Some(rule_set) => {
139
68
                    if !dependencies.contains(rule_set) {
140
68
                        // Prevent cycles
141
68
                        dependencies.insert(rule_set);
142
68
                        dependencies.extend(rule_set.resolve_dependencies());
143
68
                    }
144
                }
145
            }
146
        }
147

            
148
204
        dependencies
149
204
    }
150
}
151

            
152
impl PartialEq for RuleSet<'_> {
153
3162
    fn eq(&self, other: &Self) -> bool {
154
3162
        self.name == other.name
155
3162
    }
156
}
157

            
158
impl Eq for RuleSet<'_> {}
159

            
160
impl Hash for RuleSet<'_> {
161
23018
    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
162
23018
        self.name.hash(state);
163
23018
    }
164
}
165

            
166
impl Display for RuleSet<'_> {
167
    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
168
        let n_rules = self.get_rules().len();
169
        let solver_families = self
170
            .solver_families
171
            .iter()
172
            .map(|f| f.to_string())
173
            .collect::<Vec<String>>();
174

            
175
        write!(
176
            f,
177
            "RuleSet {{\n\
178
            \tname: {}\n\
179
            \trules: {}\n\
180
            \tsolver_families: {:?}\n\
181
        }}",
182
            self.name, n_rules, solver_families
183
        )
184
    }
185
}