1
#![allow(clippy::arc_with_non_send_sync)]
2

            
3
use std::{collections::BTreeSet, fmt::Display};
4

            
5
use crate::{ast::Metadata, into_matrix_expr, matrix_expr};
6
use conjure_cp_core::ast::ReturnType;
7
use itertools::Itertools as _;
8
use parking_lot::RwLockReadGuard;
9
use serde::{Deserialize, Serialize};
10
use serde_with::serde_as;
11
use uniplate::{Biplate, Uniplate};
12

            
13
use super::{
14
    DeclarationPtr, Domain, DomainPtr, Expression, Model, Moo, Name, Range, SymbolTable,
15
    SymbolTablePtr, Typeable, ac_operators::ACOperatorKind, serde::PtrAsInner,
16
};
17

            
18
#[derive(Clone, PartialEq, Eq, Hash, Serialize, Deserialize, Debug, Uniplate)]
19
#[biplate(to=Expression)]
20
#[biplate(to=Name)]
21
pub enum ComprehensionQualifier {
22
    Generator { name: Name, domain: DomainPtr },
23
    Condition(Expression),
24
}
25

            
26
/// A comprehension.
27
#[serde_as]
28
#[derive(Clone, PartialEq, Eq, Hash, Uniplate, Serialize, Deserialize, Debug)]
29
#[biplate(to=Expression)]
30
#[biplate(to=SymbolTable)]
31
#[biplate(to=SymbolTablePtr)]
32
#[non_exhaustive]
33
pub struct Comprehension {
34
    pub return_expression: Expression,
35
    pub qualifiers: Vec<ComprehensionQualifier>,
36
    #[doc(hidden)]
37
    #[serde_as(as = "PtrAsInner")]
38
    pub symbols: SymbolTablePtr,
39
}
40

            
41
impl Comprehension {
42
    pub fn domain_of(&self) -> Option<DomainPtr> {
43
        let return_expr_domain = self.return_expression.domain_of()?;
44

            
45
        // return a list (matrix with index domain int(1..)) of return_expr elements
46
        Some(Domain::matrix(
47
            return_expr_domain,
48
            vec![Domain::int(vec![Range::UnboundedR(1)])],
49
        ))
50
    }
51

            
52
12248
    pub fn return_expression(self) -> Expression {
53
12248
        self.return_expression
54
12248
    }
55

            
56
    pub fn replace_return_expression(&mut self, new_expr: Expression) {
57
        self.return_expression = new_expr;
58
    }
59
23264

            
60
62252
    pub fn symbols(&self) -> RwLockReadGuard<'_, SymbolTable> {
61
62252
        self.symbols.read()
62
38988
    }
63

            
64
14932
    pub fn quantified_vars(&self) -> Vec<Name> {
65
14932
        self.qualifiers
66
14932
            .iter()
67
84034
            .filter_map(|q| match q {
68
83194
                ComprehensionQualifier::Generator { name, .. } => Some(name.clone()),
69
67538
                ComprehensionQualifier::Condition(_) => None,
70
17336
            })
71
43292
            .collect()
72
43292
    }
73
28360

            
74
34950
    pub fn generator_conditions(&self) -> Vec<Expression> {
75
33354
        self.qualifiers
76
3640
            .iter()
77
35278
            .filter_map(|q| match q {
78
28480
                ComprehensionQualifier::Condition(c) => Some(c.clone()),
79
30612
                ComprehensionQualifier::Generator { .. } => None,
80
2372
            })
81
5924
            .collect()
82
5924
    }
83
3880

            
84
    /// Builds a temporary model containing generator qualifiers and guards.
85
2272
    pub fn to_generator_model(&self) -> Model {
86
6312
        let mut model = self.empty_model_with_symbols();
87
6540
        model.add_constraints(self.generator_conditions());
88
5924
        model
89
5924
    }
90

            
91
    /// Builds a temporary model containing the return expression only.
