1
use std::{
2
    collections::HashMap,
3
    sync::{Arc, RwLock},
4
};
5

            
6
use conjure_cp::{
7
    ast::{
8
        Atom, DecisionVariable, DeclarationKind, DeclarationPtr, Expression, Literal, Metadata,
9
        Model, Name, Reference, SymbolTable, eval_constant, run_partial_evaluator,
10
        serde::{HasId as _, ObjId},
11
    },
12
    context::Context,
13
    rule_engine::{RuleSet, rewrite_morph, rewrite_naive},
14
    settings::Rewriter,
15
};
16
use uniplate::Biplate as _;
17

            
18
/// Configures a temporary model for solver-based comprehension expansion.
19
1240
pub(super) fn with_temporary_model(model: Model, search_order: Option<Vec<Name>>) -> Model {
20
1240
    let mut model = model;
21
1240
    model.context = Arc::new(RwLock::new(Context::default()));
22
1240
    model.search_order = search_order;
23
1240
    model
24
1240
}
25

            
26
/// Rewrites a model using the currently configured rewriter and Minion-oriented rule sets.
27
620
pub(super) fn rewrite_model_with_configured_rewriter<'a>(
28
620
    model: Model,
29
620
    rule_sets: &Vec<&'a RuleSet<'a>>,
30
620
    configured_rewriter: Rewriter,
