1
use std::{
2
    collections::HashMap,
3
    sync::{Arc, RwLock},
4
};
5

            
6
use conjure_cp::{
7
    ast::{
8
        Atom, DecisionVariable, DeclarationKind, DeclarationPtr, Expression, Literal, Metadata,
9
        Model, Name, Reference, SymbolTable, eval_constant, run_partial_evaluator,
10
        serde::{HasId as _, ObjId},
11
    },
12
    context::Context,
13
    rule_engine::{
14
        RuleSet,
15
        rewrite_model_with_configured_rewriter as rewrite_model_with_configured_rewriter_core,
16
    },
17
    settings::Rewriter,
18
};
19
use uniplate::Biplate as _;
20

            
21
/// Configures a temporary model for solver-based comprehension expansion.
22
1578
pub(super) fn with_temporary_model(model: Model, search_order: Option<Vec<Name>>) -> Model {
23
1578
    let mut model = model;
24
1578
    model.context = Arc::new(RwLock::new(Context::default()));
25
1578
    model.search_order = search_order;
26
1578
    model
27
1578
}
28

            
29
/// Rewrites a model using the currently configured rewriter and Minion-oriented rule sets.
30
789
pub(super) fn rewrite_model_with_configured_rewriter<'a>(
31
789
    model: Model,
32
789
    rule_sets: &Vec<&'a RuleSet<'a>>,
33
789
    configured_rewriter: Rewriter,
