conjure_core/rule_engine/
resolve_rules.rs1use itertools::Itertools;
2use std::cmp::Ordering;
3use std::collections::{BTreeSet, HashSet};
4use std::fmt::Display;
5use thiserror::Error;
6
7use crate::rule_engine::{get_rule_set_by_name, get_rule_sets_for_solver_family, Rule, RuleSet};
8use crate::solver::SolverFamily;
9
10#[derive(Debug, Clone)]
12pub struct RuleData<'a> {
13 pub rule: &'a Rule<'a>,
14 pub priority: u16,
15 pub rule_set: &'a RuleSet<'a>,
16}
17
18impl Display for RuleData<'_> {
19 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
20 write!(
21 f,
22 "Rule: {} (priority: {}, from rule set: {})",
23 self.rule.name, self.priority, self.rule_set.name
24 )
25 }
26}
27
28impl PartialEq for RuleData<'_> {
33 fn eq(&self, other: &Self) -> bool {
34 self.rule == other.rule
35 }
36}
37
38impl Eq for RuleData<'_> {}
39
40impl PartialOrd for RuleData<'_> {
42 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
43 Some(self.cmp(other))
44 }
45}
46
47impl Ord for RuleData<'_> {
48 fn cmp(&self, other: &Self) -> std::cmp::Ordering {
49 match self.priority.cmp(&other.priority) {
50 Ordering::Equal => self.rule.name.cmp(other.rule.name),
51 ord => ord.reverse(),
52 }
53 }
54}
55
56#[derive(Debug, Error)]
58pub enum ResolveRulesError {
59 RuleSetNotFound,
60}
61
62impl Display for ResolveRulesError {
63 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
64 match self {
65 ResolveRulesError::RuleSetNotFound => write!(f, "Rule set not found."),
66 }
67 }
68}
69
70fn get_rule_set(rule_set_name: &str) -> Result<&'static RuleSet<'static>, ResolveRulesError> {
78 match get_rule_set_by_name(rule_set_name) {
79 Some(rule_set) => Ok(rule_set),
80 None => Err(ResolveRulesError::RuleSetNotFound),
81 }
82}
83
84#[allow(clippy::mutable_key_type)] fn rule_sets_by_names(
94 rule_set_names: &[&str],
95) -> Result<HashSet<&'static RuleSet<'static>>, ResolveRulesError> {
96 let mut rs_set: HashSet<&'static RuleSet<'static>> = HashSet::new();
97
98 for rule_set_name in rule_set_names {
99 let rule_set = get_rule_set(rule_set_name)?;
100 let new_dependencies = rule_set.get_dependencies();
101 rs_set.insert(rule_set);
102 rs_set.extend(new_dependencies);
103 }
104
105 Ok(rs_set)
106}
107
108pub fn get_rules<'a>(
116 rule_sets: &Vec<&'a RuleSet<'a>>,
117) -> Result<impl IntoIterator<Item = RuleData<'a>>, ResolveRulesError> {
118 #[allow(clippy::mutable_key_type)]
120 let mut ans = BTreeSet::<RuleData<'a>>::new();
121
122 for rs in rule_sets {
123 for (rule, priority) in rs.get_rules() {
124 ans.insert(RuleData {
125 rule,
126 priority: *priority,
127 rule_set: rs,
128 });
129 }
130 }
131
132 Ok(ans)
133}
134
135pub fn get_rules_grouped<'a>(
143 rule_sets: &Vec<&'a RuleSet<'a>>,
144) -> Result<impl IntoIterator<Item = (u16, Vec<RuleData<'a>>)> + 'a, ResolveRulesError> {
145 let rules = get_rules(rule_sets)?;
146 let grouped: Vec<(u16, Vec<RuleData<'a>>)> = rules
147 .into_iter()
148 .chunk_by(|rule_data| rule_data.priority)
149 .into_iter()
150 .map(|(priority, chunk)| (priority, chunk.collect()))
152 .collect();
153 Ok(grouped)
154}
155
156pub fn resolve_rule_sets(
166 target_solver: SolverFamily,
167 extra_rs_names: &[&str],
168) -> Result<Vec<&'static RuleSet<'static>>, ResolveRulesError> {
169 #[allow(clippy::mutable_key_type)]
170 let mut ans = HashSet::new();
172
173 for rs in get_rule_sets_for_solver_family(target_solver) {
174 ans.extend(rs.with_dependencies());
175 }
176
177 ans.extend(rule_sets_by_names(extra_rs_names)?);
178 Ok(ans.iter().cloned().collect())
179}