1
#![allow(clippy::arc_with_non_send_sync)]
2

            
3
use std::{collections::BTreeSet, fmt::Display};
4

            
5
use crate::{ast::Metadata, into_matrix_expr, matrix_expr};
6
use conjure_cp_core::ast::ReturnType;
7
use itertools::Itertools as _;
8
use parking_lot::RwLockReadGuard;
9
use serde::{Deserialize, Serialize};
10
use serde_with::serde_as;
11
use uniplate::{Biplate, Uniplate};
12

            
13
use super::{
14
    DeclarationPtr, Domain, DomainPtr, Expression, Model, Moo, Name, Range, SymbolTable,
15
    SymbolTablePtr, Typeable, ac_operators::ACOperatorKind, serde::PtrAsInner,
16
};
17

            
18
#[derive(Clone, PartialEq, Eq, Hash, Serialize, Deserialize, Debug, Uniplate)]
19
#[biplate(to=Expression)]
20
#[biplate(to=Name)]
21
pub enum ComprehensionQualifier {
22
    Generator { name: Name, domain: DomainPtr },
23
    Condition(Expression),
24
}
25

            
26
/// A comprehension.
27
#[serde_as]
28
#[derive(Clone, PartialEq, Eq, Hash, Uniplate, Serialize, Deserialize, Debug)]
29
#[biplate(to=Expression)]
30
#[biplate(to=SymbolTable)]
31
#[biplate(to=SymbolTablePtr)]
32
#[non_exhaustive]
33
pub struct Comprehension {
34
    pub return_expression: Expression,
35
    pub qualifiers: Vec<ComprehensionQualifier>,
36
    #[doc(hidden)]
37
    #[serde_as(as = "PtrAsInner")]
38
    pub symbols: SymbolTablePtr,
39
}
40

            
41
impl Comprehension {
42
    pub fn domain_of(&self) -> Option<DomainPtr> {
43
        let return_expr_domain = self.return_expression.domain_of()?;
44

            
45
        // return a list (matrix with index domain int(1..)) of return_expr elements
46
        Some(Domain::matrix(
47
            return_expr_domain,
48
            vec![Domain::int(vec![Range::UnboundedR(1)])],
49
        ))
50
    }
51

            
52
12852
    pub fn return_expression(self) -> Expression {
53
12852
        self.return_expression
54
12852
    }
55

            
56
    pub fn replace_return_expression(&mut self, new_expr: Expression) {
57
        self.return_expression = new_expr;
58
    }
59

            
60
40929
    pub fn symbols(&self) -> RwLockReadGuard<'_, SymbolTable> {
61
40929
        self.symbols.read()
62
40929
    }
63

            
64
15666
    pub fn quantified_vars(&self) -> Vec<Name> {
65
15666
        self.qualifiers
66
15666
            .iter()
67
18165
            .filter_map(|q| match q {
68
17283
                ComprehensionQualifier::Generator { name, .. } => Some(name.clone()),
69
882
                ComprehensionQualifier::Condition(_) => None,
70
18165
            })
71
15666
            .collect()
72
15666
    }
73

            
74
2142
    pub fn generator_conditions(&self) -> Vec<Expression> {
75
2142
        self.qualifiers
76
2142
            .iter()
77
2478
            .filter_map(|q| match q {
78
126
                ComprehensionQualifier::Condition(c) => Some(c.clone()),
79
2352
                ComprehensionQualifier::Generator { .. } => None,
80
2478
            })
81
2142
            .collect()
82
2142
    }
83

            
84
    /// Builds a temporary model containing generator qualifiers and guards.
85
2142
    pub fn to_generator_model(&self) -> Model {
86
2142
        let mut model = self.empty_model_with_symbols();
87
2142
        model.add_constraints(self.generator_conditions());
88
2142
        model
89
2142
    }
90

            
91
    /// Builds a temporary model containing the return expression only.
92
2142
    pub fn to_return_expression_model(&self) -> Model {
93
2142
        let mut model = self.empty_model_with_symbols();
94
2142
        model.add_constraint(self.return_expression.clone());
95
2142
        model
96
2142
    }
97

            
98
4284
    fn empty_model_with_symbols(&self) -> Model {
99
4284
        let parent = self.symbols.read().parent().clone();
100
4284
        let mut model = if let Some(parent) = parent {
101
4284
            Model::new_in_parent_scope(parent)
102
        } else {
103
            Model::default()
104
        };
105
4284
        *model.symbols_ptr_unchecked_mut() = self.symbols.clone();
106
4284
        model
107
4284
    }
108

            
109
    /// Adds a guard to the comprehension. Returns false if the guard does not only reference quantified variables.
110
    pub fn add_quantified_guard(&mut self, guard: Expression) -> bool {
111
        if self.is_quantified_guard(&guard) {
112
            self.qualifiers
113
                .push(ComprehensionQualifier::Condition(guard));
114
            true
115
        } else {
116
            false
117
        }
118
    }
119

            
120
    /// True iff expr only references quantified variables.
121
    pub fn is_quantified_guard(&self, expr: &Expression) -> bool {
122
        let quantified: BTreeSet<Name> = self.quantified_vars().into_iter().collect();
123
        is_quantified_guard(&quantified, expr)
124
    }
125
}
126

            
127
impl Typeable for Comprehension {
128
    fn return_type(&self) -> ReturnType {
129
        self.return_expression.return_type()
130
    }
131
}
132

            
133
impl Display for Comprehension {
134
25305
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
135
25305
        let generators_and_guards = self
136
25305
            .qualifiers
137
25305
            .iter()
138
29463
            .map(|qualifier| match qualifier {
139
27699
                ComprehensionQualifier::Generator { name, domain } => {
140
27699
                    format!("{name} : {domain}")
141
                }
142
1764
                ComprehensionQualifier::Condition(expr) => format!("{expr}"),
143
29463
            })
144
25305
            .join(", ");
145

            
146
25305
        write!(
147
25305
            f,
148
            "[ {} | {generators_and_guards} ]",
149
            self.return_expression
150
        )
151
25305
    }
