conjure_cp_core/ast/
comprehension.rs

1#![allow(clippy::arc_with_non_send_sync)]
2
3use std::{cell::RefCell, collections::BTreeSet, fmt::Display, rc::Rc, sync::atomic::AtomicBool};
4
5use crate::{ast::Metadata, into_matrix_expr, matrix_expr};
6use conjure_cp_core::ast::ReturnType;
7use itertools::Itertools as _;
8use serde::{Deserialize, Serialize};
9use uniplate::{Biplate, Uniplate};
10
11use super::{
12    DeclarationPtr, Domain, DomainPtr, Expression, Moo, Name, Range, SubModel, SymbolTable,
13    Typeable, ac_operators::ACOperatorKind,
14};
15
16// TODO: move this global setting somewhere better?
17
18/// The rewriter to use for rewriting comprehensions.
19///
20/// True for optimised, false for naive
21pub static USE_OPTIMISED_REWRITER_FOR_COMPREHENSIONS: AtomicBool = AtomicBool::new(false);
22
23// TODO: do not use Names to compare variables, use DeclarationPtr and ids instead
24// see issue #930
25//
26// this will simplify *a lot* of the knarly stuff here, but can only be done once everything else
27// uses DeclarationPtr.
28//
29// ~ nikdewally, 10/06/25
30
31/// A comprehension.
32#[derive(Clone, PartialEq, Eq, Hash, Uniplate, Serialize, Deserialize, Debug)]
33#[biplate(to=SubModel)]
34#[biplate(to=Expression)]
35#[non_exhaustive]
36pub struct Comprehension {
37    #[doc(hidden)]
38    pub return_expression_submodel: SubModel,
39    #[doc(hidden)]
40    pub generator_submodel: SubModel,
41    #[doc(hidden)]
42    pub induction_vars: Vec<Name>,
43}
44
45impl Comprehension {
46    pub fn domain_of(&self) -> Option<DomainPtr> {
47        let return_expr_domain = self
48            .return_expression_submodel
49            .clone()
50            .into_single_expression()
51            .domain_of()?;
52
53        // return a list (matrix with index domain int(1..)) of return_expr elements
54        Some(Domain::matrix(
55            return_expr_domain,
56            vec![Domain::int(vec![Range::UnboundedR(1)])],
57        ))
58    }
59
60    pub fn return_expression(self) -> Expression {
61        self.return_expression_submodel.into_single_expression()
62    }
63
64    pub fn replace_return_expression(&mut self, new_expr: Expression) {
65        let new_expr = match new_expr {
66            Expression::And(_, exprs) if (*exprs).clone().unwrap_list().is_some() => {
67                Expression::Root(Metadata::new(), (*exprs).clone().unwrap_list().unwrap())
68            }
69            expr => Expression::Root(Metadata::new(), vec![expr]),
70        };
71
72        *self.return_expression_submodel.root_mut_unchecked() = new_expr;
73    }
74
75    /// Adds a guard to the comprehension. Returns false if the guard does not only reference induction variables.
76    pub fn add_induction_guard(&mut self, guard: Expression) -> bool {
77        if self.is_induction_guard(&guard) {
78            self.generator_submodel.add_constraint(guard);
79            true
80        } else {
81            false
82        }
83    }
84
85    /// True iff expr only references induction variables.
86    pub fn is_induction_guard(&self, expr: &Expression) -> bool {
87        is_induction_guard(&(self.induction_vars.clone().into_iter().collect()), expr)
88    }
89}
90
91impl Typeable for Comprehension {
92    fn return_type(&self) -> ReturnType {
93        self.return_expression_submodel
94            .clone()
95            .into_single_expression()
96            .return_type()
97    }
98}
99
100impl Display for Comprehension {
101    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
102        let generators: String = self
103            .generator_submodel
104            .symbols()
105            .clone()
106            .into_iter_local()
107            .map(|(name, decl): (Name, DeclarationPtr)| {
108                let domain: DomainPtr = decl.domain().unwrap();
109                (name, domain)
110            })
111            .map(|(name, domain)| format!("{name}: {domain}"))
112            .join(",");
113
114        let guards = self
115            .generator_submodel
116            .constraints()
117            .iter()
118            .map(|x| format!("{x}"))
119            .join(",");
120
121        let generators_and_guards = itertools::join([generators, guards], ",");
122
123        let expression = &self.return_expression_submodel;
124        write!(f, "[{expression} | {generators_and_guards}]")
125    }
126}
127
128/// A builder for a comprehension.
129#[derive(Clone, Debug, PartialEq, Eq)]
130pub struct ComprehensionBuilder {
131    guards: Vec<Expression>,
132    // symbol table containing all the generators
133    // for now, this is just used during parsing - a new symbol table is created using this when we initialise the comprehension
134    // this is not ideal, but i am chucking all this code very soon anyways...
135    generator_symboltable: Rc<RefCell<SymbolTable>>,
136    return_expr_symboltable: Rc<RefCell<SymbolTable>>,
137    induction_variables: BTreeSet<Name>,
138}
139
140impl ComprehensionBuilder {
141    pub fn new(symbol_table_ptr: Rc<RefCell<SymbolTable>>) -> Self {
142        ComprehensionBuilder {
143            guards: vec![],
144            generator_symboltable: Rc::new(RefCell::new(SymbolTable::with_parent(
145                symbol_table_ptr.clone(),
146            ))),
147            return_expr_symboltable: Rc::new(RefCell::new(SymbolTable::with_parent(
148                symbol_table_ptr,
149            ))),
150            induction_variables: BTreeSet::new(),
151        }
152    }
153
154    /// The symbol table for the comprehension generators
155    pub fn generator_symboltable(&mut self) -> Rc<RefCell<SymbolTable>> {
156        Rc::clone(&self.generator_symboltable)
157    }
158
159    /// The symbol table for the comprehension return expression
160    pub fn return_expr_symboltable(&mut self) -> Rc<RefCell<SymbolTable>> {
161        Rc::clone(&self.return_expr_symboltable)
162    }
163
164    pub fn guard(mut self, guard: Expression) -> Self {
165        self.guards.push(guard);
166        self
167    }
168
169    pub fn generator(mut self, declaration: DeclarationPtr) -> Self {
170        let name = declaration.name().clone();
171        let domain = declaration.domain().unwrap();
172        assert!(!self.induction_variables.contains(&name));
173
174        self.induction_variables.insert(name.clone());
175
176        // insert into generator symbol table as a variable
177        (*self.generator_symboltable)
178            .borrow_mut()
179            .insert(declaration);
180
181        // insert into return expression symbol table as a given
182        (*self.return_expr_symboltable)
183            .borrow_mut()
184            .insert(DeclarationPtr::new_given(name, domain));
185
186        self
187    }
188
189    /// Creates a comprehension with the given return expression.
190    ///
191    /// If this comprehension is inside an AC-operator, the kind of this operator should be passed
192    /// in the `comprehension_kind` field.
193    ///
194    /// If a comprehension kind is not given, comprehension guards containing decision variables
195    /// are invalid, and will cause a panic.
196    pub fn with_return_value(
197        self,
198        mut expression: Expression,
199        comprehension_kind: Option<ACOperatorKind>,
200    ) -> Comprehension {
201        let parent_symboltable = self
202            .generator_symboltable
203            .as_ref()
204            .borrow_mut()
205            .parent_mut_unchecked()
206            .clone()
207            .unwrap();
208        let mut generator_submodel = SubModel::new(parent_symboltable.clone());
209        let mut return_expression_submodel = SubModel::new(parent_symboltable);
210
211        *generator_submodel.symbols_ptr_unchecked_mut() = self.generator_symboltable;
212        *return_expression_submodel.symbols_ptr_unchecked_mut() = self.return_expr_symboltable;
213
214        // TODO:also allow guards that reference lettings and givens.
215
216        let induction_variables = self.induction_variables;
217
218        // only guards referencing induction variables can go inside the comprehension
219        let (mut induction_guards, mut other_guards): (Vec<_>, Vec<_>) = self
220            .guards
221            .into_iter()
222            .partition(|x| is_induction_guard(&induction_variables, x));
223
224        let induction_variables_2 = induction_variables.clone();
225        let generator_symboltable_ptr = generator_submodel.symbols_ptr_unchecked().clone();
226
227        // fix induction guard pointers so that they all point to variables in the generator model
228        induction_guards =
229            Biplate::<DeclarationPtr>::transform_bi(&induction_guards, &move |decl| {
230                if induction_variables_2.contains(&decl.name()) {
231                    (*generator_symboltable_ptr)
232                        .borrow()
233                        .lookup_local(&decl.name())
234                        .unwrap()
235                } else {
236                    decl
237                }
238            })
239            .into_iter()
240            .collect_vec();
241
242        let induction_variables_2 = induction_variables.clone();
243        let return_expr_symboltable_ptr =
244            return_expression_submodel.symbols_ptr_unchecked().clone();
245
246        // fix other guard pointers so that they all point to variables in the return expr model
247        other_guards = Biplate::<DeclarationPtr>::transform_bi(&other_guards, &move |decl| {
248            if induction_variables_2.contains(&decl.name()) {
249                (*return_expr_symboltable_ptr)
250                    .borrow()
251                    .lookup_local(&decl.name())
252                    .unwrap()
253            } else {
254                decl
255            }
256        })
257        .into_iter()
258        .collect_vec();
259
260        // handle guards that reference non-induction variables
261        if !other_guards.is_empty() {
262            let comprehension_kind = comprehension_kind.expect(
263                "if any guards reference decision variables, a comprehension kind should be given",
264            );
265
266            let guard_expr = match other_guards.as_slice() {
267                [x] => x.clone(),
268                xs => Expression::And(Metadata::new(), Moo::new(into_matrix_expr!(xs.to_vec()))),
269            };
270
271            expression = match comprehension_kind {
272                ACOperatorKind::And => {
273                    Expression::Imply(Metadata::new(), Moo::new(guard_expr), Moo::new(expression))
274                }
275                ACOperatorKind::Or => Expression::And(
276                    Metadata::new(),
277                    Moo::new(Expression::And(
278                        Metadata::new(),
279                        Moo::new(matrix_expr![guard_expr, expression]),
280                    )),
281                ),
282
283                ACOperatorKind::Sum => {
284                    panic!("guards that reference decision variables not yet implemented for sum");
285                }
286
287                ACOperatorKind::Product => {
288                    panic!(
289                        "guards that reference decision variables not yet implemented for product"
290                    );
291                }
292            }
293        }
294
295        generator_submodel.add_constraints(induction_guards);
296
297        return_expression_submodel.add_constraint(expression);
298
299        Comprehension {
300            return_expression_submodel,
301            generator_submodel,
302            induction_vars: induction_variables.into_iter().collect_vec(),
303        }
304    }
305}
306
307/// True iff the guard only references induction variables.
308fn is_induction_guard(induction_variables: &BTreeSet<Name>, guard: &Expression) -> bool {
309    guard
310        .universe_bi()
311        .iter()
312        .all(|x| induction_variables.contains(x))
313}