Skip to main content

conjure_cp_core/ast/
comprehension.rs

1#![allow(clippy::arc_with_non_send_sync)]
2
3use std::{
4    collections::BTreeSet,
5    fmt::Display,
6    sync::atomic::{AtomicBool, AtomicU8, Ordering},
7};
8
9use crate::settings::QuantifiedExpander;
10use crate::{ast::Metadata, into_matrix_expr, matrix_expr};
11use conjure_cp_core::ast::ReturnType;
12use itertools::Itertools as _;
13use serde::{Deserialize, Serialize};
14use uniplate::{Biplate, Uniplate};
15
16use super::{
17    DeclarationPtr, Domain, DomainPtr, Expression, Moo, Name, Range, SubModel, SymbolTablePtr,
18    Typeable, ac_operators::ACOperatorKind,
19};
20
21// TODO: move this global setting somewhere better?
22
23/// The rewriter to use for rewriting comprehensions.
24///
25/// True for optimised, false for naive
26pub static USE_OPTIMISED_REWRITER_FOR_COMPREHENSIONS: AtomicBool = AtomicBool::new(false);
27
28/// Global setting for which comprehension quantified-variable expander to use.
29///
30/// Defaults to [`QuantifiedExpander::Native`].
31pub static QUANTIFIED_EXPANDER_FOR_COMPREHENSIONS: AtomicU8 =
32    AtomicU8::new(QuantifiedExpander::Native.as_u8());
33
34pub fn set_quantified_expander_for_comprehensions(expander: QuantifiedExpander) {
35    QUANTIFIED_EXPANDER_FOR_COMPREHENSIONS.store(expander.as_u8(), Ordering::Relaxed);
36}
37
38pub fn quantified_expander_for_comprehensions() -> QuantifiedExpander {
39    QuantifiedExpander::from_u8(QUANTIFIED_EXPANDER_FOR_COMPREHENSIONS.load(Ordering::Relaxed))
40}
41
42// TODO: do not use Names to compare variables, use DeclarationPtr and ids instead
43// see issue #930
44//
45// this will simplify *a lot* of the knarly stuff here, but can only be done once everything else
46// uses DeclarationPtr.
47//
48// ~ nikdewally, 10/06/25
49
50/// A comprehension.
51#[derive(Clone, PartialEq, Eq, Hash, Uniplate, Serialize, Deserialize, Debug)]
52#[biplate(to=SubModel)]
53#[biplate(to=Expression)]
54#[non_exhaustive]
55pub struct Comprehension {
56    #[doc(hidden)]
57    pub return_expression_submodel: SubModel,
58    #[doc(hidden)]
59    pub generator_submodel: SubModel,
60    #[doc(hidden)]
61    pub quantified_vars: Vec<Name>,
62}
63
64impl Comprehension {
65    pub fn domain_of(&self) -> Option<DomainPtr> {
66        let return_expr_domain = self
67            .return_expression_submodel
68            .clone()
69            .into_single_expression()
70            .domain_of()?;
71
72        // return a list (matrix with index domain int(1..)) of return_expr elements
73        Some(Domain::matrix(
74            return_expr_domain,
75            vec![Domain::int(vec![Range::UnboundedR(1)])],
76        ))
77    }
78
79    pub fn return_expression(self) -> Expression {
80        self.return_expression_submodel.into_single_expression()
81    }
82
83    pub fn replace_return_expression(&mut self, new_expr: Expression) {
84        let new_expr = match new_expr {
85            Expression::And(_, exprs) if (*exprs).clone().unwrap_list().is_some() => {
86                Expression::Root(Metadata::new(), (*exprs).clone().unwrap_list().unwrap())
87            }
88            expr => Expression::Root(Metadata::new(), vec![expr]),
89        };
90
91        *self.return_expression_submodel.root_mut_unchecked() = new_expr;
92    }
93
94    /// Adds a guard to the comprehension. Returns false if the guard does not only reference quantified variables.
95    pub fn add_quantified_guard(&mut self, guard: Expression) -> bool {
96        if self.is_quantified_guard(&guard) {
97            self.generator_submodel.add_constraint(guard);
98            true
99        } else {
100            false
101        }
102    }
103
104    /// True iff expr only references quantified variables.
105    pub fn is_quantified_guard(&self, expr: &Expression) -> bool {
106        is_quantified_guard(&(self.quantified_vars.clone().into_iter().collect()), expr)
107    }
108}
109
110impl Typeable for Comprehension {
111    fn return_type(&self) -> ReturnType {
112        self.return_expression_submodel
113            .clone()
114            .into_single_expression()
115            .return_type()
116    }
117}
118
119impl Display for Comprehension {
120    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
121        let return_expression = self
122            .return_expression_submodel
123            .clone()
124            .into_single_expression();
125
126        let generator_symbols = self.generator_submodel.symbols().clone();
127        let generators = self
128            .quantified_vars
129            .iter()
130            .map(|name| {
131                let decl: DeclarationPtr = generator_symbols
132                    .lookup_local(name)
133                    .expect("quantified variable should be in the generator symbol table");
134                let domain: DomainPtr = decl.domain().unwrap();
135                format!("{name} : {domain}")
136            })
137            .collect_vec();
138
139        let guards = self
140            .generator_submodel
141            .constraints()
142            .iter()
143            .map(|x| format!("{x}"))
144            .collect_vec();
145
146        let generators_and_guards = generators.into_iter().chain(guards).join(", ");
147
148        write!(f, "[ {return_expression} | {generators_and_guards} ]")
149    }
150}
151
152/// A builder for a comprehension.
153#[derive(Clone, Debug, PartialEq, Eq)]
154pub struct ComprehensionBuilder {
155    guards: Vec<Expression>,
156    // symbol table containing all the generators
157    // for now, this is just used during parsing - a new symbol table is created using this when we initialise the comprehension
158    // this is not ideal, but i am chucking all this code very soon anyways...
159    generator_symboltable: SymbolTablePtr,
160    return_expr_symboltable: SymbolTablePtr,
161    quantified_variables: BTreeSet<Name>,
162}
163
164impl ComprehensionBuilder {
165    pub fn new(symbol_table_ptr: SymbolTablePtr) -> Self {
166        ComprehensionBuilder {
167            guards: vec![],
168            generator_symboltable: SymbolTablePtr::with_parent(symbol_table_ptr.clone()),
169            return_expr_symboltable: SymbolTablePtr::with_parent(symbol_table_ptr),
170            quantified_variables: BTreeSet::new(),
171        }
172    }
173
174    /// The symbol table for the comprehension generators
175    pub fn generator_symboltable(&mut self) -> SymbolTablePtr {
176        self.generator_symboltable.clone()
177    }
178
179    /// The symbol table for the comprehension return expression
180    pub fn return_expr_symboltable(&mut self) -> SymbolTablePtr {
181        self.return_expr_symboltable.clone()
182    }
183
184    pub fn guard(mut self, guard: Expression) -> Self {
185        self.guards.push(guard);
186        self
187    }
188
189    pub fn generator(mut self, declaration: DeclarationPtr) -> Self {
190        let name = declaration.name().clone();
191        let domain = declaration.domain().unwrap();
192        assert!(!self.quantified_variables.contains(&name));
193
194        self.quantified_variables.insert(name.clone());
195
196        // insert into generator symbol table as a local quantified variable
197        let quantified_decl = DeclarationPtr::new_quantified(name, domain);
198        self.generator_symboltable
199            .write()
200            .insert(quantified_decl.clone());
201
202        // insert into return expression symbol table as a quantified variable
203        self.return_expr_symboltable.write().insert(
204            DeclarationPtr::new_quantified_from_generator(&quantified_decl)
205                .expect("quantified variables should always have a domain"),
206        );
207
208        self
209    }
210
211    /// Creates a comprehension with the given return expression.
212    ///
213    /// If this comprehension is inside an AC-operator, the kind of this operator should be passed
214    /// in the `comprehension_kind` field.
215    ///
216    /// If a comprehension kind is not given, comprehension guards containing decision variables
217    /// are invalid, and will cause a panic.
218    pub fn with_return_value(
219        self,
220        mut expression: Expression,
221        comprehension_kind: Option<ACOperatorKind>,
222    ) -> Comprehension {
223        let parent_symboltable = self.generator_symboltable.read().parent().clone().unwrap();
224
225        let mut generator_submodel = SubModel::new(parent_symboltable.clone());
226        let mut return_expression_submodel = SubModel::new(parent_symboltable);
227
228        *generator_submodel.symbols_ptr_unchecked_mut() = self.generator_symboltable;
229        *return_expression_submodel.symbols_ptr_unchecked_mut() = self.return_expr_symboltable;
230
231        // TODO:also allow guards that reference lettings and givens.
232
233        let quantified_variables = self.quantified_variables;
234
235        // only guards referencing quantified variables can go inside the comprehension
236        let (mut quantified_guards, mut other_guards): (Vec<_>, Vec<_>) = self
237            .guards
238            .into_iter()
239            .partition(|x| is_quantified_guard(&quantified_variables, x));
240
241        let quantified_variables_2 = quantified_variables.clone();
242        let generator_symboltable_ptr = generator_submodel.symbols_ptr_unchecked().clone();
243
244        // fix quantified guard pointers so that they all point to variables in the generator model
245        quantified_guards =
246            Biplate::<DeclarationPtr>::transform_bi(&quantified_guards, &move |decl| {
247                if quantified_variables_2.contains(&decl.name()) {
248                    generator_symboltable_ptr
249                        .read()
250                        .lookup_local(&decl.name())
251                        .unwrap()
252                } else {
253                    decl
254                }
255            })
256            .into_iter()
257            .collect_vec();
258
259        let quantified_variables_2 = quantified_variables.clone();
260        let return_expr_symboltable_ptr =
261            return_expression_submodel.symbols_ptr_unchecked().clone();
262
263        // fix other guard pointers so that they all point to variables in the return expr model
264        other_guards = Biplate::<DeclarationPtr>::transform_bi(&other_guards, &move |decl| {
265            if quantified_variables_2.contains(&decl.name()) {
266                return_expr_symboltable_ptr
267                    .read()
268                    .lookup_local(&decl.name())
269                    .unwrap()
270            } else {
271                decl
272            }
273        })
274        .into_iter()
275        .collect_vec();
276
277        // handle guards that reference non-quantified variables
278        if !other_guards.is_empty() {
279            let comprehension_kind = comprehension_kind.expect(
280                "if any guards reference decision variables, a comprehension kind should be given",
281            );
282
283            let guard_expr = match other_guards.as_slice() {
284                [x] => x.clone(),
285                xs => Expression::And(Metadata::new(), Moo::new(into_matrix_expr!(xs.to_vec()))),
286            };
287
288            expression = match comprehension_kind {
289                ACOperatorKind::And => {
290                    Expression::Imply(Metadata::new(), Moo::new(guard_expr), Moo::new(expression))
291                }
292                ACOperatorKind::Or => Expression::And(
293                    Metadata::new(),
294                    Moo::new(Expression::And(
295                        Metadata::new(),
296                        Moo::new(matrix_expr![guard_expr, expression]),
297                    )),
298                ),
299
300                ACOperatorKind::Sum => {
301                    panic!("guards that reference decision variables not yet implemented for sum");
302                }
303
304                ACOperatorKind::Product => {
305                    panic!(
306                        "guards that reference decision variables not yet implemented for product"
307                    );
308                }
309            }
310        }
311
312        generator_submodel.add_constraints(quantified_guards);
313
314        return_expression_submodel.add_constraint(expression);
315
316        Comprehension {
317            return_expression_submodel,
318            generator_submodel,
319            quantified_vars: quantified_variables.into_iter().collect_vec(),
320        }
321    }
322}
323
324/// True iff the guard only references quantified variables.
325fn is_quantified_guard(quantified_variables: &BTreeSet<Name>, guard: &Expression) -> bool {
326    guard
327        .universe_bi()
328        .iter()
329        .all(|x| quantified_variables.contains(x))
330}