1
//! Comprehension expansion rules
2

            
3
mod expand_native;
4
mod expand_via_solver;
5
mod expand_via_solver_ac;
6
mod via_solver_common;
7

            
8
pub use expand_native::expand_native;
9
pub use expand_via_solver::expand_via_solver;
10
pub use expand_via_solver_ac::expand_via_solver_ac;
11

            
12
use conjure_cp::{
13
    ast::{
14
        DeclarationPtr, Domain, DomainPtr, Expression as Expr, IntVal, Moo, Name, Range, Reference,
15
        SymbolTable, UnresolvedDomain,
16
        comprehension::{Comprehension, ComprehensionQualifier},
17
        serde::{HasId, ObjId},
18
    },
19
    bug, into_matrix_expr,
20
    rule_engine::{
21
        ApplicationError::RuleNotApplicable, ApplicationResult, Reduction, register_rule,
22
    },
23
    settings::{QuantifiedExpander, comprehension_expander},
24
};
25
use std::collections::HashMap;
26
use uniplate::{Biplate, Uniplate};
27

            
28
/// Rewrite top-level `exists` comprehensions into constraints over fresh machine `find`s.
29
///
30
/// `exists` is represented as `or([comprehension])`.
31
#[register_rule(("Base", 2003))]
32
185952
fn exists_quantified_to_finds(expr: &Expr, symbols: &SymbolTable) -> ApplicationResult {
33
185952
    let Expr::Root(metadata, constraints) = expr else {
34
176373
        return Err(RuleNotApplicable);
35
    };
36

            
37
9579
    let mut new_symbols = symbols.clone();
38
9579
    let mut new_constraints = Vec::with_capacity(constraints.len());
39
9579
    let mut changed = false;
40

            
41
40536
    for constraint in constraints {
42
40536
        let Some(comprehension) = as_exists_comprehension(constraint) else {
43
40374
            new_constraints.push(constraint.clone());
44
40374
            continue;
45
        };
46

            
47
162
        let Some(new_constraints_for_exists) =
48
162
            rewrite_exists_comprehension_to_constraints(&comprehension, &mut new_symbols)
49
        else {
50
            new_constraints.push(constraint.clone());
51
            continue;
52
        };
53

            
54
162
        new_constraints.extend(new_constraints_for_exists);
55
162
        changed = true;
56
    }
57

            
58
9579
    if changed {
59
117
        Ok(Reduction::with_symbols(
60
117
            Expr::Root(metadata.clone(), new_constraints),
61
117
            new_symbols,
62
117
        ))
63
    } else {
64
9462
        Err(RuleNotApplicable)
65
    }
66
185952
}
67

            
68
/// Expand comprehensions using `--comprehension-expander native`.
69
#[register_rule(("Base", 2000))]
70
35243
fn expand_comprehension_native(expr: &Expr, symbols: &SymbolTable) -> ApplicationResult {
71
35243
    if comprehension_expander() != QuantifiedExpander::Native {
72
33939
        return Err(RuleNotApplicable);
73
1304
    }
74

            
75
1304
    let Expr::Comprehension(_, comprehension) = expr else {
76
1250
        return Err(RuleNotApplicable);
77
    };
78

            
79
54
    let comprehension = comprehension.as_ref().clone();
80
54
    let mut symbols = symbols.clone();
81
54
    let results = expand_native(comprehension, &mut symbols)
82
        .unwrap_or_else(|e| bug!("native comprehension expansion failed: {e}"));
83
54
    Ok(Reduction::with_symbols(into_matrix_expr!(results), symbols))
84
35243
}
85

            
86
/// Expand comprehensions using `--comprehension-expander via-solver`.
87
///
88
/// Algorithm sketch:
89
/// 1. Match one comprehension node.
90
/// 2. Build a temporary generator submodel from its qualifiers/guards.
91
/// 3. Materialise quantified declarations as temporary `find` declarations.
92
/// 4. Wrap that submodel as a standalone temporary model, with search order restricted to the
93
///    quantified names.
94
/// 5. Rewrite the temporary model using the configured rewriter and Minion-oriented rules.
95
/// 6. Solve the rewritten temporary model with Minion and keep only quantified assignments from
96
///    each solution.
97
/// 7. Instantiate the original return expression under each quantified assignment.
98
/// 8. Replace the comprehension by a matrix literal containing all instantiated return values.
99
#[register_rule(("Base", 2000))]
100
35243
fn expand_comprehension_via_solver(expr: &Expr, symbols: &SymbolTable) -> ApplicationResult {
101
1304
    if !matches!(
102
35243
        comprehension_expander(),
103
        QuantifiedExpander::ViaSolver | QuantifiedExpander::ViaSolverAc
104
    ) {
105
1304
        return Err(RuleNotApplicable);
106
33939
    }
107

            
108
33939
    let Expr::Comprehension(_, comprehension) = expr else {
109
33876
        return Err(RuleNotApplicable);
110
    };
111

            
112
63
    let comprehension = comprehension.as_ref().clone();
113
63
    let results = expand_via_solver(comprehension)
114
        .unwrap_or_else(|e| bug!("via-solver comprehension expansion failed: {e}"));
115
63
    Ok(Reduction::with_symbols(
116
63
        into_matrix_expr!(results),
117
63
        symbols.clone(),
118
63
    ))
119
35243
}
120

            
121
/// Expand comprehensions inside AC operators using `--comprehension-expander via-solver-ac`.
122
///
123
/// Algorithm sketch:
124
/// 1. Match an AC operator whose single child is a comprehension.
125
/// 2. Build a temporary generator submodel from the comprehension qualifiers/guards.
126
/// 3. Add a derived constraint from the return expression to this generator model:
127
///    localise non-local references, and replace non-quantified fragments with dummy variables so
128
///    the constraint depends only on locally solvable symbols.
129
/// 4. Materialise quantified declarations as temporary `find` declarations in the temporary model.
130
/// 5. Rewrite and solve the temporary model with Minion; keep only quantified assignments.
131
/// 6. Instantiate the original return expression under those assignments.
132
/// 7. Rebuild the same AC operator around the instantiated matrix literal.
133
#[register_rule(("Base", 2002))]
134
183144
fn expand_comprehension_via_solver_ac(expr: &Expr, symbols: &SymbolTable) -> ApplicationResult {
135
183144
    if comprehension_expander() != QuantifiedExpander::ViaSolverAc {
136
9822
        return Err(RuleNotApplicable);
137
173322
    }
138

            
139
    // Is this an ac expression?
140
173322
    let ac_operator_kind = expr.to_ac_operator_kind().ok_or(RuleNotApplicable)?;
141

            
142
13455
    debug_assert_eq!(
143
13455
        expr.children().len(),
144
        1,
145
        "AC expressions should have exactly one child."
146
    );
147

            
148
13455
    let comprehension = as_single_comprehension(&expr.children()[0]).ok_or(RuleNotApplicable)?;
149

            
150
477
    let results =
151
729
        expand_via_solver_ac(comprehension, ac_operator_kind).or(Err(RuleNotApplicable))?;
152

            
153
477
    let new_expr = ac_operator_kind.as_expression(into_matrix_expr!(results));
154
477
    Ok(Reduction::with_symbols(new_expr, symbols.clone()))
155
183144
}
156

            
157
15732
fn as_single_comprehension(expr: &Expr) -> Option<Comprehension> {
158
15732
    if let Expr::Comprehension(_, comprehension) = expr {
159
891
        return Some(comprehension.as_ref().clone());
160
14841
    }
161

            
162
14841
    let exprs = expr.clone().unwrap_list()?;
163
8352
    let [Expr::Comprehension(_, comprehension)] = exprs.as_slice() else {
164
8352
        return None;
165
    };
166

            
167
    Some(comprehension.as_ref().clone())
168
15732
}
169

            
170
40536
fn as_exists_comprehension(expr: &Expr) -> Option<Comprehension> {
171
40536
    let Expr::Or(_, or_child) = expr else {
172
38259
        return None;
173
    };
174

            
175
2277
    as_single_comprehension(or_child.as_ref())
176
40536
}
177

            
178
162
fn rewrite_exists_comprehension_to_constraints(
179
162
    comprehension: &Comprehension,
180
162
    symbols: &mut SymbolTable,
181
162
) -> Option<Vec<Expr>> {
182
162
    let quantified_declarations = quantified_declarations(comprehension)?;
183

            
184
162
    let mut replacements_by_id: HashMap<ObjId, DeclarationPtr> = HashMap::new();
185
162
    let mut replacements_by_name: HashMap<Name, DeclarationPtr> = HashMap::new();
186

            
187
198
    for decl in quantified_declarations {
188
198
        let domain = decl.domain()?;
189
198
        let rewritten_domain =
190
198
            replace_declaration_ptrs_in_domain(domain, &replacements_by_id, &replacements_by_name);
191
198
        let fresh_decl = symbols.gensym(&rewritten_domain);
192
198
        replacements_by_id.insert(decl.id(), fresh_decl.clone());
193
198
        replacements_by_name.insert(decl.name().clone(), fresh_decl);
194
    }
195

            
196
162
    let mut conjuncts = Vec::new();
197
198
    for qualifier in &comprehension.qualifiers {
198
198
        if let ComprehensionQualifier::Condition(condition) = qualifier {
199
            conjuncts.push(replace_declaration_ptrs_in_expr(
200
                condition.clone(),
201
                &replacements_by_id,
202
                &replacements_by_name,
203
            ));
204
198
        }
205
    }
206
162
    conjuncts.push(replace_declaration_ptrs_in_expr(
207
162
        comprehension.return_expression.clone(),
208
162
        &replacements_by_id,
209
162
        &replacements_by_name,
210
    ));
211

            
212
162
    Some(conjuncts)
213
162
}
214

            
215
162
fn quantified_declarations(comprehension: &Comprehension) -> Option<Vec<DeclarationPtr>> {
216
162
    let quantified_names = comprehension.quantified_vars();
217
162
    let symbols = comprehension.symbols();
218
162
    quantified_names
219
162
        .into_iter()
220
198
        .map(|name| symbols.lookup_local(&name))
221
162
        .collect()
222
162
}
223

            
224
198
fn replace_declaration_ptrs_in_expr(
225
198
    expr: Expr,
226
198
    replacements_by_id: &HashMap<ObjId, DeclarationPtr>,
227
198
    replacements_by_name: &HashMap<Name, DeclarationPtr>,
228
198
) -> Expr {
229
504
    expr.transform_bi(&|decl: DeclarationPtr| {
230
504
        if let Some(replacement) = replacements_by_id.get(&decl.id()) {
231
288
            return replacement.clone();
232
216
        }
233

            
234
216
        let name = decl.name().clone();
235
216
        replacements_by_name.get(&name).cloned().unwrap_or(decl)
236
504
    })
237
198
}
238

            
239
198
fn replace_declaration_ptrs_in_domain(
240
198
    domain: DomainPtr,
241
198
    replacements_by_id: &HashMap<ObjId, DeclarationPtr>,
242
198
    replacements_by_name: &HashMap<Name, DeclarationPtr>,
243
198
) -> DomainPtr {
244
198
    let mut rewritten = domain
245
198
        .transform_bi(&|expr: Expr| {
246
            replace_declaration_ptrs_in_expr(expr, replacements_by_id, replacements_by_name)
247
        })
248
198
        .transform_bi(&|reference: Reference| {
249
            replace_reference(reference, replacements_by_id, replacements_by_name)
250
        });
251

            
252
    // `Range<T>` does not participate in the generic biplate traversal, so recurse through
253
    // unresolved domain structure once to rewrite symbolic integer bounds.
254
198
    rewrite_int_ranges_in_domain_ptr(&mut rewritten, replacements_by_id, replacements_by_name);
255

            
256
198
    rewritten
257
198
}
258

            
259
fn replace_reference(
260
    reference: Reference,
261
    replacements_by_id: &HashMap<ObjId, DeclarationPtr>,
262
    replacements_by_name: &HashMap<Name, DeclarationPtr>,
263
) -> Reference {
264
    let replacement = replacements_by_id
265
        .get(&reference.ptr().id())
266
        .cloned()
267
        .or_else(|| {
268
            let name = reference.name().clone();
269
            replacements_by_name.get(&name).cloned()
270
        });
271

            
272
    replacement.map(Reference::new).unwrap_or(reference)
273
}
274

            
275
198
fn rewrite_int_ranges_in_domain_ptr(
276
198
    domain: &mut DomainPtr,
277
198
    replacements_by_id: &HashMap<ObjId, DeclarationPtr>,
278
198
    replacements_by_name: &HashMap<Name, DeclarationPtr>,
279
198
) {
280
198
    let mut rewritten = domain.as_ref().clone();
281
198
    rewrite_int_ranges_in_domain(&mut rewritten, replacements_by_id, replacements_by_name);
282
198
    *domain = Moo::new(rewritten);
283
198
}
284

            
285
198
fn rewrite_int_ranges_in_domain(
286
198
    domain: &mut Domain,
287
198
    replacements_by_id: &HashMap<ObjId, DeclarationPtr>,
288
198
    replacements_by_name: &HashMap<Name, DeclarationPtr>,
289
198
) {
290
198
    let Domain::Unresolved(unresolved) = domain else {
291
126
        return;
292
    };
293

            
294
72
    rewrite_int_ranges_in_unresolved_domain(
295
72
        Moo::make_mut(unresolved),
296
72
        replacements_by_id,
297
72
        replacements_by_name,
298
    );
299
198
}
300

            
301
72
fn rewrite_int_ranges_in_unresolved_domain(
302
72
    unresolved: &mut UnresolvedDomain,
303
72
    replacements_by_id: &HashMap<ObjId, DeclarationPtr>,
304
72
    replacements_by_name: &HashMap<Name, DeclarationPtr>,
305
72
) {
306
72
    match unresolved {
307
72
        UnresolvedDomain::Int(ranges) => {
308
72
            for range in ranges {
309
72
                rewrite_int_range(range, replacements_by_id, replacements_by_name);
310
72
            }
311
        }
312
        UnresolvedDomain::Set(attr, inner) => {
313
            rewrite_int_range(&mut attr.size, replacements_by_id, replacements_by_name);
314
            rewrite_int_ranges_in_domain_ptr(inner, replacements_by_id, replacements_by_name);
315
        }
316
        UnresolvedDomain::MSet(attr, inner) => {
317
            rewrite_int_range(&mut attr.size, replacements_by_id, replacements_by_name);
318
            rewrite_int_range(
319
                &mut attr.occurrence,
320
                replacements_by_id,
321
                replacements_by_name,
322
            );
323
            rewrite_int_ranges_in_domain_ptr(inner, replacements_by_id, replacements_by_name);
324
        }
325
        UnresolvedDomain::Matrix(inner, index_domains) => {
326
            rewrite_int_ranges_in_domain_ptr(inner, replacements_by_id, replacements_by_name);
327
            for index_domain in index_domains {
328
                rewrite_int_ranges_in_domain_ptr(
329
                    index_domain,
330
                    replacements_by_id,
331
                    replacements_by_name,
332
                );
333
            }
334
        }
335
        UnresolvedDomain::Tuple(inner_domains) => {
336
            for inner_domain in inner_domains {
337
                rewrite_int_ranges_in_domain_ptr(
338
                    inner_domain,
339
                    replacements_by_id,
340
                    replacements_by_name,
341
                );
342
            }
343
        }
344
        UnresolvedDomain::Reference(_) => {}
345
        UnresolvedDomain::Record(entries) => {
346
            for entry in entries {
347
                rewrite_int_ranges_in_domain_ptr(
348
                    &mut entry.domain,
349
                    replacements_by_id,
350
                    replacements_by_name,
351
                );
352
            }
353
        }
354
        UnresolvedDomain::Function(attr, domain, codomain) => {
355
            rewrite_int_range(&mut attr.size, replacements_by_id, replacements_by_name);
356
            rewrite_int_ranges_in_domain_ptr(domain, replacements_by_id, replacements_by_name);
357
            rewrite_int_ranges_in_domain_ptr(codomain, replacements_by_id, replacements_by_name);
358
        }
359
    }
360
72
}
361

            
362
72
fn rewrite_int_range(
363
72
    range: &mut Range<IntVal>,
364
72
    replacements_by_id: &HashMap<ObjId, DeclarationPtr>,
365
72
    replacements_by_name: &HashMap<Name, DeclarationPtr>,
366
72
) {
367
72
    match range {
368
        Range::Single(value) | Range::UnboundedL(value) | Range::UnboundedR(value) => {
369
            rewrite_int_value(value, replacements_by_id, replacements_by_name);
370
        }
371
72
        Range::Bounded(lower, upper) => {
372
72
            rewrite_int_value(lower, replacements_by_id, replacements_by_name);
373
72
            rewrite_int_value(upper, replacements_by_id, replacements_by_name);
374
72
        }
375
        Range::Unbounded => {}
376
    }
377
72
}
378

            
379
144
fn rewrite_int_value(
380
144
    int_val: &mut IntVal,
381
144
    replacements_by_id: &HashMap<ObjId, DeclarationPtr>,
382
144
    replacements_by_name: &HashMap<Name, DeclarationPtr>,
383
144
) {
384
144
    if let IntVal::Expr(expr) = int_val {
385
36
        let rewritten = replace_declaration_ptrs_in_expr(
386
36
            (**expr).clone(),
387
36
            replacements_by_id,
388
36
            replacements_by_name,
389
36
        );
390
36
        *expr = Moo::new(rewritten);
391
108
    }
392
144
}