1
#![allow(clippy::arc_with_non_send_sync)]
2

            
3
use std::{collections::BTreeSet, fmt::Display};
4

            
5
use crate::{ast::Metadata, into_matrix_expr, matrix_expr};
6
use conjure_cp_core::ast::ReturnType;
7
use itertools::Itertools as _;
8
use parking_lot::RwLockReadGuard;
9
use serde::{Deserialize, Serialize};
10
use serde_with::serde_as;
11
use uniplate::{Biplate, Uniplate};
12

            
13
use super::{
14
    DeclarationPtr, Domain, DomainPtr, Expression, Model, Moo, Name, Range, SymbolTable,
15
    SymbolTablePtr, Typeable,
16
    ac_operators::ACOperatorKind,
17
    serde::{AsId, PtrAsInner},
18
};
19

            
20
#[serde_as]
21
#[derive(Clone, PartialEq, Eq, Hash, Serialize, Deserialize, Debug, Uniplate)]
22
#[biplate(to=Expression)]
23
#[biplate(to=Name)]
24
#[biplate(to=DeclarationPtr)]
25
pub enum ComprehensionQualifier {
26
    Generator {
27
        #[serde_as(as = "AsId")]
28
        ptr: DeclarationPtr,
29
    },
30
    Condition(Expression),
31
}
32

            
33
/// A comprehension.
34
#[serde_as]
35
#[derive(Clone, PartialEq, Eq, Hash, Uniplate, Serialize, Deserialize, Debug)]
36
#[biplate(to=Expression)]
37
#[biplate(to=SymbolTable)]
38
#[biplate(to=SymbolTablePtr)]
39
#[non_exhaustive]
40
pub struct Comprehension {
41
    pub return_expression: Expression,
42
    pub qualifiers: Vec<ComprehensionQualifier>,
43
    #[doc(hidden)]
44
    #[serde_as(as = "PtrAsInner")]
45
    pub symbols: SymbolTablePtr,
46
}
47

            
48
impl Comprehension {
49
    pub fn domain_of(&self) -> Option<DomainPtr> {
50
        let return_expr_domain = self.return_expression.domain_of()?;
51

            
52
        // return a list (matrix with index domain int(1..)) of return_expr elements
53
        Some(Domain::matrix(
54
            return_expr_domain,
55
            vec![Domain::int(vec![Range::UnboundedR(1)])],
56
        ))
57
    }
58

            
59
23748
    pub fn return_expression(self) -> Expression {
60
23748
        self.return_expression
61
23748
    }
62

            
63
    pub fn replace_return_expression(&mut self, new_expr: Expression) {
64
        self.return_expression = new_expr;
65
    }
66

            
67
69468
    pub fn symbols(&self) -> RwLockReadGuard<'_, SymbolTable> {
68
69468
        self.symbols.read()
69
69468
    }
70

            
71
28672
    pub fn quantified_vars(&self) -> Vec<Name> {
72
28672
        self.qualifiers
73
28672
            .iter()
74
33456
            .filter_map(|q| match q {
75
31776
                ComprehensionQualifier::Generator { ptr } => Some(ptr.name().clone()),
76
1680
                ComprehensionQualifier::Condition(_) => None,
77
33456
            })
78
28672
            .collect()
79
28672
    }
80

            
81
3644
    pub fn generator_conditions(&self) -> Vec<Expression> {
82
3644
        self.qualifiers
83
3644
            .iter()
84
4292
            .filter_map(|q| match q {
85
240
                ComprehensionQualifier::Condition(c) => Some(c.clone()),
86
4052
                ComprehensionQualifier::Generator { .. } => None,
87
4292
            })
88
3644
            .collect()
89
3644
    }
90

            
91
    /// Builds a temporary model containing generator qualifiers and guards.
92
3644
    pub fn to_generator_model(&self) -> Model {
93
3644
        let mut model = self.empty_model_with_symbols();
94
3644
        model.add_constraints(self.generator_conditions());
95
3644
        model
96
3644
    }
