1
#![allow(clippy::arc_with_non_send_sync)]
2

            
3
use std::{collections::BTreeSet, fmt::Display};
4

            
5
use crate::{ast::Metadata, into_matrix_expr, matrix_expr};
6
use conjure_cp_core::ast::ReturnType;
7
use itertools::Itertools as _;
8
use parking_lot::RwLockReadGuard;
9
use serde::{Deserialize, Serialize};
10
use serde_with::serde_as;
11
use uniplate::{Biplate, Uniplate};
12

            
13
use super::{
14
    DeclarationPtr, Domain, DomainPtr, Expression, Model, Moo, Name, Range, SymbolTable,
15
    SymbolTablePtr, Typeable, ac_operators::ACOperatorKind, serde::PtrAsInner,
16
};
17

            
18
#[derive(Clone, PartialEq, Eq, Hash, Serialize, Deserialize, Debug, Uniplate)]
19
#[biplate(to=Expression)]
20
#[biplate(to=Name)]
21
pub enum ComprehensionQualifier {
22
    Generator { name: Name, domain: DomainPtr },
23
    Condition(Expression),
24
}
25

            
26
/// A comprehension.
27
#[serde_as]
28
#[derive(Clone, PartialEq, Eq, Hash, Uniplate, Serialize, Deserialize, Debug)]
29
#[biplate(to=Expression)]
30
#[biplate(to=SymbolTable)]
31
#[biplate(to=SymbolTablePtr)]
32
#[non_exhaustive]
33
pub struct Comprehension {
34
    pub return_expression: Expression,
35
    pub qualifiers: Vec<ComprehensionQualifier>,
36
    #[doc(hidden)]
37
    #[serde_as(as = "PtrAsInner")]
38
    pub symbols: SymbolTablePtr,
39
}
40

            
41
impl Comprehension {
42
    pub fn domain_of(&self) -> Option<DomainPtr> {
43
        let return_expr_domain = self.return_expression.domain_of()?;
44

            
45
        // return a list (matrix with index domain int(1..)) of return_expr elements
46
        Some(Domain::matrix(
47
            return_expr_domain,
48
            vec![Domain::int(vec![Range::UnboundedR(1)])],
49
        ))
50
    }
51

            
52
12248
    pub fn return_expression(self) -> Expression {
53
12248
        self.return_expression
54
12248
    }
55

            
56
    pub fn replace_return_expression(&mut self, new_expr: Expression) {
57
        self.return_expression = new_expr;
58
    }
59

            
60
38988
    pub fn symbols(&self) -> RwLockReadGuard<'_, SymbolTable> {
61
38988
        self.symbols.read()
62
38988
    }
63

            
64
14932
    pub fn quantified_vars(&self) -> Vec<Name> {
65
14932
        self.qualifiers
66
14932
            .iter()
67
17336
            .filter_map(|q| match q {
68
16496
                ComprehensionQualifier::Generator { name, .. } => Some(name.clone()),
69
840
                ComprehensionQualifier::Condition(_) => None,
70
17336
            })
71
14932
            .collect()
72
14932
    }
73

            
74
2044
    pub fn generator_conditions(&self) -> Vec<Expression> {
75
2044
        self.qualifiers
76
2044
            .iter()
77
2372
            .filter_map(|q| match q {
78
120
                ComprehensionQualifier::Condition(c) => Some(c.clone()),
79
2252
                ComprehensionQualifier::Generator { .. } => None,
80
2372
            })
81
2044
            .collect()
82
2044
    }
83

            
84
    /// Builds a temporary model containing generator qualifiers and guards.
85
2044
    pub fn to_generator_model(&self) -> Model {
86
2044
        let mut model = self.empty_model_with_symbols();
87
2044
        model.add_constraints(self.generator_conditions());
88
2044
        model
89
2044
    }
90

            
91
    /// Builds a temporary model containing the return expression only.
92
2044
    pub fn to_return_expression_model(&self) -> Model {
93
2044
        let mut model = self.empty_model_with_symbols();
94
2044
        model.add_constraint(self.return_expression.clone());
95
2044
        model
96
2044
    }
97

            
98
4088
    fn empty_model_with_symbols(&self) -> Model {
99
4088
        let parent = self.symbols.read().parent().clone();
100
4088
        let mut model = if let Some(parent) = parent {
101
4088
            Model::new_in_parent_scope(parent)
102
        } else {
103
            Model::default()
104
        };
105
4088
        *model.symbols_ptr_unchecked_mut() = self.symbols.clone();
106
4088
        model
107
4088
    }
108

            
109
    /// Adds a guard to the comprehension. Returns false if the guard does not only reference quantified variables.
110
    pub fn add_quantified_guard(&mut self, guard: Expression) -> bool {
111
        if self.is_quantified_guard(&guard) {
112
            self.qualifiers
113
                .push(ComprehensionQualifier::Condition(guard));
114
            true
115
        } else {
116
            false
117
        }
118
    }
119

            
120
    /// True iff expr only references quantified variables.
121
    pub fn is_quantified_guard(&self, expr: &Expression) -> bool {
122
        let quantified: BTreeSet<Name> = self.quantified_vars().into_iter().collect();
123
        is_quantified_guard(&quantified, expr)
124
    }
125
}
126

            
127
impl Typeable for Comprehension {
128
    fn return_type(&self) -> ReturnType {
129
        self.return_expression.return_type()
130
    }
131
}
132

            
133
impl Display for Comprehension {
134
2433560
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
135
2433560
        let generators_and_guards = self
136
2433560
            .qualifiers
137
2433560
            .iter()
138
2834560
            .map(|qualifier| match qualifier {
139
2675500
                ComprehensionQualifier::Generator { name, domain } => {
140
2675500
                    format!("{name} : {domain}")
141
                }
142
159060
                ComprehensionQualifier::Condition(expr) => format!("{expr}"),
143
2834560
            })
144
2433560
            .join(", ");
145

            
146
2433560
        write!(
147
2433560
            f,
148
            "[ {} | {generators_and_guards} ]",
149
            self.return_expression
150
        )
151
2433560
    }
