1
use std::{
2
    collections::HashMap,
3
    sync::{Arc, RwLock},
4
};
5

            
6
use conjure_cp::{
7
    ast::{
8
        Atom, DecisionVariable, DeclarationKind, DeclarationPtr, Expression, Literal, Metadata,
9
        Model, Name, Reference, SymbolTable, eval_constant, run_partial_evaluator,
10
        serde::{HasId as _, ObjId},
11
    },
12
    context::Context,
13
    rule_engine::{RuleSet, rewrite_morph, rewrite_naive},
14
    settings::Rewriter,
15
};
16
use uniplate::Biplate as _;
17

            
18
/// Configures a temporary model for solver-based comprehension expansion.
19
1836
pub(super) fn with_temporary_model(model: Model, search_order: Option<Vec<Name>>) -> Model {
20
1836
    let mut model = model;
21
1836
    model.context = Arc::new(RwLock::new(Context::default()));
22
1836
    model.search_order = search_order;
23
1836
    model
24
1836
}
25

            
26
/// Rewrites a model using the currently configured rewriter and Minion-oriented rule sets.
27
918
pub(super) fn rewrite_model_with_configured_rewriter<'a>(
28
918
    model: Model,
29
918
    rule_sets: &Vec<&'a RuleSet<'a>>,
30
918
    configured_rewriter: Rewriter,
