conjure_cp_core/rule_engine/rule_set.rs
1use std::collections::{HashMap, HashSet};
2use std::fmt::{Display, Formatter};
3use std::hash::Hash;
4use std::sync::OnceLock;
5
6use log::warn;
7
8use crate::rule_engine::{Rule, get_all_rules, get_rule_set_by_name};
9use crate::solver::SolverFamily;
10
11/// A structure representing a set of rules with a name, priority, and dependencies.
12///
13/// `RuleSet` is a way to group related rules together under a single name.
14/// You can think of it like a list of rules that belong to the same category.
15/// Each `RuleSet` can also have a number that tells it what order it should run in compared to other `RuleSet` instances.
16/// Additionally, a `RuleSet` can depend on other `RuleSet` instances, meaning it needs them to run first.
17///
18/// To make things efficient, `RuleSet` only figures out its rules and dependencies the first time they're needed,
19/// and then it remembers them so it doesn't have to do the work again.
20///
21/// # Fields
22/// - `name`: The name of the rule set.
23/// - `order`: A number that decides the order in which this `RuleSet` should be applied.
24/// If two `RuleSet` instances have the same rule but with different priorities,
25/// the one with the higher `order` number will be the one that is used.
26/// - `rules`: A lazily initialized map of rules to their priorities.
27/// - `dependency_rs_names`: The names of the rule sets that this rule set depends on.
28/// - `dependencies`: A lazily initialized set of `RuleSet` dependencies.
29/// - `solver_families`: The solver families that this rule set applies to.
30#[derive(Clone, Debug)]
31pub struct RuleSet<'a> {
32 /// The name of the rule set.
33 pub name: &'a str,
34 /// A map of rules to their priorities. This will be lazily initialized at runtime.
35 rules: OnceLock<HashMap<&'a Rule<'a>, u16>>,
36 /// The names of the rule sets that this rule set depends on.
37 dependency_rs_names: &'a [&'a str],
38 dependencies: OnceLock<HashSet<&'a RuleSet<'a>>>,
39
40 /// Returns whether the rule set applies to the given solver family.
41 /// The implementation is specified via an argument to [`register_rule_set!`].
42 applies_to_family_fn: fn(&SolverFamily) -> bool,
43}
44
45impl<'a> RuleSet<'a> {
46 pub const fn new(
47 name: &'a str,
48 dependencies: &'a [&'a str],
49 applies_to_family_fn: fn(&SolverFamily) -> bool,
50 ) -> Self {
51 Self {
52 name,
53 dependency_rs_names: dependencies,
54 rules: OnceLock::new(),
55 dependencies: OnceLock::new(),
56 applies_to_family_fn,
57 }
58 }
59
60 /// Get the rules of this rule set, evaluating them lazily if necessary
61 /// Returns a `&HashMap<&Rule, u16>` where the key is the rule and the value is the priority of the rule.
62 pub fn get_rules(&self) -> &HashMap<&'a Rule<'a>, u16> {
63 match self.rules.get() {
64 None => {
65 let rules = self.resolve_rules();
66 let _ = self.rules.set(rules); // Try to set the rules, but ignore if it fails.
67
68 // At this point, the rules cell is guaranteed to be set, so we can unwrap safely.
69 // see: https://doc.rust-lang.org/stable/std/sync/struct.OnceLock.html#method.set
70 #[allow(clippy::unwrap_used)]
71 self.rules.get().unwrap()
72 }
73 Some(rules) => rules,
74 }
75 }
76
77 /// Get the dependencies of this rule set, evaluating them lazily if necessary
78 /// Returns a `&HashSet<&RuleSet>` of the rule sets that this rule set depends on.
79 #[allow(clippy::mutable_key_type)] // RuleSet is 'static so it's fine
80 pub fn get_dependencies(&self) -> &HashSet<&'static RuleSet<'_>> {
81 match self.dependencies.get() {
82 None => {
83 let dependencies = self.resolve_dependencies();
84 let _ = self.dependencies.set(dependencies); // Try to set the dependencies, but ignore if it fails.
85
86 // At this point, the dependencies cell is guaranteed to be set, so we can unwrap safely.
87 // see: https://doc.rust-lang.org/stable/std/sync/struct.OnceLock.html#method.set
88 #[allow(clippy::unwrap_used)]
89 self.dependencies.get().unwrap()
90 }
91 Some(dependencies) => dependencies,
92 }
93 }
94
95 /// Get the dependencies of this rule set, including itself
96 #[allow(clippy::mutable_key_type)] // RuleSet is 'static so it's fine
97 pub fn with_dependencies(&self) -> HashSet<&'static RuleSet<'_>> {
98 let mut deps = self.get_dependencies().clone();
99 deps.insert(self);
100 deps
101 }
102
103 /// Resolve the rules of this rule set ("reverse the arrows")
104 fn resolve_rules(&self) -> HashMap<&'a Rule<'a>, u16> {
105 let mut rules = HashMap::new();
106
107 for rule in get_all_rules() {
108 let mut found = false;
109 let mut priority: u16 = 0;
110
111 for (name, p) in rule.rule_sets {
112 if *name == self.name {
113 found = true;
114 priority = *p;
115 break;
116 }
117 }
118
119 if found {
120 rules.insert(rule, priority);
121 }
122 }
123
124 rules
125 }
126
127 /// Recursively resolve the dependencies of this rule set.
128 #[allow(clippy::mutable_key_type)] // RuleSet is 'static so it's fine
129 fn resolve_dependencies(&self) -> HashSet<&'static RuleSet<'_>> {
130 let mut dependencies = HashSet::new();
131
132 for dep in self.dependency_rs_names {
133 match get_rule_set_by_name(dep) {
134 None => {
135 warn!(
136 "Rule set {} depends on non-existent rule set {}",
137 &self.name, dep
138 );
139 }
140 Some(rule_set) => {
141 if !dependencies.contains(rule_set) {
142 // Prevent cycles
143 dependencies.insert(rule_set);
144 dependencies.extend(rule_set.resolve_dependencies());
145 }
146 }
147 }
148 }
149
150 dependencies
151 }
152
153 pub fn applies_to_family(&self, family: &SolverFamily) -> bool {
154 (self.applies_to_family_fn)(family)
155 }
156}
157
158impl PartialEq for RuleSet<'_> {
159 fn eq(&self, other: &Self) -> bool {
160 self.name == other.name
161 }
162}
163
164impl Eq for RuleSet<'_> {}
165
166impl Hash for RuleSet<'_> {
167 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
168 self.name.hash(state);
169 }
170}
171
172impl Display for RuleSet<'_> {
173 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
174 let n_rules = self.get_rules().len();
175
176 write!(
177 f,
178 "RuleSet {{\n\
179 \tname: {}\n\
180 \trules: {}\n\
181 }}",
182 self.name, n_rules
183 )
184 }
185}