conjure_core/rule_engine/
resolve_rules.rs

1use itertools::Itertools;
2use std::cmp::Ordering;
3use std::collections::{BTreeSet, HashSet};
4use std::fmt::Display;
5use thiserror::Error;
6
7use crate::rule_engine::{get_rule_set_by_name, get_rule_sets_for_solver_family, Rule, RuleSet};
8use crate::solver::SolverFamily;
9
10/// Holds a rule and its priority, along with the rule set it came from.
11#[derive(Debug, Clone)]
12pub struct RuleData<'a> {
13    pub rule: &'a Rule<'a>,
14    pub priority: u16,
15    pub rule_set: &'a RuleSet<'a>,
16}
17
18impl Display for RuleData<'_> {
19    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
20        write!(
21            f,
22            "Rule: {} (priority: {}, from rule set: {})",
23            self.rule.name, self.priority, self.rule_set.name
24        )
25    }
26}
27
28// Equality is based on the rule itself.
29// Note: this is intentional.
30// If two RuleSets reference the same rule (possibly with different priorities),
31// we only want to keep one copy of the rule.
32impl PartialEq for RuleData<'_> {
33    fn eq(&self, other: &Self) -> bool {
34        self.rule == other.rule
35    }
36}
37
38impl Eq for RuleData<'_> {}
39
40// Sort by priority (higher priority first), then by rule name (alphabetical).
41impl PartialOrd for RuleData<'_> {
42    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
43        Some(self.cmp(other))
44    }
45}
46
47impl Ord for RuleData<'_> {
48    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
49        match self.priority.cmp(&other.priority) {
50            Ordering::Equal => self.rule.name.cmp(other.rule.name),
51            ord => ord.reverse(),
52        }
53    }
54}
55
56/// Error type for rule resolution.
57#[derive(Debug, Error)]
58pub enum ResolveRulesError {
59    RuleSetNotFound,
60}
61
62impl Display for ResolveRulesError {
63    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
64        match self {
65            ResolveRulesError::RuleSetNotFound => write!(f, "Rule set not found."),
66        }
67    }
68}
69
70/// Helper function to get a rule set by name, or return an error if it doesn't exist.
71///
72/// # Arguments
73/// - `rule_set_name` The name of the rule set to get.
74///
75/// # Returns
76/// - The rule set with the given name or `RuleSetError::RuleSetNotFound` if it doesn't exist.
77fn get_rule_set(rule_set_name: &str) -> Result<&'static RuleSet<'static>, ResolveRulesError> {
78    match get_rule_set_by_name(rule_set_name) {
79        Some(rule_set) => Ok(rule_set),
80        None => Err(ResolveRulesError::RuleSetNotFound),
81    }
82}
83
84/// Resolve a list of rule sets (and dependencies) by their names
85///
86/// # Arguments
87/// - `rule_set_names` The names of the rule sets to resolve.
88///
89/// # Returns
90/// - A list of the given rule sets and all of their dependencies, or error
91///
92#[allow(clippy::mutable_key_type)] // RuleSet is 'static so it's fine
93fn rule_sets_by_names(
94    rule_set_names: &[&str],
95) -> Result<HashSet<&'static RuleSet<'static>>, ResolveRulesError> {
96    let mut rs_set: HashSet<&'static RuleSet<'static>> = HashSet::new();
97
98    for rule_set_name in rule_set_names {
99        let rule_set = get_rule_set(rule_set_name)?;
100        let new_dependencies = rule_set.get_dependencies();
101        rs_set.insert(rule_set);
102        rs_set.extend(new_dependencies);
103    }
104
105    Ok(rs_set)
106}
107
108/// Build a list of rules to apply (sorted by priority) from a list of rule sets.
109///
110/// # Arguments
111/// - `rule_sets` The rule sets to resolve the rules from.
112///
113/// # Returns
114/// - Rules to apply, sorted from highest to lowest priority.
115pub fn get_rules<'a>(
116    rule_sets: &Vec<&'a RuleSet<'a>>,
117) -> Result<impl IntoIterator<Item = RuleData<'a>>, ResolveRulesError> {
118    // Hashing is done by name which never changes, and the references are 'static
119    #[allow(clippy::mutable_key_type)]
120    let mut ans = BTreeSet::<RuleData<'a>>::new();
121
122    for rs in rule_sets {
123        for (rule, priority) in rs.get_rules() {
124            ans.insert(RuleData {
125                rule,
126                priority: *priority,
127                rule_set: rs,
128            });
129        }
130    }
131
132    Ok(ans)
133}
134
135/// Get rules grouped by priority from a list of rule sets.
136///
137/// # Arguments
138/// - `rule_sets` The rule sets to resolve the rules from.
139///
140/// # Returns
141/// - Rules to apply, grouped by priority, sorted from highest to lowest priority.
142pub fn get_rules_grouped<'a>(
143    rule_sets: &Vec<&'a RuleSet<'a>>,
144) -> Result<impl IntoIterator<Item = (u16, Vec<RuleData<'a>>)> + 'a, ResolveRulesError> {
145    let rules = get_rules(rule_sets)?;
146    let grouped: Vec<(u16, Vec<RuleData<'a>>)> = rules
147        .into_iter()
148        .chunk_by(|rule_data| rule_data.priority)
149        .into_iter()
150        // Each chunk here is short-lived, so we clone/copy out the data
151        .map(|(priority, chunk)| (priority, chunk.collect()))
152        .collect();
153    Ok(grouped)
154}
155
156/// Resolves the final set of rule sets to apply based on target solver and extra rule set names.
157///
158/// # Arguments
159/// - `target_solver` The solver family we're targeting
160/// - `extra_rs_names` Optional extra rule set names to enable
161///
162/// # Returns
163/// - A vector of rule sets to apply.
164///
165pub fn resolve_rule_sets(
166    target_solver: SolverFamily,
167    extra_rs_names: &[&str],
168) -> Result<Vec<&'static RuleSet<'static>>, ResolveRulesError> {
169    #[allow(clippy::mutable_key_type)]
170    // Hashing is done by name which never changes, and the references are 'static
171    let mut ans = HashSet::new();
172
173    for rs in get_rule_sets_for_solver_family(target_solver) {
174        ans.extend(rs.with_dependencies());
175    }
176
177    ans.extend(rule_sets_by_names(extra_rs_names)?);
178    Ok(ans.iter().cloned().collect())
179}