1use std::{
2 cell::RefCell,
3 collections::HashSet,
4 fmt::Display,
5 rc::Rc,
6 sync::{Arc, Mutex, RwLock},
7};
8
9use itertools::Itertools as _;
10use serde::{Deserialize, Serialize};
11use uniplate::{derive::Uniplate, Biplate as _};
12
13use crate::{
14 ast::Atom,
15 context::Context,
16 into_matrix_expr, matrix_expr,
17 metadata::Metadata,
18 solver::{Solver, SolverError},
19};
20
21use super::{Declaration, Domain, Expression, Model, Name, Range, SubModel, SymbolTable};
22
23pub enum ComprehensionKind {
24 Sum,
25 And,
26 Or,
27}
28#[derive(Clone, PartialEq, Eq, Uniplate, Serialize, Deserialize, Debug)]
30#[uniplate(walk_into=[SubModel])]
31#[biplate(to=SubModel,walk_into=[Expression])]
32#[biplate(to=Expression,walk_into=[SubModel])]
33pub struct Comprehension {
34 expression: Expression,
35 submodel: SubModel,
36 induction_vars: Vec<Name>,
37}
38
39impl Comprehension {
40 pub fn solve_with_minion(self) -> Result<Vec<Expression>, SolverError> {
42 let minion = Solver::new(crate::solver::adaptors::Minion::new());
43 let mut model = Model::new(Arc::new(RwLock::new(Context::default())));
45
46 model.search_order = Some(self.induction_vars.clone());
48
49 *model.as_submodel_mut() = self.submodel.clone();
50
51 let minion = minion.load_model(model.clone())?;
52
53 let values = Arc::new(Mutex::new(Vec::new()));
54 let values_ptr = Arc::clone(&values);
55
56 tracing::debug!(model=%model.clone(),comprehension=%self.clone(),"Minion solving comprehension");
57 let expression = self.expression;
58 minion.solve(Box::new(move |sols| {
59 let values = &mut *values_ptr.lock().unwrap();
61 values.push(sols);
62 true
63 }))?;
64
65 let values = values.lock().unwrap().clone();
66 Ok(values
67 .clone()
68 .into_iter()
69 .map(|sols| {
70 expression
72 .clone()
73 .transform_bi(Arc::new(move |atom: Atom| match atom {
74 Atom::Reference(name) if sols.contains_key(&name) => {
75 Atom::Literal(sols.get(&name).unwrap().clone())
76 }
77 x => x,
78 }))
79 })
80 .collect_vec())
81 }
82
83 pub fn domain_of(&self) -> Option<Domain> {
84 self.expression
85 .domain_of(&self.submodel.symbols())
86 .map(|domain| {
87 Domain::DomainMatrix(
88 Box::new(domain),
89 vec![Domain::IntDomain(vec![Range::UnboundedR(1)])],
90 )
91 })
92 }
93}
94
95impl Display for Comprehension {
96 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
97 let generators: String = self
98 .submodel
99 .symbols()
100 .clone()
101 .into_iter_local()
102 .map(|(name, decl)| (name, decl.domain().unwrap().clone()))
103 .map(|(name, domain)| format!("{name}: {domain}"))
104 .join(",");
105
106 let guards = self
107 .submodel
108 .constraints()
109 .iter()
110 .map(|x| format!("{x}"))
111 .join(",");
112
113 let generators_and_guards = itertools::join([generators, guards], ",");
114
115 let expression = &self.expression;
116 write!(f, "[{expression} | {generators_and_guards}]")
117 }
118}
119
120#[derive(Clone, Debug, PartialEq, Eq, Default)]
122pub struct ComprehensionBuilder {
123 guards: Vec<Expression>,
124 generators: Vec<(Name, Domain)>,
125 induction_variables: HashSet<Name>,
126}
127
128impl ComprehensionBuilder {
129 pub fn new() -> Self {
130 Default::default()
131 }
132 pub fn guard(mut self, guard: Expression) -> Self {
133 self.guards.push(guard);
134 self
135 }
136
137 pub fn generator(mut self, name: Name, domain: Domain) -> Self {
138 assert!(!self.induction_variables.contains(&name));
139 self.induction_variables.insert(name.clone());
140 self.generators.push((name, domain));
141 self
142 }
143
144 pub fn with_return_value(
149 self,
150 mut expression: Expression,
151 parent: Rc<RefCell<SymbolTable>>,
152 comprehension_kind: Option<ComprehensionKind>,
153 ) -> Comprehension {
154 let mut submodel = SubModel::new(parent);
155
156 let induction_variables = self.induction_variables;
159
160 let (induction_guards, other_guards): (Vec<_>, Vec<_>) = self
162 .guards
163 .into_iter()
164 .partition(|x| is_induction_guard(&induction_variables, x));
165
166 if !other_guards.is_empty() {
168 let comprehension_kind = comprehension_kind.expect(
169 "if any guards reference decision variables, a comprehension kind should be given",
170 );
171
172 let guard_expr = match other_guards.as_slice() {
173 [x] => x.clone(),
174 xs => Expression::And(Metadata::new(), Box::new(into_matrix_expr!(xs.to_vec()))),
175 };
176
177 expression = match comprehension_kind {
178 ComprehensionKind::And => {
179 Expression::Imply(Metadata::new(), Box::new(guard_expr), Box::new(expression))
180 }
181 ComprehensionKind::Or => Expression::And(
182 Metadata::new(),
183 Box::new(Expression::And(
184 Metadata::new(),
185 Box::new(matrix_expr![guard_expr, expression]),
186 )),
187 ),
188
189 ComprehensionKind::Sum => {
190 panic!("guards that reference decision variables not yet implemented for sum");
191 }
192 }
193 }
194
195 submodel.add_constraints(induction_guards);
196 for (name, domain) in self.generators {
197 submodel
198 .symbols_mut()
199 .insert(Rc::new(Declaration::new_var(name, domain)));
200 }
201
202 Comprehension {
203 expression,
204 submodel,
205 induction_vars: induction_variables.into_iter().collect_vec(),
206 }
207 }
208}
209
210fn is_induction_guard(induction_variables: &HashSet<Name>, guard: &Expression) -> bool {
212 guard
213 .universe_bi()
214 .iter()
215 .all(|x| induction_variables.contains(x))
216}