1
use std::{
2
    collections::HashMap,
3
    sync::{Arc, RwLock},
4
};
5

            
6
use conjure_cp::{
7
    ast::{
8
        Atom, DecisionVariable, DeclarationKind, DeclarationPtr, Expression, Literal, Metadata,
9
        Model, Name, Reference, SymbolTable, eval_constant, run_partial_evaluator,
10
        serde::{HasId as _, ObjId},
11
    },
12
    context::Context,
13
    rule_engine::{
14
        RuleSet,
15
        rewrite_model_with_configured_rewriter as rewrite_model_with_configured_rewriter_core,
16
    },
17
    settings::Rewriter,
18
};
19
use uniplate::{Biplate as _, Uniplate as _};
20

            
21
/// Configures a temporary model for solver-based comprehension expansion.
22
1578
pub(super) fn with_temporary_model(model: Model, search_order: Option<Vec<Name>>) -> Model {
23
1578
    let mut model = model;
24
1578
    model.context = Arc::new(RwLock::new(Context::default()));
25
1578
    model.search_order = search_order;
26
1578
    model
27
1578
}
28

            
29
/// Rewrites a model using the currently configured rewriter and Minion-oriented rule sets.
30
789
pub(super) fn rewrite_model_with_configured_rewriter<'a>(
31
789
    model: Model,
32
789
    rule_sets: &Vec<&'a RuleSet<'a>>,
33
789
    configured_rewriter: Rewriter,
