1
#![allow(clippy::arc_with_non_send_sync)]
2

            
3
use std::{collections::BTreeSet, fmt::Display};
4

            
5
use crate::{ast::Metadata, into_matrix_expr, matrix_expr};
6
use conjure_cp_core::ast::ReturnType;
7
use itertools::Itertools as _;
8
use parking_lot::RwLockReadGuard;
9
use serde::{Deserialize, Serialize};
10
use serde_with::serde_as;
11
use uniplate::{Biplate, Uniplate};
12

            
13
use super::{
14
    DeclarationPtr, Domain, DomainPtr, Expression, Model, Moo, Name, Range, SymbolTable,
15
    SymbolTablePtr, Typeable,
16
    ac_operators::ACOperatorKind,
17
    categories::{Category, CategoryOf},
18
    serde::{AsId, PtrAsInner},
19
};
20

            
21
#[serde_as]
22
#[derive(Clone, PartialEq, Eq, Hash, Serialize, Deserialize, Debug, Uniplate)]
23
#[biplate(to=Expression)]
24
#[biplate(to=Name)]
25
#[biplate(to=DeclarationPtr)]
26
pub enum ComprehensionQualifier {
27
    ExpressionGenerator {
28
        #[serde_as(as = "AsId")]
29
        ptr: DeclarationPtr,
30
    },
31
    Generator {
32
        #[serde_as(as = "AsId")]
33
        ptr: DeclarationPtr,
34
    },
35
    Condition(Expression),
36
}
37

            
38
/// A comprehension.
39
#[serde_as]
40
#[derive(Clone, PartialEq, Eq, Hash, Uniplate, Serialize, Deserialize, Debug)]
41
#[biplate(to=Expression)]
42
#[biplate(to=SymbolTable)]
43
#[biplate(to=SymbolTablePtr)]
44
#[non_exhaustive]
45
pub struct Comprehension {
46
    pub return_expression: Expression,
47
    pub qualifiers: Vec<ComprehensionQualifier>,
48
    #[doc(hidden)]
49
    #[serde_as(as = "PtrAsInner")]
50
    pub symbols: SymbolTablePtr,
51
}
52

            
53
impl Comprehension {
54
26
    pub fn domain_of(&self) -> Option<DomainPtr> {
55
26
        let return_expr_domain = self.return_expression.domain_of()?;
56

            
57
        // return a list (matrix with index domain int(1..)) of return_expr elements
58
14
        Some(Domain::matrix(
59
14
            return_expr_domain,
60
14
            vec![Domain::int(vec![Range::UnboundedR(1)])],
61
14
        ))
62
26
    }
63

            
64
52332
    pub fn return_expression(self) -> Expression {
65
52332
        self.return_expression
66
52332
    }
67

            
68
    pub fn replace_return_expression(&mut self, new_expr: Expression) {
69
        self.return_expression = new_expr;
70
    }
71

            
72
141646
    pub fn symbols(&self) -> RwLockReadGuard<'_, SymbolTable> {
73
141646
        self.symbols.read()
74
141646
    }
75

            
76
60380
    pub fn quantified_vars(&self) -> Vec<Name> {
77
60380
        self.qualifiers
78
60380
            .iter()
79
82154
            .filter_map(|q| match q {
80
62
                ComprehensionQualifier::ExpressionGenerator { ptr } => Some(ptr.name().clone()),
81
74294
                ComprehensionQualifier::Generator { ptr } => Some(ptr.name().clone()),
82
7798
                ComprehensionQualifier::Condition(_) => None,
83
82154
            })
84
60380
            .collect()
85
60380
    }
86

            
87
5226
    pub fn generator_conditions(&self) -> Vec<Expression> {
88
5226
        self.qualifiers
89
5226
            .iter()
90
6078
            .filter_map(|q| match q {
91
240
                ComprehensionQualifier::Condition(c) => Some(c.clone()),
92
5838
                ComprehensionQualifier::Generator { .. } => None,
93
                ComprehensionQualifier::ExpressionGenerator { .. } => None,
94
6078
            })
95
5226
            .collect()
96
5226
    }
97

            
98
    /// Builds a temporary model containing generator qualifiers and guards.
99
5226
    pub fn to_generator_model(&self) -> Model {
100
5226
        let mut model = self.empty_model_with_symbols();
101
5226
        model.add_constraints(self.generator_conditions());
102
5226
        model
103
5226
    }
104

            
105
    /// Builds a temporary model containing the return expression only.
