1
use std::{
2
    collections::{HashMap, VecDeque},
3
    rc::Rc,
4
    sync::{Arc, Mutex, RwLock, atomic::Ordering},
5
};
6

            
7
use conjure_cp::{
8
    ast::{
9
        Atom, DeclarationKind, DeclarationPtr, Expression, Metadata, Model, Moo, Name, ReturnType,
10
        SubModel, SymbolTable, Typeable as _,
11
        ac_operators::ACOperatorKind,
12
        comprehension::{Comprehension, USE_OPTIMISED_REWRITER_FOR_COMPREHENSIONS},
13
        serde::{HasId as _, ObjId},
14
    },
15
    bug,
16
    context::Context,
17
    rule_engine::{resolve_rule_sets, rewrite_morph, rewrite_naive},
18
    solver::{Solver, SolverError, SolverFamily, adaptors::Minion},
19
};
20
use tracing::warn;
21
use uniplate::{Biplate, Uniplate as _, zipper::Zipper};
22

            
23
/// Expands the comprehension using Minion, returning the resulting expressions.
24
///
25
/// This method is only suitable for comprehensions inside an AC operator. The AC operator that
26
/// contains this comprehension should be passed into the `ac_operator` argument.
27
///
28
/// This method performs additional pruning of "uninteresting" values, only possible when the
29
/// comprehension is inside an AC operator.
30
///
31
/// If successful, this modifies the symbol table given to add aux-variables needed inside the
32
/// expanded expressions.
33
pub fn expand_ac(
34
    comprehension: Comprehension,
35
    symtab: &mut SymbolTable,
36
    ac_operator: ACOperatorKind,
37
) -> Result<Vec<Expression>, SolverError> {
38
    // ADD RETURN EXPRESSION TO GENERATOR MODEL AS CONSTRAINT
39
    // ======================================================
40

            
41
    // References to induction variables in the return expression point to entries in the
42
    // return_expression symbol table.
43
    //
44
    // Change these to point to the corresponding entry in the generator symbol table instead.
45
    //
46
    // In the generator symbol-table, induction variables are decision variables (as we are
47
    // solving for them), but in the return expression symbol table they are givens.
48
    let induction_vars_2 = comprehension.induction_vars.clone();
49
    let generator_symtab_ptr = Rc::clone(comprehension.generator_submodel.symbols_ptr_unchecked());
50
    let return_expression =
51
        comprehension
52
            .clone()
53
            .return_expression()
54
            .transform_bi(&move |decl: DeclarationPtr| {
55
                // if this variable is an induction var...
56
                if induction_vars_2.contains(&decl.name()) {
57
                    // ... use the generator symbol tables version of it
58

            
59
                    (*generator_symtab_ptr)
60
                        .borrow()
61
                        .lookup_local(&decl.name())
62
                        .unwrap()
63
                } else {
64
                    decl
65
                }
66
            });
67

            
68
    // Replace all boolean expressions referencing non-induction variables in the return
69
    // expression with dummy variables. This allows us to add it as a constraint to the
70
    // generator model.
71
    let generator_submodel = add_return_expression_to_generator_model(
72
        comprehension.generator_submodel.clone(),
73
        return_expression,
74
        &ac_operator,
75
    );
76

            
77
    // REWRITE GENERATOR MODEL AND PASS TO MINION
78
    // ==========================================
79

            
80
    let mut generator_model = Model::new(Arc::new(RwLock::new(Context::default())));
81

            
82
    *generator_model.as_submodel_mut() = generator_submodel;
83

            
84
    // only branch on the induction variables.
85
    generator_model.search_order = Some(comprehension.induction_vars.clone());
86

            
87
    let extra_rule_sets = &[
88
        "Base",
89
        "Constant",
90
        "Bubble",
91
        "Better_AC_Comprehension_Expansion",
92
    ];
93

            
94
    let rule_sets = resolve_rule_sets(SolverFamily::Minion, extra_rule_sets).unwrap();
95

            
96
    let generator_model = if USE_OPTIMISED_REWRITER_FOR_COMPREHENSIONS.load(Ordering::Relaxed) {
97
        rewrite_morph(generator_model, &rule_sets, false)
98
    } else {
99
        rewrite_naive(&generator_model, &rule_sets, false, false).unwrap()
100
    };
101

            
102
    let minion = Solver::new(Minion::new());
103
    let minion = minion.load_model(generator_model.clone());
104

            
105
    let minion = match minion {
106
        Err(e) => {
107
            warn!(why=%e,model=%generator_model,"Loading generator model failed, failing expand_ac rule");
108
            return Err(e);
109
        }
110
        Ok(minion) => minion,
111
    };
112

            
113
    // REWRITE RETURN EXPRESSION
114
    // =========================
115

            
116
    let return_expression_submodel = comprehension.return_expression_submodel.clone();
117
    let mut return_expression_model = Model::new(Arc::new(RwLock::new(Context::default())));
118
    *return_expression_model.as_submodel_mut() = return_expression_submodel;
119

            
120
    let return_expression_model =
121
        if USE_OPTIMISED_REWRITER_FOR_COMPREHENSIONS.load(Ordering::Relaxed) {
122
            rewrite_morph(return_expression_model, &rule_sets, false)
123
        } else {
124
            rewrite_naive(&return_expression_model, &rule_sets, false, false).unwrap()
125
        };
126

            
127
    let values = Arc::new(Mutex::new(Vec::new()));
128
    let values_ptr = Arc::clone(&values);
129

            
130
    // SOLVE FOR THE INDUCTION VARIABLES, AND SUBSTITUTE INTO THE REWRITTEN RETURN EXPRESSION
131
    // ======================================================================================
132

            
133
    tracing::debug!(model=%generator_model,comprehension=%comprehension,"Minion solving comprehnesion (ac mode)");
134

            
135
    minion.solve(Box::new(move |sols| {
136
        // TODO: deal with represented names if induction variables are abslits.
137
        let values = &mut *values_ptr.lock().unwrap();
138
        values.push(sols);
139
        true
140
    }))?;
141

            
142
    let values = values.lock().unwrap().clone();
143

            
144
    let mut return_expressions = vec![];
145

            
146
    for value in values {
147
        // convert back to an expression
148

            
149
        let return_expression_submodel = return_expression_model.as_submodel().clone();
150
        let child_symtab = return_expression_submodel.symbols().clone();
151
        let return_expression = return_expression_submodel.into_single_expression();
152

            
153
        // we only want to substitute induction variables.
154
        // (definitely not machine names, as they mean something different in this scope!)
155
        let value: HashMap<_, _> = value
156
            .into_iter()
157
            .filter(|(n, _)| comprehension.induction_vars.contains(n))
158
            .collect();
159

            
160
        let value_ptr = Arc::new(value);
161
        let value_ptr_2 = Arc::clone(&value_ptr);
162

            
163
        // substitute in the values for the induction variables
164
        let return_expression = return_expression.transform_bi(&move |x: Atom| {
165
            let Atom::Reference(ref ptr) = x else {
166
                return x;
167
            };
168

            
169
            // is this referencing an induction var?
170
            let Some(lit) = value_ptr_2.get(&ptr.name()) else {
171
                return x;
172
            };
173

            
174
            Atom::Literal(lit.clone())
175
        });
176

            
177
        // Copy the return expression's symbols into parent scope.
178

            
179
        // For variables in the return expression with machine names, create new declarations
180
        // for them in the parent symbol table, so that the machine names used are unique.
181
        //
182
        // Store the declaration translations in `machine_name_translations`.
183
        // These are stored as a map of (old declaration id) -> (new declaration ptr), as
184
        // declaration pointers do not implement hash.
185
        //
186
        let mut machine_name_translations: HashMap<ObjId, DeclarationPtr> = HashMap::new();
187

            
188
        // Populate `machine_name_translations`
189
        for (name, decl) in child_symtab.into_iter_local() {
190
            // do not add givens for induction vars to the parent symbol table.
191
            if value_ptr.get(&name).is_some()
192
                && matches!(&decl.kind() as &DeclarationKind, DeclarationKind::Given(_))
193
            {
194
                continue;
195
            }
196

            
197
            let Name::Machine(_) = &name else {
198
                bug!(
199
                    "the symbol table of the return expression of a comprehension should only contain machine names"
200
                );
201
            };
202

            
203
            let id = decl.id();
204
            let new_decl = symtab.gensym(&decl.domain().unwrap());
205

            
206
            machine_name_translations.insert(id, new_decl);
207
        }
208

            
209
        // Update references to use the new delcarations.
210
        #[allow(clippy::arc_with_non_send_sync)]
211
        let return_expression = return_expression.transform_bi(&move |atom: Atom| {
212
            if let Atom::Reference(ref decl) = atom
213
                && let id = decl.id()
214
                && let Some(new_decl) = machine_name_translations.get(&id)
215
            {
216
                Atom::Reference(conjure_cp::ast::Reference::new(new_decl.clone()))
217
            } else {
218
                atom
219
            }
220
        });
221

            
222
        return_expressions.push(return_expression);
223
    }
224

            
225
    Ok(return_expressions)
226
}
227

            
228
/// Eliminate all references to non induction variables by introducing dummy variables to the
229
/// return expression. This modified return expression is added to the generator model, which is
230
/// returned.
231
///
232
/// Dummy variables must be the same type as the AC operators identity value.
233
///
234
/// To reduce the number of dummy variables, we turn the largest expression containing only
235
/// non induction variables and of the correct type into a dummy variable.
236
///
237
/// If there is no such expression, (e.g. and[(a<i) | i: int(1..10)]) , we use the smallest
238
/// expression of the correct type that contains a non induction variable. This ensures that
239
/// we lose as few references to induction variables as possible.
240
fn add_return_expression_to_generator_model(
241
    mut generator_submodel: SubModel,
242
    return_expression: Expression,
243
    ac_operator: &ACOperatorKind,
244
) -> SubModel {
245
    let mut zipper = Zipper::new(return_expression);
246
    let mut symtab = generator_submodel.symbols_mut();
247

            
248
    // for sum/product we want to put integer expressions into dummy variables,
249
    // for and/or we want to put boolean expressions into dummy variables.
250
    let dummy_var_type = ac_operator.identity().return_type();
251

            
252
    'outer: loop {
253
        let focus: &mut Expression = zipper.focus_mut();
254

            
255
        let (non_induction_vars, induction_vars) = partition_variables(focus, &symtab);
256

            
257
        // an expression or its descendants needs to be turned into a dummy variable if it
258
        // contains non-induction variables.
259
        let has_non_induction_vars = !non_induction_vars.is_empty();
260

            
261
        // does this expression contain induction variables?
262
        let has_induction_vars = !induction_vars.is_empty();
263

            
264
        // can this expression be turned into a dummy variable?
265
        let can_be_dummy_var = can_be_dummy_variable(focus, &dummy_var_type);
266

            
267
        // The expression and its descendants don't need a dummy variables, so we don't
268
        // need to descend into its children.
269
        if !has_non_induction_vars {
270
            // go to next node or quit
271
            while zipper.go_right().is_none() {
272
                let Some(()) = zipper.go_up() else {
273
                    // visited all nodes
274
                    break 'outer;
275
                };
276
            }
277
            continue;
278
        }
279

            
280
        // The expression contains non-induction variables:
281

            
282
        // does this expression have any children that can be turned into dummy variables?
283
        let has_eligible_child = focus.universe().iter().skip(1).any(|expr| {
284
            // eligible if it can be turned into a dummy variable, and turning it into a
285
            // dummy variable removes a non-induction variable from the model.
286
            can_be_dummy_variable(expr, &dummy_var_type)
287
                && contains_non_induction_variables(expr, &symtab)
288
        });
289

            
290
        // This expression has no child that can be turned into a dummy variable, but can
291
        // be a dummy variable => turn it into a dummy variable and continue.
292
        if !has_eligible_child && can_be_dummy_var {
293
            // introduce dummy var and continue
294
            let dummy_domain = focus.domain_of().unwrap();
295
            let dummy_decl = symtab.gensym(&dummy_domain);
296
            *focus = Expression::Atomic(
297
                Metadata::new(),
298
                Atom::Reference(conjure_cp::ast::Reference::new(dummy_decl)),
299
            );
300

            
301
            // go to next node
302
            while zipper.go_right().is_none() {
303
                let Some(()) = zipper.go_up() else {
304
                    // visited all nodes
305
                    break 'outer;
306
                };
307
            }
308
            continue;
309
        }
310
        // This expression has no child that can be turned into a dummy variable, and
311
        // cannot be a dummy variable => backtrack upwards to find a parent that can be a
312
        // dummy variable, and make it a dummy variable.
313
        else if !has_eligible_child && !can_be_dummy_var {
314
            // TODO: remove this case, make has_eligible_child check better?
315

            
316
            // go upwards until we find something that can be a dummy variable, make it
317
            // a dummy variable, then continue.
318
            while let Some(()) = zipper.go_up() {
319
                let focus = zipper.focus_mut();
320
                if can_be_dummy_variable(focus, &dummy_var_type) {
321
                    // TODO: this expression we are rewritng might already contain
322
                    // dummy vars - we might need a pass to get rid of the unused
323
                    // ones!
324
                    //
325
                    // introduce dummy var and continue
326
                    let dummy_domain = focus.domain_of().unwrap();
327
                    let dummy_decl = symtab.gensym(&dummy_domain);
328
                    *focus = Expression::Atomic(
329
                        Metadata::new(),
330
                        Atom::Reference(conjure_cp::ast::Reference::new(dummy_decl)),
331
                    );
332

            
333
                    // go to next node
334
                    while zipper.go_right().is_none() {
335
                        let Some(()) = zipper.go_up() else {
336
                            // visited all nodes
337
                            break 'outer;
338
                        };
339
                    }
340
                    continue;
341
                }
342
            }
343
            unreachable!();
344
        }
345
        // If the expression contains induction variables as well as non-induction
346
        // variables, try to retain the induction varables by finding a child that can be
347
        // made a dummy variable which has only non-induction variables.
348
        else if has_eligible_child && has_induction_vars {
349
            zipper
350
                .go_down()
351
                .expect("we know the focus has a child, so zipper.go_down() should succeed");
352
        }
353
        // This expression contains no induction variables, so no point trying to turn a
354
        // child into a dummy variable.
355
        else if has_eligible_child && !has_induction_vars {
356
            // introduce dummy var and continue
357
            let dummy_domain = focus.domain_of().unwrap();
358
            let dummy_decl = symtab.gensym(&dummy_domain);
359
            *focus = Expression::Atomic(
360
                Metadata::new(),
361
                Atom::Reference(conjure_cp::ast::Reference::new(dummy_decl)),
362
            );
363

            
364
            // go to next node
365
            while zipper.go_right().is_none() {
366
                let Some(()) = zipper.go_up() else {
367
                    // visited all nodes
368
                    break 'outer;
369
                };
370
            }
371
        } else {
372
            unreachable!()
373
        }
374
    }
