conjure_core/rule_engine/
resolve_rules.rsuse std::collections::{HashMap, HashSet};
use std::fmt::Display;
use thiserror::Error;
use crate::rule_engine::{get_rule_set_by_name, get_rule_sets_for_solver_family, Rule, RuleSet};
use crate::solver::SolverFamily;
#[derive(Debug, Error)]
pub enum ResolveRulesError {
RuleSetNotFound,
}
impl Display for ResolveRulesError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
ResolveRulesError::RuleSetNotFound => write!(f, "Rule set not found."),
}
}
}
fn get_rule_set(rule_set_name: &str) -> Result<&'static RuleSet<'static>, ResolveRulesError> {
match get_rule_set_by_name(rule_set_name) {
Some(rule_set) => Ok(rule_set),
None => Err(ResolveRulesError::RuleSetNotFound),
}
}
#[allow(clippy::mutable_key_type)] pub fn rule_sets_by_names<'a>(
rule_set_names: &Vec<String>,
) -> Result<HashSet<&'a RuleSet<'static>>, ResolveRulesError> {
let mut rs_set: HashSet<&'static RuleSet<'static>> = HashSet::new();
for rule_set_name in rule_set_names {
let rule_set = get_rule_set(rule_set_name)?;
let new_dependencies = rule_set.get_dependencies();
rs_set.insert(rule_set);
rs_set.extend(new_dependencies);
}
Ok(rs_set)
}
#[allow(clippy::mutable_key_type)] pub fn resolve_rule_sets<'a>(
target_solver: SolverFamily,
extra_rs_names: &Vec<String>,
) -> Result<Vec<&'a RuleSet<'static>>, ResolveRulesError> {
let mut ans = HashSet::new();
for rs in get_rule_sets_for_solver_family(target_solver) {
ans.extend(rs.with_dependencies());
}
ans.extend(rule_sets_by_names(extra_rs_names)?);
Ok(ans.iter().cloned().collect())
}
pub fn get_rule_priorities<'a>(
rule_sets: &Vec<&'a RuleSet<'a>>,
) -> Result<HashMap<&'a Rule<'a>, u16>, ResolveRulesError> {
let mut rule_priorities: HashMap<&'a Rule<'a>, (&'a RuleSet<'a>, u16)> = HashMap::new();
for rs in rule_sets {
for (rule, priority) in rs.get_rules() {
if let Some((old_rs, _)) = rule_priorities.get(rule) {
if rs.order >= old_rs.order {
rule_priorities.insert(rule, (rs, *priority));
}
} else {
rule_priorities.insert(rule, (rs, *priority));
}
}
}
let mut ans: HashMap<&'a Rule<'a>, u16> = HashMap::new();
for (rule, (_, priority)) in rule_priorities {
ans.insert(rule, priority);
}
Ok(ans)
}
pub fn rule_cmp<'a>(
a: &Rule<'a>,
b: &Rule<'a>,
rule_priorities: &HashMap<&'a Rule<'a>, u16>,
) -> std::cmp::Ordering {
let a_priority = *rule_priorities.get(a).unwrap_or(&0);
let b_priority = *rule_priorities.get(b).unwrap_or(&0);
if a_priority == b_priority {
return a.name.cmp(b.name);
}
b_priority.cmp(&a_priority)
}
pub fn get_rules_vec<'a>(rule_priorities: &HashMap<&'a Rule<'a>, u16>) -> Vec<&'a Rule<'a>> {
let mut rules: Vec<&'a Rule<'a>> = rule_priorities.keys().copied().collect();
rules.sort_by(|a, b| rule_cmp(a, b, rule_priorities));
rules
}