1
#![allow(clippy::arc_with_non_send_sync)]
2

            
3
use std::{collections::BTreeSet, fmt::Display};
4

            
5
use crate::{ast::Metadata, into_matrix_expr, matrix_expr};
6
use conjure_cp_core::ast::ReturnType;
7
use itertools::Itertools as _;
8
use parking_lot::RwLockReadGuard;
9
use serde::{Deserialize, Serialize};
10
use serde_with::serde_as;
11
use uniplate::{Biplate, Uniplate};
12

            
13
use super::{
14
    DeclarationPtr, Domain, DomainPtr, Expression, Model, Moo, Name, Range, SymbolTable,
15
    SymbolTablePtr, Typeable,
16
    ac_operators::ACOperatorKind,
17
    serde::{AsId, PtrAsInner},
18
};
19

            
20
#[serde_as]
21
#[derive(Clone, PartialEq, Eq, Hash, Serialize, Deserialize, Debug, Uniplate)]
22
#[biplate(to=Expression)]
23
#[biplate(to=Name)]
24
#[biplate(to=DeclarationPtr)]
25
pub enum ComprehensionQualifier {
26
    Generator {
27
        #[serde_as(as = "AsId")]
28
        ptr: DeclarationPtr,
29
    },
30
    Condition(Expression),
31
}
32

            
33
/// A comprehension.
34
#[serde_as]
35
#[derive(Clone, PartialEq, Eq, Hash, Uniplate, Serialize, Deserialize, Debug)]
36
#[biplate(to=Expression)]
37
#[biplate(to=SymbolTable)]
38
#[biplate(to=SymbolTablePtr)]
39
#[non_exhaustive]
40
pub struct Comprehension {
41
    pub return_expression: Expression,
42
    pub qualifiers: Vec<ComprehensionQualifier>,
43
    #[doc(hidden)]
44
    #[serde_as(as = "PtrAsInner")]
45
    pub symbols: SymbolTablePtr,
46
}
47

            
48
impl Comprehension {
49
    pub fn domain_of(&self) -> Option<DomainPtr> {
50
        let return_expr_domain = self.return_expression.domain_of()?;
51

            
52
        // return a list (matrix with index domain int(1..)) of return_expr elements
53
        Some(Domain::matrix(
54
            return_expr_domain,
55
            vec![Domain::int(vec![Range::UnboundedR(1)])],
56
        ))
57
    }
58

            
59
12248
    pub fn return_expression(self) -> Expression {
60
12248
        self.return_expression
61
12248
    }
62

            
63
    pub fn replace_return_expression(&mut self, new_expr: Expression) {
64
        self.return_expression = new_expr;
65
    }
66

            
67
35108
    pub fn symbols(&self) -> RwLockReadGuard<'_, SymbolTable> {
68
35108
        self.symbols.read()
69
35108
    }
70

            
71
14932
    pub fn quantified_vars(&self) -> Vec<Name> {
72
14932
        self.qualifiers
73
14932
            .iter()
74
17336
            .filter_map(|q| match q {
75
16496
                ComprehensionQualifier::Generator { ptr } => Some(ptr.name().clone()),
76
840
                ComprehensionQualifier::Condition(_) => None,
77
17336
            })
78
14932
            .collect()
79
14932
    }
80

            
81
2044
    pub fn generator_conditions(&self) -> Vec<Expression> {
82
2044
        self.qualifiers
83
2044
            .iter()
84
2372
            .filter_map(|q| match q {
85
120
                ComprehensionQualifier::Condition(c) => Some(c.clone()),
86
2252
                ComprehensionQualifier::Generator { .. } => None,
87
2372
            })
88
2044
            .collect()
89
2044
    }
90

            
91
    /// Builds a temporary model containing generator qualifiers and guards.
92
2044
    pub fn to_generator_model(&self) -> Model {
93
2044
        let mut model = self.empty_model_with_symbols();
94
2044
        model.add_constraints(self.generator_conditions());
95
2044
        model
96
2044
    }
