conjure_cp_core/ast/
comprehension.rs1#![allow(clippy::arc_with_non_send_sync)]
2
3use std::{
4 collections::BTreeSet,
5 fmt::Display,
6 sync::atomic::{AtomicBool, AtomicU8, Ordering},
7};
8
9use crate::settings::QuantifiedExpander;
10use crate::{ast::Metadata, into_matrix_expr, matrix_expr};
11use conjure_cp_core::ast::ReturnType;
12use itertools::Itertools as _;
13use serde::{Deserialize, Serialize};
14use uniplate::{Biplate, Uniplate};
15
16use super::{
17 DeclarationPtr, Domain, DomainPtr, Expression, Moo, Name, Range, SubModel, SymbolTablePtr,
18 Typeable, ac_operators::ACOperatorKind,
19};
20
21pub static USE_OPTIMISED_REWRITER_FOR_COMPREHENSIONS: AtomicBool = AtomicBool::new(false);
27
28pub static QUANTIFIED_EXPANDER_FOR_COMPREHENSIONS: AtomicU8 =
32 AtomicU8::new(QuantifiedExpander::Native.as_u8());
33
34pub fn set_quantified_expander_for_comprehensions(expander: QuantifiedExpander) {
35 QUANTIFIED_EXPANDER_FOR_COMPREHENSIONS.store(expander.as_u8(), Ordering::Relaxed);
36}
37
38pub fn quantified_expander_for_comprehensions() -> QuantifiedExpander {
39 QuantifiedExpander::from_u8(QUANTIFIED_EXPANDER_FOR_COMPREHENSIONS.load(Ordering::Relaxed))
40}
41
42#[derive(Clone, PartialEq, Eq, Hash, Uniplate, Serialize, Deserialize, Debug)]
52#[biplate(to=SubModel)]
53#[biplate(to=Expression)]
54#[non_exhaustive]
55pub struct Comprehension {
56 #[doc(hidden)]
57 pub return_expression_submodel: SubModel,
58 #[doc(hidden)]
59 pub generator_submodel: SubModel,
60 #[doc(hidden)]
61 pub quantified_vars: Vec<Name>,
62}
63
64impl Comprehension {
65 pub fn domain_of(&self) -> Option<DomainPtr> {
66 let return_expr_domain = self
67 .return_expression_submodel
68 .clone()
69 .into_single_expression()
70 .domain_of()?;
71
72 Some(Domain::matrix(
74 return_expr_domain,
75 vec![Domain::int(vec![Range::UnboundedR(1)])],
76 ))
77 }
78
79 pub fn return_expression(self) -> Expression {
80 self.return_expression_submodel.into_single_expression()
81 }
82
83 pub fn replace_return_expression(&mut self, new_expr: Expression) {
84 let new_expr = match new_expr {
85 Expression::And(_, exprs) if (*exprs).clone().unwrap_list().is_some() => {
86 Expression::Root(Metadata::new(), (*exprs).clone().unwrap_list().unwrap())
87 }
88 expr => Expression::Root(Metadata::new(), vec![expr]),
89 };
90
91 *self.return_expression_submodel.root_mut_unchecked() = new_expr;
92 }
93
94 pub fn add_quantified_guard(&mut self, guard: Expression) -> bool {
96 if self.is_quantified_guard(&guard) {
97 self.generator_submodel.add_constraint(guard);
98 true
99 } else {
100 false
101 }
102 }
103
104 pub fn is_quantified_guard(&self, expr: &Expression) -> bool {
106 is_quantified_guard(&(self.quantified_vars.clone().into_iter().collect()), expr)
107 }
108}
109
110impl Typeable for Comprehension {
111 fn return_type(&self) -> ReturnType {
112 self.return_expression_submodel
113 .clone()
114 .into_single_expression()
115 .return_type()
116 }
117}
118
119impl Display for Comprehension {
120 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
121 let return_expression = self
122 .return_expression_submodel
123 .clone()
124 .into_single_expression();
125
126 let generator_symbols = self.generator_submodel.symbols().clone();
127 let generators = self
128 .quantified_vars
129 .iter()
130 .map(|name| {
131 let decl: DeclarationPtr = generator_symbols
132 .lookup_local(name)
133 .expect("quantified variable should be in the generator symbol table");
134 let domain: DomainPtr = decl.domain().unwrap();
135 format!("{name} : {domain}")
136 })
137 .collect_vec();
138
139 let guards = self
140 .generator_submodel
141 .constraints()
142 .iter()
143 .map(|x| format!("{x}"))
144 .collect_vec();
145
146 let generators_and_guards = generators.into_iter().chain(guards).join(", ");
147
148 write!(f, "[ {return_expression} | {generators_and_guards} ]")
149 }
150}
151
152#[derive(Clone, Debug, PartialEq, Eq)]
154pub struct ComprehensionBuilder {
155 guards: Vec<Expression>,
156 generator_symboltable: SymbolTablePtr,
160 return_expr_symboltable: SymbolTablePtr,
161 quantified_variables: BTreeSet<Name>,
162}
163
164impl ComprehensionBuilder {
165 pub fn new(symbol_table_ptr: SymbolTablePtr) -> Self {
166 ComprehensionBuilder {
167 guards: vec![],
168 generator_symboltable: SymbolTablePtr::with_parent(symbol_table_ptr.clone()),
169 return_expr_symboltable: SymbolTablePtr::with_parent(symbol_table_ptr),
170 quantified_variables: BTreeSet::new(),
171 }
172 }
173
174 pub fn generator_symboltable(&mut self) -> SymbolTablePtr {
176 self.generator_symboltable.clone()
177 }
178
179 pub fn return_expr_symboltable(&mut self) -> SymbolTablePtr {
181 self.return_expr_symboltable.clone()
182 }
183
184 pub fn guard(mut self, guard: Expression) -> Self {
185 self.guards.push(guard);
186 self
187 }
188
189 pub fn generator(mut self, declaration: DeclarationPtr) -> Self {
190 let name = declaration.name().clone();
191 let domain = declaration.domain().unwrap();
192 assert!(!self.quantified_variables.contains(&name));
193
194 self.quantified_variables.insert(name.clone());
195
196 let quantified_decl = DeclarationPtr::new_quantified(name, domain);
198 self.generator_symboltable
199 .write()
200 .insert(quantified_decl.clone());
201
202 self.return_expr_symboltable.write().insert(
204 DeclarationPtr::new_quantified_from_generator(&quantified_decl)
205 .expect("quantified variables should always have a domain"),
206 );
207
208 self
209 }
210
211 pub fn with_return_value(
219 self,
220 mut expression: Expression,
221 comprehension_kind: Option<ACOperatorKind>,
222 ) -> Comprehension {
223 let parent_symboltable = self.generator_symboltable.read().parent().clone().unwrap();
224
225 let mut generator_submodel = SubModel::new(parent_symboltable.clone());
226 let mut return_expression_submodel = SubModel::new(parent_symboltable);
227
228 *generator_submodel.symbols_ptr_unchecked_mut() = self.generator_symboltable;
229 *return_expression_submodel.symbols_ptr_unchecked_mut() = self.return_expr_symboltable;
230
231 let quantified_variables = self.quantified_variables;
234
235 let (mut quantified_guards, mut other_guards): (Vec<_>, Vec<_>) = self
237 .guards
238 .into_iter()
239 .partition(|x| is_quantified_guard(&quantified_variables, x));
240
241 let quantified_variables_2 = quantified_variables.clone();
242 let generator_symboltable_ptr = generator_submodel.symbols_ptr_unchecked().clone();
243
244 quantified_guards =
246 Biplate::<DeclarationPtr>::transform_bi(&quantified_guards, &move |decl| {
247 if quantified_variables_2.contains(&decl.name()) {
248 generator_symboltable_ptr
249 .read()
250 .lookup_local(&decl.name())
251 .unwrap()
252 } else {
253 decl
254 }
255 })
256 .into_iter()
257 .collect_vec();
258
259 let quantified_variables_2 = quantified_variables.clone();
260 let return_expr_symboltable_ptr =
261 return_expression_submodel.symbols_ptr_unchecked().clone();
262
263 other_guards = Biplate::<DeclarationPtr>::transform_bi(&other_guards, &move |decl| {
265 if quantified_variables_2.contains(&decl.name()) {
266 return_expr_symboltable_ptr
267 .read()
268 .lookup_local(&decl.name())
269 .unwrap()
270 } else {
271 decl
272 }
273 })
274 .into_iter()
275 .collect_vec();
276
277 if !other_guards.is_empty() {
279 let comprehension_kind = comprehension_kind.expect(
280 "if any guards reference decision variables, a comprehension kind should be given",
281 );
282
283 let guard_expr = match other_guards.as_slice() {
284 [x] => x.clone(),
285 xs => Expression::And(Metadata::new(), Moo::new(into_matrix_expr!(xs.to_vec()))),
286 };
287
288 expression = match comprehension_kind {
289 ACOperatorKind::And => {
290 Expression::Imply(Metadata::new(), Moo::new(guard_expr), Moo::new(expression))
291 }
292 ACOperatorKind::Or => Expression::And(
293 Metadata::new(),
294 Moo::new(Expression::And(
295 Metadata::new(),
296 Moo::new(matrix_expr![guard_expr, expression]),
297 )),
298 ),
299
300 ACOperatorKind::Sum => {
301 panic!("guards that reference decision variables not yet implemented for sum");
302 }
303
304 ACOperatorKind::Product => {
305 panic!(
306 "guards that reference decision variables not yet implemented for product"
307 );
308 }
309 }
310 }
311
312 generator_submodel.add_constraints(quantified_guards);
313
314 return_expression_submodel.add_constraint(expression);
315
316 Comprehension {
317 return_expression_submodel,
318 generator_submodel,
319 quantified_vars: quantified_variables.into_iter().collect_vec(),
320 }
321 }
322}
323
324fn is_quantified_guard(quantified_variables: &BTreeSet<Name>, guard: &Expression) -> bool {
326 guard
327 .universe_bi()
328 .iter()
329 .all(|x| quantified_variables.contains(x))
330}