1
use std::{
2
    collections::HashMap,
3
    sync::{Arc, RwLock},
4
};
5

            
6
use conjure_cp::{
7
    ast::{
8
        Atom, DecisionVariable, DeclarationKind, DeclarationPtr, Expression, Literal, Metadata,
9
        Model, Name, Reference, SymbolTable, eval_constant, run_partial_evaluator,
10
        serde::{HasId as _, ObjId},
11
    },
12
    context::Context,
13
    rule_engine::{
14
        RuleSet,
15
        rewrite_model_with_configured_rewriter as rewrite_model_with_configured_rewriter_core,
16
    },
17
    settings::Rewriter,
18
};
19
1232
use uniplate::Biplate as _;
20
1232

            
21
/// Configures a temporary model for solver-based comprehension expansion.
22
1852
pub(super) fn with_temporary_model(model: Model, search_order: Option<Vec<Name>>) -> Model {
23
1852
    let mut model = model;
24
1852
    model.context = Arc::new(RwLock::new(Context::default()));
25
620
    model.search_order = search_order;
26
620
    model
27
1236
}
28
616

            
29
/// Rewrites a model using the currently configured rewriter and Minion-oriented rule sets.
30
926
pub(super) fn rewrite_model_with_configured_rewriter<'a>(
31
926
    model: Model,
32
926
    rule_sets: &Vec<&'a RuleSet<'a>>,
33
310
    configured_rewriter: Rewriter,
