1#![allow(clippy::arc_with_non_send_sync)]
2
3use std::{collections::BTreeSet, fmt::Display};
4
5use crate::{ast::Metadata, into_matrix_expr, matrix_expr};
6use conjure_cp_core::ast::ReturnType;
7use itertools::Itertools as _;
8use parking_lot::RwLockReadGuard;
9use serde::{Deserialize, Serialize};
10use serde_with::serde_as;
11use uniplate::{Biplate, Uniplate};
12
13use super::{
14 DeclarationPtr, Domain, DomainPtr, Expression, Model, Moo, Name, Range, SymbolTable,
15 SymbolTablePtr, Typeable,
16 ac_operators::ACOperatorKind,
17 serde::{AsId, PtrAsInner},
18};
19
20#[serde_as]
21#[derive(Clone, PartialEq, Eq, Hash, Serialize, Deserialize, Debug, Uniplate)]
22#[biplate(to=Expression)]
23#[biplate(to=Name)]
24#[biplate(to=DeclarationPtr)]
25pub enum ComprehensionQualifier {
26 Generator {
27 #[serde_as(as = "AsId")]
28 ptr: DeclarationPtr,
29 },
30 Condition(Expression),
31}
32
33#[serde_as]
35#[derive(Clone, PartialEq, Eq, Hash, Uniplate, Serialize, Deserialize, Debug)]
36#[biplate(to=Expression)]
37#[biplate(to=SymbolTable)]
38#[biplate(to=SymbolTablePtr)]
39#[non_exhaustive]
40pub struct Comprehension {
41 pub return_expression: Expression,
42 pub qualifiers: Vec<ComprehensionQualifier>,
43 #[doc(hidden)]
44 #[serde_as(as = "PtrAsInner")]
45 pub symbols: SymbolTablePtr,
46}
47
48impl Comprehension {
49 pub fn domain_of(&self) -> Option<DomainPtr> {
50 let return_expr_domain = self.return_expression.domain_of()?;
51
52 Some(Domain::matrix(
54 return_expr_domain,
55 vec![Domain::int(vec![Range::UnboundedR(1)])],
56 ))
57 }
58
59 pub fn return_expression(self) -> Expression {
60 self.return_expression
61 }
62
63 pub fn replace_return_expression(&mut self, new_expr: Expression) {
64 self.return_expression = new_expr;
65 }
66
67 pub fn symbols(&self) -> RwLockReadGuard<'_, SymbolTable> {
68 self.symbols.read()
69 }
70
71 pub fn quantified_vars(&self) -> Vec<Name> {
72 self.qualifiers
73 .iter()
74 .filter_map(|q| match q {
75 ComprehensionQualifier::Generator { ptr } => Some(ptr.name().clone()),
76 ComprehensionQualifier::Condition(_) => None,
77 })
78 .collect()
79 }
80
81 pub fn generator_conditions(&self) -> Vec<Expression> {
82 self.qualifiers
83 .iter()
84 .filter_map(|q| match q {
85 ComprehensionQualifier::Condition(c) => Some(c.clone()),
86 ComprehensionQualifier::Generator { .. } => None,
87 })
88 .collect()
89 }
90
91 pub fn to_generator_model(&self) -> Model {
93 let mut model = self.empty_model_with_symbols();
94 model.add_constraints(self.generator_conditions());
95 model
96 }
97
98 pub fn to_return_expression_model(&self) -> Model {
100 let mut model = self.empty_model_with_symbols();
101 model.add_constraint(self.return_expression.clone());
102 model
103 }
104
105 fn empty_model_with_symbols(&self) -> Model {
106 let parent = self.symbols.read().parent().clone();
107 let mut model = if let Some(parent) = parent {
108 Model::new_in_parent_scope(parent)
109 } else {
110 Model::default()
111 };
112 *model.symbols_ptr_unchecked_mut() = self.symbols.clone();
113 model
114 }
115
116 pub fn add_quantified_guard(&mut self, guard: Expression) -> bool {
118 if self.is_quantified_guard(&guard) {
119 self.qualifiers
120 .push(ComprehensionQualifier::Condition(guard));
121 true
122 } else {
123 false
124 }
125 }
126
127 pub fn is_quantified_guard(&self, expr: &Expression) -> bool {
129 let quantified: BTreeSet<Name> = self.quantified_vars().into_iter().collect();
130 is_quantified_guard(&quantified, expr)
131 }
132}
133
134impl Typeable for Comprehension {
135 fn return_type(&self) -> ReturnType {
136 self.return_expression.return_type()
137 }
138}
139
140impl Display for Comprehension {
141 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
142 let generators_and_guards = self
143 .qualifiers
144 .iter()
145 .map(|qualifier| match qualifier {
146 ComprehensionQualifier::Generator { ptr } => {
147 let domain = ptr.domain().expect("generator declaration has domain");
148 format!("{} : {domain}", ptr.name())
149 }
150 ComprehensionQualifier::Condition(expr) => format!("{expr}"),
151 })
152 .join(", ");
153
154 write!(
155 f,
156 "[ {} | {generators_and_guards} ]",
157 self.return_expression
158 )
159 }
160}
161
162#[derive(Clone, Debug, PartialEq, Eq)]
164pub struct ComprehensionBuilder {
165 qualifiers: Vec<ComprehensionQualifier>,
166 symbols: SymbolTablePtr,
168 quantified_variables: BTreeSet<Name>,
169}
170
171impl ComprehensionBuilder {
172 pub fn new(symbol_table_ptr: SymbolTablePtr) -> Self {
173 ComprehensionBuilder {
174 qualifiers: vec![],
175 symbols: SymbolTablePtr::with_parent(symbol_table_ptr),
176 quantified_variables: BTreeSet::new(),
177 }
178 }
179
180 pub fn generator_symboltable(&mut self) -> SymbolTablePtr {
182 self.symbols.clone()
183 }
184
185 pub fn return_expr_symboltable(&mut self) -> SymbolTablePtr {
187 self.symbols.clone()
188 }
189
190 pub fn guard(mut self, guard: Expression) -> Self {
191 self.qualifiers
192 .push(ComprehensionQualifier::Condition(guard));
193 self
194 }
195
196 pub fn generator(mut self, declaration: DeclarationPtr) -> Self {
197 let name = declaration.name().clone();
198 assert!(!self.quantified_variables.contains(&name));
199
200 self.quantified_variables.insert(name.clone());
201
202 let quantified_decl = DeclarationPtr::new_quantified(name, declaration.domain().unwrap());
204 self.symbols.write().insert(quantified_decl.clone());
205
206 self.qualifiers.push(ComprehensionQualifier::Generator {
207 ptr: quantified_decl,
208 });
209
210 self
211 }
212
213 pub fn with_return_value(
221 self,
222 mut expression: Expression,
223 comprehension_kind: Option<ACOperatorKind>,
224 ) -> Comprehension {
225 let quantified_variables = self.quantified_variables;
226
227 let mut qualifiers = Vec::new();
228 let mut other_guards = Vec::new();
229
230 for qualifier in self.qualifiers {
231 match qualifier {
232 ComprehensionQualifier::Generator { .. } => qualifiers.push(qualifier),
233 ComprehensionQualifier::Condition(condition) => {
234 if is_quantified_guard(&quantified_variables, &condition) {
235 qualifiers.push(ComprehensionQualifier::Condition(condition));
236 } else {
237 other_guards.push(condition);
238 }
239 }
240 }
241 }
242
243 if !other_guards.is_empty() {
245 let comprehension_kind = comprehension_kind.expect(
246 "if any guards reference decision variables, a comprehension kind should be given",
247 );
248
249 let guard_expr = match other_guards.as_slice() {
250 [x] => x.clone(),
251 xs => Expression::And(Metadata::new(), Moo::new(into_matrix_expr!(xs.to_vec()))),
252 };
253
254 expression = match comprehension_kind {
255 ACOperatorKind::And => {
256 Expression::Imply(Metadata::new(), Moo::new(guard_expr), Moo::new(expression))
257 }
258 ACOperatorKind::Or => Expression::And(
259 Metadata::new(),
260 Moo::new(matrix_expr![guard_expr, expression]),
261 ),
262
263 ACOperatorKind::Sum => {
264 panic!("guards that reference decision variables not yet implemented for sum");
265 }
266
267 ACOperatorKind::Product => {
268 panic!(
269 "guards that reference decision variables not yet implemented for product"
270 );
271 }
272 }
273 }
274
275 Comprehension {
276 return_expression: expression,
277 qualifiers,
278 symbols: self.symbols,
279 }
280 }
281}
282
283fn is_quantified_guard(quantified_variables: &BTreeSet<Name>, guard: &Expression) -> bool {
285 guard
286 .universe_bi()
287 .iter()
288 .all(|x| quantified_variables.contains(x))
289}