Skip to main content

conjure_cp_core/ast/
abstract_comprehension.rs

1use super::SymbolTable;
2use super::declaration::{DeclarationPtr, serde::DeclarationPtrFull};
3use super::serde::RcRefCellAsInner;
4use crate::ast::{DomainPtr, Expression, Name, ReturnType, SubModel, Typeable};
5use serde::{Deserialize, Serialize};
6use serde_with::serde_as;
7use std::collections::VecDeque;
8use std::fmt::{Display, Formatter};
9use std::{cell::RefCell, hash::Hash, hash::Hasher, rc::Rc};
10use uniplate::{Biplate, Tree, Uniplate};
11
12#[serde_as]
13#[derive(Clone, PartialEq, Eq, Serialize, Deserialize, Debug, Uniplate)]
14#[biplate(to=Expression)]
15#[biplate(to=SubModel)]
16pub struct AbstractComprehension {
17    pub return_expr: Expression,
18    pub qualifiers: Vec<Qualifier>,
19
20    /// The symbol table used in the return expression.
21    ///
22    /// Variables from generator expressions are "given" in the context of the return expression.
23    /// That is, they are constants which are different for each expansion of the comprehension.
24    #[serde_as(as = "RcRefCellAsInner")]
25    pub return_expr_symbols: Rc<RefCell<SymbolTable>>,
26
27    /// The scope for variables in generator expressions.
28    ///
29    /// Variables declared in generator expressions are decision variables, since they do not
30    /// have a constant value.
31    #[serde_as(as = "RcRefCellAsInner")]
32    pub generator_symbols: Rc<RefCell<SymbolTable>>,
33}
34
35// FIXME: remove this: https://github.com/conjure-cp/conjure-oxide/issues/1428
36impl Biplate<SymbolTable> for AbstractComprehension {
37    fn biplate(
38        &self,
39    ) -> (
40        uniplate::Tree<SymbolTable>,
41        Box<dyn Fn(uniplate::Tree<SymbolTable>) -> Self>,
42    ) {
43        let return_expr_symbols: SymbolTable = (*self.return_expr_symbols).borrow().clone();
44        let generator_symbols: SymbolTable = (*self.generator_symbols).borrow().clone();
45
46        let (tables_in_exprs_tree, tables_in_exprs_ctx) =
47            Biplate::<SymbolTable>::biplate(&Biplate::<Expression>::children_bi(self));
48
49        let tree = Tree::Many(VecDeque::from([
50            Tree::One(return_expr_symbols),
51            Tree::One(generator_symbols),
52            tables_in_exprs_tree,
53        ]));
54
55        let self2 = self.clone();
56        let ctx = Box::new(move |tree: Tree<SymbolTable>| {
57            let Tree::Many(vs) = tree else {
58                panic!();
59            };
60
61            let Tree::One(return_expr_symbols) = vs[0].clone() else {
62                panic!();
63            };
64
65            let Tree::One(generator_symbols) = vs[1].clone() else {
66                panic!();
67            };
68
69            let self3 = self2.with_children_bi(tables_in_exprs_ctx(vs[2].clone()));
70
71            // WARN: I can't remember if i should change inside the refcell here, or make an new
72            // one (resulting in this symbol table being detached).
73
74            *(self3.return_expr_symbols.borrow_mut()) = return_expr_symbols;
75            *(self3.generator_symbols.borrow_mut()) = generator_symbols;
76
77            self3
78        });
79
80        (tree, ctx)
81    }
82}
83#[derive(Clone, PartialEq, Eq, Serialize, Deserialize, Debug, Hash, Uniplate)]
84#[biplate(to=Expression)]
85#[biplate(to=SubModel)]
86pub enum Qualifier {
87    Generator(Generator),
88    Condition(Expression),
89    ComprehensionLetting(ComprehensionLetting),
90}
91
92#[serde_as]
93#[derive(Clone, PartialEq, Eq, Serialize, Deserialize, Debug, Hash, Uniplate)]
94#[biplate(to=Expression)]
95#[biplate(to=SubModel)]
96pub struct ComprehensionLetting {
97    #[serde_as(as = "DeclarationPtrFull")]
98    pub decl: DeclarationPtr,
99    pub expression: Expression,
100}
101
102#[derive(Clone, PartialEq, Eq, Serialize, Deserialize, Debug, Hash, Uniplate)]
103#[biplate(to=Expression)]
104#[biplate(to=SubModel)]
105pub enum Generator {
106    DomainGenerator(DomainGenerator),
107    ExpressionGenerator(ExpressionGenerator),
108}
109
110#[serde_as]
111#[derive(Clone, PartialEq, Eq, Serialize, Deserialize, Debug, Hash, Uniplate)]
112#[biplate(to=Expression)]
113#[biplate(to=SubModel)]
114pub struct DomainGenerator {
115    #[serde_as(as = "DeclarationPtrFull")]
116    pub decl: DeclarationPtr,
117}
118
119#[serde_as]
120#[derive(Clone, PartialEq, Eq, Serialize, Deserialize, Debug, Hash, Uniplate)]
121#[biplate(to=Expression)]
122#[biplate(to=SubModel)]
123pub struct ExpressionGenerator {
124    #[serde_as(as = "DeclarationPtrFull")]
125    pub decl: DeclarationPtr,
126    pub expression: Expression,
127}
128
129impl AbstractComprehension {
130    pub fn domain_of(&self) -> Option<DomainPtr> {
131        self.return_expr.domain_of()
132    }
133}
134
135impl Typeable for AbstractComprehension {
136    fn return_type(&self) -> ReturnType {
137        self.return_expr.return_type()
138    }
139}
140
141impl Hash for AbstractComprehension {
142    fn hash<H: Hasher>(&self, state: &mut H) {
143        (*self.return_expr_symbols).borrow().hash(state);
144        self.return_expr.hash(state);
145        self.qualifiers.hash(state);
146    }
147}
148
149pub struct AbstractComprehensionBuilder {
150    pub qualifiers: Vec<Qualifier>,
151
152    /// The symbol table used in the return expression.
153    ///
154    /// Variables from generator expressions are "given" in the context of the return expression.
155    /// That is, they are constants which are different for each expansion of the comprehension.
156    pub return_expr_symbols: Rc<RefCell<SymbolTable>>,
157
158    /// The scope for variables in generator expressions.
159    ///
160    /// Variables declared in generator expressions are decision variables in their original
161    /// context, since they do not have a constant value.
162    pub generator_symbols: Rc<RefCell<SymbolTable>>,
163}
164
165impl AbstractComprehensionBuilder {
166    /// Creates an [AbstractComprehensionBuilder] with:
167    /// - An inner scope which inherits from the given symbol table
168    /// - An empty list of qualifiers
169    ///
170    /// Changes to the inner scope do not affect the given symbol table.
171    ///
172    /// The return expression is passed when finalizing the comprehension, in [with_return_value].
173    pub fn new(symbols: &Rc<RefCell<SymbolTable>>) -> Self {
174        Self {
175            qualifiers: vec![],
176            return_expr_symbols: Rc::new(RefCell::new(SymbolTable::with_parent(symbols.clone()))),
177            generator_symbols: Rc::new(RefCell::new(SymbolTable::with_parent(symbols.clone()))),
178        }
179    }
180
181    pub fn return_expr_symbols(&self) -> Rc<RefCell<SymbolTable>> {
182        self.return_expr_symbols.clone()
183    }
184
185    pub fn generator_symbols(&self) -> Rc<RefCell<SymbolTable>> {
186        self.generator_symbols.clone()
187    }
188
189    pub fn new_domain_generator(&mut self, domain: DomainPtr) -> DeclarationPtr {
190        let generator_decl = self.return_expr_symbols.borrow_mut().gensym(&domain);
191
192        self.qualifiers
193            .push(Qualifier::Generator(Generator::DomainGenerator(
194                DomainGenerator {
195                    decl: generator_decl.clone(),
196                },
197            )));
198
199        generator_decl
200    }
201
202    /// Creates a new expression generator with the given expression and variable name.
203    ///
204    /// The variable "takes from" the expression, that is, it can be any element in the expression.
205    ///
206    /// E.g. in `[ x | x <- some_set ]`, `x` can be any element of `some_set`.
207    pub fn new_expression_generator(mut self, expr: Expression, name: Name) -> Self {
208        let domain = expr
209            .domain_of()
210            .expect("Expression must have a domain")
211            .element_domain()
212            .expect("Expression must contain elements with uniform domain");
213
214        // The variable is given (a constant) in the return expression, and a decision var
215        // in the generator expression
216        let generator_ptr = DeclarationPtr::new_var(name, domain);
217        let return_expr_ptr = DeclarationPtr::new_given_quantified(&generator_ptr)
218            .expect("Return expression declaration must not be None");
219
220        self.return_expr_symbols
221            .borrow_mut()
222            .insert(return_expr_ptr);
223        self.generator_symbols
224            .borrow_mut()
225            .insert(generator_ptr.clone());
226
227        self.qualifiers
228            .push(Qualifier::Generator(Generator::ExpressionGenerator(
229                ExpressionGenerator {
230                    decl: generator_ptr,
231                    expression: expr,
232                },
233            )));
234
235        self
236    }
237
238    /// See [crate::ast::comprehension::ComprehensionBuilder::guard]
239    pub fn add_condition(&mut self, condition: Expression) {
240        if condition.return_type() != ReturnType::Bool {
241            panic!("Condition expression must have boolean return type");
242        }
243
244        self.qualifiers.push(Qualifier::Condition(condition));
245    }
246
247    pub fn new_letting(&mut self, expression: Expression) -> DeclarationPtr {
248        let letting_decl = self.return_expr_symbols.borrow_mut().gensym(
249            &expression
250                .domain_of()
251                .expect("Expression must have a domain"),
252        );
253
254        self.qualifiers
255            .push(Qualifier::ComprehensionLetting(ComprehensionLetting {
256                decl: letting_decl.clone(),
257                expression,
258            }));
259
260        letting_decl
261    }
262
263    // The lack of the generator_symboltable and return_expr_symboltable
264    // are explained bc 1. we dont have separate symboltables for each part
265    // 2. it is unclear why there would be a need to access each one uniquely
266
267    pub fn with_return_value(self, expression: Expression) -> AbstractComprehension {
268        AbstractComprehension {
269            return_expr: expression,
270            qualifiers: self.qualifiers,
271            return_expr_symbols: self.return_expr_symbols,
272            generator_symbols: self.generator_symbols,
273        }
274    }
275}
276
277impl Display for AbstractComprehension {
278    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
279        write!(f, "[ {} | ", self.return_expr)?;
280        let mut first = true;
281        for qualifier in &self.qualifiers {
282            if !first {
283                write!(f, ", ")?;
284            }
285            first = false;
286            qualifier.fmt(f)?;
287        }
288        write!(f, " ]")
289    }
290}
291
292impl Display for Qualifier {
293    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
294        match self {
295            Qualifier::Generator(generator) => generator.fmt(f),
296            Qualifier::Condition(condition) => condition.fmt(f),
297            Qualifier::ComprehensionLetting(comp_letting) => {
298                let name = comp_letting.decl.name();
299                let expr = &comp_letting.expression;
300                write!(f, "letting {} = {}", name, expr)
301            }
302        }
303    }
304}
305
306impl Display for Generator {
307    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
308        match self {
309            Generator::DomainGenerator(DomainGenerator { decl }) => {
310                let name = decl.name();
311                let domain = decl.domain().unwrap();
312                write!(f, "{} : {}", name, domain)
313            }
314            Generator::ExpressionGenerator(ExpressionGenerator { decl, expression }) => {
315                let name = decl.name();
316                write!(f, "{} <- {}", name, expression)
317            }
318        }
319    }
320}