1
use std::{
2
    collections::HashMap,
3
    sync::{Arc, Mutex, RwLock, atomic::Ordering},
4
};
5

            
6
use conjure_cp::{
7
    ast::{
8
        Atom, DeclarationKind, DeclarationPtr, Expression, Model, Name, SymbolTable,
9
        comprehension::{Comprehension, USE_OPTIMISED_REWRITER_FOR_COMPREHENSIONS},
10
        serde::{HasId as _, ObjId},
11
    },
12
    bug,
13
    context::Context,
14
    rule_engine::{resolve_rule_sets, rewrite_morph, rewrite_naive},
15
    solver::{Solver, SolverError, SolverFamily, adaptors::Minion},
16
};
17
use uniplate::Biplate as _;
18

            
19
/// Expands the comprehension using Minion, returning the resulting expressions.
20
///
21
/// This method performs simple pruning of the induction variables: an expression is returned
22
/// for each assignment to the induction variables that satisfy the static guards of the
23
/// comprehension. If the comprehension is inside an associative-commutative operation, use
24
/// [`expand_ac`] instead, as this performs further pruning of "uninteresting" return values.
25
///
26
/// If successful, this modifies the symbol table given to add aux-variables needed inside the
27
/// expanded expressions.
28
pub fn expand_simple(
29
    comprehension: Comprehension,
30
    symtab: &mut SymbolTable,
31
) -> Result<Vec<Expression>, SolverError> {
32
    let minion = Solver::new(Minion::new());
33
    // FIXME: weave proper context through
34
    let mut model = Model::new(Arc::new(RwLock::new(Context::default())));
35

            
36
    // only branch on the induction variables.
37
    model.search_order = Some(comprehension.induction_vars.clone());
38
    *model.as_submodel_mut() = comprehension.generator_submodel.clone();
39

            
40
    // TODO:  if expand_ac is enabled, add Better_AC_Comprehension_Expansion here.
41

            
42
    // call rewrite here as well as in expand_ac, just to be consistent
43
    let extra_rule_sets = &["Base", "Constant", "Bubble"];
44

            
45
    let rule_sets = resolve_rule_sets(SolverFamily::Minion, extra_rule_sets).unwrap();
46

            
47
    let model = if USE_OPTIMISED_REWRITER_FOR_COMPREHENSIONS.load(Ordering::Relaxed) {
48
        rewrite_morph(model, &rule_sets, false)
49
    } else {
50
        rewrite_naive(&model, &rule_sets, false, false).unwrap()
51
    };
52

            
53
    // Call the rewriter to rewrite inside the comprehension
54
    //
55
    // The original idea was to let the top level rewriter rewrite the return expression model
56
    // and the generator model. The comprehension wouldn't be expanded until the generator
57
    // model is in valid minion that can be ran, at which point the return expression model
58
    // should also be in valid minion.
59
    //
60
    // By calling the rewriter inside the rule, we no longer wait for the generator model to be
61
    // valid Minion, so we don't get the simplified return model either...
62
    //
63
    // We need to do this as we want to modify the generator model (add the dummy Z's) then
64
    // solve and return in one go.
65
    //
66
    // Comprehensions need a big rewrite soon, as theres lots of sharp edges such as this in
67
    // my original implementation, and I don't think we can fit our new optimisation into it.
68
    // If we wanted to avoid calling the rewriter, we would need to run the first half the rule
69
    // up to adding the return expr to the generator model, yield, then come back later to
70
    // actually solve it?
71

            
72
    let return_expression_submodel = comprehension.return_expression_submodel.clone();
73
    let mut return_expression_model = Model::new(Arc::new(RwLock::new(Context::default())));
74
    *return_expression_model.as_submodel_mut() = return_expression_submodel;
75

            
76
    let return_expression_model =
77
        if USE_OPTIMISED_REWRITER_FOR_COMPREHENSIONS.load(Ordering::Relaxed) {
78
            rewrite_morph(return_expression_model, &rule_sets, false)
79
        } else {
80
            rewrite_naive(&return_expression_model, &rule_sets, false, false).unwrap()
81
        };
82

            
83
    let minion = minion.load_model(model.clone())?;
84

            
85
    let values = Arc::new(Mutex::new(Vec::new()));
86
    let values_ptr = Arc::clone(&values);
87

            
88
    tracing::debug!(model=%model,comprehension=%comprehension,"Minion solving comprehension (simple mode)");
89
    minion.solve(Box::new(move |sols| {
90
        // TODO: deal with represented names if induction variables are abslits.
91
        let values = &mut *values_ptr.lock().unwrap();
92
        values.push(sols);
93
        true
94
    }))?;
95

            
96
    let values = values.lock().unwrap().clone();
97

            
98
    let mut return_expressions = vec![];
99

            
100
    for value in values {
101
        // convert back to an expression
102

            
103
        let return_expression_submodel = return_expression_model.as_submodel().clone();
104
        let child_symtab = return_expression_submodel.symbols().clone();
105
        let return_expression = return_expression_submodel.into_single_expression();
106

            
107
        // we only want to substitute induction variables.
108
        // (definitely not machine names, as they mean something different in this scope!)
109
        let value: HashMap<_, _> = value
110
            .into_iter()
111
            .filter(|(n, _)| comprehension.induction_vars.contains(n))
112
            .collect();
113

            
114
        let value_ptr = Arc::new(value);
115
        let value_ptr_2 = Arc::clone(&value_ptr);
116

            
117
        // substitute in the values for the induction variables
118
        let return_expression = return_expression.transform_bi(&move |x: Atom| {
119
            let Atom::Reference(ref ptr) = x else {
120
                return x;
121
            };
122

            
123
            // is this referencing an induction var?
124
            let Some(lit) = value_ptr_2.get(&ptr.name()) else {
125
                return x;
126
            };
127

            
128
            Atom::Literal(lit.clone())
129
        });
130

            
131
        // Copy the return expression's symbols into parent scope.
132

            
133
        // For variables in the return expression with machine names, create new declarations
134
        // for them in the parent symbol table, so that the machine names used are unique.
135
        //
136
        // Store the declaration translations in `machine_name_translations`.
137
        // These are stored as a map of (old declaration id) -> (new declaration ptr), as
138
        // declaration pointers do not implement hash.
139
        //
140
        let mut machine_name_translations: HashMap<ObjId, DeclarationPtr> = HashMap::new();
141

            
142
        // Populate `machine_name_translations`
143
        for (name, decl) in child_symtab.into_iter_local() {
144
            // do not add givens for induction vars to the parent symbol table.
145
            if value_ptr.get(&name).is_some()
146
                && matches!(&decl.kind() as &DeclarationKind, DeclarationKind::Given(_))
147
            {
148
                continue;
149
            }
150

            
151
            let Name::Machine(_) = &name else {
152
                bug!(
153
                    "the symbol table of the return expression of a comprehension should only contain machine names"
154
                );
155
            };
156

            
157
            let id = decl.id();
158
            let new_decl = symtab.gensym(&decl.domain().unwrap());
159

            
160
            machine_name_translations.insert(id, new_decl);
161
        }
162

            
163
        // Update references to use the new delcarations.
164
        #[allow(clippy::arc_with_non_send_sync)]
165
        let return_expression = return_expression.transform_bi(&move |atom: Atom| {
166
            if let Atom::Reference(ref decl) = atom
167
                && let id = decl.id()
168
                && let Some(new_decl) = machine_name_translations.get(&id)
169
            {
170
                Atom::Reference(conjure_cp::ast::Reference::new(new_decl.clone()))
171
            } else {
172
                atom
173
            }
174
        });
175

            
176
        return_expressions.push(return_expression);
177
    }
178

            
179
    Ok(return_expressions)
180
}