1
use itertools::Itertools;
2
use std::cmp::Ordering;
3
use std::collections::{BTreeSet, HashSet};
4
use std::fmt::Display;
5
use thiserror::Error;
6

            
7
use crate::rule_engine::{get_rule_set_by_name, get_rule_sets_for_solver_family, Rule, RuleSet};
8
use crate::solver::SolverFamily;
9

            
10
/// Holds a rule and its priority, along with the rule set it came from.
11
#[derive(Debug, Clone)]
12
pub struct RuleData<'a> {
13
    pub rule: &'a Rule<'a>,
14
    pub priority: u16,
15
    pub rule_set: &'a RuleSet<'a>,
16
}
17

            
18
impl Display for RuleData<'_> {
19
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
20
        write!(
21
            f,
22
            "Rule: {} (priority: {}, from rule set: {})",
23
            self.rule.name, self.priority, self.rule_set.name
24
        )
25
    }
26
}
27

            
28
// Equality is based on the rule itself.
29
// Note: this is intentional.
30
// If two RuleSets reference the same rule (possibly with different priorities),
31
// we only want to keep one copy of the rule.
32
impl PartialEq for RuleData<'_> {
33
    fn eq(&self, other: &Self) -> bool {
34
        self.rule == other.rule
35
    }
36
}
37

            
38
impl Eq for RuleData<'_> {}
39

            
40
// Sort by priority (higher priority first), then by rule name (alphabetical).
41
impl PartialOrd for RuleData<'_> {
42
    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
43
        Some(self.cmp(other))
44
    }
45
}
46

            
47
impl Ord for RuleData<'_> {
48
547502
    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
49
547502
        match self.priority.cmp(&other.priority) {
50
152558
            Ordering::Equal => self.rule.name.cmp(other.rule.name),
51
394944
            ord => ord.reverse(),
52
        }
53
547502
    }
54
}
55

            
56
/// Error type for rule resolution.
57
#[derive(Debug, Error)]
58
pub enum ResolveRulesError {
59
    RuleSetNotFound,
60
}
61

            
62
impl Display for ResolveRulesError {
63
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
64
        match self {
65
            ResolveRulesError::RuleSetNotFound => write!(f, "Rule set not found."),
66
        }
67
    }
68
}
69

            
70
/// Helper function to get a rule set by name, or return an error if it doesn't exist.
71
///
72
/// # Arguments
73
/// - `rule_set_name` The name of the rule set to get.
74
///
75
/// # Returns
76
/// - The rule set with the given name or `RuleSetError::RuleSetNotFound` if it doesn't exist.
77
4590
fn get_rule_set(rule_set_name: &str) -> Result<&'static RuleSet<'static>, ResolveRulesError> {
78
4590
    match get_rule_set_by_name(rule_set_name) {
79
4590
        Some(rule_set) => Ok(rule_set),
80
        None => Err(ResolveRulesError::RuleSetNotFound),
81
    }
82
4590
}
83

            
84
/// Resolve a list of rule sets (and dependencies) by their names
85
///
86
/// # Arguments
87
/// - `rule_set_names` The names of the rule sets to resolve.
88
///
89
/// # Returns
90
/// - A list of the given rule sets and all of their dependencies, or error
91
///
92
#[allow(clippy::mutable_key_type)] // RuleSet is 'static so it's fine
93
1564
fn rule_sets_by_names(
94
1564
    rule_set_names: &Vec<String>,
95
1564
) -> Result<HashSet<&'static RuleSet<'static>>, ResolveRulesError> {
96
1564
    let mut rs_set: HashSet<&'static RuleSet<'static>> = HashSet::new();
97

            
98
6154
    for rule_set_name in rule_set_names {
99
4590
        let rule_set = get_rule_set(rule_set_name)?;
100
4590
        let new_dependencies = rule_set.get_dependencies();
101
4590
        rs_set.insert(rule_set);
102
4590
        rs_set.extend(new_dependencies);
103
    }
104

            
105
1564
    Ok(rs_set)
106
1564
}
107

            
108
/// Build a list of rules to apply (sorted by priority) from a list of rule sets.
109
///
110
/// # Arguments
111
/// - `rule_sets` The rule sets to resolve the rules from.
112
///
113
/// # Returns
114
/// - Rules to apply, sorted from highest to lowest priority.
115
1547
pub fn get_rules<'a>(
116
1547
    rule_sets: &Vec<&'a RuleSet<'a>>,
117
1547
) -> Result<impl IntoIterator<Item = RuleData<'a>>, ResolveRulesError> {
118
1547
    // Hashing is done by name which never changes, and the references are 'static
119
1547
    #[allow(clippy::mutable_key_type)]
120
1547
    let mut ans = BTreeSet::<RuleData<'a>>::new();
121

            
122
7701
    for rs in rule_sets {
123
84915
        for (rule, priority) in rs.get_rules() {
124
84915
            ans.insert(RuleData {
125
84915
                rule,
126
84915
                priority: *priority,
127
84915
                rule_set: rs,
128
84915
            });
129
84915
        }
130
    }
131

            
132
1547
    Ok(ans)
133
1547
}
134

            
135
/// Get rules grouped by priority from a list of rule sets.
136
///
137
/// # Arguments
138
/// - `rule_sets` The rule sets to resolve the rules from.
139
///
140
/// # Returns
141
/// - Rules to apply, grouped by priority, sorted from highest to lowest priority.
142
1530
pub fn get_rules_grouped<'a>(
143
1530
    rule_sets: &Vec<&'a RuleSet<'a>>,
144
1530
) -> Result<impl IntoIterator<Item = (u16, Vec<RuleData<'a>>)> + 'a, ResolveRulesError> {
145
1530
    let rules = get_rules(rule_sets)?;
146
1530
    let grouped: Vec<(u16, Vec<RuleData<'a>>)> = rules
147
1530
        .into_iter()
148
84065
        .chunk_by(|rule_data| rule_data.priority)
149
1530
        .into_iter()
150
1530
        // Each chunk here is short-lived, so we clone/copy out the data
151
21420
        .map(|(priority, chunk)| (priority, chunk.collect()))
152
1530
        .collect();
153
1530
    Ok(grouped)
154
1530
}
155

            
156
/// Resolves the final set of rule sets to apply based on target solver and extra rule set names.
157
///
158
/// # Arguments
159
/// - `target_solver` The solver family we're targeting
160
/// - `extra_rs_names` Optional extra rule set names to enable
161
///
162
/// # Returns
163
/// - A vector of rule sets to apply.
164
///
165
1564
pub fn resolve_rule_sets(
166
1564
    target_solver: SolverFamily,
167
1564
    extra_rs_names: &Vec<String>,
168
1564
) -> Result<Vec<&'static RuleSet<'static>>, ResolveRulesError> {
169
1564
    #[allow(clippy::mutable_key_type)]
170
1564
    // Hashing is done by name which never changes, and the references are 'static
171
1564
    let mut ans = HashSet::new();
172

            
173
1564
    for rs in get_rule_sets_for_solver_family(target_solver) {
174
1564
        ans.extend(rs.with_dependencies());
175
1564
    }
176

            
177
1564
    ans.extend(rule_sets_by_names(extra_rs_names)?);
178
1564
    Ok(ans.iter().cloned().collect())
179
1564
}