97

            
98
    /// Builds a temporary model containing the return expression only.
99
3644
    pub fn to_return_expression_model(&self) -> Model {
100
3644
        let mut model = self.empty_model_with_symbols();
101
3644
        model.add_constraint(self.return_expression.clone());
102
3644
        model
103
3644
    }
104

            
105
7288
    fn empty_model_with_symbols(&self) -> Model {
106
7288
        let parent = self.symbols.read().parent().clone();
107
7288
        let mut model = if let Some(parent) = parent {
108
7288
            Model::new_in_parent_scope(parent)
109
        } else {
110
            Model::default()
111
        };
112
7288
        *model.symbols_ptr_unchecked_mut() = self.symbols.clone();
113
7288
        model
114
7288
    }
115

            
116
    /// Adds a guard to the comprehension. Returns false if the guard does not only reference quantified variables.
117
    pub fn add_quantified_guard(&mut self, guard: Expression) -> bool {
118
        if self.is_quantified_guard(&guard) {
119
            self.qualifiers
120
                .push(ComprehensionQualifier::Condition(guard));
121
            true
122
        } else {
123
            false
124
        }
125
    }
126

            
127
    /// True iff expr only references quantified variables.
128
    pub fn is_quantified_guard(&self, expr: &Expression) -> bool {
129
        let quantified: BTreeSet<Name> = self.quantified_vars().into_iter().collect();
130
        is_quantified_guard(&quantified, expr)
131
    }
132
}
133

            
134
impl Typeable for Comprehension {
135
    fn return_type(&self) -> ReturnType {
136
        self.return_expression.return_type()
137
    }
138
}
139

            
140
impl Display for Comprehension {
141
4687040
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
142
4687040
        let generators_and_guards = self
143
4687040
            .qualifiers
144
4687040
            .iter()
145
5486200
            .map(|qualifier| match qualifier {
146
5168080
                ComprehensionQualifier::Generator { ptr } => {
147
5168080
                    let domain = ptr.domain().expect("generator declaration has domain");
148
5168080
                    format!("{} : {domain}", ptr.name())
149
                }
150
318120
                ComprehensionQualifier::Condition(expr) => format!("{expr}"),
151
5486200
            })
152
4687040
            .join(", ");
153

            
154
4687040
        write!(
155
4687040
            f,
156
            "[ {} | {generators_and_guards} ]",
157
            self.return_expression
158
        )
159
4687040
    }
