1
#![allow(clippy::arc_with_non_send_sync)]
2

            
3
use std::{cell::RefCell, collections::BTreeSet, fmt::Display, rc::Rc, sync::atomic::AtomicBool};
4

            
5
use crate::{ast::Metadata, into_matrix_expr, matrix_expr};
6
use conjure_cp_core::ast::ReturnType;
7
use itertools::Itertools as _;
8
use serde::{Deserialize, Serialize};
9
use uniplate::{Biplate, Uniplate};
10

            
11
use super::{
12
    DeclarationPtr, Domain, DomainPtr, Expression, Moo, Name, Range, SubModel, SymbolTable,
13
    Typeable, ac_operators::ACOperatorKind,
14
};
15

            
16
// TODO: move this global setting somewhere better?
17

            
18
/// The rewriter to use for rewriting comprehensions.
19
///
20
/// True for optimised, false for naive
21
pub static USE_OPTIMISED_REWRITER_FOR_COMPREHENSIONS: AtomicBool = AtomicBool::new(false);
22

            
23
// TODO: do not use Names to compare variables, use DeclarationPtr and ids instead
24
// see issue #930
25
//
26
// this will simplify *a lot* of the knarly stuff here, but can only be done once everything else
27
// uses DeclarationPtr.
28
//
29
// ~ nikdewally, 10/06/25
30

            
31
/// A comprehension.
32
#[derive(Clone, PartialEq, Eq, Hash, Uniplate, Serialize, Deserialize, Debug)]
33
#[biplate(to=SubModel)]
34
#[biplate(to=Expression)]
35
#[non_exhaustive]
36
pub struct Comprehension {
37
    #[doc(hidden)]
38
    pub return_expression_submodel: SubModel,
39
    #[doc(hidden)]
40
    pub generator_submodel: SubModel,
41
    #[doc(hidden)]
42
    pub induction_vars: Vec<Name>,
43
}
44

            
45
impl Comprehension {
46
    pub fn domain_of(&self) -> Option<DomainPtr> {
47
        let return_expr_domain = self
48
            .return_expression_submodel
49
            .clone()
50
            .into_single_expression()
51
            .domain_of()?;
52

            
53
        // return a list (matrix with index domain int(1..)) of return_expr elements
54
        Some(Domain::matrix(
55
            return_expr_domain,
56
            vec![Domain::int(vec![Range::UnboundedR(1)])],
57
        ))
58
    }
59

            
60
    pub fn return_expression(self) -> Expression {
61
        self.return_expression_submodel.into_single_expression()
62
    }
63

            
64
    pub fn replace_return_expression(&mut self, new_expr: Expression) {
65
        let new_expr = match new_expr {
66
            Expression::And(_, exprs) if (*exprs).clone().unwrap_list().is_some() => {
67
                Expression::Root(Metadata::new(), (*exprs).clone().unwrap_list().unwrap())
68
            }
69
            expr => Expression::Root(Metadata::new(), vec![expr]),
70
        };
71

            
72
        *self.return_expression_submodel.root_mut_unchecked() = new_expr;
73
    }
74

            
75
    /// Adds a guard to the comprehension. Returns false if the guard does not only reference induction variables.
76
    pub fn add_induction_guard(&mut self, guard: Expression) -> bool {
77
        if self.is_induction_guard(&guard) {
78
            self.generator_submodel.add_constraint(guard);
79
            true
80
        } else {
81
            false
82
        }
83
    }
84

            
85
    /// True iff expr only references induction variables.
86
    pub fn is_induction_guard(&self, expr: &Expression) -> bool {
87
        is_induction_guard(&(self.induction_vars.clone().into_iter().collect()), expr)
88
    }
89
}
90

            
91
impl Typeable for Comprehension {
92
    fn return_type(&self) -> ReturnType {
93
        self.return_expression_submodel
94
            .clone()
95
            .into_single_expression()
96
            .return_type()
97
    }
98
}
99

            
100
impl Display for Comprehension {
101
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
102
        let generators: String = self
103
            .generator_submodel
104
            .symbols()
105
            .clone()
106
            .into_iter_local()
107
            .map(|(name, decl): (Name, DeclarationPtr)| {
108
                let domain: DomainPtr = decl.domain().unwrap();
109
                (name, domain)
110
            })
111
            .map(|(name, domain)| format!("{name}: {domain}"))
112
            .join(",");
113

            
114
        let guards = self
115
            .generator_submodel
116
            .constraints()
117
            .iter()
118
            .map(|x| format!("{x}"))
119
            .join(",");
120

            
121
        let generators_and_guards = itertools::join([generators, guards], ",");
122

            
123
        let expression = &self.return_expression_submodel;
124
        write!(f, "[{expression} | {generators_and_guards}]")
125
    }
