1use super::SymbolTable;
2use super::declaration::{DeclarationPtr, serde::DeclarationPtrFull};
3use super::serde::RcRefCellAsInner;
4use crate::ast::{DomainPtr, Expression, Name, ReturnType, SubModel, Typeable};
5use serde::{Deserialize, Serialize};
6use serde_with::serde_as;
7use std::collections::VecDeque;
8use std::fmt::{Display, Formatter};
9use std::{cell::RefCell, hash::Hash, hash::Hasher, rc::Rc};
10use uniplate::{Biplate, Tree, Uniplate};
11
12#[serde_as]
13#[derive(Clone, PartialEq, Eq, Serialize, Deserialize, Debug, Uniplate)]
14#[biplate(to=Expression)]
15#[biplate(to=SubModel)]
16pub struct AbstractComprehension {
17 pub return_expr: Expression,
18 pub qualifiers: Vec<Qualifier>,
19
20 #[serde_as(as = "RcRefCellAsInner")]
25 pub return_expr_symbols: Rc<RefCell<SymbolTable>>,
26
27 #[serde_as(as = "RcRefCellAsInner")]
32 pub generator_symbols: Rc<RefCell<SymbolTable>>,
33}
34
35impl Biplate<SymbolTable> for AbstractComprehension {
37 fn biplate(
38 &self,
39 ) -> (
40 uniplate::Tree<SymbolTable>,
41 Box<dyn Fn(uniplate::Tree<SymbolTable>) -> Self>,
42 ) {
43 let return_expr_symbols: SymbolTable = (*self.return_expr_symbols).borrow().clone();
44 let generator_symbols: SymbolTable = (*self.generator_symbols).borrow().clone();
45
46 let (tables_in_exprs_tree, tables_in_exprs_ctx) =
47 Biplate::<SymbolTable>::biplate(&Biplate::<Expression>::children_bi(self));
48
49 let tree = Tree::Many(VecDeque::from([
50 Tree::One(return_expr_symbols),
51 Tree::One(generator_symbols),
52 tables_in_exprs_tree,
53 ]));
54
55 let self2 = self.clone();
56 let ctx = Box::new(move |tree: Tree<SymbolTable>| {
57 let Tree::Many(vs) = tree else {
58 panic!();
59 };
60
61 let Tree::One(return_expr_symbols) = vs[0].clone() else {
62 panic!();
63 };
64
65 let Tree::One(generator_symbols) = vs[1].clone() else {
66 panic!();
67 };
68
69 let self3 = self2.with_children_bi(tables_in_exprs_ctx(vs[2].clone()));
70
71 *(self3.return_expr_symbols.borrow_mut()) = return_expr_symbols;
75 *(self3.generator_symbols.borrow_mut()) = generator_symbols;
76
77 self3
78 });
79
80 (tree, ctx)
81 }
82}
83#[derive(Clone, PartialEq, Eq, Serialize, Deserialize, Debug, Hash, Uniplate)]
84#[biplate(to=Expression)]
85#[biplate(to=SubModel)]
86pub enum Qualifier {
87 Generator(Generator),
88 Condition(Expression),
89 ComprehensionLetting(ComprehensionLetting),
90}
91
92#[serde_as]
93#[derive(Clone, PartialEq, Eq, Serialize, Deserialize, Debug, Hash, Uniplate)]
94#[biplate(to=Expression)]
95#[biplate(to=SubModel)]
96pub struct ComprehensionLetting {
97 #[serde_as(as = "DeclarationPtrFull")]
98 pub decl: DeclarationPtr,
99 pub expression: Expression,
100}
101
102#[derive(Clone, PartialEq, Eq, Serialize, Deserialize, Debug, Hash, Uniplate)]
103#[biplate(to=Expression)]
104#[biplate(to=SubModel)]
105pub enum Generator {
106 DomainGenerator(DomainGenerator),
107 ExpressionGenerator(ExpressionGenerator),
108}
109
110#[serde_as]
111#[derive(Clone, PartialEq, Eq, Serialize, Deserialize, Debug, Hash, Uniplate)]
112#[biplate(to=Expression)]
113#[biplate(to=SubModel)]
114pub struct DomainGenerator {
115 #[serde_as(as = "DeclarationPtrFull")]
116 pub decl: DeclarationPtr,
117}
118
119#[serde_as]
120#[derive(Clone, PartialEq, Eq, Serialize, Deserialize, Debug, Hash, Uniplate)]
121#[biplate(to=Expression)]
122#[biplate(to=SubModel)]
123pub struct ExpressionGenerator {
124 #[serde_as(as = "DeclarationPtrFull")]
125 pub decl: DeclarationPtr,
126 pub expression: Expression,
127}
128
129impl AbstractComprehension {
130 pub fn domain_of(&self) -> Option<DomainPtr> {
131 self.return_expr.domain_of()
132 }
133}
134
135impl Typeable for AbstractComprehension {
136 fn return_type(&self) -> ReturnType {
137 self.return_expr.return_type()
138 }
139}
140
141impl Hash for AbstractComprehension {
142 fn hash<H: Hasher>(&self, state: &mut H) {
143 (*self.return_expr_symbols).borrow().hash(state);
144 self.return_expr.hash(state);
145 self.qualifiers.hash(state);
146 }
147}
148
149pub struct AbstractComprehensionBuilder {
150 pub qualifiers: Vec<Qualifier>,
151
152 pub return_expr_symbols: Rc<RefCell<SymbolTable>>,
157
158 pub generator_symbols: Rc<RefCell<SymbolTable>>,
163}
164
165impl AbstractComprehensionBuilder {
166 pub fn new(symbols: &Rc<RefCell<SymbolTable>>) -> Self {
174 Self {
175 qualifiers: vec![],
176 return_expr_symbols: Rc::new(RefCell::new(SymbolTable::with_parent(symbols.clone()))),
177 generator_symbols: Rc::new(RefCell::new(SymbolTable::with_parent(symbols.clone()))),
178 }
179 }
180
181 pub fn return_expr_symbols(&self) -> Rc<RefCell<SymbolTable>> {
182 self.return_expr_symbols.clone()
183 }
184
185 pub fn generator_symbols(&self) -> Rc<RefCell<SymbolTable>> {
186 self.generator_symbols.clone()
187 }
188
189 pub fn new_domain_generator(&mut self, domain: DomainPtr) -> DeclarationPtr {
190 let generator_decl = self.return_expr_symbols.borrow_mut().gensym(&domain);
191
192 self.qualifiers
193 .push(Qualifier::Generator(Generator::DomainGenerator(
194 DomainGenerator {
195 decl: generator_decl.clone(),
196 },
197 )));
198
199 generator_decl
200 }
201
202 pub fn new_expression_generator(mut self, expr: Expression, name: Name) -> Self {
208 let domain = expr
209 .domain_of()
210 .expect("Expression must have a domain")
211 .element_domain()
212 .expect("Expression must contain elements with uniform domain");
213
214 let generator_ptr = DeclarationPtr::new_var(name, domain);
217 let return_expr_ptr = DeclarationPtr::new_given_quantified(&generator_ptr)
218 .expect("Return expression declaration must not be None");
219
220 self.return_expr_symbols
221 .borrow_mut()
222 .insert(return_expr_ptr);
223 self.generator_symbols
224 .borrow_mut()
225 .insert(generator_ptr.clone());
226
227 self.qualifiers
228 .push(Qualifier::Generator(Generator::ExpressionGenerator(
229 ExpressionGenerator {
230 decl: generator_ptr,
231 expression: expr,
232 },
233 )));
234
235 self
236 }
237
238 pub fn add_condition(&mut self, condition: Expression) {
240 if condition.return_type() != ReturnType::Bool {
241 panic!("Condition expression must have boolean return type");
242 }
243
244 self.qualifiers.push(Qualifier::Condition(condition));
245 }
246
247 pub fn new_letting(&mut self, expression: Expression) -> DeclarationPtr {
248 let letting_decl = self.return_expr_symbols.borrow_mut().gensym(
249 &expression
250 .domain_of()
251 .expect("Expression must have a domain"),
252 );
253
254 self.qualifiers
255 .push(Qualifier::ComprehensionLetting(ComprehensionLetting {
256 decl: letting_decl.clone(),
257 expression,
258 }));
259
260 letting_decl
261 }
262
263 pub fn with_return_value(self, expression: Expression) -> AbstractComprehension {
268 AbstractComprehension {
269 return_expr: expression,
270 qualifiers: self.qualifiers,
271 return_expr_symbols: self.return_expr_symbols,
272 generator_symbols: self.generator_symbols,
273 }
274 }
275}
276
277impl Display for AbstractComprehension {
278 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
279 write!(f, "[ {} | ", self.return_expr)?;
280 let mut first = true;
281 for qualifier in &self.qualifiers {
282 if !first {
283 write!(f, ", ")?;
284 }
285 first = false;
286 qualifier.fmt(f)?;
287 }
288 write!(f, " ]")
289 }
290}
291
292impl Display for Qualifier {
293 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
294 match self {
295 Qualifier::Generator(generator) => generator.fmt(f),
296 Qualifier::Condition(condition) => condition.fmt(f),
297 Qualifier::ComprehensionLetting(comp_letting) => {
298 let name = comp_letting.decl.name();
299 let expr = &comp_letting.expression;
300 write!(f, "letting {} = {}", name, expr)
301 }
302 }
303 }
304}
305
306impl Display for Generator {
307 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
308 match self {
309 Generator::DomainGenerator(DomainGenerator { decl }) => {
310 let name = decl.name();
311 let domain = decl.domain().unwrap();
312 write!(f, "{} : {}", name, domain)
313 }
314 Generator::ExpressionGenerator(ExpressionGenerator { decl, expression }) => {
315 let name = decl.name();
316 write!(f, "{} <- {}", name, expression)
317 }
318 }
319 }
320}