1
#![allow(clippy::arc_with_non_send_sync)]
2

            
3
use std::{collections::BTreeSet, fmt::Display};
4

            
5
use crate::{ast::Metadata, into_matrix_expr, matrix_expr};
6
use conjure_cp_core::ast::ReturnType;
7
use itertools::Itertools as _;
8
use parking_lot::RwLockReadGuard;
9
use serde::{Deserialize, Serialize};
10
use serde_with::serde_as;
11
use uniplate::{Biplate, Uniplate};
12

            
13
use super::{
14
    DeclarationPtr, Domain, DomainPtr, Expression, Model, Moo, Name, Range, SymbolTable,
15
    SymbolTablePtr, Typeable, ac_operators::ACOperatorKind, serde::PtrAsInner,
16
};
17

            
18
#[derive(Clone, PartialEq, Eq, Hash, Serialize, Deserialize, Debug, Uniplate)]
19
#[biplate(to=Expression)]
20
#[biplate(to=Name)]
21
pub enum ComprehensionQualifier {
22
    Generator { name: Name, domain: DomainPtr },
23
    Condition(Expression),
24
}
25

            
26
/// A comprehension.
27
#[serde_as]
28
#[derive(Clone, PartialEq, Eq, Hash, Uniplate, Serialize, Deserialize, Debug)]
29
#[biplate(to=Expression)]
30
#[biplate(to=SymbolTable)]
31
#[biplate(to=SymbolTablePtr)]
32
#[non_exhaustive]
33
pub struct Comprehension {
34
    pub return_expression: Expression,
35
    pub qualifiers: Vec<ComprehensionQualifier>,
36
    #[doc(hidden)]
37
    #[serde_as(as = "PtrAsInner")]
38
    pub symbols: SymbolTablePtr,
39
}
40

            
41
impl Comprehension {
42
    pub fn domain_of(&self) -> Option<DomainPtr> {
43
        let return_expr_domain = self.return_expression.domain_of()?;
44

            
45
        // return a list (matrix with index domain int(1..)) of return_expr elements
46
        Some(Domain::matrix(
47
            return_expr_domain,
48
            vec![Domain::int(vec![Range::UnboundedR(1)])],
49
        ))
50
    }
51

            
52
26710
    pub fn return_expression(self) -> Expression {
53
26710
        self.return_expression
54
26710
    }
55

            
56
    pub fn replace_return_expression(&mut self, new_expr: Expression) {
57
        self.return_expression = new_expr;
58
    }
59

            
60
78541
    pub fn symbols(&self) -> RwLockReadGuard<'_, SymbolTable> {
61
78541
        self.symbols.read()
62
78541
    }
63

            
64
37005
    pub fn quantified_vars(&self) -> Vec<Name> {
65
37005
        self.qualifiers
66
37005
            .iter()
67
41841
            .filter_map(|q| match q {
68
40125
                ComprehensionQualifier::Generator { name, .. } => Some(name.clone()),
69
1716
                ComprehensionQualifier::Condition(_) => None,
70
41841
            })
71
37005
            .collect()
72
37005
    }
73

            
74
9905
    pub fn generator_conditions(&self) -> Vec<Expression> {
75
9905
        self.qualifiers
76
9905
            .iter()
77
11036
            .filter_map(|q| match q {
78
312
                ComprehensionQualifier::Condition(c) => Some(c.clone()),
79
10724
                ComprehensionQualifier::Generator { .. } => None,
80
11036
            })
81
9905
            .collect()
82
9905
    }
83

            
84
    /// Builds a temporary model containing generator qualifiers and guards.
85
9905
    pub fn to_generator_model(&self) -> Model {
86
9905
        let mut model = self.empty_model_with_symbols();
87
9905
        model.add_constraints(self.generator_conditions());
88
9905
        model
89
9905
    }
90

            
91
    /// Builds a temporary model containing the return expression only.
92
9905
    pub fn to_return_expression_model(&self) -> Model {
93
9905
        let mut model = self.empty_model_with_symbols();
94
9905
        model.add_constraint(self.return_expression.clone());
95
9905
        model
96
9905
    }
97

            
98
19810
    fn empty_model_with_symbols(&self) -> Model {
99
19810
        let parent = self.symbols.read().parent().clone();
100
19810
        let mut model = if let Some(parent) = parent {
101
19810
            Model::new_in_parent_scope(parent)
102
        } else {
103
            Model::default()
104
        };
105
19810
        *model.symbols_ptr_unchecked_mut() = self.symbols.clone();
106
19810
        model
107
19810
    }
108

            
109
    /// Adds a guard to the comprehension. Returns false if the guard does not only reference quantified variables.
110
    pub fn add_quantified_guard(&mut self, guard: Expression) -> bool {
111
        if self.is_quantified_guard(&guard) {
112
            self.qualifiers
113
                .push(ComprehensionQualifier::Condition(guard));
114
            true
115
        } else {
116
            false
117
        }
118
    }
119

            
120
    /// True iff expr only references quantified variables.
121
    pub fn is_quantified_guard(&self, expr: &Expression) -> bool {
122
        let quantified: BTreeSet<Name> = self.quantified_vars().into_iter().collect();
123
        is_quantified_guard(&quantified, expr)
124
    }
125
}
126

            
127
impl Typeable for Comprehension {
128
    fn return_type(&self) -> ReturnType {
129
        self.return_expression.return_type()
130
    }
131
}
132

            
133
impl Display for Comprehension {
134
49637
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
135
49637
        let generators_and_guards = self
136
49637
            .qualifiers
137
49637
            .iter()
138
57047
            .map(|qualifier| match qualifier {
139
53771
                ComprehensionQualifier::Generator { name, domain } => {
140
53771
                    format!("{name} : {domain}")
141
                }
142
3276
                ComprehensionQualifier::Condition(expr) => format!("{expr}"),
143
57047
            })
144
49637
            .join(", ");
145

            
146
49637
        write!(
147
49637
            f,
148
            "[ {} | {generators_and_guards} ]",
149
            self.return_expression
150
        )
151
49637
    }
