1use itertools::Itertools;
2use std::cmp::Ordering;
3use std::collections::{BTreeSet, HashSet};
4use std::fmt::Display;
5use thiserror::Error;
67use crate::rule_engine::{get_rule_set_by_name, get_rule_sets_for_solver_family, Rule, RuleSet};
8use crate::solver::SolverFamily;
910/// Holds a rule and its priority, along with the rule set it came from.
11#[derive(Debug, Clone)]
12pub struct RuleData<'a> {
13pub rule: &'a Rule<'a>,
14pub priority: u16,
15pub rule_set: &'a RuleSet<'a>,
16}
1718impl Display for RuleData<'_> {
19fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
20write!(
21 f,
22"Rule: {} (priority: {}, from rule set: {})",
23self.rule.name, self.priority, self.rule_set.name
24 )
25 }
26}
2728// Equality is based on the rule itself.
29// Note: this is intentional.
30// If two RuleSets reference the same rule (possibly with different priorities),
31// we only want to keep one copy of the rule.
32impl PartialEq for RuleData<'_> {
33fn eq(&self, other: &Self) -> bool {
34self.rule == other.rule
35 }
36}
3738impl Eq for RuleData<'_> {}
3940// Sort by priority (higher priority first), then by rule name (alphabetical).
41impl PartialOrd for RuleData<'_> {
42fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
43Some(self.cmp(other))
44 }
45}
4647impl Ord for RuleData<'_> {
48fn cmp(&self, other: &Self) -> std::cmp::Ordering {
49match self.priority.cmp(&other.priority) {
50 Ordering::Equal => self.rule.name.cmp(other.rule.name),
51 ord => ord.reverse(),
52 }
53 }
54}
5556/// Error type for rule resolution.
57#[derive(Debug, Error)]
58pub enum ResolveRulesError {
59 RuleSetNotFound,
60}
6162impl Display for ResolveRulesError {
63fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
64match self {
65 ResolveRulesError::RuleSetNotFound => write!(f, "Rule set not found."),
66 }
67 }
68}
6970/// Helper function to get a rule set by name, or return an error if it doesn't exist.
71///
72/// # Arguments
73/// - `rule_set_name` The name of the rule set to get.
74///
75/// # Returns
76/// - The rule set with the given name or `RuleSetError::RuleSetNotFound` if it doesn't exist.
77fn get_rule_set(rule_set_name: &str) -> Result<&'static RuleSet<'static>, ResolveRulesError> {
78match get_rule_set_by_name(rule_set_name) {
79Some(rule_set) => Ok(rule_set),
80None => Err(ResolveRulesError::RuleSetNotFound),
81 }
82}
8384/// Resolve a list of rule sets (and dependencies) by their names
85///
86/// # Arguments
87/// - `rule_set_names` The names of the rule sets to resolve.
88///
89/// # Returns
90/// - A list of the given rule sets and all of their dependencies, or error
91///
92#[allow(clippy::mutable_key_type)] // RuleSet is 'static so it's fine
93fn rule_sets_by_names(
94 rule_set_names: &[&str],
95) -> Result<HashSet<&'static RuleSet<'static>>, ResolveRulesError> {
96let mut rs_set: HashSet<&'static RuleSet<'static>> = HashSet::new();
9798for rule_set_name in rule_set_names {
99let rule_set = get_rule_set(rule_set_name)?;
100let new_dependencies = rule_set.get_dependencies();
101 rs_set.insert(rule_set);
102 rs_set.extend(new_dependencies);
103 }
104105Ok(rs_set)
106}
107108/// Build a list of rules to apply (sorted by priority) from a list of rule sets.
109///
110/// # Arguments
111/// - `rule_sets` The rule sets to resolve the rules from.
112///
113/// # Returns
114/// - Rules to apply, sorted from highest to lowest priority.
115pub fn get_rules<'a>(
116 rule_sets: &Vec<&'a RuleSet<'a>>,
117) -> Result<impl IntoIterator<Item = RuleData<'a>>, ResolveRulesError> {
118// Hashing is done by name which never changes, and the references are 'static
119#[allow(clippy::mutable_key_type)]
120let mut ans = BTreeSet::<RuleData<'a>>::new();
121122for rs in rule_sets {
123for (rule, priority) in rs.get_rules() {
124 ans.insert(RuleData {
125 rule,
126 priority: *priority,
127 rule_set: rs,
128 });
129 }
130 }
131132Ok(ans)
133}
134135/// Get rules grouped by priority from a list of rule sets.
136///
137/// # Arguments
138/// - `rule_sets` The rule sets to resolve the rules from.
139///
140/// # Returns
141/// - Rules to apply, grouped by priority, sorted from highest to lowest priority.
142pub fn get_rules_grouped<'a>(
143 rule_sets: &Vec<&'a RuleSet<'a>>,
144) -> Result<impl IntoIterator<Item = (u16, Vec<RuleData<'a>>)> + 'a, ResolveRulesError> {
145let rules = get_rules(rule_sets)?;
146let grouped: Vec<(u16, Vec<RuleData<'a>>)> = rules
147 .into_iter()
148 .chunk_by(|rule_data| rule_data.priority)
149 .into_iter()
150// Each chunk here is short-lived, so we clone/copy out the data
151.map(|(priority, chunk)| (priority, chunk.collect()))
152 .collect();
153Ok(grouped)
154}
155156/// Resolves the final set of rule sets to apply based on target solver and extra rule set names.
157///
158/// # Arguments
159/// - `target_solver` The solver family we're targeting
160/// - `extra_rs_names` Optional extra rule set names to enable
161///
162/// # Returns
163/// - A vector of rule sets to apply.
164///
165pub fn resolve_rule_sets(
166 target_solver: SolverFamily,
167 extra_rs_names: &[&str],
168) -> Result<Vec<&'static RuleSet<'static>>, ResolveRulesError> {
169#[allow(clippy::mutable_key_type)]
170// Hashing is done by name which never changes, and the references are 'static
171let mut ans = HashSet::new();
172173for rs in get_rule_sets_for_solver_family(target_solver) {
174 ans.extend(rs.with_dependencies());
175 }
176177 ans.extend(rule_sets_by_names(extra_rs_names)?);
178Ok(ans.iter().cloned().collect())
179}