1
#![allow(clippy::arc_with_non_send_sync)]
2

            
3
use std::{collections::BTreeSet, fmt::Display};
4

            
5
use crate::{ast::Metadata, into_matrix_expr, matrix_expr};
6
use conjure_cp_core::ast::ReturnType;
7
use itertools::Itertools as _;
8
use parking_lot::RwLockReadGuard;
9
use serde::{Deserialize, Serialize};
10
use serde_with::serde_as;
11
use uniplate::{Biplate, Uniplate};
12

            
13
use super::{
14
    DeclarationPtr, Domain, DomainPtr, Expression, Model, Moo, Name, Range, SymbolTable,
15
    SymbolTablePtr, Typeable, ac_operators::ACOperatorKind, serde::PtrAsInner,
16
};
17

            
18
#[derive(Clone, PartialEq, Eq, Hash, Serialize, Deserialize, Debug, Uniplate)]
19
#[biplate(to=Expression)]
20
#[biplate(to=Name)]
21
pub enum ComprehensionQualifier {
22
    Generator { name: Name, domain: DomainPtr },
23
    Condition(Expression),
24
}
25

            
26
/// A comprehension.
27
#[serde_as]
28
#[derive(Clone, PartialEq, Eq, Hash, Uniplate, Serialize, Deserialize, Debug)]
29
#[biplate(to=Expression)]
30
#[biplate(to=SymbolTable)]
31
#[biplate(to=SymbolTablePtr)]
32
#[non_exhaustive]
33
pub struct Comprehension {
34
    pub return_expression: Expression,
35
    pub qualifiers: Vec<ComprehensionQualifier>,
36
    #[doc(hidden)]
37
    #[serde_as(as = "PtrAsInner")]
38
    pub symbols: SymbolTablePtr,
39
}
40

            
41
impl Comprehension {
42
    pub fn domain_of(&self) -> Option<DomainPtr> {
43
        let return_expr_domain = self.return_expression.domain_of()?;
44

            
45
        // return a list (matrix with index domain int(1..)) of return_expr elements
46
        Some(Domain::matrix(
47
            return_expr_domain,
48
            vec![Domain::int(vec![Range::UnboundedR(1)])],
49
        ))
50
    }
51

            
52
36720
    pub fn return_expression(self) -> Expression {
53
36720
        self.return_expression
54
36720
    }
55

            
56
    pub fn replace_return_expression(&mut self, new_expr: Expression) {
57
        self.return_expression = new_expr;
58
    }
59

            
60
116940
    pub fn symbols(&self) -> RwLockReadGuard<'_, SymbolTable> {
61
116940
        self.symbols.read()
62
116940
    }
63

            
64
44760
    pub fn quantified_vars(&self) -> Vec<Name> {
65
44760
        self.qualifiers
66
44760
            .iter()
67
51900
            .filter_map(|q| match q {
68
49380
                ComprehensionQualifier::Generator { name, .. } => Some(name.clone()),
69
2520
                ComprehensionQualifier::Condition(_) => None,
70
51900
            })
71
44760
            .collect()
72
44760
    }
73

            
74
6120
    pub fn generator_conditions(&self) -> Vec<Expression> {
75
6120
        self.qualifiers
76
6120
            .iter()
77
7080
            .filter_map(|q| match q {
78
360
                ComprehensionQualifier::Condition(c) => Some(c.clone()),
79
6720
                ComprehensionQualifier::Generator { .. } => None,
80
7080
            })
81
6120
            .collect()
82
6120
    }
83

            
84
    /// Builds a temporary model containing generator qualifiers and guards.
85
6120
    pub fn to_generator_model(&self) -> Model {
86
6120
        let mut model = self.empty_model_with_symbols();
87
6120
        model.add_constraints(self.generator_conditions());
88
6120
        model
89
6120
    }
90

            
91
    /// Builds a temporary model containing the return expression only.
92
6120
    pub fn to_return_expression_model(&self) -> Model {
93
6120
        let mut model = self.empty_model_with_symbols();
94
6120
        model.add_constraint(self.return_expression.clone());
95
6120
        model
96
6120
    }
97

            
98
12240
    fn empty_model_with_symbols(&self) -> Model {
99
12240
        let parent = self.symbols.read().parent().clone();
100
12240
        let mut model = if let Some(parent) = parent {
101
12240
            Model::new_in_parent_scope(parent)
102
        } else {
103
            Model::default()
104
        };
105
12240
        *model.symbols_ptr_unchecked_mut() = self.symbols.clone();
106
12240
        model
107
12240
    }
108

            
109
    /// Adds a guard to the comprehension. Returns false if the guard does not only reference quantified variables.
110
    pub fn add_quantified_guard(&mut self, guard: Expression) -> bool {
111
        if self.is_quantified_guard(&guard) {
112
            self.qualifiers
113
                .push(ComprehensionQualifier::Condition(guard));
114
            true
115
        } else {
116
            false
117
        }
118
    }
119

            
120
    /// True iff expr only references quantified variables.
121
    pub fn is_quantified_guard(&self, expr: &Expression) -> bool {
122
        let quantified: BTreeSet<Name> = self.quantified_vars().into_iter().collect();
123
        is_quantified_guard(&quantified, expr)
124
    }
125
}
126

            
127
impl Typeable for Comprehension {
128
    fn return_type(&self) -> ReturnType {
129
        self.return_expression.return_type()
130
    }
131
}
132

            
133
impl Display for Comprehension {
134
72300
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
135
72300
        let generators_and_guards = self
136
72300
            .qualifiers
137
72300
            .iter()
138
84180
            .map(|qualifier| match qualifier {
139
79140
                ComprehensionQualifier::Generator { name, domain } => {
140
79140
                    format!("{name} : {domain}")
141
                }
142
5040
                ComprehensionQualifier::Condition(expr) => format!("{expr}"),
143
84180
            })
144
72300
            .join(", ");
145

            
146
72300
        write!(
147
72300
            f,
148
            "[ {} | {generators_and_guards} ]",
149
            self.return_expression
150
        )
151
72300
    }
