conjure_core/rule_engine/
rule_set.rs

1use std::collections::{HashMap, HashSet};
2use std::fmt::{Display, Formatter};
3use std::hash::Hash;
4use std::sync::OnceLock;
5
6use log::warn;
7
8use crate::rule_engine::{get_all_rules, get_rule_set_by_name, Rule};
9use crate::solver::SolverFamily;
10
11/// A structure representing a set of rules with a name, priority, and dependencies.
12///
13/// `RuleSet` is a way to group related rules together under a single name.
14/// You can think of it like a list of rules that belong to the same category.
15/// Each `RuleSet` can also have a number that tells it what order it should run in compared to other `RuleSet` instances.
16/// Additionally, a `RuleSet` can depend on other `RuleSet` instances, meaning it needs them to run first.
17///
18/// To make things efficient, `RuleSet` only figures out its rules and dependencies the first time they're needed,
19/// and then it remembers them so it doesn't have to do the work again.
20///
21/// # Fields
22/// - `name`: The name of the rule set.
23/// - `order`: A number that decides the order in which this `RuleSet` should be applied.
24///     If two `RuleSet` instances have the same rule but with different priorities,
25///     the one with the higher `order` number will be the one that is used.
26/// - `rules`: A lazily initialized map of rules to their priorities.
27/// - `dependency_rs_names`: The names of the rule sets that this rule set depends on.
28/// - `dependencies`: A lazily initialized set of `RuleSet` dependencies.
29/// - `solver_families`: The solver families that this rule set applies to.
30#[derive(Clone, Debug)]
31pub struct RuleSet<'a> {
32    /// The name of the rule set.
33    pub name: &'a str,
34    /// A map of rules to their priorities. This will be lazily initialized at runtime.
35    rules: OnceLock<HashMap<&'a Rule<'a>, u16>>,
36    /// The names of the rule sets that this rule set depends on.
37    dependency_rs_names: &'a [&'a str],
38    dependencies: OnceLock<HashSet<&'a RuleSet<'a>>>,
39    /// The solver families that this rule set applies to.
40    pub solver_families: &'a [SolverFamily],
41}
42
43impl<'a> RuleSet<'a> {
44    pub const fn new(
45        name: &'a str,
46        dependencies: &'a [&'a str],
47        solver_families: &'a [SolverFamily],
48    ) -> Self {
49        Self {
50            name,
51            dependency_rs_names: dependencies,
52            solver_families,
53            rules: OnceLock::new(),
54            dependencies: OnceLock::new(),
55        }
56    }
57
58    /// Get the rules of this rule set, evaluating them lazily if necessary
59    /// Returns a `&HashMap<&Rule, u16>` where the key is the rule and the value is the priority of the rule.
60    pub fn get_rules(&self) -> &HashMap<&'a Rule<'a>, u16> {
61        match self.rules.get() {
62            None => {
63                let rules = self.resolve_rules();
64                let _ = self.rules.set(rules); // Try to set the rules, but ignore if it fails.
65
66                // At this point, the rules cell is guaranteed to be set, so we can unwrap safely.
67                // see: https://doc.rust-lang.org/stable/std/sync/struct.OnceLock.html#method.set
68                #[allow(clippy::unwrap_used)]
69                self.rules.get().unwrap()
70            }
71            Some(rules) => rules,
72        }
73    }
74
75    /// Get the dependencies of this rule set, evaluating them lazily if necessary
76    /// Returns a `&HashSet<&RuleSet>` of the rule sets that this rule set depends on.
77    #[allow(clippy::mutable_key_type)] // RuleSet is 'static so it's fine
78    pub fn get_dependencies(&self) -> &HashSet<&'static RuleSet> {
79        match self.dependencies.get() {
80            None => {
81                let dependencies = self.resolve_dependencies();
82                let _ = self.dependencies.set(dependencies); // Try to set the dependencies, but ignore if it fails.
83
84                // At this point, the dependencies cell is guaranteed to be set, so we can unwrap safely.
85                // see: https://doc.rust-lang.org/stable/std/sync/struct.OnceLock.html#method.set
86                #[allow(clippy::unwrap_used)]
87                self.dependencies.get().unwrap()
88            }
89            Some(dependencies) => dependencies,
90        }
91    }
92
93    /// Get the dependencies of this rule set, including itself
94    #[allow(clippy::mutable_key_type)] // RuleSet is 'static so it's fine
95    pub fn with_dependencies(&self) -> HashSet<&'static RuleSet> {
96        let mut deps = self.get_dependencies().clone();
97        deps.insert(self);
98        deps
99    }
100
101    /// Resolve the rules of this rule set ("reverse the arrows")
102    fn resolve_rules(&self) -> HashMap<&'a Rule<'a>, u16> {
103        let mut rules = HashMap::new();
104
105        for rule in get_all_rules() {
106            let mut found = false;
107            let mut priority: u16 = 0;
108
109            for (name, p) in rule.rule_sets {
110                if *name == self.name {
111                    found = true;
112                    priority = *p;
113                    break;
114                }
115            }
116
117            if found {
118                rules.insert(rule, priority);
119            }
120        }
121
122        rules
123    }
124
125    /// Recursively resolve the dependencies of this rule set.
126    #[allow(clippy::mutable_key_type)] // RuleSet is 'static so it's fine
127    fn resolve_dependencies(&self) -> HashSet<&'static RuleSet> {
128        let mut dependencies = HashSet::new();
129
130        for dep in self.dependency_rs_names {
131            match get_rule_set_by_name(dep) {
132                None => {
133                    warn!(
134                        "Rule set {} depends on non-existent rule set {}",
135                        &self.name, dep
136                    );
137                }
138                Some(rule_set) => {
139                    if !dependencies.contains(rule_set) {
140                        // Prevent cycles
141                        dependencies.insert(rule_set);
142                        dependencies.extend(rule_set.resolve_dependencies());
143                    }
144                }
145            }
146        }
147
148        dependencies
149    }
150}
151
152impl PartialEq for RuleSet<'_> {
153    fn eq(&self, other: &Self) -> bool {
154        self.name == other.name
155    }
156}
157
158impl Eq for RuleSet<'_> {}
159
160impl Hash for RuleSet<'_> {
161    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
162        self.name.hash(state);
163    }
164}
165
166impl Display for RuleSet<'_> {
167    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
168        let n_rules = self.get_rules().len();
169        let solver_families = self
170            .solver_families
171            .iter()
172            .map(|f| f.to_string())
173            .collect::<Vec<String>>();
174
175        write!(
176            f,
177            "RuleSet {{\n\
178            \tname: {}\n\
179            \trules: {}\n\
180            \tsolver_families: {:?}\n\
181        }}",
182            self.name, n_rules, solver_families
183        )
184    }
185}