34
789
) -> Model {
35
789
    rewrite_model_with_configured_rewriter_core(model, rule_sets, configured_rewriter).unwrap()
36
789
}
37

            
38
/// Instantiates rewritten return expressions with quantified assignments.
39
///
40
/// This does not mutate any parent symbol table.
41
633
pub(super) fn instantiate_return_expressions_from_values(
42
633
    values: Vec<HashMap<Name, Literal>>,
43
633
    return_expression_model: &Model,
44
633
    quantified_vars: &[Name],
45
633
) -> Vec<Expression> {
46
633
    let mut return_expressions = vec![];
47

            
48
14153
    for value in values {
49
14153
        let return_expression_model = return_expression_model.clone();
50
14153
        let child_symtab = return_expression_model.symbols().clone();
51
14153
        let mut return_expression = return_expression_model.into_single_expression();
52

            
53
        // We only bind quantified variables.
54
14153
        let value: HashMap<_, _> = value
55
14153
            .into_iter()
56
38403
            .filter(|(name, _)| quantified_vars.contains(name))
57
14153
            .collect();
58

            
59
        // Bind quantified references by updating declaration targets, then simplify.
60
14153
        let _temp_value_bindings =
61
14153
            temporarily_bind_quantified_vars_to_values(&child_symtab, &return_expression, &value);
62
14153
        return_expression = concretise_resolved_reference_atoms(return_expression);
63
14153
        let Some(mut return_expression) = strip_guarded_safe_index_conditions(return_expression)
64
        else {
65
            continue;
66
        };
67
14153
        return_expression = simplify_expression(return_expression);
68

            
69
14153
        return_expressions.push(return_expression);
70
    }
71

            
72
633
    return_expressions
73
633
}
74

            
75
14153
pub(super) fn retain_quantified_solution_values(
76
14153
    mut values: HashMap<Name, Literal>,
77
14153
    quantified_vars: &[Name],
78
14153
) -> HashMap<Name, Literal> {
79
55883
    values.retain(|name, _| quantified_vars.contains(name));
80
14153
    values
81
14153
}
82

            
83
28545
pub(super) fn simplify_expression(mut expr: Expression) -> Expression {
84
    // Keep applying evaluators to a fixed point, or until no changes are made.
85
28545
    for _ in 0..128 {
86
1141249
        let next = expr.clone().transform_bi(&|subexpr: Expression| {
87
1141249
            if let Some(lit) = eval_constant(&subexpr) {
88
661673
                return Expression::Atomic(Metadata::new(), Atom::Literal(lit));
89
479576
            }
90
479576
            if let Ok(reduction) = run_partial_evaluator(&subexpr) {
91
28976
                return reduction.new_expression;
92
450600
            }
93
450600
            subexpr
94
1141249
        });
95

            
96
54328
        if next == expr {
97
28545
            break;
98
25783
        }
99
25783
        expr = next;
100
    }
101

            
102
28545
    expr
103
28545
}
104

            
105
/// Strips internal `InDomain` guards that were introduced by bubbling a boolean `SafeIndex`
106
/// inside a comprehension return expression.
107
///
108
/// When a source comprehension already has a guard that filters out dummy/out-of-domain values,
109
/// earlier rewrites can turn that filter into a conjunction like
110
/// `and([SafeIndex(...), __inDomain(index, domain)])`. If we instantiate that directly, a
111
/// false `__inDomain` becomes a literal `false` element, which changes the comprehension from
112
/// "skip this element" to "include false".
113
///
114
/// We recover the original filtering behaviour only for this narrow internal pattern:
115
/// a top-level conjunction with exactly one non-guard term and one or more `InDomain` guards
116
/// that constrain indices used by that term. If any such guard is false after instantiation,
117
/// the element is skipped entirely.
118
27851
pub(super) fn strip_guarded_safe_index_conditions(expr: Expression) -> Option<Expression> {
119
27851
    let mut conjuncts = Vec::new();
120
27851
    collect_top_level_and_terms(expr.clone(), &mut conjuncts);
121

            
122
27851
    if conjuncts.len() == 1 && conjuncts[0] == expr {
123
27552
        return Some(expr);
124
299
    }
125

            
126
299
    let (guards, mut non_guards): (Vec<_>, Vec<_>) =
127
299
        conjuncts.into_iter().partition(is_indomain_guard);
128

            
129
299
    if guards.is_empty() || non_guards.len() != 1 {
130
219
        return Some(expr);
131
80
    }
132

            
133
80
    let guarded_term = non_guards.pop().expect("length checked above");
134

            
135
80
    if !guards
136
80
        .iter()
137
80
        .all(|guard| guard_targets_safe_index_index(guard, &guarded_term))
138
    {
139
        return Some(expr);
140
80
    }
141

            
142
80
    for guard in &guards {
143
80
        let simplified_guard = simplify_expression(guard.clone());
144
80
        match eval_constant(&simplified_guard) {
145
48
            Some(Literal::Bool(true)) => {}
146
32
            Some(Literal::Bool(false)) => return None,
147
            _ => return Some(expr),
148
        }
149
    }
150

            
151
48
    Some(guarded_term)
152
27851
}
153

            
154
29961
fn collect_top_level_and_terms(expr: Expression, out: &mut Vec<Expression>) {
155
29961
    if let Expression::And(_, ref children) = expr
156
619
        && let Some(children) = children.as_ref().clone().unwrap_list()
157
    {
158
2110
        for child in children {
159
2110
            collect_top_level_and_terms(child, out);
160
2110
        }
161
29662
    } else {
162
29662
        out.push(expr);
163
29662
    }
164
29961
}
165

            
166
2110
fn is_indomain_guard(expr: &Expression) -> bool {
167
2110
    matches!(expr, Expression::InDomain(_, _, _))
168
2110
}
169

            
170
80
fn guard_targets_safe_index_index(guard: &Expression, expr: &Expression) -> bool {
171
80
    let Expression::InDomain(_, guarded_index, _) = guard else {
172
        return false;
173
    };
174

            
175
80
    expr.universe().into_iter().any(|subexpr| {
176
80
        let Expression::SafeIndex(_, _, indices) = subexpr else {
177
            return false;
178
        };
179

            
180
160
        indices.iter().any(|index| index == guarded_index.as_ref())
181
80
    })
182
80
}
183

            
184
14153
fn concretise_resolved_reference_atoms(expr: Expression) -> Expression {
185
274247
    expr.transform_bi(&|atom: Atom| match atom {
186
236456
        Atom::Reference(reference) => reference
187
236456
            .resolve_constant()
188
236456
            .map_or_else(|| Atom::Reference(reference), Atom::Literal),
189
37791
        other => other,
190
274247
    })
191
14153
}
192

            
193
13666
pub(super) fn lift_machine_references_into_parent_scope(
194
13666
    expr: Expression,
195
13666
    child_symtab: &SymbolTable,
196
13666
    parent_symtab: &mut SymbolTable,
197
13666
) -> Expression {
198
13666
    let mut machine_name_translations: HashMap<ObjId, DeclarationPtr> = HashMap::new();
199

            
200
38644
    for (name, decl) in child_symtab.clone().into_iter_local() {
201
        // Do not add quantified declarations for quantified vars to the parent symbol table.
202
38644
        if matches!(
203
38644
            &decl.kind() as &DeclarationKind,
204
            DeclarationKind::Quantified(_)
205
        ) {
206
            continue;
207
38644
        }
208

            
209
38644
        if !matches!(&name, Name::Machine(_)) {
210
38644
            continue;
211
        }
212

            
213
        let id = decl.id();
214
        let new_decl = parent_symtab.gen_find(&decl.domain().unwrap());
215
        machine_name_translations.insert(id, new_decl);
216
    }
217

            
218
19888
    expr.transform_bi(&|atom: Atom| {
219
19888
        if let Atom::Reference(ref decl) = atom
220
2773
            && let id = decl.id()
221
2773
            && let Some(new_decl) = machine_name_translations.get(&id)
222
        {
223
            Atom::Reference(Reference::new(new_decl.clone()))
224
        } else {
225
19888
            atom
226
        }
227
19888
    })
228
13666
}
229

            
230
/// Guard that temporarily converts quantified declarations to temporary value-lettings.
231
struct TempQuantifiedValueLettingGuard {
232
    originals: Vec<(DeclarationPtr, DeclarationKind)>,
233
}
234

            
235
impl Drop for TempQuantifiedValueLettingGuard {
236
14153
    fn drop(&mut self) {
237
38403
        for (mut decl, kind) in self.originals.drain(..) {
238
38403
            let _ = decl.replace_kind(kind);
239
38403
        }
240
14153
    }
241
}
242

            
243
198968
fn maybe_bind_temp_value_letting(
244
198968
    originals: &mut Vec<(DeclarationPtr, DeclarationKind)>,
245
198968
    decl: &DeclarationPtr,
246
198968
    lit: &Literal,
247
198968
) {
248
198968
    if originals
249
198968
        .iter()
250
353557
        .any(|(existing, _)| existing.id() == decl.id())
251
    {
252
160565
        return;
253
38403
    }
254

            
255
38403
    let mut decl = decl.clone();
256
38403
    let old_kind = decl.kind().clone();
257
38403
    let temp_kind = DeclarationKind::TemporaryValueLetting(Expression::Atomic(
258
38403
        Metadata::new(),
259
38403
        Atom::Literal(lit.clone()),
260
38403
    ));
261
38403
    let _ = decl.replace_kind(temp_kind);
262
38403
    originals.push((decl, old_kind));
263
198968
}
264

            
265
14153
fn temporarily_bind_quantified_vars_to_values(
266
14153
    symbols: &SymbolTable,
267
14153
    expr: &Expression,
268
14153
    values: &HashMap<Name, Literal>,
269
14153
) -> TempQuantifiedValueLettingGuard {
270
14153
    let mut originals = Vec::new();
271

            
272
38403
    for (name, lit) in values {
273
38403
        let Some(decl) = symbols.lookup_local(name) else {
274
            continue;
275
        };
276

            
277
38403
        maybe_bind_temp_value_letting(&mut originals, &decl, lit);
278

            
279
38403
        let kind = decl.kind();
280
38403
        if let DeclarationKind::Quantified(inner) = &*kind
281
            && let Some(generator) = inner.generator()
282
        {
283
            maybe_bind_temp_value_letting(&mut originals, generator, lit);
284
38403
        }
285
    }
286

            
287
    // Some expressions can still reference quantified declarations from an earlier scope
288
    // (e.g. after comprehension rewrites that rebuild generator declarations). Bind those
289
    // declaration pointers directly as well.
290
236456
    for decl in uniplate::Biplate::<DeclarationPtr>::universe_bi(expr) {
291
236456
        let name = decl.name().clone();
292
236456
        let Some(lit) = values.get(&name) else {
293
75891
            continue;
294
        };
295

            
296
160565
        maybe_bind_temp_value_letting(&mut originals, &decl, lit);
297

            
298
160565
        let kind = decl.kind();
299
160565
        if let DeclarationKind::Quantified(inner) = &*kind
300
            && let Some(generator) = inner.generator()
301
        {
302
            maybe_bind_temp_value_letting(&mut originals, generator, lit);
303
160565
        }
304
    }
305

            
306
14153
    TempQuantifiedValueLettingGuard { originals }
307
14153
}
308

            
309
/// Guard that temporarily converts quantified declarations to find declarations.
310
pub(super) struct TempQuantifiedFindGuard {
311
    originals: Vec<(DeclarationPtr, DeclarationKind)>,
312
}
313

            
314
impl Drop for TempQuantifiedFindGuard {
315
603
    fn drop(&mut self) {
316
651
        for (mut decl, kind) in self.originals.drain(..) {
317
651
            let _ = decl.replace_kind(kind);
318
651
        }
319
603
    }
320
}
321

            
322
/// Converts quantified declarations in `model` to temporary find declarations.
323
603
pub(super) fn temporarily_materialise_quantified_vars_as_finds(
324
603
    model: &Model,
325
603
    quantified_vars: &[Name],
326
603
) -> TempQuantifiedFindGuard {
327
603
    let symbols = model.symbols().clone();
328
603
    let mut originals = Vec::new();
329

            
330
651
    for name in quantified_vars {
331
651
        let Some(mut decl) = symbols.lookup_local(name) else {
332
            continue;
333
        };
334

            
335
651
        let old_kind = decl.kind().clone();
336
651
        let Some(domain) = decl.domain() else {
337
            continue;
338
        };
339

            
340
651
        let new_kind = DeclarationKind::Find(DecisionVariable::new(domain));
341
651
        let _ = decl.replace_kind(new_kind);
342
651
        originals.push((decl, old_kind));
343
    }
344

            
345
603
    TempQuantifiedFindGuard { originals }
346
603
}