1#![allow(clippy::arc_with_non_send_sync)]
2
3use std::{collections::BTreeSet, fmt::Display};
4
5use crate::{ast::Metadata, into_matrix_expr, matrix_expr};
6use conjure_cp_core::ast::ReturnType;
7use itertools::Itertools as _;
8use parking_lot::RwLockReadGuard;
9use serde::{Deserialize, Serialize};
10use serde_with::serde_as;
11use uniplate::{Biplate, Uniplate};
12
13use super::{
14 DeclarationPtr, Domain, DomainPtr, Expression, Model, Moo, Name, Range, SymbolTable,
15 SymbolTablePtr, Typeable,
16 ac_operators::ACOperatorKind,
17 categories::{Category, CategoryOf},
18 serde::{AsId, PtrAsInner},
19};
20
21#[serde_as]
22#[derive(Clone, PartialEq, Eq, Hash, Serialize, Deserialize, Debug, Uniplate)]
23#[biplate(to=Expression)]
24#[biplate(to=Name)]
25#[biplate(to=DeclarationPtr)]
26pub enum ComprehensionQualifier {
27 ExpressionGenerator {
28 #[serde_as(as = "AsId")]
29 ptr: DeclarationPtr,
30 },
31 Generator {
32 #[serde_as(as = "AsId")]
33 ptr: DeclarationPtr,
34 },
35 Condition(Expression),
36}
37
38#[serde_as]
40#[derive(Clone, PartialEq, Eq, Hash, Uniplate, Serialize, Deserialize, Debug)]
41#[biplate(to=Expression)]
42#[biplate(to=SymbolTable)]
43#[biplate(to=SymbolTablePtr)]
44#[non_exhaustive]
45pub struct Comprehension {
46 pub return_expression: Expression,
47 pub qualifiers: Vec<ComprehensionQualifier>,
48 #[doc(hidden)]
49 #[serde_as(as = "PtrAsInner")]
50 pub symbols: SymbolTablePtr,
51}
52
53impl Comprehension {
54 pub fn domain_of(&self) -> Option<DomainPtr> {
55 let return_expr_domain = self.return_expression.domain_of()?;
56
57 Some(Domain::matrix(
59 return_expr_domain,
60 vec![Domain::int(vec![Range::UnboundedR(1)])],
61 ))
62 }
63
64 pub fn return_expression(self) -> Expression {
65 self.return_expression
66 }
67
68 pub fn replace_return_expression(&mut self, new_expr: Expression) {
69 self.return_expression = new_expr;
70 }
71
72 pub fn symbols(&self) -> RwLockReadGuard<'_, SymbolTable> {
73 self.symbols.read()
74 }
75
76 pub fn quantified_vars(&self) -> Vec<Name> {
77 self.qualifiers
78 .iter()
79 .filter_map(|q| match q {
80 ComprehensionQualifier::ExpressionGenerator { ptr } => Some(ptr.name().clone()),
81 ComprehensionQualifier::Generator { ptr } => Some(ptr.name().clone()),
82 ComprehensionQualifier::Condition(_) => None,
83 })
84 .collect()
85 }
86
87 pub fn generator_conditions(&self) -> Vec<Expression> {
88 self.qualifiers
89 .iter()
90 .filter_map(|q| match q {
91 ComprehensionQualifier::Condition(c) => Some(c.clone()),
92 ComprehensionQualifier::Generator { .. } => None,
93 ComprehensionQualifier::ExpressionGenerator { .. } => None,
94 })
95 .collect()
96 }
97
98 pub fn to_generator_model(&self) -> Model {
100 let mut model = self.empty_model_with_symbols();
101 model.add_constraints(self.generator_conditions());
102 model
103 }
104
105 pub fn to_return_expression_model(&self) -> Model {
107 let mut model = self.empty_model_with_symbols();
108 model.add_constraint(self.return_expression.clone());
109 model
110 }
111
112 fn empty_model_with_symbols(&self) -> Model {
113 let parent = self.symbols.read().parent().clone();
114 let mut model = if let Some(parent) = parent {
115 Model::new_in_parent_scope(parent)
116 } else {
117 Model::default()
118 };
119 *model.symbols_ptr_unchecked_mut() = self.symbols.clone();
120 model
121 }
122
123 pub fn add_quantified_guard(&mut self, guard: Expression) -> bool {
127 if self.is_quantified_guard(&guard) {
128 self.qualifiers
129 .push(ComprehensionQualifier::Condition(guard));
130 true
131 } else {
132 false
133 }
134 }
135
136 pub fn is_quantified_guard(&self, expr: &Expression) -> bool {
138 let quantified: BTreeSet<Name> = self.quantified_vars().into_iter().collect();
139 is_quantified_guard(&self.symbols.read(), &quantified, expr)
140 }
141}
142
143impl Typeable for Comprehension {
144 fn return_type(&self) -> ReturnType {
145 self.return_expression.return_type()
146 }
147}
148
149impl Display for Comprehension {
150 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
151 let generators_and_guards = self
152 .qualifiers
153 .iter()
154 .map(|qualifier| match qualifier {
155 ComprehensionQualifier::Generator { ptr } => {
156 let domain = ptr.domain().expect("generator declaration has domain");
157 format!("{} : {domain}", ptr.name())
158 }
159 ComprehensionQualifier::ExpressionGenerator { ptr } => {
160 let name = ptr.name();
161 if let Some(expr) = ptr.as_quantified_expr() {
162 format!("{name} <- {expr}")
163 } else {
164 panic!("Oh nein! Dat is nicht gut!")
165 }
166 }
167 ComprehensionQualifier::Condition(expr) => format!("{expr}"),
168 })
169 .join(", ");
170
171 write!(
172 f,
173 "[ {} | {generators_and_guards} ]",
174 self.return_expression
175 )
176 }
177}
178
179#[derive(Clone, Debug, PartialEq, Eq)]
181pub struct ComprehensionBuilder {
182 qualifiers: Vec<ComprehensionQualifier>,
183 symbols: SymbolTablePtr,
185 quantified_variables: BTreeSet<Name>,
186}
187
188impl ComprehensionBuilder {
189 pub fn new(symbol_table_ptr: SymbolTablePtr) -> Self {
190 ComprehensionBuilder {
191 qualifiers: vec![],
192 symbols: SymbolTablePtr::with_parent(symbol_table_ptr),
193 quantified_variables: BTreeSet::new(),
194 }
195 }
196
197 pub fn generator_symboltable(&mut self) -> SymbolTablePtr {
199 self.symbols.clone()
200 }
201
202 pub fn return_expr_symboltable(&mut self) -> SymbolTablePtr {
204 self.symbols.clone()
205 }
206
207 pub fn guard(mut self, guard: Expression) -> Self {
208 self.qualifiers
209 .push(ComprehensionQualifier::Condition(guard));
210 self
211 }
212
213 pub fn generator(mut self, declaration: DeclarationPtr) -> Self {
214 let name = declaration.name().clone();
215 assert!(!self.quantified_variables.contains(&name));
216
217 self.quantified_variables.insert(name.clone());
218
219 let quantified_decl = DeclarationPtr::new_quantified(name, declaration.domain().unwrap());
221 self.symbols.write().insert(quantified_decl.clone());
222
223 self.qualifiers.push(ComprehensionQualifier::Generator {
224 ptr: quantified_decl,
225 });
226
227 self
228 }
229
230 pub fn expression_generator(mut self, name: Name, expr: Expression) -> Self {
231 assert!(!self.quantified_variables.contains(&name));
232
233 self.quantified_variables.insert(name.clone());
234
235 let quantified_decl = DeclarationPtr::new_quantified_expr(name, expr);
237 self.symbols.write().insert(quantified_decl.clone());
238
239 self.qualifiers
240 .push(ComprehensionQualifier::ExpressionGenerator {
241 ptr: quantified_decl,
242 });
243
244 self
245 }
246
247 pub fn with_return_value(
255 self,
256 mut expression: Expression,
257 comprehension_kind: Option<ACOperatorKind>,
258 ) -> Comprehension {
259 let quantified_variables = self.quantified_variables;
260 let symbols = self.symbols.read();
261
262 let mut qualifiers = Vec::new();
263 let mut other_guards = Vec::new();
264
265 for qualifier in self.qualifiers {
266 match qualifier {
267 ComprehensionQualifier::Generator { .. } => qualifiers.push(qualifier),
268 ComprehensionQualifier::ExpressionGenerator { .. } => qualifiers.push(qualifier),
269 ComprehensionQualifier::Condition(condition) => {
270 if is_quantified_guard(&symbols, &quantified_variables, &condition) {
271 qualifiers.push(ComprehensionQualifier::Condition(condition));
272 } else {
273 other_guards.push(condition);
274 }
275 }
276 }
277 }
278 drop(symbols);
279
280 if !other_guards.is_empty() {
282 let comprehension_kind = comprehension_kind.expect(
283 "if any guards reference decision variables, a comprehension kind should be given",
284 );
285
286 let guard_expr = match other_guards.as_slice() {
287 [x] => x.clone(),
288 xs => Expression::And(Metadata::new(), Moo::new(into_matrix_expr!(xs.to_vec()))),
289 };
290
291 expression = match comprehension_kind {
292 ACOperatorKind::And => {
293 Expression::Imply(Metadata::new(), Moo::new(guard_expr), Moo::new(expression))
294 }
295 ACOperatorKind::Or => Expression::And(
296 Metadata::new(),
297 Moo::new(matrix_expr![guard_expr, expression]),
298 ),
299
300 ACOperatorKind::Sum => {
301 panic!("guards that reference decision variables not yet implemented for sum");
302 }
303
304 ACOperatorKind::Product => {
305 panic!(
306 "guards that reference decision variables not yet implemented for product"
307 );
308 }
309 }
310 }
311
312 Comprehension {
313 return_expression: expression,
314 qualifiers,
315 symbols: self.symbols,
316 }
317 }
318}
319
320fn is_quantified_guard(
322 symbols: &SymbolTable,
323 quantified_variables: &BTreeSet<Name>,
324 guard: &Expression,
325) -> bool {
326 guard.universe_bi().iter().all(|name| {
327 quantified_variables.contains(name)
328 || symbols
329 .lookup(name)
330 .is_some_and(|decl| decl.category_of() != Category::Decision)
331 })
332}