1
use std::{
2
    collections::{HashMap, VecDeque},
3
    sync::{Arc, Mutex, RwLock, atomic::Ordering},
4
};
5

            
6
use conjure_cp::{
7
    ast::{
8
        Atom, DecisionVariable, DeclarationKind, DeclarationPtr, Expression, Metadata, Model, Moo,
9
        Name, ReturnType, SubModel, SymbolTable, Typeable as _,
10
        ac_operators::ACOperatorKind,
11
        comprehension::{Comprehension, USE_OPTIMISED_REWRITER_FOR_COMPREHENSIONS},
12
        serde::{HasId as _, ObjId},
13
    },
14
    bug,
15
    context::Context,
16
    rule_engine::{resolve_rule_sets, rewrite_morph, rewrite_naive},
17
    settings::SolverFamily,
18
    solver::{Solver, SolverError, adaptors::Minion},
19
};
20
use tracing::warn;
21
use uniplate::{Biplate, Uniplate as _, zipper::Zipper};
22

            
23
/// Expands the comprehension using Minion, returning the resulting expressions.
24
///
25
/// This method is only suitable for comprehensions inside an AC operator. The AC operator that
26
/// contains this comprehension should be passed into the `ac_operator` argument.
27
///
28
/// This method performs additional pruning of "uninteresting" values, only possible when the
29
/// comprehension is inside an AC operator.
30
///
31
/// If successful, this modifies the symbol table given to add aux-variables needed inside the
32
/// expanded expressions.
33
111
pub fn expand_via_solver_ac(
34
111
    comprehension: Comprehension,
35
111
    symtab: &mut SymbolTable,
36
111
    ac_operator: ACOperatorKind,
37
111
) -> Result<Vec<Expression>, SolverError> {
38
    // ADD RETURN EXPRESSION TO GENERATOR MODEL AS CONSTRAINT
39
    // ======================================================
40

            
41
    // References to quantified variables in the return expression point to entries in the
42
    // return_expression symbol table.
43
    //
44
    // Change these to point to the corresponding entry in the generator symbol table instead.
45
    //
46
    // In the generator symbol-table, quantified variables are decision variables (as we are
47
    // solving for them), but in the return expression symbol table they are givens.
48
111
    let quantified_vars_2 = comprehension.quantified_vars.clone();
49
111
    let generator_symtab_ptr = comprehension.generator_submodel.symbols_ptr_unchecked();
50
111
    let return_expression =
51
111
        comprehension
52
111
            .clone()
53
111
            .return_expression()
54
597
            .transform_bi(&move |decl: DeclarationPtr| {
55
                // if this variable is a quantified var...
56
597
                if quantified_vars_2.contains(&decl.name()) {
57
                    // ... use the generator symbol tables version of it
58

            
59
324
                    generator_symtab_ptr
60
324
                        .read()
61
324
                        .lookup_local(&decl.name())
62
324
                        .unwrap()
63
                } else {
64
273
                    decl
65
                }
66
597
            });
67

            
68
    // Replace all boolean expressions referencing non-quantified variables in the return
69
    // expression with dummy variables. This allows us to add it as a constraint to the
70
    // generator model.
71
111
    let generator_submodel = add_return_expression_to_generator_model(
72
111
        comprehension.generator_submodel.clone(),
73
111
        return_expression,
74
111
        &ac_operator,
75
    );
76

            
77
    // REWRITE GENERATOR MODEL AND PASS TO MINION
78
    // ==========================================
79

            
80
111
    let mut generator_model = Model::new(Arc::new(RwLock::new(Context::default())));
81

            
82
111
    *generator_model.as_submodel_mut() = generator_submodel;
83

            
84
    // only branch on the quantified variables.
85
111
    generator_model.search_order = Some(comprehension.quantified_vars.clone());
86

            
87
111
    let extra_rule_sets = &["Base", "Constant", "Bubble"];
88

            
89
    // Minion unrolling expects quantified variables in the generator model as find declarations.
90
    // Keep this conversion local to the temporary model used for solving.
91
111
    let _temp_finds = temporarily_materialise_quantified_vars_as_finds(
92
111
        generator_model.as_submodel(),
93
111
        &comprehension.quantified_vars,
94
    );
95

            
96
111
    let rule_sets = resolve_rule_sets(SolverFamily::Minion, extra_rule_sets).unwrap();
97

            
98
111
    let generator_model = if USE_OPTIMISED_REWRITER_FOR_COMPREHENSIONS.load(Ordering::Relaxed) {
99
        rewrite_morph(generator_model, &rule_sets, false)
100
    } else {
101
111
        rewrite_naive(&generator_model, &rule_sets, false, false).unwrap()
102
    };
103

            
104
111
    let minion = Solver::new(Minion::new());
105
111
    let minion = minion.load_model(generator_model.clone());
106

            
107
111
    let minion = match minion {
108
3
        Err(e) => {
109
3
            warn!(why=%e,model=%generator_model,"Loading generator model failed, failing solver-backed AC comprehension expansion rule");
110
3
            return Err(e);
111
        }
112
108
        Ok(minion) => minion,
113
    };
114

            
115
    // REWRITE RETURN EXPRESSION
116
    // =========================
117

            
118
108
    let return_expression_submodel = comprehension.return_expression_submodel.clone();
119
108
    let mut return_expression_model = Model::new(Arc::new(RwLock::new(Context::default())));
120
108
    *return_expression_model.as_submodel_mut() = return_expression_submodel;
121

            
122
108
    let return_expression_model =
123
108
        if USE_OPTIMISED_REWRITER_FOR_COMPREHENSIONS.load(Ordering::Relaxed) {
124
            rewrite_morph(return_expression_model, &rule_sets, false)
125
        } else {
126
108
            rewrite_naive(&return_expression_model, &rule_sets, false, false).unwrap()
127
        };
128

            
129
108
    let values = Arc::new(Mutex::new(Vec::new()));
130
108
    let values_ptr = Arc::clone(&values);
131

            
132
    // SOLVE FOR THE QUANTIFIED VARIABLES, AND SUBSTITUTE INTO THE REWRITTEN RETURN EXPRESSION
133
    // ======================================================================================
134

            
135
    tracing::debug!(model=%generator_model,comprehension=%comprehension,"Minion solving comprehnesion (ac mode)");
136

            
137
300
    minion.solve(Box::new(move |sols| {
138
        // TODO: deal with represented names if quantified variables are abslits.
139
300
        let values = &mut *values_ptr.lock().unwrap();
140
300
        values.push(sols);
141
300
        true
142
300
    }))?;
143

            
144
108
    let values = values.lock().unwrap().clone();
145

            
146
108
    let mut return_expressions = vec![];
147

            
148
300
    for value in values {
149
        // convert back to an expression
150

            
151
300
        let return_expression_submodel = return_expression_model.as_submodel().clone();
152
300
        let child_symtab = return_expression_submodel.symbols().clone();
153
300
        let return_expression = return_expression_submodel.into_single_expression();
154

            
155
        // we only want to substitute quantified variables.
156
        // (definitely not machine names, as they mean something different in this scope!)
157
300
        let value: HashMap<_, _> = value
158
300
            .into_iter()
159
849
            .filter(|(n, _)| comprehension.quantified_vars.contains(n))
160
300
            .collect();
161

            
162
300
        let value_ptr = Arc::new(value);
163
300
        let value_ptr_2 = Arc::clone(&value_ptr);
164

            
165
        // substitute in the values for the quantified variables
166
2433
        let return_expression = return_expression.transform_bi(&move |x: Atom| {
167
2433
            let Atom::Reference(ref ptr) = x else {
168
627
                return x;
169
            };
170

            
171
            // is this referencing a quantified var?
172
1806
            let Some(lit) = value_ptr_2.get(&ptr.name()) else {
173
855
                return x;
174
            };
175

            
176
951
            Atom::Literal(lit.clone())
177
2433
        });
178

            
179
        // Copy the return expression's symbols into parent scope.
180

            
181
        // For variables in the return expression with machine names, create new declarations
182
        // for them in the parent symbol table, so that the machine names used are unique.
183
        //
184
        // Store the declaration translations in `machine_name_translations`.
185
        // These are stored as a map of (old declaration id) -> (new declaration ptr), as
186
        // declaration pointers do not implement hash.
187
        //
188
300
        let mut machine_name_translations: HashMap<ObjId, DeclarationPtr> = HashMap::new();
189

            
190
        // Populate `machine_name_translations`
191
327
        for (name, decl) in child_symtab.into_iter_local() {
192
            // do not add quantified declarations for quantified vars to the parent symbol table.
193
327
            if value_ptr.get(&name).is_some()
194
                && matches!(
195
318
                    &decl.kind() as &DeclarationKind,
196
                    DeclarationKind::Given(_) | DeclarationKind::Quantified(_)
197
                )
198
            {
199
318
                continue;
200
9
            }
201

            
202
9
            let Name::Machine(_) = &name else {
203
                bug!(
204
                    "the symbol table of the return expression of a comprehension should only contain machine names"
205
                );
206
            };
207

            
208
9
            let id = decl.id();
209
9
            let new_decl = symtab.gensym(&decl.domain().unwrap());
210

            
211
9
            machine_name_translations.insert(id, new_decl);
212
        }
213

            
214
        // Update references to use the new delcarations.
215
        #[allow(clippy::arc_with_non_send_sync)]
216
2433
        let return_expression = return_expression.transform_bi(&move |atom: Atom| {
217
2433
            if let Atom::Reference(ref decl) = atom
218
855
                && let id = decl.id()
219
855
                && let Some(new_decl) = machine_name_translations.get(&id)
220
            {
221
18
                Atom::Reference(conjure_cp::ast::Reference::new(new_decl.clone()))
222
            } else {
223
2415
                atom
224
            }
225
2433
        });
226

            
227
300
        return_expressions.push(return_expression);
228
    }
229

            
230
108
    Ok(return_expressions)
231
111
}
232

            
233
/// Guard that temporarily converts quantified declarations to find declarations.
234
struct TempQuantifiedFindGuard {
235
    originals: Vec<(DeclarationPtr, DeclarationKind)>,
236
}
237

            
238
impl Drop for TempQuantifiedFindGuard {
239
111
    fn drop(&mut self) {
240
123
        for (mut decl, kind) in self.originals.drain(..) {
241
123
            let _ = decl.replace_kind(kind);
242
123
        }
243
111
    }
244
}
245

            
246
/// Converts quantified declarations in `submodel` to temporary find declarations.
247
111
fn temporarily_materialise_quantified_vars_as_finds(
248
111
    submodel: &SubModel,
249
111
    quantified_vars: &[Name],
250
111
) -> TempQuantifiedFindGuard {
251
111
    let symbols = submodel.symbols().clone();
252
111
    let mut originals = Vec::new();
253

            
254
123
    for name in quantified_vars {
255
123
        let Some(mut decl) = symbols.lookup_local(name) else {
256
            continue;
257
        };
258

            
259
123
        let old_kind = decl.kind().clone();
260
123
        let Some(domain) = decl.domain() else {
261
            continue;
262
        };
263

            
264
123
        let new_kind = DeclarationKind::Find(DecisionVariable::new(domain));
265
123
        let _ = decl.replace_kind(new_kind);
266
123
        originals.push((decl, old_kind));
267
    }
268

            
269
111
    TempQuantifiedFindGuard { originals }
270
111
}
271

            
272
/// Eliminate all references to non-quantified variables by introducing dummy variables to the
273
/// return expression. This modified return expression is added to the generator model, which is
274
/// returned.
275
///
276
/// Dummy variables must be the same type as the AC operators identity value.
277
///
278
/// To reduce the number of dummy variables, we turn the largest expression containing only
279
/// non-quantified variables and of the correct type into a dummy variable.
280
///
281
/// If there is no such expression, (e.g. and[(a<i) | i: int(1..10)]) , we use the smallest
282
/// expression of the correct type that contains a non-quantified variable. This ensures that
283
/// we lose as few references to quantified variables as possible.
284
111
fn add_return_expression_to_generator_model(
285
111
    mut generator_submodel: SubModel,
286
111
    return_expression: Expression,
287
111
    ac_operator: &ACOperatorKind,
288
111
) -> SubModel {
289
111
    let mut zipper = Zipper::new(return_expression);
290
111
    let mut symtab = generator_submodel.symbols_mut();
291

            
292
    // for sum/product we want to put integer expressions into dummy variables,
293
    // for and/or we want to put boolean expressions into dummy variables.
294
111
    let dummy_var_type = ac_operator.identity().return_type();
295

            
296
    'outer: loop {
297
243
        let focus: &mut Expression = zipper.focus_mut();
298

            
299
243
        let (non_quantified_vars, quantified_vars) = partition_variables(focus, &symtab);
300

            
301
        // an expression or its descendants needs to be turned into a dummy variable if it
302
        // contains non-quantified variables.
303
243
        let has_non_quantified_vars = !non_quantified_vars.is_empty();
304

            
305
        // does this expression contain quantified variables?
306
243
        let has_quantified_vars = !quantified_vars.is_empty();
307

            
308
        // can this expression be turned into a dummy variable?
309
243
        let can_be_dummy_var = can_be_dummy_variable(focus, &dummy_var_type);
310

            
311
        // The expression and its descendants don't need a dummy variables, so we don't
312
        // need to descend into its children.
313
243
        if !has_non_quantified_vars {
314
            // go to next node or quit
315
42
            while zipper.go_right().is_none() {
316
24
                let Some(()) = zipper.go_up() else {
317
                    // visited all nodes
318
24
                    break 'outer;
319
                };
320
            }
321
18
            continue;
322
201
        }
323

            
324
        // The expression contains non-quantified variables:
325

            
326
        // does this expression have any children that can be turned into dummy variables?
327
1260
        let has_eligible_child = focus.universe().iter().skip(1).any(|expr| {
328
            // eligible if it can be turned into a dummy variable, and turning it into a
329
            // dummy variable removes a non-quantified variable from the model.
330
1260
            can_be_dummy_variable(expr, &dummy_var_type)
331
153
                && contains_non_quantified_variables(expr, &symtab)
332
1260
        });
333

            
334
        // This expression has no child that can be turned into a dummy variable, but can
335
        // be a dummy variable => turn it into a dummy variable and continue.
336
201
        if !has_eligible_child && can_be_dummy_var {
337
            // introduce dummy var and continue
338
102
            let dummy_domain = focus.domain_of().unwrap();
339
102
            let dummy_decl = symtab.gensym(&dummy_domain);
340
102
            *focus = Expression::Atomic(
341
102
                Metadata::new(),
342
102
                Atom::Reference(conjure_cp::ast::Reference::new(dummy_decl)),
343
102
            );
344

            
345
            // go to next node
346
165
            while zipper.go_right().is_none() {
347
135
                let Some(()) = zipper.go_up() else {
348
                    // visited all nodes
349
72
                    break 'outer;
350
                };
351
            }
352
30
            continue;
353
        }
354
        // This expression has no child that can be turned into a dummy variable, and
355
        // cannot be a dummy variable => backtrack upwards to find a parent that can be a
356
        // dummy variable, and make it a dummy variable.
357
99
        else if !has_eligible_child && !can_be_dummy_var {
358
            // TODO: remove this case, make has_eligible_child check better?
359

            
360
            // go upwards until we find something that can be a dummy variable, make it
361
            // a dummy variable, then continue.
362
21
            while let Some(()) = zipper.go_up() {
363
21
                let focus = zipper.focus_mut();
364
21
                if can_be_dummy_variable(focus, &dummy_var_type) {
365
                    // TODO: this expression we are rewritng might already contain
366
                    // dummy vars - we might need a pass to get rid of the unused
367
                    // ones!
368
                    //
369
                    // introduce dummy var and continue
370
18
                    let dummy_domain = focus.domain_of().unwrap();
371
18
                    let dummy_decl = symtab.gensym(&dummy_domain);
372
18
                    *focus = Expression::Atomic(
373
18
                        Metadata::new(),
374
18
                        Atom::Reference(conjure_cp::ast::Reference::new(dummy_decl)),
375
18
                    );
376

            
377
                    // go to next node
378
18
                    while zipper.go_right().is_none() {
379
15
                        let Some(()) = zipper.go_up() else {
380
                            // visited all nodes
381
15
                            break 'outer;
382
                        };
383
                    }
384
3
                    continue;
385
3
                }
386
            }
387
            unreachable!();
388
        }
389
        // If the expression contains quantified variables as well as non-quantified
390
        // variables, try to retain the quantified variables by finding a child that can be
391
        // made a dummy variable which has only non-quantified variables.
392
84
        else if has_eligible_child && has_quantified_vars {
393
84
            zipper
394
84
                .go_down()
395
84
                .expect("we know the focus has a child, so zipper.go_down() should succeed");
396
84
        }
397
        // This expression contains no quantified variables, so no point trying to turn a
398
        // child into a dummy variable.
399
        else if has_eligible_child && !has_quantified_vars {
400
            // introduce dummy var and continue
401
            let dummy_domain = focus.domain_of().unwrap();
402
            let dummy_decl = symtab.gensym(&dummy_domain);
403
            *focus = Expression::Atomic(
404
                Metadata::new(),
405
                Atom::Reference(conjure_cp::ast::Reference::new(dummy_decl)),
406
            );
407

            
408
            // go to next node
409
            while zipper.go_right().is_none() {
410
                let Some(()) = zipper.go_up() else {
411
                    // visited all nodes
412
                    break 'outer;
413
                };
414
            }
415
        } else {
416
            unreachable!()
417
        }
418
    }