31
620
) -> Model {
32
620
    match configured_rewriter {
33
        Rewriter::Morph => rewrite_morph(model, rule_sets, false),
34
620
        Rewriter::Naive => rewrite_naive(&model, rule_sets, false).unwrap(),
35
    }
36
620
}
37

            
38
/// Instantiates rewritten return expressions with quantified assignments.
39
///
40
/// This does not mutate any parent symbol table.
41
452
pub(super) fn instantiate_return_expressions_from_values(
42
452
    values: Vec<HashMap<Name, Literal>>,
43
452
    return_expression_model: &Model,
44
452
    quantified_vars: &[Name],
45
452
) -> Vec<Expression> {
46
452
    let mut return_expressions = vec![];
47

            
48
7440
    for value in values {
49
7440
        let return_expression_model = return_expression_model.clone();
50
7440
        let child_symtab = return_expression_model.symbols().clone();
51
7440
        let mut return_expression = return_expression_model.into_single_expression();
52

            
53
        // We only bind quantified variables.
54
7440
        let value: HashMap<_, _> = value
55
7440
            .into_iter()
56
19638
            .filter(|(name, _)| quantified_vars.contains(name))
57
7440
            .collect();
58

            
59
        // Bind quantified references by updating declaration targets, then simplify.
60
7440
        let _temp_value_bindings =
61
7440
            temporarily_bind_quantified_vars_to_values(&child_symtab, &return_expression, &value);
62
7440
        return_expression = concretise_resolved_reference_atoms(return_expression);
63
7440
        return_expression = simplify_expression(return_expression);
64

            
65
7440
        return_expressions.push(return_expression);
66
    }
67

            
68
452
    return_expressions
69
452
}
70

            
71
7440
pub(super) fn retain_quantified_solution_values(
72
7440
    mut values: HashMap<Name, Literal>,
73
7440
    quantified_vars: &[Name],
74
7440
) -> HashMap<Name, Literal> {
75
30426
    values.retain(|name, _| quantified_vars.contains(name));
76
7440
    values
77
7440
}
78

            
79
14232
pub(super) fn simplify_expression(mut expr: Expression) -> Expression {
80
    // Keep applying evaluators to a fixed point, or until no changes are made.
81
14232
    for _ in 0..128 {
82
584022
        let next = expr.clone().transform_bi(&|subexpr: Expression| {
83
584022
            if let Some(lit) = eval_constant(&subexpr) {
84
356046
                return Expression::Atomic(Metadata::new(), Atom::Literal(lit));
85
227976
            }
86
227976
            if let Ok(reduction) = run_partial_evaluator(&subexpr) {
87
12198
                return reduction.new_expression;
88
215778
            }
89
215778
            subexpr
90
584022
        });
91

            
92
27000
        if next == expr {
93
14232
            break;
94
12768
        }
95
12768
        expr = next;
96
    }
97

            
98
14232
    expr
99
14232
}
100

            
101
7440
fn concretise_resolved_reference_atoms(expr: Expression) -> Expression {
102
137712
    expr.transform_bi(&|atom: Atom| match atom {
103
118854
        Atom::Reference(reference) => reference
104
118854
            .resolve_constant()
105
118854
            .map_or_else(|| Atom::Reference(reference), Atom::Literal),
106
18858
        other => other,
107
137712
    })
108
7440
}
109

            
110
6672
pub(super) fn lift_machine_references_into_parent_scope(
111
6672
    expr: Expression,
112
6672
    child_symtab: &SymbolTable,
113
6672
    parent_symtab: &mut SymbolTable,
114
6672
) -> Expression {
115
6672
    let mut machine_name_translations: HashMap<ObjId, DeclarationPtr> = HashMap::new();
116

            
117
19356
    for (name, decl) in child_symtab.clone().into_iter_local() {
118
        // Do not add quantified declarations for quantified vars to the parent symbol table.
119
19356
        if matches!(
120
19356
            &decl.kind() as &DeclarationKind,
121
            DeclarationKind::Quantified(_)
122
        ) {
123
            continue;
124
19356
        }
125

            
126
19356
        if !matches!(&name, Name::Machine(_)) {
127
19356
            continue;
128
        }
129

            
130
        let id = decl.id();
131
        let new_decl = parent_symtab.gensym(&decl.domain().unwrap());
132
        machine_name_translations.insert(id, new_decl);
133
    }
134

            
135
7452
    expr.transform_bi(&|atom: Atom| {
136
7452
        if let Atom::Reference(ref decl) = atom
137
528
            && let id = decl.id()
138
528
            && let Some(new_decl) = machine_name_translations.get(&id)
139
        {
140
            Atom::Reference(Reference::new(new_decl.clone()))
141
        } else {
142
7452
            atom
143
        }
144
7452
    })
145
6672
}
146

            
147
/// Guard that temporarily converts quantified declarations to temporary value-lettings.
148
struct TempQuantifiedValueLettingGuard {
149
    originals: Vec<(DeclarationPtr, DeclarationKind)>,
150
}
151

            
152
impl Drop for TempQuantifiedValueLettingGuard {
153
7440
    fn drop(&mut self) {
154
19638
        for (mut decl, kind) in self.originals.drain(..) {
155
19638
            let _ = decl.replace_kind(kind);
156
19638
        }
157
7440
    }
158
}
159

            
160
100368
fn maybe_bind_temp_value_letting(
161
100368
    originals: &mut Vec<(DeclarationPtr, DeclarationKind)>,
162
100368
    decl: &DeclarationPtr,
163
100368
    lit: &Literal,
164
100368
) {
165
100368
    if originals
166
100368
        .iter()
167
177744
        .any(|(existing, _)| existing.id() == decl.id())
168
    {
169
80730
        return;
170
19638
    }
171

            
172
19638
    let mut decl = decl.clone();
173
19638
    let old_kind = decl.kind().clone();
174
19638
    let temp_kind = DeclarationKind::TemporaryValueLetting(Expression::Atomic(
175
19638
        Metadata::new(),
176
19638
        Atom::Literal(lit.clone()),
177
19638
    ));
178
19638
    let _ = decl.replace_kind(temp_kind);
179
19638
    originals.push((decl, old_kind));
180
100368
}
181

            
182
7440
fn temporarily_bind_quantified_vars_to_values(
183
7440
    symbols: &SymbolTable,
184
7440
    expr: &Expression,
185
7440
    values: &HashMap<Name, Literal>,
186
7440
) -> TempQuantifiedValueLettingGuard {
187
7440
    let mut originals = Vec::new();
188

            
189
19638
    for (name, lit) in values {
190
19638
        let Some(decl) = symbols.lookup_local(name) else {
191
            continue;
192
        };
193

            
194
19638
        maybe_bind_temp_value_letting(&mut originals, &decl, lit);
195

            
196
19638
        let kind = decl.kind();
197
19638
        if let DeclarationKind::Quantified(inner) = &*kind
198
            && let Some(generator) = inner.generator()
199
        {
200
            maybe_bind_temp_value_letting(&mut originals, generator, lit);
201
19638
        }
202
    }
203

            
204
    // Some expressions can still reference quantified declarations from an earlier scope
205
    // (e.g. after comprehension rewrites that rebuild generator declarations). Bind those
206
    // declaration pointers directly as well.
207
118854
    for decl in uniplate::Biplate::<DeclarationPtr>::universe_bi(expr) {
208
118854
        let name = decl.name().clone();
209
118854
        let Some(lit) = values.get(&name) else {
210
38124
            continue;
211
        };
212

            
213
80730
        maybe_bind_temp_value_letting(&mut originals, &decl, lit);
214

            
215
80730
        let kind = decl.kind();
216
80730
        if let DeclarationKind::Quantified(inner) = &*kind
217
            && let Some(generator) = inner.generator()
218
        {
219
            maybe_bind_temp_value_letting(&mut originals, generator, lit);
220
80730
        }
221
    }
222

            
223
7440
    TempQuantifiedValueLettingGuard { originals }
224
7440
}
225

            
226
/// Guard that temporarily converts quantified declarations to find declarations.
227
pub(super) struct TempQuantifiedFindGuard {
228
    originals: Vec<(DeclarationPtr, DeclarationKind)>,
229
}
230

            
231
impl Drop for TempQuantifiedFindGuard {
232
494
    fn drop(&mut self) {
233
534
        for (mut decl, kind) in self.originals.drain(..) {
234
534
            let _ = decl.replace_kind(kind);
235
534
        }
236
494
    }
237
}
238

            
239
/// Converts quantified declarations in `model` to temporary find declarations.
240
494
pub(super) fn temporarily_materialise_quantified_vars_as_finds(
241
494
    model: &Model,
242
494
    quantified_vars: &[Name],
243
494
) -> TempQuantifiedFindGuard {
244
494
    let symbols = model.symbols().clone();
245
494
    let mut originals = Vec::new();
246

            
247
534
    for name in quantified_vars {
248
534
        let Some(mut decl) = symbols.lookup_local(name) else {
249
            continue;
250
        };
251

            
252
534
        let old_kind = decl.kind().clone();
253
534
        let Some(domain) = decl.domain() else {
254
            continue;
255
        };
256

            
257
534
        let new_kind = DeclarationKind::Find(DecisionVariable::new(domain));
258
534
        let _ = decl.replace_kind(new_kind);
259
534
        originals.push((decl, old_kind));
260
    }
261

            
262
494
    TempQuantifiedFindGuard { originals }
263
494
}