1
use conjure_cp::{
2
    ast::{
3
        Atom, DeclarationKind, DeclarationPtr, DomainPtr, Expression, Literal, Metadata, Name,
4
        SymbolTable,
5
        comprehension::{Comprehension, ComprehensionQualifier},
6
        eval_constant,
7
    },
8
    bug,
9
    solver::SolverError,
10
};
11
use uniplate::Biplate as _;
12

            
13
use super::via_solver_common::{
14
    lift_machine_references_into_parent_scope, simplify_expression,
15
    strip_guarded_safe_index_conditions,
16
};
17

            
18
/// Expands the comprehension without calling an external solver.
19
///
20
/// Algorithm:
21
/// 1. Recurse qualifiers left-to-right.
22
/// 2. For each generator value, temporarily bind the quantified declaration to a
23
///    `TemporaryValueLetting` and recurse.
24
/// 3. For each condition, evaluate and recurse only if true.
25
/// 4. At the leaf, evaluate the return expression under the active bindings.
26
396
pub fn expand_native(
27
396
    comprehension: Comprehension,
28
396
    parent_symbols: &mut SymbolTable,
29
396
) -> Result<Vec<Expression>, SolverError> {
30
396
    let mut expanded = Vec::new();
31
396
    expand_qualifiers(&comprehension, 0, &mut expanded, parent_symbols)?;
32
396
    Ok(expanded)
33
396
}
34

            
35
16450
fn expand_qualifiers(
36
16450
    comprehension: &Comprehension,
37
16450
    qualifier_index: usize,
38
16450
    expanded: &mut Vec<Expression>,
39
16450
    parent_symbols: &mut SymbolTable,
40
16450
) -> Result<(), SolverError> {
41
16450
    if qualifier_index == comprehension.qualifiers.len() {
42
13698
        let child_symbols = comprehension.symbols().clone();
43
13698
        let return_expression =
44
13698
            concretise_resolved_reference_atoms(comprehension.return_expression.clone());
45
13698
        let Some(return_expression) = strip_guarded_safe_index_conditions(return_expression) else {
46
32
            return Ok(());
47
        };
48
13666
        let return_expression = simplify_expression(return_expression);
49
13666
        let return_expression = lift_machine_references_into_parent_scope(
50
13666
            return_expression,
51
13666
            &child_symbols,
52
13666
            parent_symbols,
53
        );
54
13666
        expanded.push(return_expression);
55
13666
        return Ok(());
56
2752
    }
57

            
58
2752
    match &comprehension.qualifiers[qualifier_index] {
59
2106
        ComprehensionQualifier::Generator { ptr } => {
60
2106
            let name = ptr.name().clone();
61
2106
            let domain = ptr.domain().expect("generator declaration has domain");
62
2106
            let values = resolve_generator_values(&name, &domain)?;
63

            
64
15454
            for literal in values {
65
15454
                with_temporary_quantified_binding(ptr, &literal, || {
66
15454
                    expand_qualifiers(comprehension, qualifier_index + 1, expanded, parent_symbols)
67
15454
                })?;
68
            }
69
        }
70
646
        ComprehensionQualifier::Condition(condition) => {
71
646
            if evaluate_guard(condition)? {
72
600
                expand_qualifiers(comprehension, qualifier_index + 1, expanded, parent_symbols)?;
73
46
            }
74
        }
75
        ComprehensionQualifier::ExpressionGenerator { .. } => {
76
            bug!(
77
                "Comprehension expander should not be called on comprehensions containing ExpressionGenerator"
78
            );
79
        }
80
    }
81

            
82
2752
    Ok(())
83
16450
}
84

            
85
2106
fn resolve_generator_values(name: &Name, domain: &DomainPtr) -> Result<Vec<Literal>, SolverError> {
86
2106
    let resolved = domain.resolve().ok_or_else(|| {
87
        SolverError::ModelFeatureNotSupported(format!(
88
            "quantified variable '{name}' has unresolved domain after assigning previous generators: {domain}"
89
        ))
90
    })?;
91

            
92
2106
    resolved.values().map(|iter| iter.collect()).map_err(|err| {
93
        SolverError::ModelFeatureNotSupported(format!(
94
            "quantified variable '{name}' has non-enumerable domain: {err}"
95
        ))
96
    })
97
2106
}
98

            
99
15454
fn with_temporary_quantified_binding<T>(
100
15454
    quantified: &DeclarationPtr,
101
15454
    value: &Literal,
102
15454
    f: impl FnOnce() -> Result<T, SolverError>,
103
15454
) -> Result<T, SolverError> {
104
15454
    let mut targets = vec![quantified.clone()];
105
15454
    if let DeclarationKind::Quantified(inner) = &*quantified.kind()
106
15454
        && let Some(generator) = inner.generator()
107
    {
108
        targets.push(generator.clone());
109
15454
    }
110

            
111
15454
    let mut originals = Vec::with_capacity(targets.len());
112
15454
    for mut target in targets {
113
15454
        let old_kind = target.replace_kind(DeclarationKind::TemporaryValueLetting(
114
15454
            Expression::Atomic(Metadata::new(), Atom::Literal(value.clone())),
115
15454
        ));
116
15454
        originals.push((target, old_kind));
117
15454
    }
118

            
119
15454
    let result = f();
120

            
121
15454
    for (mut target, old_kind) in originals.into_iter().rev() {
122
15454
        let _ = target.replace_kind(old_kind);
123
15454
    }
124

            
125
15454
    result
126
15454
}
127

            
128
646
fn evaluate_guard(guard: &Expression) -> Result<bool, SolverError> {
129
646
    let simplified = simplify_expression(guard.clone());
130
646
    match eval_constant(&simplified) {
131
646
        Some(Literal::Bool(value)) => Ok(value),
132
        Some(other) => Err(SolverError::ModelInvalid(format!(
133
            "native comprehension guard must evaluate to Bool, got {other}: {guard}"
134
        ))),
135
        None => Err(SolverError::ModelInvalid(format!(
136
            "native comprehension expansion could not evaluate guard: {guard}"
137
        ))),
138
    }
139
646
}
140

            
141
13698
fn concretise_resolved_reference_atoms(expr: Expression) -> Expression {
142
280050
    expr.transform_bi(&|atom: Atom| match atom {
143
240691
        Atom::Reference(reference) => reference
144
240691
            .resolve_constant()
145
240691
            .map_or_else(|| Atom::Reference(reference), Atom::Literal),
146
39359
        other => other,
147
280050
    })
148
13698
}