419

            
420
111
    let new_return_expression = Expression::Neq(
421
111
        Metadata::new(),
422
111
        Moo::new(Expression::Atomic(
423
111
            Metadata::new(),
424
111
            ac_operator.identity().into(),
425
111
        )),
426
111
        Moo::new(zipper.rebuild_root()),
427
111
    );
428

            
429
    // double check that the above transformation didn't miss any stray non-quantified vars
430
111
    assert!(
431
111
        Biplate::<DeclarationPtr>::universe_bi(&new_return_expression)
432
111
            .iter()
433
204
            .all(|x| symtab.lookup_local(&x.name()).is_some()),
434
        "generator model should only contain references to variables in its symbol table."
435
    );
436

            
437
111
    std::mem::drop(symtab);
438

            
439
111
    generator_submodel.add_constraint(new_return_expression);
440

            
441
111
    generator_submodel
442
111
}
443

            
444
/// Returns a tuple of non-quantified decision variables and quantified variables inside the expression.
445
///
446
/// As lettings, givens, etc. will eventually be subsituted for constants, this only returns
447
/// non-quantified _decision_ variables.
448
#[inline]
449
243
fn partition_variables(
450
243
    expr: &Expression,
451
243
    symtab: &SymbolTable,
452
243
) -> (VecDeque<Name>, VecDeque<Name>) {
453
    // doing this as two functions non_quantified_variables and quantified_variables might've been
454
    // easier to read.
455
    //
456
    // However, doing this in one function avoids an extra universe call...
457
243
    let (non_quantified_vars, quantified_vars): (VecDeque<Name>, VecDeque<Name>) =
458
243
        Biplate::<Name>::universe_bi(expr)
459
243
            .into_iter()
460
1350
            .partition(|x| symtab.lookup_local(x).is_none());
461

            
462
243
    (non_quantified_vars, quantified_vars)
463
243
}
464

            
465
/// Returns `true` if `expr` can be turned into a dummy variable.
466
#[inline]
467
1524
fn can_be_dummy_variable(expr: &Expression, dummy_variable_type: &ReturnType) -> bool {
468
    // do not put root expression in a dummy variable or things go wrong.
469
1524
    if matches!(expr, Expression::Root(_, _)) {
470
        return false;
471
1524
    };
472

            
473
    // is the expression the same type as the dummy variable?
474
1524
    expr.return_type() == *dummy_variable_type
475
1524
}
476

            
477
/// Returns `true` if `expr` or its descendants contains non-quantified variables.
478
#[inline]
479
153
fn contains_non_quantified_variables(expr: &Expression, symtab: &SymbolTable) -> bool {
480
153
    let names_referenced: VecDeque<Name> = expr.universe_bi();
481
    // a name is a non-quantified variable if its definition is not in the local scope of the
482
    // comprehension's generators.
483
153
    names_referenced
484
153
        .iter()
485
204
        .any(|x| symtab.lookup_local(x).is_none())
486
153
}