1
use std::{
2
    collections::HashMap,
3
    sync::{Arc, Mutex, RwLock, atomic::Ordering},
4
};
5

            
6
use conjure_cp::{
7
    ast::{
8
        Atom, DecisionVariable, DeclarationKind, DeclarationPtr, Expression, Model, Name, SubModel,
9
        SymbolTable,
10
        comprehension::{Comprehension, USE_OPTIMISED_REWRITER_FOR_COMPREHENSIONS},
11
        serde::{HasId as _, ObjId},
12
    },
13
    bug,
14
    context::Context,
15
    rule_engine::{resolve_rule_sets, rewrite_morph, rewrite_naive},
16
    settings::SolverFamily,
17
    solver::{Solver, SolverError, adaptors::Minion},
18
};
19
use uniplate::Biplate as _;
20

            
21
/// Expands the comprehension by solving quantified variables with Minion.
22
///
23
/// This returns one expression per assignment to quantified variables that satisfies the static
24
/// guards of the comprehension.
25
///
26
/// If successful, this modifies the symbol table given to add aux-variables needed inside the
27
/// expanded expressions.
28
12
pub fn expand_via_solver(
29
12
    comprehension: Comprehension,
30
12
    symtab: &mut SymbolTable,
31
12
) -> Result<Vec<Expression>, SolverError> {
32
12
    let minion = Solver::new(Minion::new());
33
    // FIXME: weave proper context through
34
12
    let mut model = Model::new(Arc::new(RwLock::new(Context::default())));
35

            
36
    // only branch on the quantified variables.
37
12
    model.search_order = Some(comprehension.quantified_vars.clone());
38
12
    *model.as_submodel_mut() = comprehension.generator_submodel.clone();
39

            
40
    // call rewrite here as well as in expand_via_solver_ac, just to be consistent
41
12
    let extra_rule_sets = &["Base", "Constant", "Bubble"];
42

            
43
12
    let rule_sets = resolve_rule_sets(SolverFamily::Minion, extra_rule_sets).unwrap();
44

            
45
12
    let model = if USE_OPTIMISED_REWRITER_FOR_COMPREHENSIONS.load(Ordering::Relaxed) {
46
        rewrite_morph(model, &rule_sets, false)
47
    } else {
48
12
        rewrite_naive(&model, &rule_sets, false, false).unwrap()
49
    };
50

            
51
    // Call the rewriter to rewrite inside the comprehension
52
    //
53
    // The original idea was to let the top level rewriter rewrite the return expression model
54
    // and the generator model. The comprehension wouldn't be expanded until the generator
55
    // model is in valid minion that can be ran, at which point the return expression model
56
    // should also be in valid minion.
57
    //
58
    // By calling the rewriter inside the rule, we no longer wait for the generator model to be
59
    // valid Minion, so we don't get the simplified return model either...
60
    //
61
    // We need to do this as we want to modify the generator model (add the dummy Z's) then
62
    // solve and return in one go.
63
    //
64
    // Comprehensions need a big rewrite soon, as theres lots of sharp edges such as this in
65
    // my original implementation, and I don't think we can fit our new optimisation into it.
66
    // If we wanted to avoid calling the rewriter, we would need to run the first half the rule
67
    // up to adding the return expr to the generator model, yield, then come back later to
68
    // actually solve it?
69

            
70
12
    let return_expression_submodel = comprehension.return_expression_submodel.clone();
71
12
    let mut return_expression_model = Model::new(Arc::new(RwLock::new(Context::default())));
72
12
    *return_expression_model.as_submodel_mut() = return_expression_submodel;
73

            
74
12
    let return_expression_model =
75
12
        if USE_OPTIMISED_REWRITER_FOR_COMPREHENSIONS.load(Ordering::Relaxed) {
76
            rewrite_morph(return_expression_model, &rule_sets, false)
77
        } else {
78
12
            rewrite_naive(&return_expression_model, &rule_sets, false, false).unwrap()
79
        };
80

            
81
12
    let solver_model = model.clone();
82

            
83
    // Minion expects quantified variables in the temporary generator model as find declarations.
84
    // Keep this conversion local to the model passed into Minion.
85
12
    let _temp_finds = temporarily_materialise_quantified_vars_as_finds(
86
12
        solver_model.as_submodel(),
87
12
        &comprehension.quantified_vars,
88
    );
89

            
90
12
    let minion = minion.load_model(solver_model)?;
91

            
92
12
    let values = Arc::new(Mutex::new(Vec::new()));
93
12
    let values_ptr = Arc::clone(&values);
94

            
95
    tracing::debug!(model=%model,comprehension=%comprehension,"Minion solving comprehension (solver mode)");
96
57
    minion.solve(Box::new(move |sols| {
97
        // TODO: deal with represented names if quantified variables are abslits.
98
57
        let values = &mut *values_ptr.lock().unwrap();
99
57
        values.push(sols);
100
57
        true
101
57
    }))?;
102

            
103
12
    let values = values.lock().unwrap().clone();
104

            
105
12
    let mut return_expressions = vec![];
106

            
107
57
    for value in values {
108
        // convert back to an expression
109

            
110
57
        let return_expression_submodel = return_expression_model.as_submodel().clone();
111
57
        let child_symtab = return_expression_submodel.symbols().clone();
112
57
        let return_expression = return_expression_submodel.into_single_expression();
113

            
114
        // we only want to substitute quantified variables.
115
        // (definitely not machine names, as they mean something different in this scope!)
116
57
        let value: HashMap<_, _> = value
117
57
            .into_iter()
118
93
            .filter(|(n, _)| comprehension.quantified_vars.contains(n))
119
57
            .collect();
120

            
121
57
        let value_ptr = Arc::new(value);
122
57
        let value_ptr_2 = Arc::clone(&value_ptr);
123

            
124
        // substitute in the values for the quantified variables
125
159
        let return_expression = return_expression.transform_bi(&move |x: Atom| {
126
159
            let Atom::Reference(ref ptr) = x else {
127
18
                return x;
128
            };
129

            
130
            // is this referencing a quantified var?
131
141
            let Some(lit) = value_ptr_2.get(&ptr.name()) else {
132
42
                return x;
133
            };
134

            
135
99
            Atom::Literal(lit.clone())
136
159
        });
137

            
138
        // Copy the return expression's symbols into parent scope.
139

            
140
        // For variables in the return expression with machine names, create new declarations
141
        // for them in the parent symbol table, so that the machine names used are unique.
142
        //
143
        // Store the declaration translations in `machine_name_translations`.
144
        // These are stored as a map of (old declaration id) -> (new declaration ptr), as
145
        // declaration pointers do not implement hash.
146
        //
147
57
        let mut machine_name_translations: HashMap<ObjId, DeclarationPtr> = HashMap::new();
148

            
149
        // Populate `machine_name_translations`
150
57
        for (name, decl) in child_symtab.into_iter_local() {
151
            // do not add quantified declarations for quantified vars to the parent symbol table.
152
57
            if value_ptr.get(&name).is_some()
153
                && matches!(
154
57
                    &decl.kind() as &DeclarationKind,
155
                    DeclarationKind::Given(_) | DeclarationKind::Quantified(_)
156
                )
157
            {
158
57
                continue;
159
            }
160

            
161
            let Name::Machine(_) = &name else {
162
                bug!(
163
                    "the symbol table of the return expression of a comprehension should only contain machine names"
164
                );
165
            };
166

            
167
            let id = decl.id();
168
            let new_decl = symtab.gensym(&decl.domain().unwrap());
169

            
170
            machine_name_translations.insert(id, new_decl);
171
        }
172

            
173
        // Update references to use the new delcarations.
174
        #[allow(clippy::arc_with_non_send_sync)]
175
159
        let return_expression = return_expression.transform_bi(&move |atom: Atom| {
176
159
            if let Atom::Reference(ref decl) = atom
177
42
                && let id = decl.id()
178
42
                && let Some(new_decl) = machine_name_translations.get(&id)
179
            {
180
                Atom::Reference(conjure_cp::ast::Reference::new(new_decl.clone()))
181
            } else {
182
159
                atom
183
            }
184
159
        });
185

            
186
57
        return_expressions.push(return_expression);
187
    }
188

            
189
12
    Ok(return_expressions)
190
12
}
191

            
192
/// Guard that temporarily converts quantified declarations to find declarations.
193
struct TempQuantifiedFindGuard {
194
    originals: Vec<(DeclarationPtr, DeclarationKind)>,
195
}
196

            
197
impl Drop for TempQuantifiedFindGuard {
198
12
    fn drop(&mut self) {
199
12
        for (mut decl, kind) in self.originals.drain(..) {
200
12
            let _ = decl.replace_kind(kind);
201
12
        }
202
12
    }
203
}
204

            
205
/// Converts quantified declarations in `submodel` to temporary find declarations.
206
12
fn temporarily_materialise_quantified_vars_as_finds(
207
12
    submodel: &SubModel,
208
12
    quantified_vars: &[Name],
209
12
) -> TempQuantifiedFindGuard {
210
12
    let symbols = submodel.symbols().clone();
211
12
    let mut originals = Vec::new();
212

            
213
12
    for name in quantified_vars {
214
12
        let Some(mut decl) = symbols.lookup_local(name) else {
215
            continue;
216
        };
217

            
218
12
        let old_kind = decl.kind().clone();
219
12
        let Some(domain) = decl.domain() else {
220
            continue;
221
        };
222

            
223
12
        let new_kind = DeclarationKind::Find(DecisionVariable::new(domain));
224
12
        let _ = decl.replace_kind(new_kind);
225
12
        originals.push((decl, old_kind));
226
    }
227

            
228
12
    TempQuantifiedFindGuard { originals }
229
12
}