92
5924
    pub fn to_return_expression_model(&self) -> Model {
93
5924
        let mut model = self.empty_model_with_symbols();
94
5924
        model.add_constraint(self.return_expression.clone());
95
5924
        model
96
5924
    }
97

            
98
4088
    fn empty_model_with_symbols(&self) -> Model {
99
7968
        let parent = self.symbols.read().parent().clone();
100
7968
        let mut model = if let Some(parent) = parent {
101
7968
            Model::new_in_parent_scope(parent)
102
3880
        } else {
103
3880
            Model::default()
104
        };
105
11848
        *model.symbols_ptr_unchecked_mut() = self.symbols.clone();
106
11848
        model
107
11848
    }
108
7760

            
109
    /// Adds a guard to the comprehension. Returns false if the guard does not only reference quantified variables.
110
    pub fn add_quantified_guard(&mut self, guard: Expression) -> bool {
111
        if self.is_quantified_guard(&guard) {
112
7760
            self.qualifiers
113
7760
                .push(ComprehensionQualifier::Condition(guard));
114
7760
            true
115
        } else {
116
            false
117
        }
118
    }
119

            
120
    /// True iff expr only references quantified variables.
121
    pub fn is_quantified_guard(&self, expr: &Expression) -> bool {
122
        let quantified: BTreeSet<Name> = self.quantified_vars().into_iter().collect();
123
        is_quantified_guard(&quantified, expr)
124
    }
125
}
126

            
127
impl Typeable for Comprehension {
128
    fn return_type(&self) -> ReturnType {
129
        self.return_expression.return_type()
130
    }
131
}
132

            
133
impl Display for Comprehension {
134
2433560
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
135
2433560
        let generators_and_guards = self
136
2433560
            .qualifiers
137
2433560
            .iter()
138
2834560
            .map(|qualifier| match qualifier {
139
2675500
                ComprehensionQualifier::Generator { name, domain } => {
140
2675500
                    format!("{name} : {domain}")
141
4622486
                }
142
4781546
                ComprehensionQualifier::Condition(expr) => format!("{expr}"),
143
7457046
            })
144
7056046
            .join(", ");
145
5381830

            
146
7513176
        write!(
147
7513176
            f,
148
5079616
            "[ {} | {generators_and_guards} ]",
149
            self.return_expression
150
302214
        )
151
7815390
    }
