Skip to main content

conjure_cp_core/ast/
comprehension.rs

1#![allow(clippy::arc_with_non_send_sync)]
2
3use std::{collections::BTreeSet, fmt::Display};
4
5use crate::{ast::Metadata, into_matrix_expr, matrix_expr};
6use conjure_cp_core::ast::ReturnType;
7use itertools::Itertools as _;
8use parking_lot::RwLockReadGuard;
9use serde::{Deserialize, Serialize};
10use serde_with::serde_as;
11use uniplate::{Biplate, Uniplate};
12
13use super::{
14    DeclarationPtr, Domain, DomainPtr, Expression, Model, Moo, Name, Range, SymbolTable,
15    SymbolTablePtr, Typeable,
16    ac_operators::ACOperatorKind,
17    categories::{Category, CategoryOf},
18    serde::{AsId, PtrAsInner},
19};
20
21#[serde_as]
22#[derive(Clone, PartialEq, Eq, Hash, Serialize, Deserialize, Debug, Uniplate)]
23#[biplate(to=Expression)]
24#[biplate(to=Name)]
25#[biplate(to=DeclarationPtr)]
26pub enum ComprehensionQualifier {
27    ExpressionGenerator {
28        #[serde_as(as = "AsId")]
29        ptr: DeclarationPtr,
30    },
31    Generator {
32        #[serde_as(as = "AsId")]
33        ptr: DeclarationPtr,
34    },
35    Condition(Expression),
36}
37
38/// A comprehension.
39#[serde_as]
40#[derive(Clone, PartialEq, Eq, Hash, Uniplate, Serialize, Deserialize, Debug)]
41#[biplate(to=Expression)]
42#[biplate(to=SymbolTable)]
43#[biplate(to=SymbolTablePtr)]
44#[non_exhaustive]
45pub struct Comprehension {
46    pub return_expression: Expression,
47    pub qualifiers: Vec<ComprehensionQualifier>,
48    #[doc(hidden)]
49    #[serde_as(as = "PtrAsInner")]
50    pub symbols: SymbolTablePtr,
51}
52
53impl Comprehension {
54    pub fn domain_of(&self) -> Option<DomainPtr> {
55        let return_expr_domain = self.return_expression.domain_of()?;
56
57        // return a list (matrix with index domain int(1..)) of return_expr elements
58        Some(Domain::matrix(
59            return_expr_domain,
60            vec![Domain::int(vec![Range::UnboundedR(1)])],
61        ))
62    }
63
64    pub fn return_expression(self) -> Expression {
65        self.return_expression
66    }
67
68    pub fn replace_return_expression(&mut self, new_expr: Expression) {
69        self.return_expression = new_expr;
70    }
71
72    pub fn symbols(&self) -> RwLockReadGuard<'_, SymbolTable> {
73        self.symbols.read()
74    }
75
76    pub fn quantified_vars(&self) -> Vec<Name> {
77        self.qualifiers
78            .iter()
79            .filter_map(|q| match q {
80                ComprehensionQualifier::ExpressionGenerator { ptr } => Some(ptr.name().clone()),
81                ComprehensionQualifier::Generator { ptr } => Some(ptr.name().clone()),
82                ComprehensionQualifier::Condition(_) => None,
83            })
84            .collect()
85    }
86
87    pub fn generator_conditions(&self) -> Vec<Expression> {
88        self.qualifiers
89            .iter()
90            .filter_map(|q| match q {
91                ComprehensionQualifier::Condition(c) => Some(c.clone()),
92                ComprehensionQualifier::Generator { .. } => None,
93                ComprehensionQualifier::ExpressionGenerator { .. } => None,
94            })
95            .collect()
96    }
97
98    /// Builds a temporary model containing generator qualifiers and guards.
99    pub fn to_generator_model(&self) -> Model {
100        let mut model = self.empty_model_with_symbols();
101        model.add_constraints(self.generator_conditions());
102        model
103    }
104
105    /// Builds a temporary model containing the return expression only.
106    pub fn to_return_expression_model(&self) -> Model {
107        let mut model = self.empty_model_with_symbols();
108        model.add_constraint(self.return_expression.clone());
109        model
110    }
111
112    fn empty_model_with_symbols(&self) -> Model {
113        let parent = self.symbols.read().parent().clone();
114        let mut model = if let Some(parent) = parent {
115            Model::new_in_parent_scope(parent)
116        } else {
117            Model::default()
118        };
119        *model.symbols_ptr_unchecked_mut() = self.symbols.clone();
120        model
121    }
122
123    /// Adds a guard to the comprehension.
124    ///
125    /// Returns false if the guard references non-quantified decision variables.
126    pub fn add_quantified_guard(&mut self, guard: Expression) -> bool {
127        if self.is_quantified_guard(&guard) {
128            self.qualifiers
129                .push(ComprehensionQualifier::Condition(guard));
130            true
131        } else {
132            false
133        }
134    }
135
136    /// True iff expr does not reference non-quantified decision variables.
137    pub fn is_quantified_guard(&self, expr: &Expression) -> bool {
138        let quantified: BTreeSet<Name> = self.quantified_vars().into_iter().collect();
139        is_quantified_guard(&self.symbols.read(), &quantified, expr)
140    }
141}
142
143impl Typeable for Comprehension {
144    fn return_type(&self) -> ReturnType {
145        self.return_expression.return_type()
146    }
147}
148
149impl Display for Comprehension {
150    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
151        let generators_and_guards = self
152            .qualifiers
153            .iter()
154            .map(|qualifier| match qualifier {
155                ComprehensionQualifier::Generator { ptr } => {
156                    let domain = ptr.domain().expect("generator declaration has domain");
157                    format!("{} : {domain}", ptr.name())
158                }
159                ComprehensionQualifier::ExpressionGenerator { ptr } => {
160                    let name = ptr.name();
161                    if let Some(expr) = ptr.as_quantified_expr() {
162                        format!("{name} <- {expr}")
163                    } else {
164                        panic!("Oh nein! Dat is nicht gut!")
165                    }
166                }
167                ComprehensionQualifier::Condition(expr) => format!("{expr}"),
168            })
169            .join(", ");
170
171        write!(
172            f,
173            "[ {} | {generators_and_guards} ]",
174            self.return_expression
175        )
176    }
177}
178
179/// A builder for a comprehension.
180#[derive(Clone, Debug, PartialEq, Eq)]
181pub struct ComprehensionBuilder {
182    qualifiers: Vec<ComprehensionQualifier>,
183    // A single scope for generators and return expression.
184    symbols: SymbolTablePtr,
185    quantified_variables: BTreeSet<Name>,
186}
187
188impl ComprehensionBuilder {
189    pub fn new(symbol_table_ptr: SymbolTablePtr) -> Self {
190        ComprehensionBuilder {
191            qualifiers: vec![],
192            symbols: SymbolTablePtr::with_parent(symbol_table_ptr),
193            quantified_variables: BTreeSet::new(),
194        }
195    }
196
197    /// Backwards-compatible parser API: same table for generators and return expression.
198    pub fn generator_symboltable(&mut self) -> SymbolTablePtr {
199        self.symbols.clone()
200    }
201
202    /// Backwards-compatible parser API: same table for generators and return expression.
203    pub fn return_expr_symboltable(&mut self) -> SymbolTablePtr {
204        self.symbols.clone()
205    }
206
207    pub fn guard(mut self, guard: Expression) -> Self {
208        self.qualifiers
209            .push(ComprehensionQualifier::Condition(guard));
210        self
211    }
212
213    pub fn generator(mut self, declaration: DeclarationPtr) -> Self {
214        let name = declaration.name().clone();
215        assert!(!self.quantified_variables.contains(&name));
216
217        self.quantified_variables.insert(name.clone());
218
219        // insert into comprehension scope as a local quantified variable
220        let quantified_decl = DeclarationPtr::new_quantified(name, declaration.domain().unwrap());
221        self.symbols.write().insert(quantified_decl.clone());
222
223        self.qualifiers.push(ComprehensionQualifier::Generator {
224            ptr: quantified_decl,
225        });
226
227        self
228    }
229
230    pub fn expression_generator(mut self, name: Name, expr: Expression) -> Self {
231        assert!(!self.quantified_variables.contains(&name));
232
233        self.quantified_variables.insert(name.clone());
234
235        // insert into comprehension scope as a local quantified variable
236        let quantified_decl = DeclarationPtr::new_quantified_expr(name, expr);
237        self.symbols.write().insert(quantified_decl.clone());
238
239        self.qualifiers
240            .push(ComprehensionQualifier::ExpressionGenerator {
241                ptr: quantified_decl,
242            });
243
244        self
245    }
246
247    /// Creates a comprehension with the given return expression.
248    ///
249    /// If this comprehension is inside an AC-operator, the kind of this operator should be passed
250    /// in the `comprehension_kind` field.
251    ///
252    /// If a comprehension kind is not given, comprehension guards containing non-quantified
253    /// decision variables are invalid, and will cause a panic.
254    pub fn with_return_value(
255        self,
256        mut expression: Expression,
257        comprehension_kind: Option<ACOperatorKind>,
258    ) -> Comprehension {
259        let quantified_variables = self.quantified_variables;
260        let symbols = self.symbols.read();
261
262        let mut qualifiers = Vec::new();
263        let mut other_guards = Vec::new();
264
265        for qualifier in self.qualifiers {
266            match qualifier {
267                ComprehensionQualifier::Generator { .. } => qualifiers.push(qualifier),
268                ComprehensionQualifier::ExpressionGenerator { .. } => qualifiers.push(qualifier),
269                ComprehensionQualifier::Condition(condition) => {
270                    if is_quantified_guard(&symbols, &quantified_variables, &condition) {
271                        qualifiers.push(ComprehensionQualifier::Condition(condition));
272                    } else {
273                        other_guards.push(condition);
274                    }
275                }
276            }
277        }
278        drop(symbols);
279
280        // handle guards that reference non-quantified decision variables
281        if !other_guards.is_empty() {
282            let comprehension_kind = comprehension_kind.expect(
283                "if any guards reference decision variables, a comprehension kind should be given",
284            );
285
286            let guard_expr = match other_guards.as_slice() {
287                [x] => x.clone(),
288                xs => Expression::And(Metadata::new(), Moo::new(into_matrix_expr!(xs.to_vec()))),
289            };
290
291            expression = match comprehension_kind {
292                ACOperatorKind::And => {
293                    Expression::Imply(Metadata::new(), Moo::new(guard_expr), Moo::new(expression))
294                }
295                ACOperatorKind::Or => Expression::And(
296                    Metadata::new(),
297                    Moo::new(matrix_expr![guard_expr, expression]),
298                ),
299
300                ACOperatorKind::Sum => {
301                    panic!("guards that reference decision variables not yet implemented for sum");
302                }
303
304                ACOperatorKind::Product => {
305                    panic!(
306                        "guards that reference decision variables not yet implemented for product"
307                    );
308                }
309            }
310        }
311
312        Comprehension {
313            return_expression: expression,
314            qualifiers,
315            symbols: self.symbols,
316        }
317    }
318}
319
320/// True iff the guard does not reference non-quantified decision variables.
321fn is_quantified_guard(
322    symbols: &SymbolTable,
323    quantified_variables: &BTreeSet<Name>,
324    guard: &Expression,
325) -> bool {
326    guard.universe_bi().iter().all(|name| {
327        quantified_variables.contains(name)
328            || symbols
329                .lookup(name)
330                .is_some_and(|decl| decl.category_of() != Category::Decision)
331    })
332}