106
5226
    pub fn to_return_expression_model(&self) -> Model {
107
5226
        let mut model = self.empty_model_with_symbols();
108
5226
        model.add_constraint(self.return_expression.clone());
109
5226
        model
110
5226
    }
111

            
112
10452
    fn empty_model_with_symbols(&self) -> Model {
113
10452
        let parent = self.symbols.read().parent().clone();
114
10452
        let mut model = if let Some(parent) = parent {
115
10452
            Model::new_in_parent_scope(parent)
116
        } else {
117
            Model::default()
118
        };
119
10452
        *model.symbols_ptr_unchecked_mut() = self.symbols.clone();
120
10452
        model
121
10452
    }
122

            
123
    /// Adds a guard to the comprehension.
124
    ///
125
    /// Returns false if the guard references non-quantified decision variables.
126
    pub fn add_quantified_guard(&mut self, guard: Expression) -> bool {
127
        if self.is_quantified_guard(&guard) {
128
            self.qualifiers
129
                .push(ComprehensionQualifier::Condition(guard));
130
            true
131
        } else {
132
            false
133
        }
134
    }
135

            
136
    /// True iff expr does not reference non-quantified decision variables.
137
    pub fn is_quantified_guard(&self, expr: &Expression) -> bool {
138
        let quantified: BTreeSet<Name> = self.quantified_vars().into_iter().collect();
139
        is_quantified_guard(&self.symbols.read(), &quantified, expr)
140
    }
141
}
142

            
143
impl Typeable for Comprehension {
144
4
    fn return_type(&self) -> ReturnType {
145
4
        self.return_expression.return_type()
146
4
    }
147
}
148

            
149
impl Display for Comprehension {
150
30252
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
151
30252
        let generators_and_guards = self
152
30252
            .qualifiers
153
30252
            .iter()
154
39034
            .map(|qualifier| match qualifier {
155
36620
                ComprehensionQualifier::Generator { ptr } => {
156
36620
                    let domain = ptr.domain().expect("generator declaration has domain");
157
36620
                    format!("{} : {domain}", ptr.name())
158
                }
159
200
                ComprehensionQualifier::ExpressionGenerator { ptr } => {
160
200
                    let name = ptr.name();
161
200
                    if let Some(expr) = ptr.as_quantified_expr() {
162
200
                        format!("{name} <- {expr}")
163
                    } else {
164
                        panic!("Oh nein! Dat is nicht gut!")
165
                    }
166
                }
167
2214
                ComprehensionQualifier::Condition(expr) => format!("{expr}"),
168
39034
            })
169
30252
            .join(", ");
170

            
171
30252
        write!(
172
30252
            f,
173
            "[ {} | {generators_and_guards} ]",
174
            self.return_expression
175
        )
176
30252
    }
