1
use conjure_cp::{
2
    ast::{
3
        Atom, DeclarationKind, DeclarationPtr, DomainPtr, Expression, Literal, Metadata, Name,
4
        SymbolTable,
5
        comprehension::{Comprehension, ComprehensionQualifier},
6
        eval_constant,
7
    },
8
    solver::SolverError,
9
};
10
use uniplate::Biplate as _;
11

            
12
use super::via_solver_common::{lift_machine_references_into_parent_scope, simplify_expression};
13

            
14
/// Expands the comprehension without calling an external solver.
15
///
16
/// Algorithm:
17
/// 1. Recurse qualifiers left-to-right.
18
/// 2. For each generator value, temporarily bind the quantified declaration to a
19
///    `TemporaryValueLetting` and recurse.
20
/// 3. For each condition, evaluate and recurse only if true.
21
/// 4. At the leaf, evaluate the return expression under the active bindings.
22
162
pub fn expand_native(
23
162
    comprehension: Comprehension,
24
162
    parent_symbols: &mut SymbolTable,
25
162
) -> Result<Vec<Expression>, SolverError> {
26
162
    let mut expanded = Vec::new();
27
162
    expand_qualifiers(&comprehension, 0, &mut expanded, parent_symbols)?;
28
162
    Ok(expanded)
29
162
}
30

            
31
11934
fn expand_qualifiers(
32
11934
    comprehension: &Comprehension,
33
11934
    qualifier_index: usize,
34
11934
    expanded: &mut Vec<Expression>,
35
11934
    parent_symbols: &mut SymbolTable,
36
11934
) -> Result<(), SolverError> {
37
11934
    if qualifier_index == comprehension.qualifiers.len() {
38
10008
        let child_symbols = comprehension.symbols().clone();
39
10008
        let return_expression =
40
10008
            concretise_resolved_reference_atoms(comprehension.return_expression.clone());
41
10008
        let return_expression = simplify_expression(return_expression);
42
10008
        let return_expression = lift_machine_references_into_parent_scope(
43
10008
            return_expression,
44
10008
            &child_symbols,
45
10008
            parent_symbols,
46
        );
47
10008
        expanded.push(return_expression);
48
10008
        return Ok(());
49
1926
    }
50

            
51
1926
    match &comprehension.qualifiers[qualifier_index] {
52
1746
        ComprehensionQualifier::Generator { name, domain } => {
53
1746
            let values = resolve_generator_values(name, domain)?;
54
1746
            let quantified_declaration = lookup_quantified_declaration(comprehension, name)?;
55
1164

            
56
3891
            for literal in values {
57
11673
                with_temporary_quantified_binding(&quantified_declaration, &literal, || {
58
11673
                    expand_qualifiers(comprehension, qualifier_index + 1, expanded, parent_symbols)
59
11673
                })?;
60
7782
            }
61
        }
62
60
        ComprehensionQualifier::Condition(condition) => {
63
180
            if evaluate_guard(condition)? {
64
153
                expand_qualifiers(comprehension, qualifier_index + 1, expanded, parent_symbols)?;
65
93
            }
66
54
        }
67
    }
68

            
69
642
    Ok(())
70
5262
}
71
7956

            
72
582
fn resolve_generator_values(name: &Name, domain: &DomainPtr) -> Result<Vec<Literal>, SolverError> {
73
1746
    let resolved = domain.resolve().ok_or_else(|| {
74
1164
        SolverError::ModelFeatureNotSupported(format!(
75
            "quantified variable '{name}' has unresolved domain after assigning previous generators: {domain}"
76
        ))
77
    })?;
78

            
79
582
    resolved.values().map(|iter| iter.collect()).map_err(|err| {
80
1164
        SolverError::ModelFeatureNotSupported(format!(
81
            "quantified variable '{name}' has non-enumerable domain: {err}"
82
        ))
83
    })
84
582
}
85
1164

            
86
582
fn lookup_quantified_declaration(
87
8364
    comprehension: &Comprehension,
88
8364
    name: &Name,
89
8364
) -> Result<DeclarationPtr, SolverError> {
90
8364
    comprehension.symbols().lookup_local(name).ok_or_else(|| {
91
7782
        SolverError::ModelInvalid(format!(
92
7782
            "quantified variable '{name}' is missing from local comprehension symbol table"
93
7782
        ))
94
7782
    })
95
582
}
96

            
97
11673
fn with_temporary_quantified_binding<T>(
98
3891
    quantified: &DeclarationPtr,
99
11673
    value: &Literal,
100
11673
    f: impl FnOnce() -> Result<T, SolverError>,
101
11673
) -> Result<T, SolverError> {
102
11673
    let mut targets = vec![quantified.clone()];
103
11673
    if let DeclarationKind::Quantified(inner) = &*quantified.kind()
104
11673
        && let Some(generator) = inner.generator()
105
7782
    {
106
        targets.push(generator.clone());
107
11673
    }
108

            
109
11673
    let mut originals = Vec::with_capacity(targets.len());
110
11673
    for mut target in targets {
111
11673
        let old_kind = target.replace_kind(DeclarationKind::TemporaryValueLetting(
112
3891
            Expression::Atomic(Metadata::new(), Atom::Literal(value.clone())),
113
11673
        ));
114
11673
        originals.push((target, old_kind));
115
3891
    }
116
120

            
117
4011
    let result = f();
118
120

            
119
4011
    for (mut target, old_kind) in originals.into_iter().rev() {
120
3891
        let _ = target.replace_kind(old_kind);
121
3891
    }
122

            
123
3891
    result
124
3891
}
125

            
126
60
fn evaluate_guard(guard: &Expression) -> Result<bool, SolverError> {
127
180
    let simplified = simplify_expression(guard.clone());
128
60
    match eval_constant(&simplified) {
129
6732
        Some(Literal::Bool(value)) => Ok(value),
130
133962
        Some(other) => Err(SolverError::ModelInvalid(format!(
131
115872
            "native comprehension guard must evaluate to Bool, got {other}: {guard}"
132
115872
        ))),
133
115872
        None => Err(SolverError::ModelInvalid(format!(
134
18090
            "native comprehension expansion could not evaluate guard: {guard}"
135
133962
        ))),
136
6672
    }
137
60
}
138

            
139
3336
fn concretise_resolved_reference_atoms(expr: Expression) -> Expression {
140
66981
    expr.transform_bi(&|atom: Atom| match atom {
141
57936
        Atom::Reference(reference) => reference
142
57936
            .resolve_constant()
143
57936
            .map_or_else(|| Atom::Reference(reference), Atom::Literal),
144
9045
        other => other,
145
66981
    })
146
3336
}