1
use std::{
2
    collections::HashMap,
3
    sync::{Arc, RwLock},
4
};
5

            
6
use conjure_cp::{
7
    ast::{
8
        Atom, DecisionVariable, DeclarationKind, DeclarationPtr, Expression, Literal, Metadata,
9
        Model, Name, Reference, SymbolTable, eval_constant, run_partial_evaluator,
10
        serde::{HasId as _, ObjId},
11
    },
12
    context::Context,
13
    rule_engine::{RuleSet, rewrite_morph, rewrite_naive},
14
    settings::Rewriter,
15
};
16
use uniplate::Biplate as _;
17

            
18
/// Configures a temporary model for solver-based comprehension expansion.
19
612
pub(super) fn with_temporary_model(model: Model, search_order: Option<Vec<Name>>) -> Model {
20
612
    let mut model = model;
21
612
    model.context = Arc::new(RwLock::new(Context::default()));
22
612
    model.search_order = search_order;
23
612
    model
24
612
}
25

            
26
/// Rewrites a model using the currently configured rewriter and Minion-oriented rule sets.
27
306
pub(super) fn rewrite_model_with_configured_rewriter<'a>(
28
306
    model: Model,
29
306
    rule_sets: &Vec<&'a RuleSet<'a>>,
30
306
    configured_rewriter: Rewriter,