31
918
) -> Model {
32
918
    match configured_rewriter {
33
        Rewriter::Morph => rewrite_morph(model, rule_sets, false),
34
918
        Rewriter::Naive => rewrite_naive(&model, rule_sets, false).unwrap(),
35
    }
36
918
}
37

            
38
/// Instantiates rewritten return expressions with quantified assignments.
39
///
40
/// This does not mutate any parent symbol table.
41
666
pub(super) fn instantiate_return_expressions_from_values(
42
666
    values: Vec<HashMap<Name, Literal>>,
43
666
    return_expression_model: &Model,
44
666
    quantified_vars: &[Name],
45
666
) -> Vec<Expression> {
46
666
    let mut return_expressions = vec![];
47

            
48
11142
    for value in values {
49
11142
        let return_expression_model = return_expression_model.clone();
50
11142
        let child_symtab = return_expression_model.symbols().clone();
51
11142
        let mut return_expression = return_expression_model.into_single_expression();
52

            
53
        // We only bind quantified variables.
54
11142
        let value: HashMap<_, _> = value
55
11142
            .into_iter()
56
29403
            .filter(|(name, _)| quantified_vars.contains(name))
57
11142
            .collect();
58

            
59
        // Bind quantified references by updating declaration targets, then simplify.
60
11142
        let _temp_value_bindings =
61
11142
            temporarily_bind_quantified_vars_to_values(&child_symtab, &return_expression, &value);
62
11142
        return_expression = concretise_resolved_reference_atoms(return_expression);
63
11142
        return_expression = simplify_expression(return_expression);
64

            
65
11142
        return_expressions.push(return_expression);
66
    }
67

            
68
666
    return_expressions
69
666
}
70

            
71
11142
pub(super) fn retain_quantified_solution_values(
72
11142
    mut values: HashMap<Name, Literal>,
73
11142
    quantified_vars: &[Name],
74
11142
) -> HashMap<Name, Literal> {
75
45243
    values.retain(|name, _| quantified_vars.contains(name));
76
11142
    values
77
11142
}
78

            
79
21330
pub(super) fn simplify_expression(mut expr: Expression) -> Expression {
80
    // Keep applying evaluators to a fixed point, or until no changes are made.
81
21330
    for _ in 0..128 {
82
874791
        let next = expr.clone().transform_bi(&|subexpr: Expression| {
83
874791
            if let Some(lit) = eval_constant(&subexpr) {
84
533457
                return Expression::Atomic(Metadata::new(), Atom::Literal(lit));
85
341334
            }
86
341334
            if let Ok(reduction) = run_partial_evaluator(&subexpr) {
87
18279
                return reduction.new_expression;
88
323055
            }
89
323055
            subexpr
90
874791
        });
91

            
92
40464
        if next == expr {
93
21330
            break;
94
19134
        }
95
19134
        expr = next;
96
    }
97

            
98
21330
    expr
99
21330
}
100

            
101
11142
fn concretise_resolved_reference_atoms(expr: Expression) -> Expression {
102
206172
    expr.transform_bi(&|atom: Atom| match atom {
103
177939
        Atom::Reference(reference) => reference
104
177939
            .resolve_constant()
105
177939
            .map_or_else(|| Atom::Reference(reference), Atom::Literal),
106
28233
        other => other,
107
206172
    })
108
11142
}
109

            
110
10008
pub(super) fn lift_machine_references_into_parent_scope(
111
10008
    expr: Expression,
112
10008
    child_symtab: &SymbolTable,
113
10008
    parent_symtab: &mut SymbolTable,
114
10008
) -> Expression {
115
10008
    let mut machine_name_translations: HashMap<ObjId, DeclarationPtr> = HashMap::new();
116

            
117
29034
    for (name, decl) in child_symtab.clone().into_iter_local() {
118
        // Do not add quantified declarations for quantified vars to the parent symbol table.
119
29034
        if matches!(
120
29034
            &decl.kind() as &DeclarationKind,
121
            DeclarationKind::Quantified(_)
122
        ) {
123
            continue;
124
29034
        }
125

            
126
29034
        if !matches!(&name, Name::Machine(_)) {
127
29034
            continue;
128
        }
129

            
130
        let id = decl.id();
131
        let new_decl = parent_symtab.gensym(&decl.domain().unwrap());
132
        machine_name_translations.insert(id, new_decl);
133
    }
134

            
135
11178
    expr.transform_bi(&|atom: Atom| {
136
11178
        if let Atom::Reference(ref decl) = atom
137
792
            && let id = decl.id()
138
792
            && let Some(new_decl) = machine_name_translations.get(&id)
139
        {
140
            Atom::Reference(Reference::new(new_decl.clone()))
141
        } else {
142
11178
            atom
143
        }
144
11178
    })
145
10008
}
146

            
147
/// Guard that temporarily converts quantified declarations to temporary value-lettings.
148
struct TempQuantifiedValueLettingGuard {
149
    originals: Vec<(DeclarationPtr, DeclarationKind)>,
150
}
151

            
152
impl Drop for TempQuantifiedValueLettingGuard {
153
11142
    fn drop(&mut self) {
154
29403
        for (mut decl, kind) in self.originals.drain(..) {
155
29403
            let _ = decl.replace_kind(kind);
156
29403
        }
157
11142
    }
158
}
159

            
160
150264
fn maybe_bind_temp_value_letting(
161
150264
    originals: &mut Vec<(DeclarationPtr, DeclarationKind)>,
162
150264
    decl: &DeclarationPtr,
163
150264
    lit: &Literal,
164
150264
) {
165
150264
    if originals
166
150264
        .iter()
167
265962
        .any(|(existing, _)| existing.id() == decl.id())
168
    {
169
120861
        return;
170
29403
    }
171

            
172
29403
    let mut decl = decl.clone();
173
29403
    let old_kind = decl.kind().clone();
174
29403
    let temp_kind = DeclarationKind::TemporaryValueLetting(Expression::Atomic(
175
29403
        Metadata::new(),
176
29403
        Atom::Literal(lit.clone()),
177
29403
    ));
178
29403
    let _ = decl.replace_kind(temp_kind);
179
29403
    originals.push((decl, old_kind));
180
150264
}
181

            
182
11142
fn temporarily_bind_quantified_vars_to_values(
183
11142
    symbols: &SymbolTable,
184
11142
    expr: &Expression,
185
11142
    values: &HashMap<Name, Literal>,
186
11142
) -> TempQuantifiedValueLettingGuard {
187
11142
    let mut originals = Vec::new();
188

            
189
29403
    for (name, lit) in values {
190
29403
        let Some(decl) = symbols.lookup_local(name) else {
191
            continue;
192
        };
193

            
194
29403
        maybe_bind_temp_value_letting(&mut originals, &decl, lit);
195

            
196
29403
        let kind = decl.kind();
197
29403
        if let DeclarationKind::Quantified(inner) = &*kind
198
            && let Some(generator) = inner.generator()
199
        {
200
            maybe_bind_temp_value_letting(&mut originals, generator, lit);
201
29403
        }
202
    }
203

            
204
    // Some expressions can still reference quantified declarations from an earlier scope
205
    // (e.g. after comprehension rewrites that rebuild generator declarations). Bind those
206
    // declaration pointers directly as well.
207
177939
    for decl in uniplate::Biplate::<DeclarationPtr>::universe_bi(expr) {
208
177939
        let name = decl.name().clone();
209
177939
        let Some(lit) = values.get(&name) else {
210
57078
            continue;
211
        };
212

            
213
120861
        maybe_bind_temp_value_letting(&mut originals, &decl, lit);
214

            
215
120861
        let kind = decl.kind();
216
120861
        if let DeclarationKind::Quantified(inner) = &*kind
217
            && let Some(generator) = inner.generator()
218
        {
219
            maybe_bind_temp_value_letting(&mut originals, generator, lit);
220
120861
        }
221
    }
222

            
223
11142
    TempQuantifiedValueLettingGuard { originals }
224
11142
}
225

            
226
/// Guard that temporarily converts quantified declarations to find declarations.
227
pub(super) struct TempQuantifiedFindGuard {
228
    originals: Vec<(DeclarationPtr, DeclarationKind)>,
229
}
230

            
231
impl Drop for TempQuantifiedFindGuard {
232
729
    fn drop(&mut self) {
233
765
        for (mut decl, kind) in self.originals.drain(..) {
234
765
            let _ = decl.replace_kind(kind);
235
765
        }
236
729
    }
237
}
238

            
239
/// Converts quantified declarations in `model` to temporary find declarations.
240
729
pub(super) fn temporarily_materialise_quantified_vars_as_finds(
241
729
    model: &Model,
242
729
    quantified_vars: &[Name],
243
729
) -> TempQuantifiedFindGuard {
244
729
    let symbols = model.symbols().clone();
245
729
    let mut originals = Vec::new();
246

            
247
765
    for name in quantified_vars {
248
765
        let Some(mut decl) = symbols.lookup_local(name) else {
249
            continue;
250
        };
251

            
252
765
        let old_kind = decl.kind().clone();
253
765
        let Some(domain) = decl.domain() else {
254
            continue;
255
        };
256

            
257
765
        let new_kind = DeclarationKind::Find(DecisionVariable::new(domain));
258
765
        let _ = decl.replace_kind(new_kind);
259
765
        originals.push((decl, old_kind));
260
    }
261

            
262
729
    TempQuantifiedFindGuard { originals }
263
729
}