152
}
153

            
154
/// A builder for a comprehension.
155
#[derive(Clone, Debug, PartialEq, Eq)]
156
pub struct ComprehensionBuilder {
157
    qualifiers: Vec<ComprehensionQualifier>,
158
    // A single scope for generators and return expression.
159
    symbols: SymbolTablePtr,
160
    quantified_variables: BTreeSet<Name>,
161
}
162

            
163
impl ComprehensionBuilder {
164
3626
    pub fn new(symbol_table_ptr: SymbolTablePtr) -> Self {
165
3626
        ComprehensionBuilder {
166
3626
            qualifiers: vec![],
167
3626
            symbols: SymbolTablePtr::with_parent(symbol_table_ptr),
168
3626
            quantified_variables: BTreeSet::new(),
169
3626
        }
170
3626
    }
171

            
172
    /// Backwards-compatible parser API: same table for generators and return expression.
173
3588
    pub fn generator_symboltable(&mut self) -> SymbolTablePtr {
174
3588
        self.symbols.clone()
175
3588
    }
176

            
177
    /// Backwards-compatible parser API: same table for generators and return expression.
178
3626
    pub fn return_expr_symboltable(&mut self) -> SymbolTablePtr {
179
3626
        self.symbols.clone()
180
3626
    }
181

            
182
585
    pub fn guard(mut self, guard: Expression) -> Self {
183
585
        self.qualifiers
184
585
            .push(ComprehensionQualifier::Condition(guard));
185
585
        self
186
585
    }
187

            
188
4016
    pub fn generator(mut self, declaration: DeclarationPtr) -> Self {
189
4016
        let name = declaration.name().clone();
190
4016
        let domain = declaration.domain().unwrap();
191
4016
        assert!(!self.quantified_variables.contains(&name));
192

            
193
4016
        self.quantified_variables.insert(name.clone());
194

            
195
        // insert into comprehension scope as a local quantified variable
196
4016
        let quantified_decl = DeclarationPtr::new_quantified(name.clone(), domain.clone());
197
4016
        self.symbols.write().insert(quantified_decl);
198

            
199
4016
        self.qualifiers
200
4016
            .push(ComprehensionQualifier::Generator { name, domain });
201

            
202
4016
        self
203
4016
    }
204

            
205
    /// Creates a comprehension with the given return expression.
206
    ///
207
    /// If this comprehension is inside an AC-operator, the kind of this operator should be passed
208
    /// in the `comprehension_kind` field.
209
    ///
210
    /// If a comprehension kind is not given, comprehension guards containing decision variables
211
    /// are invalid, and will cause a panic.
212
3587
    pub fn with_return_value(
213
3587
        self,
214
3587
        mut expression: Expression,
215
3587
        comprehension_kind: Option<ACOperatorKind>,
216
3587
    ) -> Comprehension {
217
3587
        let quantified_variables = self.quantified_variables;
218

            
219
3587
        let mut qualifiers = Vec::new();
220
3587
        let mut other_guards = Vec::new();
221

            
222
4601
        for qualifier in self.qualifiers {
223
4601
            match qualifier {
224
4016
                ComprehensionQualifier::Generator { .. } => qualifiers.push(qualifier),
225
585
                ComprehensionQualifier::Condition(condition) => {
226
585
                    if is_quantified_guard(&quantified_variables, &condition) {
227
351
                        qualifiers.push(ComprehensionQualifier::Condition(condition));
228
351
                    } else {
229
234
                        other_guards.push(condition);
230
234
                    }
231
                }
232
            }
233
        }
234

            
235
        // handle guards that reference non-quantified variables
236
3587
        if !other_guards.is_empty() {
237
234
            let comprehension_kind = comprehension_kind.expect(
238
234
                "if any guards reference decision variables, a comprehension kind should be given",
239
            );
240

            
241
234
            let guard_expr = match other_guards.as_slice() {
242
234
                [x] => x.clone(),
243
                xs => Expression::And(Metadata::new(), Moo::new(into_matrix_expr!(xs.to_vec()))),
244
            };
245

            
246
234
            expression = match comprehension_kind {
247
                ACOperatorKind::And => {
248
234
                    Expression::Imply(Metadata::new(), Moo::new(guard_expr), Moo::new(expression))
249
                }
250
                ACOperatorKind::Or => Expression::And(
251
                    Metadata::new(),
252
                    Moo::new(matrix_expr![guard_expr, expression]),
253
                ),
254

            
255
                ACOperatorKind::Sum => {
256
                    panic!("guards that reference decision variables not yet implemented for sum");
257
                }
258

            
259
                ACOperatorKind::Product => {
260
                    panic!(
261
                        "guards that reference decision variables not yet implemented for product"
262
                    );
263
                }
264
            }
265
3353
        }
266

            
267
3587
        Comprehension {
268
3587
            return_expression: expression,
269
3587
            qualifiers,
270
3587
            symbols: self.symbols,
271
3587
        }
272
3587
    }
273
}
274

            
275
/// True iff the guard only references quantified variables.
276
585
fn is_quantified_guard(quantified_variables: &BTreeSet<Name>, guard: &Expression) -> bool {
277
585
    guard
278
585
        .universe_bi()
279
585
        .iter()
280
585
        .all(|x| quantified_variables.contains(x))
281
585
}