31
306
) -> Model {
32
306
    match configured_rewriter {
33
        Rewriter::Morph => rewrite_morph(model, rule_sets, false),
34
306
        Rewriter::Naive => rewrite_naive(&model, rule_sets, false).unwrap(),
35
    }
36
306
}
37

            
38
/// Instantiates rewritten return expressions with quantified assignments.
39
///
40
/// This does not mutate any parent symbol table.
41
222
pub(super) fn instantiate_return_expressions_from_values(
42
222
    values: Vec<HashMap<Name, Literal>>,
43
222
    return_expression_model: &Model,
44
222
    quantified_vars: &[Name],
45
222
) -> Vec<Expression> {
46
222
    let mut return_expressions = vec![];
47

            
48
3714
    for value in values {
49
3714
        let return_expression_model = return_expression_model.clone();
50
3714
        let child_symtab = return_expression_model.symbols().clone();
51
3714
        let mut return_expression = return_expression_model.into_single_expression();
52

            
53
        // We only bind quantified variables.
54
3714
        let value: HashMap<_, _> = value
55
3714
            .into_iter()
56
9801
            .filter(|(name, _)| quantified_vars.contains(name))
57
3714
            .collect();
58

            
59
        // Bind quantified references by updating declaration targets, then simplify.
60
3714
        let _temp_value_bindings =
61
3714
            temporarily_bind_quantified_vars_to_values(&child_symtab, &return_expression, &value);
62
3714
        return_expression = concretise_resolved_reference_atoms(return_expression);
63
3714
        return_expression = simplify_expression(return_expression);
64

            
65
3714
        return_expressions.push(return_expression);
66
    }
67

            
68
222
    return_expressions
69
222
}
70

            
71
3714
pub(super) fn retain_quantified_solution_values(
72
3714
    mut values: HashMap<Name, Literal>,
73
3714
    quantified_vars: &[Name],
74
3714
) -> HashMap<Name, Literal> {
75
15081
    values.retain(|name, _| quantified_vars.contains(name));
76
3714
    values
77
3714
}
78

            
79
7110
pub(super) fn simplify_expression(mut expr: Expression) -> Expression {
80
    // Keep applying evaluators to a fixed point, or until no changes are made.
81
7110
    for _ in 0..128 {
82
291597
        let next = expr.clone().transform_bi(&|subexpr: Expression| {
83
291597
            if let Some(lit) = eval_constant(&subexpr) {
84
177819
                return Expression::Atomic(Metadata::new(), Atom::Literal(lit));
85
113778
            }
86
113778
            if let Ok(reduction) = run_partial_evaluator(&subexpr) {
87
6093
                return reduction.new_expression;
88
107685
            }
89
107685
            subexpr
90
291597
        });
91

            
92
13488
        if next == expr {
93
7110
            break;
94
6378
        }
95
6378
        expr = next;
96
    }
97

            
98
7110
    expr
99
7110
}
100

            
101
3714
fn concretise_resolved_reference_atoms(expr: Expression) -> Expression {
102
68724
    expr.transform_bi(&|atom: Atom| match atom {
103
59313
        Atom::Reference(reference) => reference
104
59313
            .resolve_constant()
105
59313
            .map_or_else(|| Atom::Reference(reference), Atom::Literal),
106
9411
        other => other,
107
68724
    })
108
3714
}
109

            
110
3336
pub(super) fn lift_machine_references_into_parent_scope(
111
3336
    expr: Expression,
112
3336
    child_symtab: &SymbolTable,
113
3336
    parent_symtab: &mut SymbolTable,
114
3336
) -> Expression {
115
3336
    let mut machine_name_translations: HashMap<ObjId, DeclarationPtr> = HashMap::new();
116

            
117
9678
    for (name, decl) in child_symtab.clone().into_iter_local() {
118
        // Do not add quantified declarations for quantified vars to the parent symbol table.
119
9678
        if matches!(
120
9678
            &decl.kind() as &DeclarationKind,
121
            DeclarationKind::Quantified(_)
122
        ) {
123
            continue;
124
9678
        }
125

            
126
9678
        if !matches!(&name, Name::Machine(_)) {
127
9678
            continue;
128
        }
129

            
130
        let id = decl.id();
131
        let new_decl = parent_symtab.gensym(&decl.domain().unwrap());
132
        machine_name_translations.insert(id, new_decl);
133
    }
134

            
135
3726
    expr.transform_bi(&|atom: Atom| {
136
3726
        if let Atom::Reference(ref decl) = atom
137
264
            && let id = decl.id()
138
264
            && let Some(new_decl) = machine_name_translations.get(&id)
139
        {
140
            Atom::Reference(Reference::new(new_decl.clone()))
141
        } else {
142
3726
            atom
143
        }
144
3726
    })
145
3336
}
146

            
147
/// Guard that temporarily converts quantified declarations to temporary value-lettings.
148
struct TempQuantifiedValueLettingGuard {
149
    originals: Vec<(DeclarationPtr, DeclarationKind)>,
150
}
151

            
152
impl Drop for TempQuantifiedValueLettingGuard {
153
3714
    fn drop(&mut self) {
154
9801
        for (mut decl, kind) in self.originals.drain(..) {
155
9801
            let _ = decl.replace_kind(kind);
156
9801
        }
157
3714
    }
158
}
159

            
160
50088
fn maybe_bind_temp_value_letting(
161
50088
    originals: &mut Vec<(DeclarationPtr, DeclarationKind)>,
162
50088
    decl: &DeclarationPtr,
163
50088
    lit: &Literal,
164
50088
) {
165
50088
    if originals
166
50088
        .iter()
167
88524
        .any(|(existing, _)| existing.id() == decl.id())
168
    {
169
40287
        return;
170
9801
    }
171

            
172
9801
    let mut decl = decl.clone();
173
9801
    let old_kind = decl.kind().clone();
174
9801
    let temp_kind = DeclarationKind::TemporaryValueLetting(Expression::Atomic(
175
9801
        Metadata::new(),
176
9801
        Atom::Literal(lit.clone()),
177
9801
    ));
178
9801
    let _ = decl.replace_kind(temp_kind);
179
9801
    originals.push((decl, old_kind));
180
50088
}
181

            
182
3714
fn temporarily_bind_quantified_vars_to_values(
183
3714
    symbols: &SymbolTable,
184
3714
    expr: &Expression,
185
3714
    values: &HashMap<Name, Literal>,
186
3714
) -> TempQuantifiedValueLettingGuard {
187
3714
    let mut originals = Vec::new();
188

            
189
9801
    for (name, lit) in values {
190
9801
        let Some(decl) = symbols.lookup_local(name) else {
191
            continue;
192
        };
193

            
194
9801
        maybe_bind_temp_value_letting(&mut originals, &decl, lit);
195

            
196
9801
        let kind = decl.kind();
197
9801
        if let DeclarationKind::Quantified(inner) = &*kind
198
            && let Some(generator) = inner.generator()
199
        {
200
            maybe_bind_temp_value_letting(&mut originals, generator, lit);
201
9801
        }
202
    }
203

            
204
    // Some expressions can still reference quantified declarations from an earlier scope
205
    // (e.g. after comprehension rewrites that rebuild generator declarations). Bind those
206
    // declaration pointers directly as well.
207
59313
    for decl in uniplate::Biplate::<DeclarationPtr>::universe_bi(expr) {
208
59313
        let name = decl.name().clone();
209
59313
        let Some(lit) = values.get(&name) else {
210
19026
            continue;
211
        };
212

            
213
40287
        maybe_bind_temp_value_letting(&mut originals, &decl, lit);
214

            
215
40287
        let kind = decl.kind();
216
40287
        if let DeclarationKind::Quantified(inner) = &*kind
217
            && let Some(generator) = inner.generator()
218
        {
219
            maybe_bind_temp_value_letting(&mut originals, generator, lit);
220
40287
        }
221
    }
222

            
223
3714
    TempQuantifiedValueLettingGuard { originals }
224
3714
}
225

            
226
/// Guard that temporarily converts quantified declarations to find declarations.
227
pub(super) struct TempQuantifiedFindGuard {
228
    originals: Vec<(DeclarationPtr, DeclarationKind)>,
229
}
230

            
231
impl Drop for TempQuantifiedFindGuard {
232
243
    fn drop(&mut self) {
233
255
        for (mut decl, kind) in self.originals.drain(..) {
234
255
            let _ = decl.replace_kind(kind);
235
255
        }
236
243
    }
237
}
238

            
239
/// Converts quantified declarations in `model` to temporary find declarations.
240
243
pub(super) fn temporarily_materialise_quantified_vars_as_finds(
241
243
    model: &Model,
242
243
    quantified_vars: &[Name],
243
243
) -> TempQuantifiedFindGuard {
244
243
    let symbols = model.symbols().clone();
245
243
    let mut originals = Vec::new();
246

            
247
255
    for name in quantified_vars {
248
255
        let Some(mut decl) = symbols.lookup_local(name) else {
249
            continue;
250
        };
251

            
252
255
        let old_kind = decl.kind().clone();
253
255
        let Some(domain) = decl.domain() else {
254
            continue;
255
        };
256

            
257
255
        let new_kind = DeclarationKind::Find(DecisionVariable::new(domain));
258
255
        let _ = decl.replace_kind(new_kind);
259
255
        originals.push((decl, old_kind));
260
    }
261

            
262
243
    TempQuantifiedFindGuard { originals }
263
243
}