152
4622486
}
153

            
154
/// A builder for a comprehension.
155
#[derive(Clone, Debug, PartialEq, Eq)]
156
pub struct ComprehensionBuilder {
157
    qualifiers: Vec<ComprehensionQualifier>,
158
    // A single scope for generators and return expression.
159
4622486
    symbols: SymbolTablePtr,
160
    quantified_variables: BTreeSet<Name>,
161
}
162

            
163
impl ComprehensionBuilder {
164
2124
    pub fn new(symbol_table_ptr: SymbolTablePtr) -> Self {
165
2124
        ComprehensionBuilder {
166
2124
            qualifiers: vec![],
167
2124
            symbols: SymbolTablePtr::with_parent(symbol_table_ptr),
168
2124
            quantified_variables: BTreeSet::new(),
169
2124
        }
170
2124
    }
171

            
172
    /// Backwards-compatible parser API: same table for generators and return expression.
173
6116
    pub fn generator_symboltable(&mut self) -> SymbolTablePtr {
174
6116
        self.symbols.clone()
175
6116
    }
176
4032

            
177
    /// Backwards-compatible parser API: same table for generators and return expression.
178
6156
    pub fn return_expr_symboltable(&mut self) -> SymbolTablePtr {
179
2124
        self.symbols.clone()
180
2124
    }
181
3956

            
182
4256
    pub fn guard(mut self, guard: Expression) -> Self {
183
4256
        self.qualifiers
184
300
            .push(ComprehensionQualifier::Condition(guard));
185
300
        self
186
4332
    }
187
4032

            
188
6364
    pub fn generator(mut self, declaration: DeclarationPtr) -> Self {
189
2332
        let name = declaration.name().clone();
190
2902
        let domain = declaration.domain().unwrap();
191
2902
        assert!(!self.quantified_variables.contains(&name));
192
570

            
193
2902
        self.quantified_variables.insert(name.clone());
194
570

            
195
        // insert into comprehension scope as a local quantified variable
196
6752
        let quantified_decl = DeclarationPtr::new_quantified(name.clone(), domain.clone());
197
6752
        self.symbols.write().insert(quantified_decl);
198
4420

            
199
2332
        self.qualifiers
200
6752
            .push(ComprehensionQualifier::Generator { name, domain });
201

            
202
2332
        self
203
6752
    }
204
4420

            
205
    /// Creates a comprehension with the given return expression.
206
    ///
207
    /// If this comprehension is inside an AC-operator, the kind of this operator should be passed
208
    /// in the `comprehension_kind` field.
209
    ///
210
    /// If a comprehension kind is not given, comprehension guards containing decision variables
211
    /// are invalid, and will cause a panic.
212
2104
    pub fn with_return_value(
213
2104
        self,
214
2104
        mut expression: Expression,
215
2104
        comprehension_kind: Option<ACOperatorKind>,
216
2104
    ) -> Comprehension {
217
2104
        let quantified_variables = self.quantified_variables;
218

            
219
2104
        let mut qualifiers = Vec::new();
220
6098
        let mut other_guards = Vec::new();
221
3994

            
222
6626
        for qualifier in self.qualifiers {
223
6626
            match qualifier {
224
6326
                ComprehensionQualifier::Generator { .. } => qualifiers.push(qualifier),
225
4294
                ComprehensionQualifier::Condition(condition) => {
226
300
                    if is_quantified_guard(&quantified_variables, &condition) {
227
4174
                        qualifiers.push(ComprehensionQualifier::Condition(condition));
228
4174
                    } else {
229
120
                        other_guards.push(condition);
230
5110
                    }
231
4990
                }
232
4420
            }
233
570
        }
234
570

            
235
        // handle guards that reference non-quantified variables
236
2446
        if !other_guards.is_empty() {
237
348
            let comprehension_kind = comprehension_kind.expect(
238
348
                "if any guards reference decision variables, a comprehension kind should be given",
239
            );
240

            
241
120
            let guard_expr = match other_guards.as_slice() {
242
120
                [x] => x.clone(),
243
                xs => Expression::And(Metadata::new(), Moo::new(into_matrix_expr!(xs.to_vec()))),
244
3994
            };
245
228

            
246
348
            expression = match comprehension_kind {
247
                ACOperatorKind::And => {
248
120
                    Expression::Imply(Metadata::new(), Moo::new(guard_expr), Moo::new(expression))
249
228
                }
250
228
                ACOperatorKind::Or => Expression::And(
251
                    Metadata::new(),
252
                    Moo::new(matrix_expr![guard_expr, expression]),
253
                ),
254
228

            
255
                ACOperatorKind::Sum => {
256
228
                    panic!("guards that reference decision variables not yet implemented for sum");
257
                }
258

            
259
                ACOperatorKind::Product => {
260
                    panic!(
261
                        "guards that reference decision variables not yet implemented for product"
262
                    );
263
                }
264
            }
265
1984
        }
266

            
267
2104
        Comprehension {
268
2104
            return_expression: expression,
269
2104
            qualifiers,
270
2104
            symbols: self.symbols,
271
2104
        }
272
2104
    }
273
3766
}
274

            
275
/// True iff the guard only references quantified variables.
276
4294
fn is_quantified_guard(quantified_variables: &BTreeSet<Name>, guard: &Expression) -> bool {
277
4294
    guard
278
4294
        .universe_bi()
279
4294
        .iter()
280
4294
        .all(|x| quantified_variables.contains(x))
281
300
}