152
}
153

            
154
/// A builder for a comprehension.
155
#[derive(Clone, Debug, PartialEq, Eq)]
156
pub struct ComprehensionBuilder {
157
    qualifiers: Vec<ComprehensionQualifier>,
158
    // A single scope for generators and return expression.
159
    symbols: SymbolTablePtr,
160
    quantified_variables: BTreeSet<Name>,
161
}
162

            
163
impl ComprehensionBuilder {
164
2226
    pub fn new(symbol_table_ptr: SymbolTablePtr) -> Self {
165
2226
        ComprehensionBuilder {
166
2226
            qualifiers: vec![],
167
2226
            symbols: SymbolTablePtr::with_parent(symbol_table_ptr),
168
2226
            quantified_variables: BTreeSet::new(),
169
2226
        }
170
2226
    }
171

            
172
    /// Backwards-compatible parser API: same table for generators and return expression.
173
2184
    pub fn generator_symboltable(&mut self) -> SymbolTablePtr {
174
2184
        self.symbols.clone()
175
2184
    }
176

            
177
    /// Backwards-compatible parser API: same table for generators and return expression.
178
2226
    pub fn return_expr_symboltable(&mut self) -> SymbolTablePtr {
179
2226
        self.symbols.clone()
180
2226
    }
181

            
182
315
    pub fn guard(mut self, guard: Expression) -> Self {
183
315
        self.qualifiers
184
315
            .push(ComprehensionQualifier::Condition(guard));
185
315
        self
186
315
    }
187

            
188
2436
    pub fn generator(mut self, declaration: DeclarationPtr) -> Self {
189
2436
        let name = declaration.name().clone();
190
2436
        let domain = declaration.domain().unwrap();
191
2436
        assert!(!self.quantified_variables.contains(&name));
192

            
193
2436
        self.quantified_variables.insert(name.clone());
194

            
195
        // insert into comprehension scope as a local quantified variable
196
2436
        let quantified_decl = DeclarationPtr::new_quantified(name.clone(), domain.clone());
197
2436
        self.symbols.write().insert(quantified_decl);
198

            
199
2436
        self.qualifiers
200
2436
            .push(ComprehensionQualifier::Generator { name, domain });
201

            
202
2436
        self
203
2436
    }
204

            
205
    /// Creates a comprehension with the given return expression.
206
    ///
207
    /// If this comprehension is inside an AC-operator, the kind of this operator should be passed
208
    /// in the `comprehension_kind` field.
209
    ///
210
    /// If a comprehension kind is not given, comprehension guards containing decision variables
211
    /// are invalid, and will cause a panic.
212
2205
    pub fn with_return_value(
213
2205
        self,
214
2205
        mut expression: Expression,
215
2205
        comprehension_kind: Option<ACOperatorKind>,
216
2205
    ) -> Comprehension {
217
2205
        let quantified_variables = self.quantified_variables;
218

            
219
2205
        let mut qualifiers = Vec::new();
220
2205
        let mut other_guards = Vec::new();
221

            
222
2751
        for qualifier in self.qualifiers {
223
2751
            match qualifier {
224
2436
                ComprehensionQualifier::Generator { .. } => qualifiers.push(qualifier),
225
315
                ComprehensionQualifier::Condition(condition) => {
226
315
                    if is_quantified_guard(&quantified_variables, &condition) {
227
189
                        qualifiers.push(ComprehensionQualifier::Condition(condition));
228
189
                    } else {
229
126
                        other_guards.push(condition);
230
126
                    }
231
                }
232
            }
233
        }
234

            
235
        // handle guards that reference non-quantified variables
236
2205
        if !other_guards.is_empty() {
237
126
            let comprehension_kind = comprehension_kind.expect(
238
126
                "if any guards reference decision variables, a comprehension kind should be given",
239
            );
240

            
241
126
            let guard_expr = match other_guards.as_slice() {
242
126
                [x] => x.clone(),
243
                xs => Expression::And(Metadata::new(), Moo::new(into_matrix_expr!(xs.to_vec()))),
244
            };
245

            
246
126
            expression = match comprehension_kind {
247
                ACOperatorKind::And => {
248
126
                    Expression::Imply(Metadata::new(), Moo::new(guard_expr), Moo::new(expression))
249
                }
250
                ACOperatorKind::Or => Expression::And(
251
                    Metadata::new(),
252
                    Moo::new(matrix_expr![guard_expr, expression]),
253
                ),
254

            
255
                ACOperatorKind::Sum => {
256
                    panic!("guards that reference decision variables not yet implemented for sum");
257
                }
258

            
259
                ACOperatorKind::Product => {
260
                    panic!(
261
                        "guards that reference decision variables not yet implemented for product"
262
                    );
263
                }
264
            }
265
2079
        }
266

            
267
2205
        Comprehension {
268
2205
            return_expression: expression,
269
2205
            qualifiers,
270
2205
            symbols: self.symbols,
271
2205
        }
272
2205
    }
273
}
274

            
275
/// True iff the guard only references quantified variables.
276
315
fn is_quantified_guard(quantified_variables: &BTreeSet<Name>, guard: &Expression) -> bool {
277
315
    guard
278
315
        .universe_bi()
279
315
        .iter()
280
315
        .all(|x| quantified_variables.contains(x))
281
315
}