1
use crate::RecoverableParseError;
2
use crate::errors::FatalParseError;
3
use crate::expression::parse_expression;
4
use crate::field;
5
use crate::parser::ParseContext;
6
use crate::parser::domain::parse_domain;
7
use crate::util::{TypecheckingContext, named_children};
8
use conjure_cp_core::ast::ac_operators::ACOperatorKind;
9
use conjure_cp_core::ast::comprehension::ComprehensionBuilder;
10
use conjure_cp_core::ast::{DeclarationPtr, Expression, Metadata, Moo, Name};
11
use std::vec;
12
use tree_sitter::Node;
13

            
14
711
pub fn parse_comprehension(
15
711
    ctx: &mut ParseContext,
16
711
    node: &Node,
17
711
) -> Result<Option<Expression>, FatalParseError> {
18
    // If we're in a set context, add error and return early since comprehensions don't produce sets
19
711
    if ctx.typechecking_context == crate::util::TypecheckingContext::Set {
20
        ctx.record_error(crate::errors::RecoverableParseError::new(
21
            format!(
22
                "Type error: {}\n\tExpected: set\n\tGot: comprehension",
23
                ctx.source_code[node.start_byte()..node.end_byte()].trim()
24
            ),
25
            Some(node.range()),
26
        ));
27
711
    }
28

            
29
    // Comprehensions require a symbol table passed in
30
711
    let symbols_ptr = match ctx.symbols.clone() {
31
711
        Some(s) => s,
32
        None => {
33
            ctx.record_error(RecoverableParseError::new(
34
                "Comprehensions require a symbol table".to_string(),
35
                Some(node.range()),
36
            ));
37
            return Ok(None);
38
        }
39
    };
40

            
41
711
    let mut builder = ComprehensionBuilder::new(symbols_ptr);
42

            
43
    // We need to track the return expression node separately since it appears first in syntax
44
    // but we need to parse generators first (to get variables in scope)
45
711
    let mut return_expr_node: Option<Node> = None;
46

            
47
    // set return expression node and parse generators/conditions
48
1894
    for child in named_children(node) {
49
1894
        match child.kind() {
50
1894
            "arithmetic_expr" | "bool_expr" | "comparison_expr" | "atom" => {
51
                // Store the return expression node to parse later
52
711
                return_expr_node = Some(child);
53
711
            }
54
1183
            "generator" => {
55
                // Parse the generator variable
56
787
                let Some(var_node) = field!(recover, ctx, child, "variable") else {
57
                    return Ok(None);
58
                };
59
787
                let var_name_str = &ctx.source_code[var_node.start_byte()..var_node.end_byte()];
60
787
                let var_name = Name::user(var_name_str);
61

            
62
                // Parse the domain
63
787
                let Some(domain_node) = field!(recover, ctx, child, "domain") else {
64
                    return Ok(None);
65
                };
66

            
67
                // Parse with a new context using the generator symbol table
68
787
                let mut domain_ctx = ctx.with_new_symbols(Some(builder.generator_symboltable()));
69
787
                let Some(var_domain) = parse_domain(&mut domain_ctx, domain_node)? else {
70
                    return Ok(None);
71
                };
72

            
73
                // Add generator using the builder
74
787
                let decl = DeclarationPtr::new_find(var_name, var_domain);
75
787
                builder = builder.generator(decl);
76
            }
77
396
            "condition" => {
78
                // Parse the condition expression
79
396
                let Some(expr_node) = field!(recover, ctx, child, "expression") else {
80
                    return Ok(None);
81
                };
82
396
                let generator_symboltable = builder.generator_symboltable();
83

            
84
                // Parse with a new context using the generator symbol table
85
396
                let mut guard_ctx = ctx.with_new_symbols(Some(generator_symboltable));
86
396
                let Some(guard_expr) = parse_expression(&mut guard_ctx, expr_node)? else {
87
                    return Ok(None);
88
                };
89

            
90
                // Add the condition as a guard
91
396
                builder = builder.guard(guard_expr);
92
            }
93
            _ => {
94
                // Skip other nodes (like punctuation)
95
            }
96
        }
97
    }
98

            
99
    // parse the return expression
100
711
    let return_expr_node = match return_expr_node {
101
711
        Some(node) => node,
102
        None => {
103
            ctx.record_error(RecoverableParseError::new(
104
                "Comprehension missing return expression".to_string(),
105
                Some(node.range()),
106
            ));
107
            return Ok(None);
108
        }
109
    };
110

            
111
    // Use the return expression symbol table which already has quantified variables (as Given) and parent as parent
112
    // Parse using the inner typechecking context
113
711
    let saved_inner_ctx = ctx.inner_typechecking_context;
114
711
    let mut return_ctx = ctx.with_new_symbols(Some(builder.return_expr_symboltable()));
115
711
    return_ctx.typechecking_context = saved_inner_ctx;
116
711
    return_ctx.inner_typechecking_context = TypecheckingContext::Unknown;
117
711
    let Some(return_expr) = parse_expression(&mut return_ctx, return_expr_node)? else {
118
        return Ok(None);
119
    };
120

            
121
    // Build the comprehension with the return expression and default ACOperatorKind::And
122
711
    let comprehension = builder.with_return_value(return_expr, Some(ACOperatorKind::And));
123

            
124
711
    Ok(Some(Expression::Comprehension(
125
711
        Metadata::new(),
126
711
        Moo::new(comprehension),
127
711
    )))
128
711
}
129

            
130
/// Parse comprehension-style expressions
131
/// - `forAll vars : domain . expr` → `And(Comprehension(...))`
132
/// - `sum vars : domain . expr` → `Sum(Comprehension(...))`
133
2080
pub fn parse_quantifier_or_aggregate_expr(
134
2080
    ctx: &mut ParseContext,
135
2080
    node: &Node,
136
2080
) -> Result<Option<Expression>, FatalParseError> {
137
    // Quantifier and aggregate expressions require a symbol table
138
2080
    let symbols_ptr = match ctx.symbols.clone() {
139
2080
        Some(s) => s,
140
        None => {
141
            ctx.record_error(RecoverableParseError::new(
142
                "Quantifier and aggregate expressions require a symbol table".to_string(),
143
                Some(node.range()),
144
            ));
145
            return Ok(None);
146
        }
147
    };
148

            
149
    // Create the comprehension builder
150
2080
    let mut builder = ComprehensionBuilder::new(symbols_ptr);
151

            
152
    // First pass: collect domain/collection, variables
153
2080
    let mut domain = None;
154
2080
    let mut collection_node = None;
155
2080
    let mut variables = vec![];
156

            
157
6422
    for child in named_children(node) {
158
6422
        match child.kind() {
159
6422
            "identifier" => {
160
2262
                let var_name_str = &ctx.source_code[child.start_byte()..child.end_byte()];
161
2262
                let var_name = Name::user(var_name_str);
162
2262
                variables.push(var_name);
163
2262
            }
164
4160
            "domain" => {
165
                // Parse domains under Unknown context so arithmetic bounds in domains
166
                // (e.g. int(1..m-1)) are not rejected by surrounding boolean/arithmetic contexts.
167
2080
                let saved_ctx = ctx.typechecking_context;
168
2080
                let saved_inner_ctx = ctx.inner_typechecking_context;
169
2080
                ctx.typechecking_context = TypecheckingContext::Unknown;
170
2080
                ctx.inner_typechecking_context = TypecheckingContext::Unknown;
171

            
172
2080
                let Some(parsed_domain) = parse_domain(ctx, child)? else {
173
                    ctx.typechecking_context = saved_ctx;
174
                    ctx.inner_typechecking_context = saved_inner_ctx;
175
                    return Ok(None);
176
                };
177

            
178
2080
                ctx.typechecking_context = saved_ctx;
179
2080
                ctx.inner_typechecking_context = saved_inner_ctx;
180
2080
                domain = Some(parsed_domain);
181
            }
182
2080
            "set_literal" | "matrix" | "tuple" | "record" => {
183
                // Store the collection node to parse later
184
                collection_node = Some(child);
185
            }
186
2080
            _ => continue,
187
        }
188
    }
189

            
190
    // We need either a domain or a collection
191
2080
    if domain.is_none() && collection_node.is_none() {
192
        ctx.record_error(RecoverableParseError::new(
193
            "Quantifier and aggregate expressions require a domain or collection".to_string(),
194
            Some(node.range()),
195
        ));
196
        return Ok(None);
197
2080
    }
198

            
199
2080
    if variables.is_empty() {
200
        ctx.record_error(RecoverableParseError::new(
201
            "Quantifier and aggregate expressions require variables".to_string(),
202
            Some(node.range()),
203
        ));
204
        return Ok(None);
205
2080
    }
206

            
207
    // Get the operator type
208
2080
    let Some(operator_node) = field!(recover, ctx, node, "operator") else {
209
        return Ok(None);
210
    };
211
2080
    let operator_str = &ctx.source_code[operator_node.start_byte()..operator_node.end_byte()];
212

            
213
2080
    let (ac_operator_kind, wrapper) = match operator_str {
214
2080
        "forAll" => (ACOperatorKind::And, "And"),
215
988
        "exists" => (ACOperatorKind::Or, "Or"),
216
234
        "sum" => (ACOperatorKind::Sum, "Sum"),
217
        "min" => (ACOperatorKind::Sum, "Min"), // AC operator doesn't matter for non-boolean aggregates
218
        "max" => (ACOperatorKind::Sum, "Max"),
219
        _ => {
220
            ctx.record_error(RecoverableParseError::new(
221
                format!("Unknown operator: {}", operator_str),
222
                Some(operator_node.range()),
223
            ));
224
            return Ok(None);
225
        }
226
    };
227

            
228
    // Add variables as generators
229
2080
    if let Some(dom) = domain {
230
2262
        for var_name in variables {
231
2262
            let decl = DeclarationPtr::new_find(var_name, dom.clone());
232
2262
            builder = builder.generator(decl);
233
2262
        }
234
    } else if let Some(_coll_node) = collection_node {
235
        // TODO: support collection domains
236
        ctx.record_error(RecoverableParseError::new(
237
            "Collection domains in quantifier and aggregate expressions".to_string(),
238
            Some(_coll_node.range()),
239
        ));
240
        return Ok(None);
241
    }
242

            
243
    // Parse the expression (after variables are in the symbol table)
244
2080
    let Some(expression_node) = field!(recover, ctx, node, "expression") else {
245
        return Ok(None);
246
    };
247

            
248
    // Parse with a new context using the return expression symbol table
249
    // Prase using the inner typechecking context
250
2080
    let saved_inner_ctx = ctx.inner_typechecking_context;
251
2080
    let mut expr_ctx = ctx.with_new_symbols(Some(builder.return_expr_symboltable()));
252
2080
    expr_ctx.typechecking_context = saved_inner_ctx;
253
2080
    expr_ctx.inner_typechecking_context = TypecheckingContext::Unknown;
254
2080
    let Some(expression) = parse_expression(&mut expr_ctx, expression_node)? else {
255
        return Ok(None);
256
    };
257

            
258
    // Build the comprehension
259
2080
    let comprehension = builder.with_return_value(expression, Some(ac_operator_kind));
260
2080
    let wrapped_comprehension = Expression::Comprehension(Metadata::new(), Moo::new(comprehension));
261

            
262
    // Wrap in the appropriate expression type
263
2080
    match wrapper {
264
2080
        "And" => Ok(Some(Expression::And(
265
1092
            Metadata::new(),
266
1092
            Moo::new(wrapped_comprehension),
267
1092
        ))),
268
988
        "Or" => Ok(Some(Expression::Or(
269
754
            Metadata::new(),
270
754
            Moo::new(wrapped_comprehension),
271
754
        ))),
272
234
        "Sum" => Ok(Some(Expression::Sum(
273
234
            Metadata::new(),
274
234
            Moo::new(wrapped_comprehension),
275
234
        ))),
276
        "Min" => Ok(Some(Expression::Min(
277
            Metadata::new(),
278
            Moo::new(wrapped_comprehension),
279
        ))),
280
        "Max" => Ok(Some(Expression::Max(
281
            Metadata::new(),
282
            Moo::new(wrapped_comprehension),
283
        ))),
284
        _ => unreachable!(),
285
    }
286
2080
}