1
use super::declaration::DeclarationPtr;
2
use super::serde::PtrAsInner;
3
use super::{DomainPtr, Expression, Name, ReturnType, SubModel, SymbolTablePtr, Typeable};
4
use serde::{Deserialize, Serialize};
5
use serde_with::serde_as;
6
use std::fmt::{Display, Formatter};
7
use std::hash::Hash;
8
use uniplate::Uniplate;
9

            
10
#[serde_as]
11
#[derive(Clone, PartialEq, Eq, Hash, Serialize, Deserialize, Debug, Uniplate)]
12
#[biplate(to=Expression)]
13
#[biplate(to=SubModel)]
14
#[biplate(to=SymbolTablePtr)]
15
pub struct AbstractComprehension {
16
    pub return_expr: Expression,
17
    pub qualifiers: Vec<Qualifier>,
18

            
19
    /// The symbol table used in the return expression.
20
    ///
21
    /// Variables from generator expressions are "given" in the context of the return expression.
22
    /// That is, they are constants which are different for each expansion of the comprehension.
23
    #[serde_as(as = "PtrAsInner")]
24
    pub return_expr_symbols: SymbolTablePtr,
25

            
26
    /// The scope for variables in generator expressions.
27
    ///
28
    /// Variables declared in generator expressions are decision variables, since they do not
29
    /// have a constant value.
30
    #[serde_as(as = "PtrAsInner")]
31
    pub generator_symbols: SymbolTablePtr,
32
}
33

            
34
#[derive(Clone, PartialEq, Eq, Serialize, Deserialize, Debug, Hash)]
35
pub enum Qualifier {
36
    Generator(Generator),
37
    Condition(Expression),
38
    ComprehensionLetting(ComprehensionLetting),
39
}
40

            
41
#[serde_as]
42
#[derive(Clone, PartialEq, Eq, Serialize, Deserialize, Debug, Hash, Uniplate)]
43
#[biplate(to=Expression)]
44
#[biplate(to=SubModel)]
45
pub struct ComprehensionLetting {
46
    #[serde_as(as = "PtrAsInner")]
47
    pub decl: DeclarationPtr,
48
    pub expression: Expression,
49
}
50

            
51
#[derive(Clone, PartialEq, Eq, Serialize, Deserialize, Debug, Hash, Uniplate)]
52
#[biplate(to=Expression)]
53
#[biplate(to=SubModel)]
54
pub enum Generator {
55
    DomainGenerator(DomainGenerator),
56
    ExpressionGenerator(ExpressionGenerator),
57
}
58

            
59
#[serde_as]
60
#[derive(Clone, PartialEq, Eq, Serialize, Deserialize, Debug, Hash, Uniplate)]
61
#[biplate(to=Expression)]
62
#[biplate(to=SubModel)]
63
pub struct DomainGenerator {
64
    #[serde_as(as = "PtrAsInner")]
65
    pub decl: DeclarationPtr,
66
}
67

            
68
#[serde_as]
69
#[derive(Clone, PartialEq, Eq, Serialize, Deserialize, Debug, Hash, Uniplate)]
70
#[biplate(to=Expression)]
71
#[biplate(to=SubModel)]
72
pub struct ExpressionGenerator {
73
    #[serde_as(as = "PtrAsInner")]
74
    pub decl: DeclarationPtr,
75
    pub expression: Expression,
76
}
77

            
78
impl AbstractComprehension {
79
    pub fn domain_of(&self) -> Option<DomainPtr> {
80
        self.return_expr.domain_of()
81
    }
82
}
83

            
84
impl Typeable for AbstractComprehension {
85
    fn return_type(&self) -> ReturnType {
86
        self.return_expr.return_type()
87
    }
88
}
89

            
90
pub struct AbstractComprehensionBuilder {
91
    pub qualifiers: Vec<Qualifier>,
92

            
93
    /// The symbol table used in the return expression.
94
    ///
95
    /// Variables from generator expressions are "given" in the context of the return expression.
96
    /// That is, they are constants which are different for each expansion of the comprehension.
97
    pub return_expr_symbols: SymbolTablePtr,
98

            
99
    /// The scope for variables in generator expressions.
100
    ///
101
    /// Variables declared in generator expressions are decision variables in their original
102
    /// context, since they do not have a constant value.
103
    pub generator_symbols: SymbolTablePtr,
104
}
105

            
106
impl AbstractComprehensionBuilder {
107
    /// Creates an [AbstractComprehensionBuilder] with:
108
    /// - An inner scope which inherits from the given symbol table
109
    /// - An empty list of qualifiers
110
    ///
111
    /// Changes to the inner scope do not affect the given symbol table.
112
    ///
113
    /// The return expression is passed when finalizing the comprehension, in [with_return_value].
114
20
    pub fn new(symbols: &SymbolTablePtr) -> Self {
115
20
        Self {
116
20
            qualifiers: vec![],
117
20
            return_expr_symbols: SymbolTablePtr::with_parent(symbols.clone()),
118
20
            generator_symbols: SymbolTablePtr::with_parent(symbols.clone()),
119
20
        }
120
20
    }
121

            
122
20
    pub fn return_expr_symbols(&self) -> SymbolTablePtr {
123
20
        self.return_expr_symbols.clone()
124
20
    }
125

            
126
    pub fn generator_symbols(&self) -> SymbolTablePtr {
127
        self.generator_symbols.clone()
128
    }
129

            
130
    pub fn new_domain_generator(&mut self, domain: DomainPtr) -> DeclarationPtr {
131
        let generator_decl = self.return_expr_symbols.write().gensym(&domain);
132

            
133
        self.qualifiers
134
            .push(Qualifier::Generator(Generator::DomainGenerator(
135
                DomainGenerator {
136
                    decl: generator_decl.clone(),
137
                },
138
            )));
139

            
140
        generator_decl
141
    }
142

            
143
    /// Creates a new expression generator with the given expression and variable name.
144
    ///
145
    /// The variable "takes from" the expression, that is, it can be any element in the expression.
146
    ///
147
    /// E.g. in `[ x | x <- some_set ]`, `x` can be any element of `some_set`.
148
20
    pub fn new_expression_generator(mut self, expr: Expression, name: Name) -> Self {
149
20
        let domain = expr
150
20
            .domain_of()
151
20
            .expect("Expression must have a domain")
152
20
            .element_domain()
153
20
            .expect("Expression must contain elements with uniform domain");
154

            
155
        // The variable is quantified in both scopes.
156
20
        let generator_ptr = DeclarationPtr::new_quantified(name, domain);
157
20
        let return_expr_ptr = DeclarationPtr::new_quantified_from_generator(&generator_ptr)
158
20
            .expect("Return expression declaration must not be None");
159

            
160
20
        self.return_expr_symbols.write().insert(return_expr_ptr);
161
20
        self.generator_symbols.write().insert(generator_ptr.clone());
162

            
163
20
        self.qualifiers
164
20
            .push(Qualifier::Generator(Generator::ExpressionGenerator(
165
20
                ExpressionGenerator {
166
20
                    decl: generator_ptr,
167
20
                    expression: expr,
168
20
                },
169
20
            )));
170

            
171
20
        self
172
20
    }
173

            
174
    /// See [crate::ast::comprehension::ComprehensionBuilder::guard]
175
    pub fn add_condition(&mut self, condition: Expression) {
176
        if condition.return_type() != ReturnType::Bool {
177
            panic!("Condition expression must have boolean return type");
178
        }
179

            
180
        self.qualifiers.push(Qualifier::Condition(condition));
181
    }
182

            
183
    pub fn new_letting(&mut self, expression: Expression) -> DeclarationPtr {
184
        let letting_decl = self.return_expr_symbols.write().gensym(
185
            &expression
186
                .domain_of()
187
                .expect("Expression must have a domain"),
188
        );
189

            
190
        self.qualifiers
191
            .push(Qualifier::ComprehensionLetting(ComprehensionLetting {
192
                decl: letting_decl.clone(),
193
                expression,
194
            }));
195

            
196
        letting_decl
197
    }
198

            
199
    // The lack of the generator_symboltable and return_expr_symboltable
200
    // are explained bc 1. we dont have separate symboltables for each part
201
    // 2. it is unclear why there would be a need to access each one uniquely
202

            
203
20
    pub fn with_return_value(self, expression: Expression) -> AbstractComprehension {
204
20
        AbstractComprehension {
205
20
            return_expr: expression,
206
20
            qualifiers: self.qualifiers,
207
20
            return_expr_symbols: self.return_expr_symbols,
208
20
            generator_symbols: self.generator_symbols,
209
20
        }
210
20
    }
211
}
212

            
213
impl Display for AbstractComprehension {
214
80
    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
215
80
        write!(f, "[ {} | ", self.return_expr)?;
216
80
        let mut first = true;
217
80
        for qualifier in &self.qualifiers {
218
80
            if !first {
219
                write!(f, ", ")?;
220
80
            }
221
80
            first = false;
222
80
            qualifier.fmt(f)?;
223
        }
224
80
        write!(f, " ]")
225
80
    }
226
}
227

            
228
impl Display for Qualifier {
229
80
    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
230
80
        match self {
231
80
            Qualifier::Generator(generator) => generator.fmt(f),
232
            Qualifier::Condition(condition) => condition.fmt(f),
233
            Qualifier::ComprehensionLetting(comp_letting) => {
234
                let name = comp_letting.decl.name();
235
                let expr = &comp_letting.expression;
236
                write!(f, "letting {} = {}", name, expr)
237
            }
238
        }
239
80
    }
240
}
241

            
242
impl Display for Generator {
243
80
    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
244
80
        match self {
245
            Generator::DomainGenerator(DomainGenerator { decl }) => {
246
                let name = decl.name();
247
                let domain = decl.domain().unwrap();
248
                write!(f, "{} : {}", name, domain)
249
            }
250
80
            Generator::ExpressionGenerator(ExpressionGenerator { decl, expression }) => {
251
80
                let name = decl.name();
252
80
                write!(f, "{} <- {}", name, expression)
253
            }
254
        }
255
80
    }
256
}