160
}
161

            
162
/// A builder for a comprehension.
163
#[derive(Clone, Debug, PartialEq, Eq)]
164
pub struct ComprehensionBuilder {
165
    qualifiers: Vec<ComprehensionQualifier>,
166
    // A single scope for generators and return expression.
167
    symbols: SymbolTablePtr,
168
    quantified_variables: BTreeSet<Name>,
169
}
170

            
171
impl ComprehensionBuilder {
172
4144
    pub fn new(symbol_table_ptr: SymbolTablePtr) -> Self {
173
4144
        ComprehensionBuilder {
174
4144
            qualifiers: vec![],
175
4144
            symbols: SymbolTablePtr::with_parent(symbol_table_ptr),
176
4144
            quantified_variables: BTreeSet::new(),
177
4144
        }
178
4144
    }
179

            
180
    /// Backwards-compatible parser API: same table for generators and return expression.
181
2844
    pub fn generator_symboltable(&mut self) -> SymbolTablePtr {
182
2844
        self.symbols.clone()
183
2844
    }
184

            
185
    /// Backwards-compatible parser API: same table for generators and return expression.
186
4144
    pub fn return_expr_symboltable(&mut self) -> SymbolTablePtr {
187
4144
        self.symbols.clone()
188
4144
    }
189

            
190
600
    pub fn guard(mut self, guard: Expression) -> Self {
191
600
        self.qualifiers
192
600
            .push(ComprehensionQualifier::Condition(guard));
193
600
        self
194
600
    }
195

            
196
4572
    pub fn generator(mut self, declaration: DeclarationPtr) -> Self {
197
4572
        let name = declaration.name().clone();
198
4572
        assert!(!self.quantified_variables.contains(&name));
199

            
200
4572
        self.quantified_variables.insert(name.clone());
201

            
202
        // insert into comprehension scope as a local quantified variable
203
4572
        let quantified_decl = DeclarationPtr::new_quantified(name, declaration.domain().unwrap());
204
4572
        self.symbols.write().insert(quantified_decl.clone());
205

            
206
4572
        self.qualifiers.push(ComprehensionQualifier::Generator {
207
4572
            ptr: quantified_decl,
208
4572
        });
209

            
210
4572
        self
211
4572
    }
212

            
213
    /// Creates a comprehension with the given return expression.
214
    ///
215
    /// If this comprehension is inside an AC-operator, the kind of this operator should be passed
216
    /// in the `comprehension_kind` field.
217
    ///
218
    /// If a comprehension kind is not given, comprehension guards containing decision variables
219
    /// are invalid, and will cause a panic.
220
4124
    pub fn with_return_value(
221
4124
        self,
222
4124
        mut expression: Expression,
223
4124
        comprehension_kind: Option<ACOperatorKind>,
224
4124
    ) -> Comprehension {
225
4124
        let quantified_variables = self.quantified_variables;
226

            
227
4124
        let mut qualifiers = Vec::new();
228
4124
        let mut other_guards = Vec::new();
229

            
230
5172
        for qualifier in self.qualifiers {
231
5172
            match qualifier {
232
4572
                ComprehensionQualifier::Generator { .. } => qualifiers.push(qualifier),
233
600
                ComprehensionQualifier::Condition(condition) => {
234
600
                    if is_quantified_guard(&quantified_variables, &condition) {
235
360
                        qualifiers.push(ComprehensionQualifier::Condition(condition));
236
360
                    } else {
237
240
                        other_guards.push(condition);
238
240
                    }
239
                }
240
            }
241
        }
242

            
243
        // handle guards that reference non-quantified variables
244
4124
        if !other_guards.is_empty() {
245
240
            let comprehension_kind = comprehension_kind.expect(
246
240
                "if any guards reference decision variables, a comprehension kind should be given",
247
            );
248

            
249
240
            let guard_expr = match other_guards.as_slice() {
250
240
                [x] => x.clone(),
251
                xs => Expression::And(Metadata::new(), Moo::new(into_matrix_expr!(xs.to_vec()))),
252
            };
253

            
254
240
            expression = match comprehension_kind {
255
                ACOperatorKind::And => {
256
240
                    Expression::Imply(Metadata::new(), Moo::new(guard_expr), Moo::new(expression))
257
                }
258
                ACOperatorKind::Or => Expression::And(
259
                    Metadata::new(),
260
                    Moo::new(matrix_expr![guard_expr, expression]),
261
                ),
262

            
263
                ACOperatorKind::Sum => {
264
                    panic!("guards that reference decision variables not yet implemented for sum");
265
                }
266

            
267
                ACOperatorKind::Product => {
268
                    panic!(
269
                        "guards that reference decision variables not yet implemented for product"
270
                    );
271
                }
272
            }
273
3884
        }
274

            
275
4124
        Comprehension {
276
4124
            return_expression: expression,
277
4124
            qualifiers,
278
4124
            symbols: self.symbols,
279
4124
        }
280
4124
    }
281
}
282

            
283
/// True iff the guard only references quantified variables.
284
600
fn is_quantified_guard(quantified_variables: &BTreeSet<Name>, guard: &Expression) -> bool {
285
600
    guard
286
600
        .universe_bi()
287
600
        .iter()
288
600
        .all(|x| quantified_variables.contains(x))
289
600
}