177
}
178

            
179
/// A builder for a comprehension.
180
#[derive(Clone, Debug, PartialEq, Eq)]
181
pub struct ComprehensionBuilder {
182
    qualifiers: Vec<ComprehensionQualifier>,
183
    // A single scope for generators and return expression.
184
    symbols: SymbolTablePtr,
185
    quantified_variables: BTreeSet<Name>,
186
}
187

            
188
impl ComprehensionBuilder {
189
8420
    pub fn new(symbol_table_ptr: SymbolTablePtr) -> Self {
190
8420
        ComprehensionBuilder {
191
8420
            qualifiers: vec![],
192
8420
            symbols: SymbolTablePtr::with_parent(symbol_table_ptr),
193
8420
            quantified_variables: BTreeSet::new(),
194
8420
        }
195
8420
    }
196

            
197
    /// Backwards-compatible parser API: same table for generators and return expression.
198
5880
    pub fn generator_symboltable(&mut self) -> SymbolTablePtr {
199
5880
        self.symbols.clone()
200
5880
    }
201

            
202
    /// Backwards-compatible parser API: same table for generators and return expression.
203
8414
    pub fn return_expr_symboltable(&mut self) -> SymbolTablePtr {
204
8414
        self.symbols.clone()
205
8414
    }
206

            
207
1104
    pub fn guard(mut self, guard: Expression) -> Self {
208
1104
        self.qualifiers
209
1104
            .push(ComprehensionQualifier::Condition(guard));
210
1104
        self
211
1104
    }
212

            
213
9104
    pub fn generator(mut self, declaration: DeclarationPtr) -> Self {
214
9104
        let name = declaration.name().clone();
215
9104
        assert!(!self.quantified_variables.contains(&name));
216

            
217
9104
        self.quantified_variables.insert(name.clone());
218

            
219
        // insert into comprehension scope as a local quantified variable
220
9104
        let quantified_decl = DeclarationPtr::new_quantified(name, declaration.domain().unwrap());
221
9104
        self.symbols.write().insert(quantified_decl.clone());
222

            
223
9104
        self.qualifiers.push(ComprehensionQualifier::Generator {
224
9104
            ptr: quantified_decl,
225
9104
        });
226

            
227
9104
        self
228
9104
    }
229

            
230
60
    pub fn expression_generator(mut self, name: Name, expr: Expression) -> Self {
231
60
        assert!(!self.quantified_variables.contains(&name));
232

            
233
60
        self.quantified_variables.insert(name.clone());
234

            
235
        // insert into comprehension scope as a local quantified variable
236
60
        let quantified_decl = DeclarationPtr::new_quantified_expr(name, expr);
237
60
        self.symbols.write().insert(quantified_decl.clone());
238

            
239
60
        self.qualifiers
240
60
            .push(ComprehensionQualifier::ExpressionGenerator {
241
60
                ptr: quantified_decl,
242
60
            });
243

            
244
60
        self
245
60
    }
246

            
247
    /// Creates a comprehension with the given return expression.
248
    ///
249
    /// If this comprehension is inside an AC-operator, the kind of this operator should be passed
250
    /// in the `comprehension_kind` field.
251
    ///
252
    /// If a comprehension kind is not given, comprehension guards containing non-quantified
253
    /// decision variables are invalid, and will cause a panic.
254
8420
    pub fn with_return_value(
255
8420
        self,
256
8420
        mut expression: Expression,
257
8420
        comprehension_kind: Option<ACOperatorKind>,
258
8420
    ) -> Comprehension {
259
8420
        let quantified_variables = self.quantified_variables;
260
8420
        let symbols = self.symbols.read();
261

            
262
8420
        let mut qualifiers = Vec::new();
263
8420
        let mut other_guards = Vec::new();
264

            
265
10268
        for qualifier in self.qualifiers {
266
10268
            match qualifier {
267
9104
                ComprehensionQualifier::Generator { .. } => qualifiers.push(qualifier),
268
60
                ComprehensionQualifier::ExpressionGenerator { .. } => qualifiers.push(qualifier),
269
1104
                ComprehensionQualifier::Condition(condition) => {
270
1104
                    if is_quantified_guard(&symbols, &quantified_variables, &condition) {
271
744
                        qualifiers.push(ComprehensionQualifier::Condition(condition));
272
744
                    } else {
273
360
                        other_guards.push(condition);
274
360
                    }
275
                }
276
            }
277
        }
278
8420
        drop(symbols);
279

            
280
        // handle guards that reference non-quantified decision variables
281
8420
        if !other_guards.is_empty() {
282
360
            let comprehension_kind = comprehension_kind.expect(
283
360
                "if any guards reference decision variables, a comprehension kind should be given",
284
            );
285

            
286
360
            let guard_expr = match other_guards.as_slice() {
287
360
                [x] => x.clone(),
288
                xs => Expression::And(Metadata::new(), Moo::new(into_matrix_expr!(xs.to_vec()))),
289
            };
290

            
291
360
            expression = match comprehension_kind {
292
                ACOperatorKind::And => {
293
360
                    Expression::Imply(Metadata::new(), Moo::new(guard_expr), Moo::new(expression))
294
                }
295
                ACOperatorKind::Or => Expression::And(
296
                    Metadata::new(),
297
                    Moo::new(matrix_expr![guard_expr, expression]),
298
                ),
299

            
300
                ACOperatorKind::Sum => {
301
                    panic!("guards that reference decision variables not yet implemented for sum");
302
                }
303

            
304
                ACOperatorKind::Product => {
305
                    panic!(
306
                        "guards that reference decision variables not yet implemented for product"
307
                    );
308
                }
309
            }
310
8060
        }
311

            
312
8420
        Comprehension {
313
8420
            return_expression: expression,
314
8420
            qualifiers,
315
8420
            symbols: self.symbols,
316
8420
        }
317
8420
    }
318
}
319

            
320
/// True iff the guard does not reference non-quantified decision variables.
321
1104
fn is_quantified_guard(
322
1104
    symbols: &SymbolTable,
323
1104
    quantified_variables: &BTreeSet<Name>,
324
1104
    guard: &Expression,
325
1104
) -> bool {
326
1372
    guard.universe_bi().iter().all(|name| {
327
1372
        quantified_variables.contains(name)
328
648
            || symbols
329
648
                .lookup(name)
330
648
                .is_some_and(|decl| decl.category_of() != Category::Decision)
331
1372
    })
332
1104
}