1
//! Comprehension expansion rules
2

            
3
mod expand_native;
4
mod expand_via_solver;
5
mod expand_via_solver_ac;
6
mod via_solver_common;
7

            
8
pub use expand_native::expand_native;
9
pub use expand_via_solver::expand_via_solver;
10
pub use expand_via_solver_ac::expand_via_solver_ac;
11

            
12
use conjure_cp::{
13
    ast::{
14
        DeclarationPtr, Domain, DomainPtr, Expression as Expr, IntVal, Moo, Name, Range, Reference,
15
        SymbolTable, UnresolvedDomain,
16
        comprehension::{Comprehension, ComprehensionQualifier},
17
        serde::{HasId, ObjId},
18
    },
19
    bug, into_matrix_expr,
20
    rule_engine::{
21
        ApplicationError::RuleNotApplicable, ApplicationResult, Reduction, register_rule,
22
    },
23
    settings::{QuantifiedExpander, comprehension_expander},
24
};
25
use std::collections::HashMap;
26
use uniplate::{Biplate, Uniplate};
27

            
28
/// Rewrite top-level `exists` comprehensions into constraints over fresh machine `find`s.
29
///
30
/// `exists` is represented as `or([comprehension])`.
31
#[register_rule("Base", 2003, [Root])]
32
326892
fn exists_quantified_to_finds(expr: &Expr, symbols: &SymbolTable) -> ApplicationResult {
33
326892
    let Expr::Root(metadata, constraints) = expr else {
34
314521
        return Err(RuleNotApplicable);
35
    };
36

            
37
12371
    let mut new_symbols = symbols.clone();
38
12371
    let mut new_constraints = Vec::with_capacity(constraints.len());
39
12371
    let mut changed = false;
40

            
41
57231
    for constraint in constraints {
42
57231
        let Some(comprehension) = as_exists_comprehension(constraint) else {
43
56991
            new_constraints.push(constraint.clone());
44
56991
            continue;
45
        };
46

            
47
240
        let Some(new_constraints_for_exists) =
48
240
            rewrite_exists_comprehension_to_constraints(&comprehension, &mut new_symbols)
49
        else {
50
            new_constraints.push(constraint.clone());
51
            continue;
52
        };
53

            
54
240
        new_constraints.extend(new_constraints_for_exists);
55
240
        changed = true;
56
    }
57

            
58
12371
    if changed {
59
192
        Ok(Reduction::with_symbols(
60
192
            Expr::Root(metadata.clone(), new_constraints),
61
192
            new_symbols,
62
192
        ))
63
    } else {
64
12179
        Err(RuleNotApplicable)
65
    }
66
326892
}
67

            
68
/// Expand comprehensions using `--comprehension-expander native`.
69
#[register_rule("Base", 2000, [Comprehension])]
70
168892
fn expand_comprehension_native(expr: &Expr, symbols: &SymbolTable) -> ApplicationResult {
71
168892
    if comprehension_expander() != QuantifiedExpander::Native {
72
116120
        return Err(RuleNotApplicable);
73
52772
    }
74

            
75
52772
    let Expr::Comprehension(_, comprehension) = expr else {
76
52348
        return Err(RuleNotApplicable);
77
    };
78

            
79
424
    let comprehension = comprehension.as_ref().clone();
80

            
81
742
    for qual in &comprehension.qualifiers {
82
742
        if let ComprehensionQualifier::ExpressionGenerator { .. } = qual {
83
28
            return Err(RuleNotApplicable);
84
714
        }
85
    }
86

            
87
396
    let mut symbols = symbols.clone();
88
396
    let results = expand_native(comprehension, &mut symbols)
89
        .unwrap_or_else(|e| bug!("native comprehension expansion failed: {e}"));
90
396
    Ok(Reduction::with_symbols(into_matrix_expr!(results), symbols))
91
168892
}
92

            
93
/// Expand comprehensions using `--comprehension-expander via-solver`.
94
///
95
/// Algorithm sketch:
96
/// 1. Match one comprehension node.
97
/// 2. Build a temporary generator submodel from its qualifiers/guards.
98
/// 3. Materialise quantified declarations as temporary `find` declarations.
99
/// 4. Wrap that submodel as a standalone temporary model, with search order restricted to the
100
///    quantified names.
101
/// 5. Rewrite the temporary model using the configured rewriter and Minion-oriented rules.
102
/// 6. Solve the rewritten temporary model with Minion and keep only quantified assignments from
103
///    each solution.
104
/// 7. Instantiate the original return expression under each quantified assignment.
105
/// 8. Replace the comprehension by a matrix literal containing all instantiated return values.
106
#[register_rule("Base", 2000, [Comprehension])]
107
168892
fn expand_comprehension_via_solver(expr: &Expr, symbols: &SymbolTable) -> ApplicationResult {
108
52772
    if !matches!(
109
168892
        comprehension_expander(),
110
        QuantifiedExpander::ViaSolver | QuantifiedExpander::ViaSolverAc
111
    ) {
112
52772
        return Err(RuleNotApplicable);
113
116120
    }
114

            
115
116120
    let Expr::Comprehension(_, comprehension) = expr else {
116
115928
        return Err(RuleNotApplicable);
117
    };
118

            
119
192
    let comprehension = comprehension.as_ref().clone();
120

            
121
264
    for qual in &comprehension.qualifiers {
122
264
        if let ComprehensionQualifier::ExpressionGenerator { .. } = qual {
123
6
            return Err(RuleNotApplicable);
124
258
        }
125
    }
126

            
127
186
    let results = expand_via_solver(comprehension)
128
        .unwrap_or_else(|e| bug!("via-solver comprehension expansion failed: {e}"));
129
186
    Ok(Reduction::with_symbols(
130
186
        into_matrix_expr!(results),
131
186
        symbols.clone(),
132
186
    ))
133
168892
}
134

            
135
/// Expand comprehensions inside AC operators using `--comprehension-expander via-solver-ac`.
136
///
137
/// Algorithm sketch:
138
/// 1. Match an AC operator whose single child is a comprehension.
139
/// 2. Build a temporary generator submodel from the comprehension qualifiers/guards.
140
/// 3. Add a derived constraint from the return expression to this generator model:
141
///    localise non-local references, and replace non-quantified fragments with dummy variables so
142
///    the constraint depends only on locally solvable symbols.
143
/// 4. Materialise quantified declarations as temporary `find` declarations in the temporary model.
144
/// 5. Rewrite and solve the temporary model with Minion; keep only quantified assignments.
145
/// 6. Instantiate the original return expression under those assignments.
146
/// 7. Rebuild the same AC operator around the instantiated matrix literal.
147
#[register_rule("Base", 2002)]
148
318309
fn expand_comprehension_via_solver_ac(expr: &Expr, symbols: &SymbolTable) -> ApplicationResult {
149
318309
    if comprehension_expander() != QuantifiedExpander::ViaSolverAc {
150
131035
        return Err(RuleNotApplicable);
151
187274
    }
152

            
153
    // Is this an ac expression?
154
187274
    let ac_operator_kind = expr.to_ac_operator_kind().ok_or(RuleNotApplicable)?;
155

            
156
11474
    debug_assert_eq!(
157
11474
        expr.children().len(),
158
        1,
159
        "AC expressions should have exactly one child."
160
    );
161

            
162
11474
    let comprehension = as_single_comprehension(&expr.children()[0]).ok_or(RuleNotApplicable)?;
163

            
164
675
    for qual in &comprehension.qualifiers {
165
675
        if let ComprehensionQualifier::ExpressionGenerator { .. } = qual {
166
6
            return Err(RuleNotApplicable);
167
669
        }
168
    }
169

            
170
447
    let results =
171
603
        expand_via_solver_ac(comprehension, ac_operator_kind).or(Err(RuleNotApplicable))?;
172

            
173
447
    let new_expr = ac_operator_kind.as_expression(into_matrix_expr!(results));
174
447
    Ok(Reduction::with_symbols(new_expr, symbols.clone()))
175
318309
}
176

            
177
21946
fn as_single_comprehension(expr: &Expr) -> Option<Comprehension> {
178
21946
    if let Expr::Comprehension(_, comprehension) = expr {
179
849
        return Some(comprehension.as_ref().clone());
180
21097
    }
181

            
182
21097
    let exprs = expr.clone().unwrap_list()?;
183
17212
    let [Expr::Comprehension(_, comprehension)] = exprs.as_slice() else {
184
17212
        return None;
185
    };
186

            
187
    Some(comprehension.as_ref().clone())
188
21946
}
189

            
190
57231
fn as_exists_comprehension(expr: &Expr) -> Option<Comprehension> {
191
57231
    let Expr::Or(_, or_child) = expr else {
192
46759
        return None;
193
    };
194

            
195
10472
    as_single_comprehension(or_child.as_ref())
196
57231
}
197

            
198
240
fn rewrite_exists_comprehension_to_constraints(
199
240
    comprehension: &Comprehension,
200
240
    symbols: &mut SymbolTable,
201
240
) -> Option<Vec<Expr>> {
202
240
    let quantified_declarations = quantified_declarations(comprehension)?;
203

            
204
240
    let mut replacements_by_id: HashMap<ObjId, DeclarationPtr> = HashMap::new();
205
240
    let mut replacements_by_name: HashMap<Name, DeclarationPtr> = HashMap::new();
206

            
207
336
    for decl in quantified_declarations {
208
336
        let domain = decl.domain()?;
209
336
        let rewritten_domain =
210
336
            replace_declaration_ptrs_in_domain(domain, &replacements_by_id, &replacements_by_name);
211
336
        let fresh_decl = symbols.gen_find(&rewritten_domain);
212
336
        replacements_by_id.insert(decl.id(), fresh_decl.clone());
213
336
        replacements_by_name.insert(decl.name().clone(), fresh_decl);
214
    }
215

            
216
240
    let mut conjuncts = Vec::new();
217
336
    for qualifier in &comprehension.qualifiers {
218
336
        if let ComprehensionQualifier::Condition(condition) = qualifier {
219
            conjuncts.push(replace_declaration_ptrs_in_expr(
220
                condition.clone(),
221
                &replacements_by_id,
222
                &replacements_by_name,
223
            ));
224
336
        }
225
    }
226
240
    conjuncts.push(replace_declaration_ptrs_in_expr(
227
240
        comprehension.return_expression.clone(),
228
240
        &replacements_by_id,
229
240
        &replacements_by_name,
230
    ));
231

            
232
240
    Some(conjuncts)
233
240
}
234

            
235
240
fn quantified_declarations(comprehension: &Comprehension) -> Option<Vec<DeclarationPtr>> {
236
240
    let quantified_names = comprehension.quantified_vars();
237
240
    let symbols = comprehension.symbols();
238
240
    quantified_names
239
240
        .into_iter()
240
336
        .map(|name| symbols.lookup_local(&name))
241
240
        .collect()
242
240
}
243

            
244
336
fn replace_declaration_ptrs_in_expr(
245
336
    expr: Expr,
246
336
    replacements_by_id: &HashMap<ObjId, DeclarationPtr>,
247
336
    replacements_by_name: &HashMap<Name, DeclarationPtr>,
248
336
) -> Expr {
249
828
    expr.transform_bi(&|decl: DeclarationPtr| {
250
828
        if let Some(replacement) = replacements_by_id.get(&decl.id()) {
251
504
            return replacement.clone();
252
324
        }
253

            
254
324
        let name = decl.name().clone();
255
324
        replacements_by_name.get(&name).cloned().unwrap_or(decl)
256
828
    })
257
336
}
258

            
259
336
fn replace_declaration_ptrs_in_domain(
260
336
    domain: DomainPtr,
261
336
    replacements_by_id: &HashMap<ObjId, DeclarationPtr>,
262
336
    replacements_by_name: &HashMap<Name, DeclarationPtr>,
263
336
) -> DomainPtr {
264
336
    let mut rewritten = domain
265
336
        .transform_bi(&|expr: Expr| {
266
            replace_declaration_ptrs_in_expr(expr, replacements_by_id, replacements_by_name)
267
        })
268
336
        .transform_bi(&|reference: Reference| {
269
            replace_reference(reference, replacements_by_id, replacements_by_name)
270
        });
271

            
272
    // `Range<T>` does not participate in the generic biplate traversal, so recurse through
273
    // unresolved domain structure once to rewrite symbolic integer bounds.
274
336
    rewrite_int_ranges_in_domain_ptr(&mut rewritten, replacements_by_id, replacements_by_name);
275

            
276
336
    rewritten
277
336
}
278

            
279
fn replace_reference(
280
    reference: Reference,
281
    replacements_by_id: &HashMap<ObjId, DeclarationPtr>,
282
    replacements_by_name: &HashMap<Name, DeclarationPtr>,
283
) -> Reference {
284
    let replacement = replacements_by_id
285
        .get(&reference.ptr().id())
286
        .cloned()
287
        .or_else(|| {
288
            let name = reference.name().clone();
289
            replacements_by_name.get(&name).cloned()
290
        });
291

            
292
    replacement.map(Reference::new).unwrap_or(reference)
293
}
294

            
295
336
fn rewrite_int_ranges_in_domain_ptr(
296
336
    domain: &mut DomainPtr,
297
336
    replacements_by_id: &HashMap<ObjId, DeclarationPtr>,
298
336
    replacements_by_name: &HashMap<Name, DeclarationPtr>,
299
336
) {
300
336
    let mut rewritten = domain.as_ref().clone();
301
336
    rewrite_int_ranges_in_domain(&mut rewritten, replacements_by_id, replacements_by_name);
302
336
    *domain = Moo::new(rewritten);
303
336
}
304

            
305
336
fn rewrite_int_ranges_in_domain(
306
336
    domain: &mut Domain,
307
336
    replacements_by_id: &HashMap<ObjId, DeclarationPtr>,
308
336
    replacements_by_name: &HashMap<Name, DeclarationPtr>,
309
336
) {
310
336
    let Domain::Unresolved(unresolved) = domain else {
311
240
        return;
312
    };
313

            
314
96
    rewrite_int_ranges_in_unresolved_domain(
315
96
        Moo::make_mut(unresolved),
316
96
        replacements_by_id,
317
96
        replacements_by_name,
318
    );
319
336
}
320

            
321
96
fn rewrite_int_ranges_in_unresolved_domain(
322
96
    unresolved: &mut UnresolvedDomain,
323
96
    replacements_by_id: &HashMap<ObjId, DeclarationPtr>,
324
96
    replacements_by_name: &HashMap<Name, DeclarationPtr>,
325
96
) {
326
96
    match unresolved {
327
96
        UnresolvedDomain::Int(ranges) => {
328
96
            for range in ranges {
329
96
                rewrite_int_range(range, replacements_by_id, replacements_by_name);
330
96
            }
331
        }
332
        UnresolvedDomain::Set(attr, inner) => {
333
            rewrite_int_range(&mut attr.size, replacements_by_id, replacements_by_name);
334
            rewrite_int_ranges_in_domain_ptr(inner, replacements_by_id, replacements_by_name);
335
        }
336
        UnresolvedDomain::MSet(attr, inner) => {
337
            rewrite_int_range(&mut attr.size, replacements_by_id, replacements_by_name);
338
            rewrite_int_range(
339
                &mut attr.occurrence,
340
                replacements_by_id,
341
                replacements_by_name,
342
            );
343
            rewrite_int_ranges_in_domain_ptr(inner, replacements_by_id, replacements_by_name);
344
        }
345
        UnresolvedDomain::Matrix(inner, index_domains) => {
346
            rewrite_int_ranges_in_domain_ptr(inner, replacements_by_id, replacements_by_name);
347
            for index_domain in index_domains {
348
                rewrite_int_ranges_in_domain_ptr(
349
                    index_domain,
350
                    replacements_by_id,
351
                    replacements_by_name,
352
                );
353
            }
354
        }
355
        UnresolvedDomain::Tuple(inner_domains) => {
356
            for inner_domain in inner_domains {
357
                rewrite_int_ranges_in_domain_ptr(
358
                    inner_domain,
359
                    replacements_by_id,
360
                    replacements_by_name,
361
                );
362
            }
363
        }
364
        UnresolvedDomain::Reference(_) => {}
365
        UnresolvedDomain::Record(entries) => {
366
            for entry in entries {
367
                rewrite_int_ranges_in_domain_ptr(
368
                    &mut entry.domain,
369
                    replacements_by_id,
370
                    replacements_by_name,
371
                );
372
            }
373
        }
374
        UnresolvedDomain::Function(attr, domain, codomain) => {
375
            rewrite_int_range(&mut attr.size, replacements_by_id, replacements_by_name);
376
            rewrite_int_ranges_in_domain_ptr(domain, replacements_by_id, replacements_by_name);
377
            rewrite_int_ranges_in_domain_ptr(codomain, replacements_by_id, replacements_by_name);
378
        }
379
    }
380
96
}
381

            
382
96
fn rewrite_int_range(
383
96
    range: &mut Range<IntVal>,
384
96
    replacements_by_id: &HashMap<ObjId, DeclarationPtr>,
385
96
    replacements_by_name: &HashMap<Name, DeclarationPtr>,
386
96
) {
387
96
    match range {
388
        Range::Single(value) | Range::UnboundedL(value) | Range::UnboundedR(value) => {
389
            rewrite_int_value(value, replacements_by_id, replacements_by_name);
390
        }
391
96
        Range::Bounded(lower, upper) => {
392
96
            rewrite_int_value(lower, replacements_by_id, replacements_by_name);
393
96
            rewrite_int_value(upper, replacements_by_id, replacements_by_name);
394
96
        }
395
        Range::Unbounded => {}
396
    }
397
96
}
398

            
399
192
fn rewrite_int_value(
400
192
    int_val: &mut IntVal,
401
192
    replacements_by_id: &HashMap<ObjId, DeclarationPtr>,
402
192
    replacements_by_name: &HashMap<Name, DeclarationPtr>,
403
192
) {
404
192
    if let IntVal::Expr(expr) = int_val {
405
96
        let rewritten = replace_declaration_ptrs_in_expr(
406
96
            (**expr).clone(),
407
96
            replacements_by_id,
408
96
            replacements_by_name,
409
96
        );
410
96
        *expr = Moo::new(rewritten);
411
96
    }
412
192
}