1
#![allow(clippy::arc_with_non_send_sync)]
2

            
3
use std::{
4
    collections::BTreeSet,
5
    fmt::Display,
6
    sync::atomic::{AtomicBool, AtomicU8, Ordering},
7
};
8

            
9
use crate::settings::QuantifiedExpander;
10
use crate::{ast::Metadata, into_matrix_expr, matrix_expr};
11
use conjure_cp_core::ast::ReturnType;
12
use itertools::Itertools as _;
13
use serde::{Deserialize, Serialize};
14
use uniplate::{Biplate, Uniplate};
15

            
16
use super::{
17
    DeclarationPtr, Domain, DomainPtr, Expression, Moo, Name, Range, SubModel, SymbolTablePtr,
18
    Typeable, ac_operators::ACOperatorKind,
19
};
20

            
21
// TODO: move this global setting somewhere better?
22

            
23
/// The rewriter to use for rewriting comprehensions.
24
///
25
/// True for optimised, false for naive
26
pub static USE_OPTIMISED_REWRITER_FOR_COMPREHENSIONS: AtomicBool = AtomicBool::new(false);
27

            
28
/// Global setting for which comprehension quantified-variable expander to use.
29
///
30
/// Defaults to [`QuantifiedExpander::Native`].
31
pub static QUANTIFIED_EXPANDER_FOR_COMPREHENSIONS: AtomicU8 =
32
    AtomicU8::new(QuantifiedExpander::Native.as_u8());
33

            
34
5590
pub fn set_quantified_expander_for_comprehensions(expander: QuantifiedExpander) {
35
5590
    QUANTIFIED_EXPANDER_FOR_COMPREHENSIONS.store(expander.as_u8(), Ordering::Relaxed);
36
5590
}
37

            
38
570420
pub fn quantified_expander_for_comprehensions() -> QuantifiedExpander {
39
570420
    QuantifiedExpander::from_u8(QUANTIFIED_EXPANDER_FOR_COMPREHENSIONS.load(Ordering::Relaxed))
40
570420
}
41

            
42
// TODO: do not use Names to compare variables, use DeclarationPtr and ids instead
43
// see issue #930
44
//
45
// this will simplify *a lot* of the knarly stuff here, but can only be done once everything else
46
// uses DeclarationPtr.
47
//
48
// ~ nikdewally, 10/06/25
49

            
50
/// A comprehension.
51
#[derive(Clone, PartialEq, Eq, Hash, Uniplate, Serialize, Deserialize, Debug)]
52
#[biplate(to=SubModel)]
53
#[biplate(to=Expression)]
54
#[non_exhaustive]
55
pub struct Comprehension {
56
    #[doc(hidden)]
57
    pub return_expression_submodel: SubModel,
58
    #[doc(hidden)]
59
    pub generator_submodel: SubModel,
60
    #[doc(hidden)]
61
    pub quantified_vars: Vec<Name>,
62
}
63

            
64
impl Comprehension {
65
    pub fn domain_of(&self) -> Option<DomainPtr> {
66
        let return_expr_domain = self
67
            .return_expression_submodel
68
            .clone()
69
            .into_single_expression()
70
            .domain_of()?;
71

            
72
        // return a list (matrix with index domain int(1..)) of return_expr elements
73
        Some(Domain::matrix(
74
            return_expr_domain,
75
            vec![Domain::int(vec![Range::UnboundedR(1)])],
76
        ))
77
    }
78

            
79
2560
    pub fn return_expression(self) -> Expression {
80
2560
        self.return_expression_submodel.into_single_expression()
81
2560
    }
82

            
83
    pub fn replace_return_expression(&mut self, new_expr: Expression) {
84
        let new_expr = match new_expr {
85
            Expression::And(_, exprs) if (*exprs).clone().unwrap_list().is_some() => {
86
                Expression::Root(Metadata::new(), (*exprs).clone().unwrap_list().unwrap())
87
            }
88
            expr => Expression::Root(Metadata::new(), vec![expr]),
89
        };
90

            
91
        *self.return_expression_submodel.root_mut_unchecked() = new_expr;
92
    }
93

            
94
    /// Adds a guard to the comprehension. Returns false if the guard does not only reference quantified variables.
95
    pub fn add_quantified_guard(&mut self, guard: Expression) -> bool {
96
        if self.is_quantified_guard(&guard) {
97
            self.generator_submodel.add_constraint(guard);
98
            true
99
        } else {
100
            false
101
        }
102
    }
103

            
104
    /// True iff expr only references quantified variables.
105
    pub fn is_quantified_guard(&self, expr: &Expression) -> bool {
106
        is_quantified_guard(&(self.quantified_vars.clone().into_iter().collect()), expr)
107
    }
108
}
109

            
110
impl Typeable for Comprehension {
111
    fn return_type(&self) -> ReturnType {
112
        self.return_expression_submodel
113
            .clone()
114
            .into_single_expression()
115
            .return_type()
116
    }
117
}
118

            
119
impl Display for Comprehension {
120
8280
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
121
8280
        let return_expression = self
122
8280
            .return_expression_submodel
123
8280
            .clone()
124
8280
            .into_single_expression();
125

            
126
8280
        let generator_symbols = self.generator_submodel.symbols().clone();
127
8280
        let generators = self
128
8280
            .quantified_vars
129
8280
            .iter()
130
8600
            .map(|name| {
131
8600
                let decl: DeclarationPtr = generator_symbols
132
8600
                    .lookup_local(name)
133
8600
                    .expect("quantified variable should be in the generator symbol table");
134
8600
                let domain: DomainPtr = decl.domain().unwrap();
135
8600
                format!("{name} : {domain}")
136
8600
            })
137
8280
            .collect_vec();
138

            
139
8280
        let guards = self
140
8280
            .generator_submodel
141
8280
            .constraints()
142
8280
            .iter()
143
8280
            .map(|x| format!("{x}"))
144
8280
            .collect_vec();
145

            
146
8280
        let generators_and_guards = generators.into_iter().chain(guards).join(", ");
147

            
148
8280
        write!(f, "[ {return_expression} | {generators_and_guards} ]")
149
8280
    }