375

            
376
    let new_return_expression = Expression::Neq(
377
        Metadata::new(),
378
        Moo::new(Expression::Atomic(
379
            Metadata::new(),
380
            ac_operator.identity().into(),
381
        )),
382
        Moo::new(zipper.rebuild_root()),
383
    );
384

            
385
    // double check that the above transformation didn't miss any stray non induction vars
386
    assert!(
387
        Biplate::<DeclarationPtr>::universe_bi(&new_return_expression)
388
            .iter()
389
            .all(|x| symtab.lookup_local(&x.name()).is_some()),
390
        "generator model should only contain references to variables in its symbol table."
391
    );
392

            
393
    std::mem::drop(symtab);
394

            
395
    generator_submodel.add_constraint(new_return_expression);
396

            
397
    generator_submodel
398
}
399

            
400
/// Returns a tuple of non-induction decision variables and induction variables inside the expression.
401
///
402
/// As lettings, givens, etc. will eventually be subsituted for constants, this only returns
403
/// non-induction _decision_ variables.
404
#[inline]
405
fn partition_variables(
406
    expr: &Expression,
407
    symtab: &SymbolTable,
408
) -> (VecDeque<Name>, VecDeque<Name>) {
409
    // doing this as two functions non_induction_variables and induction_variables might've been
410
    // easier to read.
411
    //
412
    // However, doing this in one function avoids an extra universe call...
413
    let (non_induction_vars, induction_vars): (VecDeque<Name>, VecDeque<Name>) =
414
        Biplate::<Name>::universe_bi(expr)
415
            .into_iter()
416
            .partition(|x| symtab.lookup_local(x).is_none());
417

            
418
    (non_induction_vars, induction_vars)
419
}
420

            
421
/// Returns `true` if `expr` can be turned into a dummy variable.
422
#[inline]
423
fn can_be_dummy_variable(expr: &Expression, dummy_variable_type: &ReturnType) -> bool {
424
    // do not put root expression in a dummy variable or things go wrong.
425
    if matches!(expr, Expression::Root(_, _)) {
426
        return false;
427
    };
428

            
429
    // is the expression the same type as the dummy variable?
430
    expr.return_type() == *dummy_variable_type
431
}
432

            
433
/// Returns `true` if `expr` or its descendants contains non-induction variables.
434
#[inline]
435
fn contains_non_induction_variables(expr: &Expression, symtab: &SymbolTable) -> bool {
436
    let names_referenced: VecDeque<Name> = expr.universe_bi();
437
    // a name is a non-induction variable if its definition is not in the local scope of the
438
    // comprehension's generators.
439
    names_referenced
440
        .iter()
441
        .any(|x| symtab.lookup_local(x).is_none())
442
}