152
}
153

            
154
/// A builder for a comprehension.
155
#[derive(Clone, Debug, PartialEq, Eq)]
156
pub struct ComprehensionBuilder {
157
    qualifiers: Vec<ComprehensionQualifier>,
158
    // A single scope for generators and return expression.
159
    symbols: SymbolTablePtr,
160
    quantified_variables: BTreeSet<Name>,
161
}
162

            
163
impl ComprehensionBuilder {
164
2124
    pub fn new(symbol_table_ptr: SymbolTablePtr) -> Self {
165
2124
        ComprehensionBuilder {
166
2124
            qualifiers: vec![],
167
2124
            symbols: SymbolTablePtr::with_parent(symbol_table_ptr),
168
2124
            quantified_variables: BTreeSet::new(),
169
2124
        }
170
2124
    }
171

            
172
    /// Backwards-compatible parser API: same table for generators and return expression.
173
2084
    pub fn generator_symboltable(&mut self) -> SymbolTablePtr {
174
2084
        self.symbols.clone()
175
2084
    }
176

            
177
    /// Backwards-compatible parser API: same table for generators and return expression.
178
2124
    pub fn return_expr_symboltable(&mut self) -> SymbolTablePtr {
179
2124
        self.symbols.clone()
180
2124
    }
181

            
182
300
    pub fn guard(mut self, guard: Expression) -> Self {
183
300
        self.qualifiers
184
300
            .push(ComprehensionQualifier::Condition(guard));
185
300
        self
186
300
    }
187

            
188
2332
    pub fn generator(mut self, declaration: DeclarationPtr) -> Self {
189
2332
        let name = declaration.name().clone();
190
2332
        let domain = declaration.domain().unwrap();
191
2332
        assert!(!self.quantified_variables.contains(&name));
192

            
193
2332
        self.quantified_variables.insert(name.clone());
194

            
195
        // insert into comprehension scope as a local quantified variable
196
2332
        let quantified_decl = DeclarationPtr::new_quantified(name.clone(), domain.clone());
197
2332
        self.symbols.write().insert(quantified_decl);
198

            
199
2332
        self.qualifiers
200
2332
            .push(ComprehensionQualifier::Generator { name, domain });
201

            
202
2332
        self
203
2332
    }
204

            
205
    /// Creates a comprehension with the given return expression.
206
    ///
207
    /// If this comprehension is inside an AC-operator, the kind of this operator should be passed
208
    /// in the `comprehension_kind` field.
209
    ///
210
    /// If a comprehension kind is not given, comprehension guards containing decision variables
211
    /// are invalid, and will cause a panic.
212
2104
    pub fn with_return_value(
213
2104
        self,
214
2104
        mut expression: Expression,
215
2104
        comprehension_kind: Option<ACOperatorKind>,
216
2104
    ) -> Comprehension {
217
2104
        let quantified_variables = self.quantified_variables;
218

            
219
2104
        let mut qualifiers = Vec::new();
220
2104
        let mut other_guards = Vec::new();
221

            
222
2632
        for qualifier in self.qualifiers {
223
2632
            match qualifier {
224
2332
                ComprehensionQualifier::Generator { .. } => qualifiers.push(qualifier),
225
300
                ComprehensionQualifier::Condition(condition) => {
226
300
                    if is_quantified_guard(&quantified_variables, &condition) {
227
180
                        qualifiers.push(ComprehensionQualifier::Condition(condition));
228
180
                    } else {
229
120
                        other_guards.push(condition);
230
120
                    }
231
                }
232
            }
233
        }
234

            
235
        // handle guards that reference non-quantified variables
236
2104
        if !other_guards.is_empty() {
237
120
            let comprehension_kind = comprehension_kind.expect(
238
120
                "if any guards reference decision variables, a comprehension kind should be given",
239
            );
240

            
241
120
            let guard_expr = match other_guards.as_slice() {
242
120
                [x] => x.clone(),
243
                xs => Expression::And(Metadata::new(), Moo::new(into_matrix_expr!(xs.to_vec()))),
244
            };
245

            
246
120
            expression = match comprehension_kind {
247
                ACOperatorKind::And => {
248
120
                    Expression::Imply(Metadata::new(), Moo::new(guard_expr), Moo::new(expression))
249
                }
250
                ACOperatorKind::Or => Expression::And(
251
                    Metadata::new(),
252
                    Moo::new(matrix_expr![guard_expr, expression]),
253
                ),
254

            
255
                ACOperatorKind::Sum => {
256
                    panic!("guards that reference decision variables not yet implemented for sum");
257
                }
258

            
259
                ACOperatorKind::Product => {
260
                    panic!(
261
                        "guards that reference decision variables not yet implemented for product"
262
                    );
263
                }
264
            }
265
1984
        }
266

            
267
2104
        Comprehension {
268
2104
            return_expression: expression,
269
2104
            qualifiers,
270
2104
            symbols: self.symbols,
271
2104
        }
272
2104
    }
273
}
274

            
275
/// True iff the guard only references quantified variables.
276
300
fn is_quantified_guard(quantified_variables: &BTreeSet<Name>, guard: &Expression) -> bool {
277
300
    guard
278
300
        .universe_bi()
279
300
        .iter()
280
300
        .all(|x| quantified_variables.contains(x))
281
300
}