Skip to main content

conjure_cp_core/ast/
comprehension.rs

1#![allow(clippy::arc_with_non_send_sync)]
2
3use std::{collections::BTreeSet, fmt::Display};
4
5use crate::{ast::Metadata, into_matrix_expr, matrix_expr};
6use conjure_cp_core::ast::ReturnType;
7use itertools::Itertools as _;
8use parking_lot::RwLockReadGuard;
9use serde::{Deserialize, Serialize};
10use serde_with::serde_as;
11use uniplate::{Biplate, Uniplate};
12
13use super::{
14    DeclarationPtr, Domain, DomainPtr, Expression, Model, Moo, Name, Range, SymbolTable,
15    SymbolTablePtr, Typeable,
16    ac_operators::ACOperatorKind,
17    serde::{AsId, PtrAsInner},
18};
19
20#[serde_as]
21#[derive(Clone, PartialEq, Eq, Hash, Serialize, Deserialize, Debug, Uniplate)]
22#[biplate(to=Expression)]
23#[biplate(to=Name)]
24#[biplate(to=DeclarationPtr)]
25pub enum ComprehensionQualifier {
26    Generator {
27        #[serde_as(as = "AsId")]
28        ptr: DeclarationPtr,
29    },
30    Condition(Expression),
31}
32
33/// A comprehension.
34#[serde_as]
35#[derive(Clone, PartialEq, Eq, Hash, Uniplate, Serialize, Deserialize, Debug)]
36#[biplate(to=Expression)]
37#[biplate(to=SymbolTable)]
38#[biplate(to=SymbolTablePtr)]
39#[non_exhaustive]
40pub struct Comprehension {
41    pub return_expression: Expression,
42    pub qualifiers: Vec<ComprehensionQualifier>,
43    #[doc(hidden)]
44    #[serde_as(as = "PtrAsInner")]
45    pub symbols: SymbolTablePtr,
46}
47
48impl Comprehension {
49    pub fn domain_of(&self) -> Option<DomainPtr> {
50        let return_expr_domain = self.return_expression.domain_of()?;
51
52        // return a list (matrix with index domain int(1..)) of return_expr elements
53        Some(Domain::matrix(
54            return_expr_domain,
55            vec![Domain::int(vec![Range::UnboundedR(1)])],
56        ))
57    }
58
59    pub fn return_expression(self) -> Expression {
60        self.return_expression
61    }
62
63    pub fn replace_return_expression(&mut self, new_expr: Expression) {
64        self.return_expression = new_expr;
65    }
66
67    pub fn symbols(&self) -> RwLockReadGuard<'_, SymbolTable> {
68        self.symbols.read()
69    }
70
71    pub fn quantified_vars(&self) -> Vec<Name> {
72        self.qualifiers
73            .iter()
74            .filter_map(|q| match q {
75                ComprehensionQualifier::Generator { ptr } => Some(ptr.name().clone()),
76                ComprehensionQualifier::Condition(_) => None,
77            })
78            .collect()
79    }
80
81    pub fn generator_conditions(&self) -> Vec<Expression> {
82        self.qualifiers
83            .iter()
84            .filter_map(|q| match q {
85                ComprehensionQualifier::Condition(c) => Some(c.clone()),
86                ComprehensionQualifier::Generator { .. } => None,
87            })
88            .collect()
89    }
90
91    /// Builds a temporary model containing generator qualifiers and guards.
92    pub fn to_generator_model(&self) -> Model {
93        let mut model = self.empty_model_with_symbols();
94        model.add_constraints(self.generator_conditions());
95        model
96    }
97
98    /// Builds a temporary model containing the return expression only.
99    pub fn to_return_expression_model(&self) -> Model {
100        let mut model = self.empty_model_with_symbols();
101        model.add_constraint(self.return_expression.clone());
102        model
103    }
104
105    fn empty_model_with_symbols(&self) -> Model {
106        let parent = self.symbols.read().parent().clone();
107        let mut model = if let Some(parent) = parent {
108            Model::new_in_parent_scope(parent)
109        } else {
110            Model::default()
111        };
112        *model.symbols_ptr_unchecked_mut() = self.symbols.clone();
113        model
114    }
115
116    /// Adds a guard to the comprehension. Returns false if the guard does not only reference quantified variables.
117    pub fn add_quantified_guard(&mut self, guard: Expression) -> bool {
118        if self.is_quantified_guard(&guard) {
119            self.qualifiers
120                .push(ComprehensionQualifier::Condition(guard));
121            true
122        } else {
123            false
124        }
125    }
126
127    /// True iff expr only references quantified variables.
128    pub fn is_quantified_guard(&self, expr: &Expression) -> bool {
129        let quantified: BTreeSet<Name> = self.quantified_vars().into_iter().collect();
130        is_quantified_guard(&quantified, expr)
131    }
132}
133
134impl Typeable for Comprehension {
135    fn return_type(&self) -> ReturnType {
136        self.return_expression.return_type()
137    }
138}
139
140impl Display for Comprehension {
141    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
142        let generators_and_guards = self
143            .qualifiers
144            .iter()
145            .map(|qualifier| match qualifier {
146                ComprehensionQualifier::Generator { ptr } => {
147                    let domain = ptr.domain().expect("generator declaration has domain");
148                    format!("{} : {domain}", ptr.name())
149                }
150                ComprehensionQualifier::Condition(expr) => format!("{expr}"),
151            })
152            .join(", ");
153
154        write!(
155            f,
156            "[ {} | {generators_and_guards} ]",
157            self.return_expression
158        )
159    }
160}
161
162/// A builder for a comprehension.
163#[derive(Clone, Debug, PartialEq, Eq)]
164pub struct ComprehensionBuilder {
165    qualifiers: Vec<ComprehensionQualifier>,
166    // A single scope for generators and return expression.
167    symbols: SymbolTablePtr,
168    quantified_variables: BTreeSet<Name>,
169}
170
171impl ComprehensionBuilder {
172    pub fn new(symbol_table_ptr: SymbolTablePtr) -> Self {
173        ComprehensionBuilder {
174            qualifiers: vec![],
175            symbols: SymbolTablePtr::with_parent(symbol_table_ptr),
176            quantified_variables: BTreeSet::new(),
177        }
178    }
179
180    /// Backwards-compatible parser API: same table for generators and return expression.
181    pub fn generator_symboltable(&mut self) -> SymbolTablePtr {
182        self.symbols.clone()
183    }
184
185    /// Backwards-compatible parser API: same table for generators and return expression.
186    pub fn return_expr_symboltable(&mut self) -> SymbolTablePtr {
187        self.symbols.clone()
188    }
189
190    pub fn guard(mut self, guard: Expression) -> Self {
191        self.qualifiers
192            .push(ComprehensionQualifier::Condition(guard));
193        self
194    }
195
196    pub fn generator(mut self, declaration: DeclarationPtr) -> Self {
197        let name = declaration.name().clone();
198        assert!(!self.quantified_variables.contains(&name));
199
200        self.quantified_variables.insert(name.clone());
201
202        // insert into comprehension scope as a local quantified variable
203        let quantified_decl = DeclarationPtr::new_quantified(name, declaration.domain().unwrap());
204        self.symbols.write().insert(quantified_decl.clone());
205
206        self.qualifiers.push(ComprehensionQualifier::Generator {
207            ptr: quantified_decl,
208        });
209
210        self
211    }
212
213    /// Creates a comprehension with the given return expression.
214    ///
215    /// If this comprehension is inside an AC-operator, the kind of this operator should be passed
216    /// in the `comprehension_kind` field.
217    ///
218    /// If a comprehension kind is not given, comprehension guards containing decision variables
219    /// are invalid, and will cause a panic.
220    pub fn with_return_value(
221        self,
222        mut expression: Expression,
223        comprehension_kind: Option<ACOperatorKind>,
224    ) -> Comprehension {
225        let quantified_variables = self.quantified_variables;
226
227        let mut qualifiers = Vec::new();
228        let mut other_guards = Vec::new();
229
230        for qualifier in self.qualifiers {
231            match qualifier {
232                ComprehensionQualifier::Generator { .. } => qualifiers.push(qualifier),
233                ComprehensionQualifier::Condition(condition) => {
234                    if is_quantified_guard(&quantified_variables, &condition) {
235                        qualifiers.push(ComprehensionQualifier::Condition(condition));
236                    } else {
237                        other_guards.push(condition);
238                    }
239                }
240            }
241        }
242
243        // handle guards that reference non-quantified variables
244        if !other_guards.is_empty() {
245            let comprehension_kind = comprehension_kind.expect(
246                "if any guards reference decision variables, a comprehension kind should be given",
247            );
248
249            let guard_expr = match other_guards.as_slice() {
250                [x] => x.clone(),
251                xs => Expression::And(Metadata::new(), Moo::new(into_matrix_expr!(xs.to_vec()))),
252            };
253
254            expression = match comprehension_kind {
255                ACOperatorKind::And => {
256                    Expression::Imply(Metadata::new(), Moo::new(guard_expr), Moo::new(expression))
257                }
258                ACOperatorKind::Or => Expression::And(
259                    Metadata::new(),
260                    Moo::new(matrix_expr![guard_expr, expression]),
261                ),
262
263                ACOperatorKind::Sum => {
264                    panic!("guards that reference decision variables not yet implemented for sum");
265                }
266
267                ACOperatorKind::Product => {
268                    panic!(
269                        "guards that reference decision variables not yet implemented for product"
270                    );
271                }
272            }
273        }
274
275        Comprehension {
276            return_expression: expression,
277            qualifiers,
278            symbols: self.symbols,
279        }
280    }
281}
282
283/// True iff the guard only references quantified variables.
284fn is_quantified_guard(quantified_variables: &BTreeSet<Name>, guard: &Expression) -> bool {
285    guard
286        .universe_bi()
287        .iter()
288        .all(|x| quantified_variables.contains(x))
289}