97

            
98
    /// Builds a temporary model containing the return expression only.
99
2044
    pub fn to_return_expression_model(&self) -> Model {
100
2044
        let mut model = self.empty_model_with_symbols();
101
2044
        model.add_constraint(self.return_expression.clone());
102
2044
        model
103
2044
    }
104

            
105
4088
    fn empty_model_with_symbols(&self) -> Model {
106
4088
        let parent = self.symbols.read().parent().clone();
107
4088
        let mut model = if let Some(parent) = parent {
108
4088
            Model::new_in_parent_scope(parent)
109
        } else {
110
            Model::default()
111
        };
112
4088
        *model.symbols_ptr_unchecked_mut() = self.symbols.clone();
113
4088
        model
114
4088
    }
115

            
116
    /// Adds a guard to the comprehension. Returns false if the guard does not only reference quantified variables.
117
    pub fn add_quantified_guard(&mut self, guard: Expression) -> bool {
118
        if self.is_quantified_guard(&guard) {
119
            self.qualifiers
120
                .push(ComprehensionQualifier::Condition(guard));
121
            true
122
        } else {
123
            false
124
        }
125
    }
126

            
127
    /// True iff expr only references quantified variables.
128
    pub fn is_quantified_guard(&self, expr: &Expression) -> bool {
129
        let quantified: BTreeSet<Name> = self.quantified_vars().into_iter().collect();
130
        is_quantified_guard(&quantified, expr)
131
    }
132
}
133

            
134
impl Typeable for Comprehension {
135
    fn return_type(&self) -> ReturnType {
136
        self.return_expression.return_type()
137
    }
138
}
139

            
140
impl Display for Comprehension {
141
2433560
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
142
2433560
        let generators_and_guards = self
143
2433560
            .qualifiers
144
2433560
            .iter()
145
2834560
            .map(|qualifier| match qualifier {
146
2675500
                ComprehensionQualifier::Generator { ptr } => {
147
2675500
                    let domain = ptr.domain().expect("generator declaration has domain");
148
2675500
                    format!("{} : {domain}", ptr.name())
149
                }
150
159060
                ComprehensionQualifier::Condition(expr) => format!("{expr}"),
151
2834560
            })
152
2433560
            .join(", ");
153

            
154
2433560
        write!(
155
2433560
            f,
156
            "[ {} | {generators_and_guards} ]",
157
            self.return_expression
158
        )
159
2433560
    }