152
}
153

            
154
/// A builder for a comprehension.
155
#[derive(Clone, Debug, PartialEq, Eq)]
156
pub struct ComprehensionBuilder {
157
    qualifiers: Vec<ComprehensionQualifier>,
158
    // A single scope for generators and return expression.
159
    symbols: SymbolTablePtr,
160
    quantified_variables: BTreeSet<Name>,
161
}
162

            
163
impl ComprehensionBuilder {
164
6360
    pub fn new(symbol_table_ptr: SymbolTablePtr) -> Self {
165
6360
        ComprehensionBuilder {
166
6360
            qualifiers: vec![],
167
6360
            symbols: SymbolTablePtr::with_parent(symbol_table_ptr),
168
6360
            quantified_variables: BTreeSet::new(),
169
6360
        }
170
6360
    }
171

            
172
    /// Backwards-compatible parser API: same table for generators and return expression.
173
6240
    pub fn generator_symboltable(&mut self) -> SymbolTablePtr {
174
6240
        self.symbols.clone()
175
6240
    }
176

            
177
    /// Backwards-compatible parser API: same table for generators and return expression.
178
6360
    pub fn return_expr_symboltable(&mut self) -> SymbolTablePtr {
179
6360
        self.symbols.clone()
180
6360
    }
181

            
182
900
    pub fn guard(mut self, guard: Expression) -> Self {
183
900
        self.qualifiers
184
900
            .push(ComprehensionQualifier::Condition(guard));
185
900
        self
186
900
    }
187

            
188
6960
    pub fn generator(mut self, declaration: DeclarationPtr) -> Self {
189
6960
        let name = declaration.name().clone();
190
6960
        let domain = declaration.domain().unwrap();
191
6960
        assert!(!self.quantified_variables.contains(&name));
192

            
193
6960
        self.quantified_variables.insert(name.clone());
194

            
195
        // insert into comprehension scope as a local quantified variable
196
6960
        let quantified_decl = DeclarationPtr::new_quantified(name.clone(), domain.clone());
197
6960
        self.symbols.write().insert(quantified_decl);
198

            
199
6960
        self.qualifiers
200
6960
            .push(ComprehensionQualifier::Generator { name, domain });
201

            
202
6960
        self
203
6960
    }
204

            
205
    /// Creates a comprehension with the given return expression.
206
    ///
207
    /// If this comprehension is inside an AC-operator, the kind of this operator should be passed
208
    /// in the `comprehension_kind` field.
209
    ///
210
    /// If a comprehension kind is not given, comprehension guards containing decision variables
211
    /// are invalid, and will cause a panic.
212
6300
    pub fn with_return_value(
213
6300
        self,
214
6300
        mut expression: Expression,
215
6300
        comprehension_kind: Option<ACOperatorKind>,
216
6300
    ) -> Comprehension {
217
6300
        let quantified_variables = self.quantified_variables;
218

            
219
6300
        let mut qualifiers = Vec::new();
220
6300
        let mut other_guards = Vec::new();
221

            
222
7860
        for qualifier in self.qualifiers {
223
7860
            match qualifier {
224
6960
                ComprehensionQualifier::Generator { .. } => qualifiers.push(qualifier),
225
900
                ComprehensionQualifier::Condition(condition) => {
226
900
                    if is_quantified_guard(&quantified_variables, &condition) {
227
540
                        qualifiers.push(ComprehensionQualifier::Condition(condition));
228
540
                    } else {
229
360
                        other_guards.push(condition);
230
360
                    }
231
                }
232
            }
233
        }
234

            
235
        // handle guards that reference non-quantified variables
236
6300
        if !other_guards.is_empty() {
237
360
            let comprehension_kind = comprehension_kind.expect(
238
360
                "if any guards reference decision variables, a comprehension kind should be given",
239
            );
240

            
241
360
            let guard_expr = match other_guards.as_slice() {
242
360
                [x] => x.clone(),
243
                xs => Expression::And(Metadata::new(), Moo::new(into_matrix_expr!(xs.to_vec()))),
244
            };
245

            
246
360
            expression = match comprehension_kind {
247
                ACOperatorKind::And => {
248
360
                    Expression::Imply(Metadata::new(), Moo::new(guard_expr), Moo::new(expression))
249
                }
250
                ACOperatorKind::Or => Expression::And(
251
                    Metadata::new(),
252
                    Moo::new(matrix_expr![guard_expr, expression]),
253
                ),
254

            
255
                ACOperatorKind::Sum => {
256
                    panic!("guards that reference decision variables not yet implemented for sum");
257
                }
258

            
259
                ACOperatorKind::Product => {
260
                    panic!(
261
                        "guards that reference decision variables not yet implemented for product"
262
                    );
263
                }
264
            }
265
5940
        }
266

            
267
6300
        Comprehension {
268
6300
            return_expression: expression,
269
6300
            qualifiers,
270
6300
            symbols: self.symbols,
271
6300
        }
272
6300
    }
273
}
274

            
275
/// True iff the guard only references quantified variables.
276
900
fn is_quantified_guard(quantified_variables: &BTreeSet<Name>, guard: &Expression) -> bool {
277
900
    guard
278
900
        .universe_bi()
279
900
        .iter()
280
900
        .all(|x| quantified_variables.contains(x))
281
900
}