1use super::declaration::DeclarationPtr;
2use super::serde::PtrAsInner;
3use super::{DomainPtr, Expression, Name, ReturnType, SubModel, SymbolTablePtr, Typeable};
4use serde::{Deserialize, Serialize};
5use serde_with::serde_as;
6use std::fmt::{Display, Formatter};
7use std::hash::Hash;
8use uniplate::Uniplate;
9
10#[serde_as]
11#[derive(Clone, PartialEq, Eq, Hash, Serialize, Deserialize, Debug, Uniplate)]
12#[biplate(to=Expression)]
13#[biplate(to=SubModel)]
14#[biplate(to=SymbolTablePtr)]
15pub struct AbstractComprehension {
16 pub return_expr: Expression,
17 pub qualifiers: Vec<Qualifier>,
18
19 #[serde_as(as = "PtrAsInner")]
24 pub return_expr_symbols: SymbolTablePtr,
25
26 #[serde_as(as = "PtrAsInner")]
31 pub generator_symbols: SymbolTablePtr,
32}
33
34#[derive(Clone, PartialEq, Eq, Serialize, Deserialize, Debug, Hash)]
35pub enum Qualifier {
36 Generator(Generator),
37 Condition(Expression),
38 ComprehensionLetting(ComprehensionLetting),
39}
40
41#[serde_as]
42#[derive(Clone, PartialEq, Eq, Serialize, Deserialize, Debug, Hash, Uniplate)]
43#[biplate(to=Expression)]
44#[biplate(to=SubModel)]
45pub struct ComprehensionLetting {
46 #[serde_as(as = "PtrAsInner")]
47 pub decl: DeclarationPtr,
48 pub expression: Expression,
49}
50
51#[derive(Clone, PartialEq, Eq, Serialize, Deserialize, Debug, Hash, Uniplate)]
52#[biplate(to=Expression)]
53#[biplate(to=SubModel)]
54pub enum Generator {
55 DomainGenerator(DomainGenerator),
56 ExpressionGenerator(ExpressionGenerator),
57}
58
59#[serde_as]
60#[derive(Clone, PartialEq, Eq, Serialize, Deserialize, Debug, Hash, Uniplate)]
61#[biplate(to=Expression)]
62#[biplate(to=SubModel)]
63pub struct DomainGenerator {
64 #[serde_as(as = "PtrAsInner")]
65 pub decl: DeclarationPtr,
66}
67
68#[serde_as]
69#[derive(Clone, PartialEq, Eq, Serialize, Deserialize, Debug, Hash, Uniplate)]
70#[biplate(to=Expression)]
71#[biplate(to=SubModel)]
72pub struct ExpressionGenerator {
73 #[serde_as(as = "PtrAsInner")]
74 pub decl: DeclarationPtr,
75 pub expression: Expression,
76}
77
78impl AbstractComprehension {
79 pub fn domain_of(&self) -> Option<DomainPtr> {
80 self.return_expr.domain_of()
81 }
82}
83
84impl Typeable for AbstractComprehension {
85 fn return_type(&self) -> ReturnType {
86 self.return_expr.return_type()
87 }
88}
89
90pub struct AbstractComprehensionBuilder {
91 pub qualifiers: Vec<Qualifier>,
92
93 pub return_expr_symbols: SymbolTablePtr,
98
99 pub generator_symbols: SymbolTablePtr,
104}
105
106impl AbstractComprehensionBuilder {
107 pub fn new(symbols: &SymbolTablePtr) -> Self {
115 Self {
116 qualifiers: vec![],
117 return_expr_symbols: SymbolTablePtr::with_parent(symbols.clone()),
118 generator_symbols: SymbolTablePtr::with_parent(symbols.clone()),
119 }
120 }
121
122 pub fn return_expr_symbols(&self) -> SymbolTablePtr {
123 self.return_expr_symbols.clone()
124 }
125
126 pub fn generator_symbols(&self) -> SymbolTablePtr {
127 self.generator_symbols.clone()
128 }
129
130 pub fn new_domain_generator(&mut self, domain: DomainPtr) -> DeclarationPtr {
131 let generator_decl = self.return_expr_symbols.write().gensym(&domain);
132
133 self.qualifiers
134 .push(Qualifier::Generator(Generator::DomainGenerator(
135 DomainGenerator {
136 decl: generator_decl.clone(),
137 },
138 )));
139
140 generator_decl
141 }
142
143 pub fn new_expression_generator(mut self, expr: Expression, name: Name) -> Self {
149 let domain = expr
150 .domain_of()
151 .expect("Expression must have a domain")
152 .element_domain()
153 .expect("Expression must contain elements with uniform domain");
154
155 let generator_ptr = DeclarationPtr::new_quantified(name, domain);
157 let return_expr_ptr = DeclarationPtr::new_quantified_from_generator(&generator_ptr)
158 .expect("Return expression declaration must not be None");
159
160 self.return_expr_symbols.write().insert(return_expr_ptr);
161 self.generator_symbols.write().insert(generator_ptr.clone());
162
163 self.qualifiers
164 .push(Qualifier::Generator(Generator::ExpressionGenerator(
165 ExpressionGenerator {
166 decl: generator_ptr,
167 expression: expr,
168 },
169 )));
170
171 self
172 }
173
174 pub fn add_condition(&mut self, condition: Expression) {
176 if condition.return_type() != ReturnType::Bool {
177 panic!("Condition expression must have boolean return type");
178 }
179
180 self.qualifiers.push(Qualifier::Condition(condition));
181 }
182
183 pub fn new_letting(&mut self, expression: Expression) -> DeclarationPtr {
184 let letting_decl = self.return_expr_symbols.write().gensym(
185 &expression
186 .domain_of()
187 .expect("Expression must have a domain"),
188 );
189
190 self.qualifiers
191 .push(Qualifier::ComprehensionLetting(ComprehensionLetting {
192 decl: letting_decl.clone(),
193 expression,
194 }));
195
196 letting_decl
197 }
198
199 pub fn with_return_value(self, expression: Expression) -> AbstractComprehension {
204 AbstractComprehension {
205 return_expr: expression,
206 qualifiers: self.qualifiers,
207 return_expr_symbols: self.return_expr_symbols,
208 generator_symbols: self.generator_symbols,
209 }
210 }
211}
212
213impl Display for AbstractComprehension {
214 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
215 write!(f, "[ {} | ", self.return_expr)?;
216 let mut first = true;
217 for qualifier in &self.qualifiers {
218 if !first {
219 write!(f, ", ")?;
220 }
221 first = false;
222 qualifier.fmt(f)?;
223 }
224 write!(f, " ]")
225 }
226}
227
228impl Display for Qualifier {
229 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
230 match self {
231 Qualifier::Generator(generator) => generator.fmt(f),
232 Qualifier::Condition(condition) => condition.fmt(f),
233 Qualifier::ComprehensionLetting(comp_letting) => {
234 let name = comp_letting.decl.name();
235 let expr = &comp_letting.expression;
236 write!(f, "letting {} = {}", name, expr)
237 }
238 }
239 }
240}
241
242impl Display for Generator {
243 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
244 match self {
245 Generator::DomainGenerator(DomainGenerator { decl }) => {
246 let name = decl.name();
247 let domain = decl.domain().unwrap();
248 write!(f, "{} : {}", name, domain)
249 }
250 Generator::ExpressionGenerator(ExpressionGenerator { decl, expression }) => {
251 let name = decl.name();
252 write!(f, "{} <- {}", name, expression)
253 }
254 }
255 }
256}