126
}
127

            
128
/// A builder for a comprehension.
129
#[derive(Clone, Debug, PartialEq, Eq)]
130
pub struct ComprehensionBuilder {
131
    guards: Vec<Expression>,
132
    // symbol table containing all the generators
133
    // for now, this is just used during parsing - a new symbol table is created using this when we initialise the comprehension
134
    // this is not ideal, but i am chucking all this code very soon anyways...
135
    generator_symboltable: Rc<RefCell<SymbolTable>>,
136
    return_expr_symboltable: Rc<RefCell<SymbolTable>>,
137
    induction_variables: BTreeSet<Name>,
138
}
139

            
140
impl ComprehensionBuilder {
141
    pub fn new(symbol_table_ptr: Rc<RefCell<SymbolTable>>) -> Self {
142
        ComprehensionBuilder {
143
            guards: vec![],
144
            generator_symboltable: Rc::new(RefCell::new(SymbolTable::with_parent(
145
                symbol_table_ptr.clone(),
146
            ))),
147
            return_expr_symboltable: Rc::new(RefCell::new(SymbolTable::with_parent(
148
                symbol_table_ptr,
149
            ))),
150
            induction_variables: BTreeSet::new(),
151
        }
152
    }
153

            
154
    /// The symbol table for the comprehension generators
155
    pub fn generator_symboltable(&mut self) -> Rc<RefCell<SymbolTable>> {
156
        Rc::clone(&self.generator_symboltable)
157
    }
158

            
159
    /// The symbol table for the comprehension return expression
160
    pub fn return_expr_symboltable(&mut self) -> Rc<RefCell<SymbolTable>> {
161
        Rc::clone(&self.return_expr_symboltable)
162
    }
163

            
164
    pub fn guard(mut self, guard: Expression) -> Self {
165
        self.guards.push(guard);
166
        self
167
    }
168

            
169
    pub fn generator(mut self, declaration: DeclarationPtr) -> Self {
170
        let name = declaration.name().clone();
171
        let domain = declaration.domain().unwrap();
172
        assert!(!self.induction_variables.contains(&name));
173

            
174
        self.induction_variables.insert(name.clone());
175

            
176
        // insert into generator symbol table as a variable
177
        (*self.generator_symboltable)
178
            .borrow_mut()
179
            .insert(declaration);
180

            
181
        // insert into return expression symbol table as a given
182
        (*self.return_expr_symboltable)
183
            .borrow_mut()
184
            .insert(DeclarationPtr::new_given(name, domain));
185

            
186
        self
187
    }
188

            
189
    /// Creates a comprehension with the given return expression.
190
    ///
191
    /// If this comprehension is inside an AC-operator, the kind of this operator should be passed
192
    /// in the `comprehension_kind` field.
193
    ///
194
    /// If a comprehension kind is not given, comprehension guards containing decision variables
195
    /// are invalid, and will cause a panic.
196
    pub fn with_return_value(
197
        self,
198
        mut expression: Expression,
199
        comprehension_kind: Option<ACOperatorKind>,
200
    ) -> Comprehension {
201
        let parent_symboltable = self
202
            .generator_symboltable
203
            .as_ref()
204
            .borrow_mut()
205
            .parent_mut_unchecked()
206
            .clone()
207
            .unwrap();
208
        let mut generator_submodel = SubModel::new(parent_symboltable.clone());
209
        let mut return_expression_submodel = SubModel::new(parent_symboltable);
210

            
211
        *generator_submodel.symbols_ptr_unchecked_mut() = self.generator_symboltable;
212
        *return_expression_submodel.symbols_ptr_unchecked_mut() = self.return_expr_symboltable;
213

            
214
        // TODO:also allow guards that reference lettings and givens.
215

            
216
        let induction_variables = self.induction_variables;
217

            
218
        // only guards referencing induction variables can go inside the comprehension
219
        let (mut induction_guards, mut other_guards): (Vec<_>, Vec<_>) = self
220
            .guards
221
            .into_iter()
222
            .partition(|x| is_induction_guard(&induction_variables, x));
223

            
224
        let induction_variables_2 = induction_variables.clone();
225
        let generator_symboltable_ptr = generator_submodel.symbols_ptr_unchecked().clone();
226

            
227
        // fix induction guard pointers so that they all point to variables in the generator model
228
        induction_guards =
229
            Biplate::<DeclarationPtr>::transform_bi(&induction_guards, &move |decl| {
230
                if induction_variables_2.contains(&decl.name()) {
231
                    (*generator_symboltable_ptr)
232
                        .borrow()
233
                        .lookup_local(&decl.name())
234
                        .unwrap()
235
                } else {
236
                    decl
237
                }
238
            })
239
            .into_iter()
240
            .collect_vec();
241

            
242
        let induction_variables_2 = induction_variables.clone();
243
        let return_expr_symboltable_ptr =
244
            return_expression_submodel.symbols_ptr_unchecked().clone();
245

            
246
        // fix other guard pointers so that they all point to variables in the return expr model
247
        other_guards = Biplate::<DeclarationPtr>::transform_bi(&other_guards, &move |decl| {
248
            if induction_variables_2.contains(&decl.name()) {
249
                (*return_expr_symboltable_ptr)
250
                    .borrow()
251
                    .lookup_local(&decl.name())
252
                    .unwrap()
253
            } else {
254
                decl
255
            }
256
        })
257
        .into_iter()
258
        .collect_vec();
259

            
260
        // handle guards that reference non-induction variables
261
        if !other_guards.is_empty() {
262
            let comprehension_kind = comprehension_kind.expect(
263
                "if any guards reference decision variables, a comprehension kind should be given",
264
            );
265

            
266
            let guard_expr = match other_guards.as_slice() {
267
                [x] => x.clone(),
268
                xs => Expression::And(Metadata::new(), Moo::new(into_matrix_expr!(xs.to_vec()))),
269
            };
270

            
271
            expression = match comprehension_kind {
272
                ACOperatorKind::And => {
273
                    Expression::Imply(Metadata::new(), Moo::new(guard_expr), Moo::new(expression))
274
                }
275
                ACOperatorKind::Or => Expression::And(
276
                    Metadata::new(),
277
                    Moo::new(Expression::And(
278
                        Metadata::new(),
279
                        Moo::new(matrix_expr![guard_expr, expression]),
280
                    )),
281
                ),
282

            
283
                ACOperatorKind::Sum => {
284
                    panic!("guards that reference decision variables not yet implemented for sum");
285
                }
286

            
287
                ACOperatorKind::Product => {
288
                    panic!(
289
                        "guards that reference decision variables not yet implemented for product"
290
                    );
291
                }
292
            }
293
        }
294

            
295
        generator_submodel.add_constraints(induction_guards);
296

            
297
        return_expression_submodel.add_constraint(expression);
298

            
299
        Comprehension {
300
            return_expression_submodel,
301
            generator_submodel,
302
            induction_vars: induction_variables.into_iter().collect_vec(),
303
        }
304
    }
305
}
306

            
307
/// True iff the guard only references induction variables.
308
fn is_induction_guard(induction_variables: &BTreeSet<Name>, guard: &Expression) -> bool {
309
    guard
310
        .universe_bi()
311
        .iter()
312
        .all(|x| induction_variables.contains(x))
313
}