1#![allow(clippy::arc_with_non_send_sync)]
2
3use std::{cell::RefCell, collections::BTreeSet, fmt::Display, rc::Rc, sync::atomic::AtomicBool};
4
5use crate::{ast::Metadata, into_matrix_expr, matrix_expr};
6use conjure_cp_core::ast::ReturnType;
7use itertools::Itertools as _;
8use serde::{Deserialize, Serialize};
9use uniplate::{Biplate, Uniplate};
10
11use super::{
12 DeclarationPtr, Domain, DomainPtr, Expression, Moo, Name, Range, SubModel, SymbolTable,
13 Typeable, ac_operators::ACOperatorKind,
14};
15
16pub static USE_OPTIMISED_REWRITER_FOR_COMPREHENSIONS: AtomicBool = AtomicBool::new(false);
22
23#[derive(Clone, PartialEq, Eq, Hash, Uniplate, Serialize, Deserialize, Debug)]
33#[biplate(to=SubModel)]
34#[biplate(to=Expression)]
35#[non_exhaustive]
36pub struct Comprehension {
37 #[doc(hidden)]
38 pub return_expression_submodel: SubModel,
39 #[doc(hidden)]
40 pub generator_submodel: SubModel,
41 #[doc(hidden)]
42 pub induction_vars: Vec<Name>,
43}
44
45impl Comprehension {
46 pub fn domain_of(&self) -> Option<DomainPtr> {
47 let return_expr_domain = self
48 .return_expression_submodel
49 .clone()
50 .into_single_expression()
51 .domain_of()?;
52
53 Some(Domain::matrix(
55 return_expr_domain,
56 vec![Domain::int(vec![Range::UnboundedR(1)])],
57 ))
58 }
59
60 pub fn return_expression(self) -> Expression {
61 self.return_expression_submodel.into_single_expression()
62 }
63
64 pub fn replace_return_expression(&mut self, new_expr: Expression) {
65 let new_expr = match new_expr {
66 Expression::And(_, exprs) if (*exprs).clone().unwrap_list().is_some() => {
67 Expression::Root(Metadata::new(), (*exprs).clone().unwrap_list().unwrap())
68 }
69 expr => Expression::Root(Metadata::new(), vec![expr]),
70 };
71
72 *self.return_expression_submodel.root_mut_unchecked() = new_expr;
73 }
74
75 pub fn add_induction_guard(&mut self, guard: Expression) -> bool {
77 if self.is_induction_guard(&guard) {
78 self.generator_submodel.add_constraint(guard);
79 true
80 } else {
81 false
82 }
83 }
84
85 pub fn is_induction_guard(&self, expr: &Expression) -> bool {
87 is_induction_guard(&(self.induction_vars.clone().into_iter().collect()), expr)
88 }
89}
90
91impl Typeable for Comprehension {
92 fn return_type(&self) -> ReturnType {
93 self.return_expression_submodel
94 .clone()
95 .into_single_expression()
96 .return_type()
97 }
98}
99
100impl Display for Comprehension {
101 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
102 let generators: String = self
103 .generator_submodel
104 .symbols()
105 .clone()
106 .into_iter_local()
107 .map(|(name, decl): (Name, DeclarationPtr)| {
108 let domain: DomainPtr = decl.domain().unwrap();
109 (name, domain)
110 })
111 .map(|(name, domain)| format!("{name}: {domain}"))
112 .join(",");
113
114 let guards = self
115 .generator_submodel
116 .constraints()
117 .iter()
118 .map(|x| format!("{x}"))
119 .join(",");
120
121 let generators_and_guards = itertools::join([generators, guards], ",");
122
123 let expression = &self.return_expression_submodel;
124 write!(f, "[{expression} | {generators_and_guards}]")
125 }
126}
127
128#[derive(Clone, Debug, PartialEq, Eq)]
130pub struct ComprehensionBuilder {
131 guards: Vec<Expression>,
132 generator_symboltable: Rc<RefCell<SymbolTable>>,
136 return_expr_symboltable: Rc<RefCell<SymbolTable>>,
137 induction_variables: BTreeSet<Name>,
138}
139
140impl ComprehensionBuilder {
141 pub fn new(symbol_table_ptr: Rc<RefCell<SymbolTable>>) -> Self {
142 ComprehensionBuilder {
143 guards: vec![],
144 generator_symboltable: Rc::new(RefCell::new(SymbolTable::with_parent(
145 symbol_table_ptr.clone(),
146 ))),
147 return_expr_symboltable: Rc::new(RefCell::new(SymbolTable::with_parent(
148 symbol_table_ptr,
149 ))),
150 induction_variables: BTreeSet::new(),
151 }
152 }
153
154 pub fn generator_symboltable(&mut self) -> Rc<RefCell<SymbolTable>> {
156 Rc::clone(&self.generator_symboltable)
157 }
158
159 pub fn return_expr_symboltable(&mut self) -> Rc<RefCell<SymbolTable>> {
161 Rc::clone(&self.return_expr_symboltable)
162 }
163
164 pub fn guard(mut self, guard: Expression) -> Self {
165 self.guards.push(guard);
166 self
167 }
168
169 pub fn generator(mut self, declaration: DeclarationPtr) -> Self {
170 let name = declaration.name().clone();
171 let domain = declaration.domain().unwrap();
172 assert!(!self.induction_variables.contains(&name));
173
174 self.induction_variables.insert(name.clone());
175
176 (*self.generator_symboltable)
178 .borrow_mut()
179 .insert(declaration);
180
181 (*self.return_expr_symboltable)
183 .borrow_mut()
184 .insert(DeclarationPtr::new_given(name, domain));
185
186 self
187 }
188
189 pub fn with_return_value(
197 self,
198 mut expression: Expression,
199 comprehension_kind: Option<ACOperatorKind>,
200 ) -> Comprehension {
201 let parent_symboltable = self
202 .generator_symboltable
203 .as_ref()
204 .borrow_mut()
205 .parent_mut_unchecked()
206 .clone()
207 .unwrap();
208 let mut generator_submodel = SubModel::new(parent_symboltable.clone());
209 let mut return_expression_submodel = SubModel::new(parent_symboltable);
210
211 *generator_submodel.symbols_ptr_unchecked_mut() = self.generator_symboltable;
212 *return_expression_submodel.symbols_ptr_unchecked_mut() = self.return_expr_symboltable;
213
214 let induction_variables = self.induction_variables;
217
218 let (mut induction_guards, mut other_guards): (Vec<_>, Vec<_>) = self
220 .guards
221 .into_iter()
222 .partition(|x| is_induction_guard(&induction_variables, x));
223
224 let induction_variables_2 = induction_variables.clone();
225 let generator_symboltable_ptr = generator_submodel.symbols_ptr_unchecked().clone();
226
227 induction_guards =
229 Biplate::<DeclarationPtr>::transform_bi(&induction_guards, &move |decl| {
230 if induction_variables_2.contains(&decl.name()) {
231 (*generator_symboltable_ptr)
232 .borrow()
233 .lookup_local(&decl.name())
234 .unwrap()
235 } else {
236 decl
237 }
238 })
239 .into_iter()
240 .collect_vec();
241
242 let induction_variables_2 = induction_variables.clone();
243 let return_expr_symboltable_ptr =
244 return_expression_submodel.symbols_ptr_unchecked().clone();
245
246 other_guards = Biplate::<DeclarationPtr>::transform_bi(&other_guards, &move |decl| {
248 if induction_variables_2.contains(&decl.name()) {
249 (*return_expr_symboltable_ptr)
250 .borrow()
251 .lookup_local(&decl.name())
252 .unwrap()
253 } else {
254 decl
255 }
256 })
257 .into_iter()
258 .collect_vec();
259
260 if !other_guards.is_empty() {
262 let comprehension_kind = comprehension_kind.expect(
263 "if any guards reference decision variables, a comprehension kind should be given",
264 );
265
266 let guard_expr = match other_guards.as_slice() {
267 [x] => x.clone(),
268 xs => Expression::And(Metadata::new(), Moo::new(into_matrix_expr!(xs.to_vec()))),
269 };
270
271 expression = match comprehension_kind {
272 ACOperatorKind::And => {
273 Expression::Imply(Metadata::new(), Moo::new(guard_expr), Moo::new(expression))
274 }
275 ACOperatorKind::Or => Expression::And(
276 Metadata::new(),
277 Moo::new(Expression::And(
278 Metadata::new(),
279 Moo::new(matrix_expr![guard_expr, expression]),
280 )),
281 ),
282
283 ACOperatorKind::Sum => {
284 panic!("guards that reference decision variables not yet implemented for sum");
285 }
286
287 ACOperatorKind::Product => {
288 panic!(
289 "guards that reference decision variables not yet implemented for product"
290 );
291 }
292 }
293 }
294
295 generator_submodel.add_constraints(induction_guards);
296
297 return_expression_submodel.add_constraint(expression);
298
299 Comprehension {
300 return_expression_submodel,
301 generator_submodel,
302 induction_vars: induction_variables.into_iter().collect_vec(),
303 }
304 }
305}
306
307fn is_induction_guard(induction_variables: &BTreeSet<Name>, guard: &Expression) -> bool {
309 guard
310 .universe_bi()
311 .iter()
312 .all(|x| induction_variables.contains(x))
313}