160
}
161

            
162
/// A builder for a comprehension.
163
#[derive(Clone, Debug, PartialEq, Eq)]
164
pub struct ComprehensionBuilder {
165
    qualifiers: Vec<ComprehensionQualifier>,
166
    // A single scope for generators and return expression.
167
    symbols: SymbolTablePtr,
168
    quantified_variables: BTreeSet<Name>,
169
}
170

            
171
impl ComprehensionBuilder {
172
2144
    pub fn new(symbol_table_ptr: SymbolTablePtr) -> Self {
173
2144
        ComprehensionBuilder {
174
2144
            qualifiers: vec![],
175
2144
            symbols: SymbolTablePtr::with_parent(symbol_table_ptr),
176
2144
            quantified_variables: BTreeSet::new(),
177
2144
        }
178
2144
    }
179

            
180
    /// Backwards-compatible parser API: same table for generators and return expression.
181
2084
    pub fn generator_symboltable(&mut self) -> SymbolTablePtr {
182
2084
        self.symbols.clone()
183
2084
    }
184

            
185
    /// Backwards-compatible parser API: same table for generators and return expression.
186
2144
    pub fn return_expr_symboltable(&mut self) -> SymbolTablePtr {
187
2144
        self.symbols.clone()
188
2144
    }
189

            
190
300
    pub fn guard(mut self, guard: Expression) -> Self {
191
300
        self.qualifiers
192
300
            .push(ComprehensionQualifier::Condition(guard));
193
300
        self
194
300
    }
195

            
196
2372
    pub fn generator(mut self, declaration: DeclarationPtr) -> Self {
197
2372
        let name = declaration.name().clone();
198
2372
        assert!(!self.quantified_variables.contains(&name));
199

            
200
2372
        self.quantified_variables.insert(name.clone());
201

            
202
        // insert into comprehension scope as a local quantified variable
203
2372
        let quantified_decl = DeclarationPtr::new_quantified(name, declaration.domain().unwrap());
204
2372
        self.symbols.write().insert(quantified_decl.clone());
205

            
206
2372
        self.qualifiers.push(ComprehensionQualifier::Generator {
207
2372
            ptr: quantified_decl,
208
2372
        });
209

            
210
2372
        self
211
2372
    }
212

            
213
    /// Creates a comprehension with the given return expression.
214
    ///
215
    /// If this comprehension is inside an AC-operator, the kind of this operator should be passed
216
    /// in the `comprehension_kind` field.
217
    ///
218
    /// If a comprehension kind is not given, comprehension guards containing decision variables
219
    /// are invalid, and will cause a panic.
220
2124
    pub fn with_return_value(
221
2124
        self,
222
2124
        mut expression: Expression,
223
2124
        comprehension_kind: Option<ACOperatorKind>,
224
2124
    ) -> Comprehension {
225
2124
        let quantified_variables = self.quantified_variables;
226

            
227
2124
        let mut qualifiers = Vec::new();
228
2124
        let mut other_guards = Vec::new();
229

            
230
2672
        for qualifier in self.qualifiers {
231
2672
            match qualifier {
232
2372
                ComprehensionQualifier::Generator { .. } => qualifiers.push(qualifier),
233
300
                ComprehensionQualifier::Condition(condition) => {
234
300
                    if is_quantified_guard(&quantified_variables, &condition) {
235
180
                        qualifiers.push(ComprehensionQualifier::Condition(condition));
236
180
                    } else {
237
120
                        other_guards.push(condition);
238
120
                    }
239
                }
240
            }
241
        }
242

            
243
        // handle guards that reference non-quantified variables
244
2124
        if !other_guards.is_empty() {
245
120
            let comprehension_kind = comprehension_kind.expect(
246
120
                "if any guards reference decision variables, a comprehension kind should be given",
247
            );
248

            
249
120
            let guard_expr = match other_guards.as_slice() {
250
120
                [x] => x.clone(),
251
                xs => Expression::And(Metadata::new(), Moo::new(into_matrix_expr!(xs.to_vec()))),
252
            };
253

            
254
120
            expression = match comprehension_kind {
255
                ACOperatorKind::And => {
256
120
                    Expression::Imply(Metadata::new(), Moo::new(guard_expr), Moo::new(expression))
257
                }
258
                ACOperatorKind::Or => Expression::And(
259
                    Metadata::new(),
260
                    Moo::new(matrix_expr![guard_expr, expression]),
261
                ),
262

            
263
                ACOperatorKind::Sum => {
264
                    panic!("guards that reference decision variables not yet implemented for sum");
265
                }
266

            
267
                ACOperatorKind::Product => {
268
                    panic!(
269
                        "guards that reference decision variables not yet implemented for product"
270
                    );
271
                }
272
            }
273
2004
        }
274

            
275
2124
        Comprehension {
276
2124
            return_expression: expression,
277
2124
            qualifiers,
278
2124
            symbols: self.symbols,
279
2124
        }
280
2124
    }
281
}
282

            
283
/// True iff the guard only references quantified variables.
284
300
fn is_quantified_guard(quantified_variables: &BTreeSet<Name>, guard: &Expression) -> bool {
285
300
    guard
286
300
        .universe_bi()
287
300
        .iter()
288
300
        .all(|x| quantified_variables.contains(x))
289
300
}