conjure_core/rule_engine/
resolve_rules.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
use std::collections::{HashMap, HashSet};
use std::fmt::Display;

use thiserror::Error;

use crate::rule_engine::{get_rule_set_by_name, get_rule_sets_for_solver_family, Rule, RuleSet};
use crate::solver::SolverFamily;

#[derive(Debug, Error)]
pub enum ResolveRulesError {
    RuleSetNotFound,
}

impl Display for ResolveRulesError {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        match self {
            ResolveRulesError::RuleSetNotFound => write!(f, "Rule set not found."),
        }
    }
}

/// Helper function to get a rule set by name, or return an error if it doesn't exist.
///
/// # Arguments
/// - `rule_set_name` The name of the rule set to get.
///
/// # Returns
/// - The rule set with the given name or `RuleSetError::RuleSetNotFound` if it doesn't exist.
fn get_rule_set(rule_set_name: &str) -> Result<&'static RuleSet<'static>, ResolveRulesError> {
    match get_rule_set_by_name(rule_set_name) {
        Some(rule_set) => Ok(rule_set),
        None => Err(ResolveRulesError::RuleSetNotFound),
    }
}

/// Resolve a list of rule sets (and dependencies) by their names
///
/// # Arguments
/// - `rule_set_names` The names of the rule sets to resolve.
///
/// # Returns
/// - A list of the given rule sets and all of their dependencies, or error
///
#[allow(clippy::mutable_key_type)] // RuleSet is 'static so it's fine
pub fn rule_sets_by_names<'a>(
    rule_set_names: &Vec<String>,
) -> Result<HashSet<&'a RuleSet<'static>>, ResolveRulesError> {
    let mut rs_set: HashSet<&'static RuleSet<'static>> = HashSet::new();

    for rule_set_name in rule_set_names {
        let rule_set = get_rule_set(rule_set_name)?;
        let new_dependencies = rule_set.get_dependencies();
        rs_set.insert(rule_set);
        rs_set.extend(new_dependencies);
    }

    Ok(rs_set)
}

/// Resolves the final set of rule sets to apply based on target solver and extra rule set names.
///
/// # Arguments
/// - `target_solver` The solver to resolve the rule sets for.
/// - `extra_rs_names` The names of the extra rule sets to use
///
/// # Returns
/// - A vector of rule sets to apply.
///
#[allow(clippy::mutable_key_type)] // RuleSet is 'static so it's fine
pub fn resolve_rule_sets<'a>(
    target_solver: SolverFamily,
    extra_rs_names: &Vec<String>,
) -> Result<Vec<&'a RuleSet<'static>>, ResolveRulesError> {
    let mut ans = HashSet::new();

    for rs in get_rule_sets_for_solver_family(target_solver) {
        ans.extend(rs.with_dependencies());
    }

    ans.extend(rule_sets_by_names(extra_rs_names)?);
    Ok(ans.iter().cloned().collect())
}

/// Convert a list of rule sets into a final map of rules to their priorities.
///
/// # Arguments
/// - `rule_sets` The rule sets to get the rules from.
/// # Returns
/// - A map of rules to their priorities.
pub fn get_rule_priorities<'a>(
    rule_sets: &Vec<&'a RuleSet<'a>>,
) -> Result<HashMap<&'a Rule<'a>, u16>, ResolveRulesError> {
    let mut rule_priorities: HashMap<&'a Rule<'a>, (&'a RuleSet<'a>, u16)> = HashMap::new();

    for rs in rule_sets {
        for (rule, priority) in rs.get_rules() {
            if let Some((old_rs, _)) = rule_priorities.get(rule) {
                if rs.order >= old_rs.order {
                    rule_priorities.insert(rule, (rs, *priority));
                }
            } else {
                rule_priorities.insert(rule, (rs, *priority));
            }
        }
    }

    let mut ans: HashMap<&'a Rule<'a>, u16> = HashMap::new();
    for (rule, (_, priority)) in rule_priorities {
        ans.insert(rule, priority);
    }

    Ok(ans)
}

/// Compare two rules by their priorities and names.
///
/// Takes the rules and a map of rules to their priorities.
/// If rules are not in the map, they are assumed to have priority 0.
/// If the rules have the same priority, they are compared by their names.
///
/// # Arguments
/// - `a` first rule to compare.
/// - `b` second rule to compare.
/// - `rule_priorities` The priorities of the rules.
///
/// # Returns
/// - The ordering of the two rules.
pub fn rule_cmp<'a>(
    a: &Rule<'a>,
    b: &Rule<'a>,
    rule_priorities: &HashMap<&'a Rule<'a>, u16>,
) -> std::cmp::Ordering {
    let a_priority = *rule_priorities.get(a).unwrap_or(&0);
    let b_priority = *rule_priorities.get(b).unwrap_or(&0);

    if a_priority == b_priority {
        return a.name.cmp(b.name);
    }

    b_priority.cmp(&a_priority)
}

/// Get a final ordering of rules based on their priorities and names.
///
/// # Arguments
/// - `rule_priorities` The priorities of the rules.
///
/// # Returns
/// - A list of rules sorted by their priorities and names.
pub fn get_rules_vec<'a>(rule_priorities: &HashMap<&'a Rule<'a>, u16>) -> Vec<&'a Rule<'a>> {
    let mut rules: Vec<&'a Rule<'a>> = rule_priorities.keys().copied().collect();
    rules.sort_by(|a, b| rule_cmp(a, b, rule_priorities));
    rules
}