34
926
) -> Model {
35
310
    rewrite_model_with_configured_rewriter_core(model, rule_sets, configured_rewriter).unwrap()
36
926
}
37

            
38
/// Instantiates rewritten return expressions with quantified assignments.
39
///
40
/// This does not mutate any parent symbol table.
41
674
pub(super) fn instantiate_return_expressions_from_values(
42
674
    values: Vec<HashMap<Name, Literal>>,
43
674
    return_expression_model: &Model,
44
674
    quantified_vars: &[Name],
45
674
) -> Vec<Expression> {
46
674
    let mut return_expressions = vec![];
47

            
48
11154
    for value in values {
49
11154
        let return_expression_model = return_expression_model.clone();
50
11154
        let child_symtab = return_expression_model.symbols().clone();
51
11154
        let mut return_expression = return_expression_model.into_single_expression();
52

            
53
        // We only bind quantified variables.
54
11154
        let value: HashMap<_, _> = value
55
11154
            .into_iter()
56
29439
            .filter(|(name, _)| quantified_vars.contains(name))
57
11154
            .collect();
58

            
59
        // Bind quantified references by updating declaration targets, then simplify.
60
11154
        let _temp_value_bindings =
61
11154
            temporarily_bind_quantified_vars_to_values(&child_symtab, &return_expression, &value);
62
11154
        return_expression = concretise_resolved_reference_atoms(return_expression);
63
11154
        return_expression = simplify_expression(return_expression);
64

            
65
11154
        return_expressions.push(return_expression);
66
    }
67

            
68
674
    return_expressions
69
674
}
70

            
71
11154
pub(super) fn retain_quantified_solution_values(
72
11154
    mut values: HashMap<Name, Literal>,
73
11154
    quantified_vars: &[Name],
74
11154
) -> HashMap<Name, Literal> {
75
45507
    values.retain(|name, _| quantified_vars.contains(name));
76
11154
    values
77
11154
}
78

            
79
21342
pub(super) fn simplify_expression(mut expr: Expression) -> Expression {
80
    // Keep applying evaluators to a fixed point, or until no changes are made.
81
21342
    for _ in 0..128 {
82
875619
        let next = expr.clone().transform_bi(&|subexpr: Expression| {
83
875619
            if let Some(lit) = eval_constant(&subexpr) {
84
533865
                return Expression::Atomic(Metadata::new(), Atom::Literal(lit));
85
341754
            }
86
341754
            if let Ok(reduction) = run_partial_evaluator(&subexpr) {
87
18291
                return reduction.new_expression;
88
323463
            }
89
323463
            subexpr
90
875619
        });
91

            
92
40488
        if next == expr {
93
21342
            break;
94
19146
        }
95
19146
        expr = next;
96
    }
97

            
98
21342
    expr
99
21342
}
100

            
101
11154
fn concretise_resolved_reference_atoms(expr: Expression) -> Expression {
102
206436
    expr.transform_bi(&|atom: Atom| match atom {
103
178167
        Atom::Reference(reference) => reference
104
178167
            .resolve_constant()
105
178167
            .map_or_else(|| Atom::Reference(reference), Atom::Literal),
106
28269
        other => other,
107
206436
    })
108
11154
}
109

            
110
10008
pub(super) fn lift_machine_references_into_parent_scope(
111
10008
    expr: Expression,
112
10008
    child_symtab: &SymbolTable,
113
10008
    parent_symtab: &mut SymbolTable,
114
10008
) -> Expression {
115
10008
    let mut machine_name_translations: HashMap<ObjId, DeclarationPtr> = HashMap::new();
116

            
117
29034
    for (name, decl) in child_symtab.clone().into_iter_local() {
118
        // Do not add quantified declarations for quantified vars to the parent symbol table.
119
29034
        if matches!(
120
29034
            &decl.kind() as &DeclarationKind,
121
            DeclarationKind::Quantified(_)
122
        ) {
123
            continue;
124
29034
        }
125

            
126
29034
        if !matches!(&name, Name::Machine(_)) {
127
29034
            continue;
128
        }
129

            
130
        let id = decl.id();
131
        let new_decl = parent_symtab.gensym(&decl.domain().unwrap());
132
        machine_name_translations.insert(id, new_decl);
133
    }
134

            
135
11178
    expr.transform_bi(&|atom: Atom| {
136
11178
        if let Atom::Reference(ref decl) = atom
137
792
            && let id = decl.id()
138
792
            && let Some(new_decl) = machine_name_translations.get(&id)
139
        {
140
            Atom::Reference(Reference::new(new_decl.clone()))
141
        } else {
142
11178
            atom
143
        }
144
11178
    })
145
10008
}
146

            
147
/// Guard that temporarily converts quantified declarations to temporary value-lettings.
148
struct TempQuantifiedValueLettingGuard {
149
    originals: Vec<(DeclarationPtr, DeclarationKind)>,
150
}
151

            
152
impl Drop for TempQuantifiedValueLettingGuard {
153
11154
    fn drop(&mut self) {
154
29439
        for (mut decl, kind) in self.originals.drain(..) {
155
29439
            let _ = decl.replace_kind(kind);
156
29439
        }
157
11154
    }
158
}
159

            
160
150456
fn maybe_bind_temp_value_letting(
161
150456
    originals: &mut Vec<(DeclarationPtr, DeclarationKind)>,
162
150456
    decl: &DeclarationPtr,
163
150456
    lit: &Literal,
164
150456
) {
165
150456
    if originals
166
150456
        .iter()
167
266188
        .any(|(existing, _)| existing.id() == decl.id())
168
    {
169
121017
        return;
170
29439
    }
171

            
172
29439
    let mut decl = decl.clone();
173
29439
    let old_kind = decl.kind().clone();
174
29439
    let temp_kind = DeclarationKind::TemporaryValueLetting(Expression::Atomic(
175
29439
        Metadata::new(),
176
29439
        Atom::Literal(lit.clone()),
177
29439
    ));
178
29439
    let _ = decl.replace_kind(temp_kind);
179
29439
    originals.push((decl, old_kind));
180
150456
}
181

            
182
11154
fn temporarily_bind_quantified_vars_to_values(
183
11154
    symbols: &SymbolTable,
184
11154
    expr: &Expression,
185
11154
    values: &HashMap<Name, Literal>,
186
11154
) -> TempQuantifiedValueLettingGuard {
187
11154
    let mut originals = Vec::new();
188

            
189
29439
    for (name, lit) in values {
190
29439
        let Some(decl) = symbols.lookup_local(name) else {
191
            continue;
192
        };
193

            
194
29439
        maybe_bind_temp_value_letting(&mut originals, &decl, lit);
195

            
196
29439
        let kind = decl.kind();
197
29439
        if let DeclarationKind::Quantified(inner) = &*kind
198
            && let Some(generator) = inner.generator()
199
        {
200
            maybe_bind_temp_value_letting(&mut originals, generator, lit);
201
29439
        }
202
    }
203

            
204
    // Some expressions can still reference quantified declarations from an earlier scope
205
    // (e.g. after comprehension rewrites that rebuild generator declarations). Bind those
206
    // declaration pointers directly as well.
207
178167
    for decl in uniplate::Biplate::<DeclarationPtr>::universe_bi(expr) {
208
178167
        let name = decl.name().clone();
209
178167
        let Some(lit) = values.get(&name) else {
210
57150
            continue;
211
        };
212

            
213
121017
        maybe_bind_temp_value_letting(&mut originals, &decl, lit);
214

            
215
121017
        let kind = decl.kind();
216
121017
        if let DeclarationKind::Quantified(inner) = &*kind
217
            && let Some(generator) = inner.generator()
218
        {
219
            maybe_bind_temp_value_letting(&mut originals, generator, lit);
220
121017
        }
221
    }
222

            
223
11154
    TempQuantifiedValueLettingGuard { originals }
224
11154
}
225

            
226
/// Guard that temporarily converts quantified declarations to find declarations.
227
pub(super) struct TempQuantifiedFindGuard {
228
    originals: Vec<(DeclarationPtr, DeclarationKind)>,
229
}
230

            
231
impl Drop for TempQuantifiedFindGuard {
232
737
    fn drop(&mut self) {
233
789
        for (mut decl, kind) in self.originals.drain(..) {
234
789
            let _ = decl.replace_kind(kind);
235
789
        }
236
737
    }
237
}
238

            
239
/// Converts quantified declarations in `model` to temporary find declarations.
240
737
pub(super) fn temporarily_materialise_quantified_vars_as_finds(
241
737
    model: &Model,
242
737
    quantified_vars: &[Name],
243
737
) -> TempQuantifiedFindGuard {
244
737
    let symbols = model.symbols().clone();
245
737
    let mut originals = Vec::new();
246

            
247
789
    for name in quantified_vars {
248
789
        let Some(mut decl) = symbols.lookup_local(name) else {
249
            continue;
250
        };
251

            
252
789
        let old_kind = decl.kind().clone();
253
789
        let Some(domain) = decl.domain() else {
254
            continue;
255
        };
256

            
257
789
        let new_kind = DeclarationKind::Find(DecisionVariable::new(domain));
258
789
        let _ = decl.replace_kind(new_kind);
259
789
        originals.push((decl, old_kind));
260
    }
261

            
262
737
    TempQuantifiedFindGuard { originals }
263
737
}