1
use std::collections::HashMap;
2

            
3
use conjure_cp::{
4
    ast::{
5
        Atom, DeclarationKind, DeclarationPtr, Expression, Literal, Metadata, Name, SymbolTable,
6
        comprehension::Comprehension,
7
        eval_constant, run_partial_evaluator,
8
        serde::{HasId as _, ObjId},
9
    },
10
    bug,
11
    solver::SolverError,
12
};
13
use uniplate::Biplate as _;
14

            
15
/// Expands the comprehension without calling an external solver.
16
///
17
/// Quantified variables are enumerated with native Rust loops over their finite domains. Guards
18
/// are evaluated using constant and partial evaluators after substitution.
19
pub fn expand_native(
20
    comprehension: Comprehension,
21
    symtab: &mut SymbolTable,
22
) -> Result<Vec<Expression>, SolverError> {
23
    let generator_symbols = comprehension.generator_submodel.symbols().clone();
24
    let quantified_vars = comprehension.quantified_vars.clone();
25

            
26
    let mut quantified_domains = Vec::with_capacity(quantified_vars.len());
27
    for name in &quantified_vars {
28
        let decl = generator_symbols.lookup_local(name).ok_or_else(|| {
29
            SolverError::ModelInvalid(format!(
30
                "quantified variable '{name}' is missing from generator symbol table"
31
            ))
32
        })?;
33

            
34
        let domain = decl.domain().ok_or_else(|| {
35
            SolverError::ModelInvalid(format!("quantified variable '{name}' has no domain"))
36
        })?;
37
        let resolved = domain.resolve().ok_or_else(|| {
38
            SolverError::ModelFeatureNotSupported(format!(
39
                "quantified variable '{name}' has unresolved domain: {domain}"
40
            ))
41
        })?;
42

            
43
        let values: Vec<Literal> = resolved
44
            .values()
45
            .map_err(|err| {
46
                SolverError::ModelFeatureNotSupported(format!(
47
                    "quantified variable '{name}' has non-enumerable domain: {err}"
48
                ))
49
            })?
50
            .collect();
51

            
52
        quantified_domains.push(values);
53
    }
54

            
55
    let mut assignments = HashMap::new();
56
    let mut expanded = Vec::new();
57

            
58
    enumerate_assignments(
59
        0,
60
        &quantified_vars,
61
        &quantified_domains,
62
        &mut assignments,
63
        &mut |assignment| {
64
            for guard in comprehension.generator_submodel.constraints() {
65
                match evaluate_guard(guard, assignment) {
66
                    Some(true) => {}
67
                    Some(false) => return Ok(()),
68
                    None => {
69
                        return Err(SolverError::ModelInvalid(format!(
70
                            "native comprehension expansion could not evaluate guard: {guard}"
71
                        )));
72
                    }
73
                }
74
            }
75

            
76
            let return_expression_submodel = comprehension.return_expression_submodel.clone();
77
            let child_symtab = return_expression_submodel.symbols().clone();
78
            let return_expression = return_expression_submodel.into_single_expression();
79

            
80
            let return_expression = substitute_quantified_literals(return_expression, assignment);
81
            let return_expression = simplify_expression(return_expression);
82

            
83
            // Copy machine-name declarations from comprehension-local return-expression scope.
84
            let mut machine_name_translations: HashMap<ObjId, DeclarationPtr> = HashMap::new();
85
            for (name, decl) in child_symtab.into_iter_local() {
86
                if assignment.get(&name).is_some()
87
                    && matches!(
88
                        &decl.kind() as &DeclarationKind,
89
                        DeclarationKind::Given(_) | DeclarationKind::Quantified(_)
90
                    )
91
                {
92
                    continue;
93
                }
94

            
95
                let Name::Machine(_) = &name else {
96
                    bug!(
97
                        "the symbol table of the return expression of a comprehension should only contain machine names"
98
                    );
99
                };
100

            
101
                let id = decl.id();
102
                let new_decl = symtab.gensym(&decl.domain().unwrap());
103
                machine_name_translations.insert(id, new_decl);
104
            }
105

            
106
            #[allow(clippy::arc_with_non_send_sync)]
107
            let return_expression = return_expression.transform_bi(&move |atom: Atom| {
108
                if let Atom::Reference(ref decl) = atom
109
                    && let id = decl.id()
110
                    && let Some(new_decl) = machine_name_translations.get(&id)
111
                {
112
                    Atom::Reference(conjure_cp::ast::Reference::new(new_decl.clone()))
113
                } else {
114
                    atom
115
                }
116
            });
117

            
118
            expanded.push(return_expression);
119
            Ok(())
120
        },
121
    )?;
122

            
123
    Ok(expanded)
124
}
125

            
126
fn enumerate_assignments(
127
    index: usize,
128
    quantified_vars: &[Name],
129
    quantified_domains: &[Vec<Literal>],
130
    assignment: &mut HashMap<Name, Literal>,
131
    on_assignment: &mut impl FnMut(&HashMap<Name, Literal>) -> Result<(), SolverError>,
132
) -> Result<(), SolverError> {
133
    if index == quantified_vars.len() {
134
        return on_assignment(assignment);
135
    }
136

            
137
    let name = &quantified_vars[index];
138
    for lit in &quantified_domains[index] {
139
        assignment.insert(name.clone(), lit.clone());
140
        enumerate_assignments(
141
            index + 1,
142
            quantified_vars,
143
            quantified_domains,
144
            assignment,
145
            on_assignment,
146
        )?;
147
    }
148
    assignment.remove(name);
149
    Ok(())
150
}
151

            
152
fn evaluate_guard(guard: &Expression, assignment: &HashMap<Name, Literal>) -> Option<bool> {
153
    let substituted = substitute_quantified_literals(guard.clone(), assignment);
154
    let simplified = simplify_expression(substituted);
155
    match eval_constant(&simplified)? {
156
        Literal::Bool(value) => Some(value),
157
        _ => None,
158
    }
159
}
160

            
161
fn substitute_quantified_literals(
162
    expr: Expression,
163
    assignment: &HashMap<Name, Literal>,
164
) -> Expression {
165
    expr.transform_bi(&|atom: Atom| {
166
        let Atom::Reference(ref decl) = atom else {
167
            return atom;
168
        };
169

            
170
        let Some(lit) = assignment.get(&decl.name()) else {
171
            return atom;
172
        };
173

            
174
        Atom::Literal(lit.clone())
175
    })
176
}
177

            
178
fn simplify_expression(mut expr: Expression) -> Expression {
179
    // Keep applying evaluators to a fixed point, or until no changes are made.
180
    for _ in 0..128 {
181
        let next = expr.clone().transform_bi(&|subexpr: Expression| {
182
            if let Some(lit) = eval_constant(&subexpr) {
183
                return Expression::Atomic(Metadata::new(), Atom::Literal(lit));
184
            }
185
            if let Ok(reduction) = run_partial_evaluator(&subexpr) {
186
                return reduction.new_expression;
187
            }
188
            subexpr
189
        });
190

            
191
        if next == expr {
192
            break;
193
        }
194
        expr = next;
195
    }
196
    expr
197
}