34
789
) -> Model {
35
789
    rewrite_model_with_configured_rewriter_core(model, rule_sets, configured_rewriter).unwrap()
36
789
}
37

            
38
/// Instantiates rewritten return expressions with quantified assignments.
39
///
40
/// This does not mutate any parent symbol table.
41
633
pub(super) fn instantiate_return_expressions_from_values(
42
633
    values: Vec<HashMap<Name, Literal>>,
43
633
    return_expression_model: &Model,
44
633
    quantified_vars: &[Name],
45
633
) -> Vec<Expression> {
46
633
    let mut return_expressions = vec![];
47

            
48
14153
    for value in values {
49
14153
        let return_expression_model = return_expression_model.clone();
50
14153
        let child_symtab = return_expression_model.symbols().clone();
51
14153
        let mut return_expression = return_expression_model.into_single_expression();
52

            
53
        // We only bind quantified variables.
54
14153
        let value: HashMap<_, _> = value
55
14153
            .into_iter()
56
38403
            .filter(|(name, _)| quantified_vars.contains(name))
57
14153
            .collect();
58

            
59
        // Bind quantified references by updating declaration targets, then simplify.
60
14153
        let _temp_value_bindings =
61
14153
            temporarily_bind_quantified_vars_to_values(&child_symtab, &return_expression, &value);
62
14153
        return_expression = concretise_resolved_reference_atoms(return_expression);
63
14153
        return_expression = simplify_expression(return_expression);
64

            
65
14153
        return_expressions.push(return_expression);
66
    }
67

            
68
633
    return_expressions
69
633
}
70

            
71
14153
pub(super) fn retain_quantified_solution_values(
72
14153
    mut values: HashMap<Name, Literal>,
73
14153
    quantified_vars: &[Name],
74
14153
) -> HashMap<Name, Literal> {
75
55883
    values.retain(|name, _| quantified_vars.contains(name));
76
14153
    values
77
14153
}
78

            
79
28497
pub(super) fn simplify_expression(mut expr: Expression) -> Expression {
80
    // Keep applying evaluators to a fixed point, or until no changes are made.
81
28497
    for _ in 0..128 {
82
1141649
        let next = expr.clone().transform_bi(&|subexpr: Expression| {
83
1141649
            if let Some(lit) = eval_constant(&subexpr) {
84
661753
                return Expression::Atomic(Metadata::new(), Atom::Literal(lit));
85
479896
            }
86
479896
            if let Ok(reduction) = run_partial_evaluator(&subexpr) {
87
29056
                return reduction.new_expression;
88
450840
            }
89
450840
            subexpr
90
1141649
        });
91

            
92
54232
        if next == expr {
93
28497
            break;
94
25735
        }
95
25735
        expr = next;
96
    }
97

            
98
28497
    expr
99
28497
}
100

            
101
14153
fn concretise_resolved_reference_atoms(expr: Expression) -> Expression {
102
274247
    expr.transform_bi(&|atom: Atom| match atom {
103
236456
        Atom::Reference(reference) => reference
104
236456
            .resolve_constant()
105
236456
            .map_or_else(|| Atom::Reference(reference), Atom::Literal),
106
37791
        other => other,
107
274247
    })
108
14153
}
109

            
110
13698
pub(super) fn lift_machine_references_into_parent_scope(
111
13698
    expr: Expression,
112
13698
    child_symtab: &SymbolTable,
113
13698
    parent_symtab: &mut SymbolTable,
114
13698
) -> Expression {
115
13698
    let mut machine_name_translations: HashMap<ObjId, DeclarationPtr> = HashMap::new();
116

            
117
38676
    for (name, decl) in child_symtab.clone().into_iter_local() {
118
        // Do not add quantified declarations for quantified vars to the parent symbol table.
119
38676
        if matches!(
120
38676
            &decl.kind() as &DeclarationKind,
121
            DeclarationKind::Quantified(_)
122
        ) {
123
            continue;
124
38676
        }
125

            
126
38676
        if !matches!(&name, Name::Machine(_)) {
127
38676
            continue;
128
        }
129

            
130
        let id = decl.id();
131
        let new_decl = parent_symtab.gen_find(&decl.domain().unwrap());
132
        machine_name_translations.insert(id, new_decl);
133
    }
134

            
135
19920
    expr.transform_bi(&|atom: Atom| {
136
19920
        if let Atom::Reference(ref decl) = atom
137
2773
            && let id = decl.id()
138
2773
            && let Some(new_decl) = machine_name_translations.get(&id)
139
        {
140
            Atom::Reference(Reference::new(new_decl.clone()))
141
        } else {
142
19920
            atom
143
        }
144
19920
    })
145
13698
}
146

            
147
/// Guard that temporarily converts quantified declarations to temporary value-lettings.
148
struct TempQuantifiedValueLettingGuard {
149
    originals: Vec<(DeclarationPtr, DeclarationKind)>,
150
}
151

            
152
impl Drop for TempQuantifiedValueLettingGuard {
153
14153
    fn drop(&mut self) {
154
38403
        for (mut decl, kind) in self.originals.drain(..) {
155
38403
            let _ = decl.replace_kind(kind);
156
38403
        }
157
14153
    }
158
}
159

            
160
198968
fn maybe_bind_temp_value_letting(
161
198968
    originals: &mut Vec<(DeclarationPtr, DeclarationKind)>,
162
198968
    decl: &DeclarationPtr,
163
198968
    lit: &Literal,
164
198968
) {
165
198968
    if originals
166
198968
        .iter()
167
353732
        .any(|(existing, _)| existing.id() == decl.id())
168
    {
169
160565
        return;
170
38403
    }
171

            
172
38403
    let mut decl = decl.clone();
173
38403
    let old_kind = decl.kind().clone();
174
38403
    let temp_kind = DeclarationKind::TemporaryValueLetting(Expression::Atomic(
175
38403
        Metadata::new(),
176
38403
        Atom::Literal(lit.clone()),
177
38403
    ));
178
38403
    let _ = decl.replace_kind(temp_kind);
179
38403
    originals.push((decl, old_kind));
180
198968
}
181

            
182
14153
fn temporarily_bind_quantified_vars_to_values(
183
14153
    symbols: &SymbolTable,
184
14153
    expr: &Expression,
185
14153
    values: &HashMap<Name, Literal>,
186
14153
) -> TempQuantifiedValueLettingGuard {
187
14153
    let mut originals = Vec::new();
188

            
189
38403
    for (name, lit) in values {
190
38403
        let Some(decl) = symbols.lookup_local(name) else {
191
            continue;
192
        };
193

            
194
38403
        maybe_bind_temp_value_letting(&mut originals, &decl, lit);
195

            
196
38403
        let kind = decl.kind();
197
38403
        if let DeclarationKind::Quantified(inner) = &*kind
198
            && let Some(generator) = inner.generator()
199
        {
200
            maybe_bind_temp_value_letting(&mut originals, generator, lit);
201
38403
        }
202
    }
203

            
204
    // Some expressions can still reference quantified declarations from an earlier scope
205
    // (e.g. after comprehension rewrites that rebuild generator declarations). Bind those
206
    // declaration pointers directly as well.
207
236456
    for decl in uniplate::Biplate::<DeclarationPtr>::universe_bi(expr) {
208
236456
        let name = decl.name().clone();
209
236456
        let Some(lit) = values.get(&name) else {
210
75891
            continue;
211
        };
212

            
213
160565
        maybe_bind_temp_value_letting(&mut originals, &decl, lit);
214

            
215
160565
        let kind = decl.kind();
216
160565
        if let DeclarationKind::Quantified(inner) = &*kind
217
            && let Some(generator) = inner.generator()
218
        {
219
            maybe_bind_temp_value_letting(&mut originals, generator, lit);
220
160565
        }
221
    }
222

            
223
14153
    TempQuantifiedValueLettingGuard { originals }
224
14153
}
225

            
226
/// Guard that temporarily converts quantified declarations to find declarations.
227
pub(super) struct TempQuantifiedFindGuard {
228
    originals: Vec<(DeclarationPtr, DeclarationKind)>,
229
}
230

            
231
impl Drop for TempQuantifiedFindGuard {
232
603
    fn drop(&mut self) {
233
651
        for (mut decl, kind) in self.originals.drain(..) {
234
651
            let _ = decl.replace_kind(kind);
235
651
        }
236
603
    }
237
}
238

            
239
/// Converts quantified declarations in `model` to temporary find declarations.
240
603
pub(super) fn temporarily_materialise_quantified_vars_as_finds(
241
603
    model: &Model,
242
603
    quantified_vars: &[Name],
243
603
) -> TempQuantifiedFindGuard {
244
603
    let symbols = model.symbols().clone();
245
603
    let mut originals = Vec::new();
246

            
247
651
    for name in quantified_vars {
248
651
        let Some(mut decl) = symbols.lookup_local(name) else {
249
            continue;
250
        };
251

            
252
651
        let old_kind = decl.kind().clone();
253
651
        let Some(domain) = decl.domain() else {
254
            continue;
255
        };
256

            
257
651
        let new_kind = DeclarationKind::Find(DecisionVariable::new(domain));
258
651
        let _ = decl.replace_kind(new_kind);
259
651
        originals.push((decl, old_kind));
260
    }
261

            
262
603
    TempQuantifiedFindGuard { originals }
263
603
}