150
}
151

            
152
/// A builder for a comprehension.
153
#[derive(Clone, Debug, PartialEq, Eq)]
154
pub struct ComprehensionBuilder {
155
    guards: Vec<Expression>,
156
    // symbol table containing all the generators
157
    // for now, this is just used during parsing - a new symbol table is created using this when we initialise the comprehension
158
    // this is not ideal, but i am chucking all this code very soon anyways...
159
    generator_symboltable: SymbolTablePtr,
160
    return_expr_symboltable: SymbolTablePtr,
161
    quantified_variables: BTreeSet<Name>,
162
}
163

            
164
impl ComprehensionBuilder {
165
820
    pub fn new(symbol_table_ptr: SymbolTablePtr) -> Self {
166
820
        ComprehensionBuilder {
167
820
            guards: vec![],
168
820
            generator_symboltable: SymbolTablePtr::with_parent(symbol_table_ptr.clone()),
169
820
            return_expr_symboltable: SymbolTablePtr::with_parent(symbol_table_ptr),
170
820
            quantified_variables: BTreeSet::new(),
171
820
        }
172
820
    }
173

            
174
    /// The symbol table for the comprehension generators
175
820
    pub fn generator_symboltable(&mut self) -> SymbolTablePtr {
176
820
        self.generator_symboltable.clone()
177
820
    }
178

            
179
    /// The symbol table for the comprehension return expression
180
820
    pub fn return_expr_symboltable(&mut self) -> SymbolTablePtr {
181
820
        self.return_expr_symboltable.clone()
182
820
    }
183

            
184
100
    pub fn guard(mut self, guard: Expression) -> Self {
185
100
        self.guards.push(guard);
186
100
        self
187
100
    }
188

            
189
880
    pub fn generator(mut self, declaration: DeclarationPtr) -> Self {
190
880
        let name = declaration.name().clone();
191
880
        let domain = declaration.domain().unwrap();
192
880
        assert!(!self.quantified_variables.contains(&name));
193

            
194
880
        self.quantified_variables.insert(name.clone());
195

            
196
        // insert into generator symbol table as a local quantified variable
197
880
        let quantified_decl = DeclarationPtr::new_quantified(name, domain);
198
880
        self.generator_symboltable
199
880
            .write()
200
880
            .insert(quantified_decl.clone());
201

            
202
        // insert into return expression symbol table as a quantified variable
203
880
        self.return_expr_symboltable.write().insert(
204
880
            DeclarationPtr::new_quantified_from_generator(&quantified_decl)
205
880
                .expect("quantified variables should always have a domain"),
206
880
        );
207

            
208
880
        self
209
880
    }
210

            
211
    /// Creates a comprehension with the given return expression.
212
    ///
213
    /// If this comprehension is inside an AC-operator, the kind of this operator should be passed
214
    /// in the `comprehension_kind` field.
215
    ///
216
    /// If a comprehension kind is not given, comprehension guards containing decision variables
217
    /// are invalid, and will cause a panic.
218
800
    pub fn with_return_value(
219
800
        self,
220
800
        mut expression: Expression,
221
800
        comprehension_kind: Option<ACOperatorKind>,
222
800
    ) -> Comprehension {
223
800
        let parent_symboltable = self.generator_symboltable.read().parent().clone().unwrap();
224

            
225
800
        let mut generator_submodel = SubModel::new(parent_symboltable.clone());
226
800
        let mut return_expression_submodel = SubModel::new(parent_symboltable);
227

            
228
800
        *generator_submodel.symbols_ptr_unchecked_mut() = self.generator_symboltable;
229
800
        *return_expression_submodel.symbols_ptr_unchecked_mut() = self.return_expr_symboltable;
230

            
231
        // TODO:also allow guards that reference lettings and givens.
232

            
233
800
        let quantified_variables = self.quantified_variables;
234

            
235
        // only guards referencing quantified variables can go inside the comprehension
236
800
        let (mut quantified_guards, mut other_guards): (Vec<_>, Vec<_>) = self
237
800
            .guards
238
800
            .into_iter()
239
800
            .partition(|x| is_quantified_guard(&quantified_variables, x));
240

            
241
800
        let quantified_variables_2 = quantified_variables.clone();
242
800
        let generator_symboltable_ptr = generator_submodel.symbols_ptr_unchecked().clone();
243

            
244
        // fix quantified guard pointers so that they all point to variables in the generator model
245
800
        quantified_guards =
246
800
            Biplate::<DeclarationPtr>::transform_bi(&quantified_guards, &move |decl| {
247
60
                if quantified_variables_2.contains(&decl.name()) {
248
60
                    generator_symboltable_ptr
249
60
                        .read()
250
60
                        .lookup_local(&decl.name())
251
60
                        .unwrap()
252
                } else {
253
                    decl
254
                }
255
60
            })
256
800
            .into_iter()
257
800
            .collect_vec();
258

            
259
800
        let quantified_variables_2 = quantified_variables.clone();
260
800
        let return_expr_symboltable_ptr =
261
800
            return_expression_submodel.symbols_ptr_unchecked().clone();
262

            
263
        // fix other guard pointers so that they all point to variables in the return expr model
264
800
        other_guards = Biplate::<DeclarationPtr>::transform_bi(&other_guards, &move |decl| {
265
60
            if quantified_variables_2.contains(&decl.name()) {
266
20
                return_expr_symboltable_ptr
267
20
                    .read()
268
20
                    .lookup_local(&decl.name())
269
20
                    .unwrap()
270
            } else {
271
40
                decl
272
            }
273
60
        })
274
800
        .into_iter()
275
800
        .collect_vec();
276

            
277
        // handle guards that reference non-quantified variables
278
800
        if !other_guards.is_empty() {
279
40
            let comprehension_kind = comprehension_kind.expect(
280
40
                "if any guards reference decision variables, a comprehension kind should be given",
281
            );
282

            
283
40
            let guard_expr = match other_guards.as_slice() {
284
40
                [x] => x.clone(),
285
                xs => Expression::And(Metadata::new(), Moo::new(into_matrix_expr!(xs.to_vec()))),
286
            };
287

            
288
40
            expression = match comprehension_kind {
289
                ACOperatorKind::And => {
290
40
                    Expression::Imply(Metadata::new(), Moo::new(guard_expr), Moo::new(expression))
291
                }
292
                ACOperatorKind::Or => Expression::And(
293
                    Metadata::new(),
294
                    Moo::new(Expression::And(
295
                        Metadata::new(),
296
                        Moo::new(matrix_expr![guard_expr, expression]),
297
                    )),
298
                ),
299

            
300
                ACOperatorKind::Sum => {
301
                    panic!("guards that reference decision variables not yet implemented for sum");
302
                }
303

            
304
                ACOperatorKind::Product => {
305
                    panic!(
306
                        "guards that reference decision variables not yet implemented for product"
307
                    );
308
                }
309
            }
310
760
        }
311

            
312
800
        generator_submodel.add_constraints(quantified_guards);
313

            
314
800
        return_expression_submodel.add_constraint(expression);
315

            
316
800
        Comprehension {
317
800
            return_expression_submodel,
318
800
            generator_submodel,
319
800
            quantified_vars: quantified_variables.into_iter().collect_vec(),
320
800
        }
321
800
    }
322
}
323

            
324
/// True iff the guard only references quantified variables.
325
100
fn is_quantified_guard(quantified_variables: &BTreeSet<Name>, guard: &Expression) -> bool {
326
100
    guard
327
100
        .universe_bi()
328
100
        .iter()
329
100
        .all(|x| quantified_variables.contains(x))
330
100
}