1
use crate::RecoverableParseError;
2
use crate::errors::FatalParseError;
3
use crate::expression::parse_expression;
4
use crate::field;
5
use crate::parser::ParseContext;
6
use crate::parser::domain::parse_domain;
7
use crate::util::named_children;
8
use conjure_cp_core::ast::ac_operators::ACOperatorKind;
9
use conjure_cp_core::ast::comprehension::ComprehensionBuilder;
10
use conjure_cp_core::ast::{DeclarationPtr, Expression, Metadata, Moo, Name};
11
use std::vec;
12
use tree_sitter::Node;
13

            
14
711
pub fn parse_comprehension(
15
711
    ctx: &mut ParseContext,
16
711
    node: &Node,
17
711
) -> Result<Option<Expression>, FatalParseError> {
18
    // If we're in a set context, add error and return early since comprehensions don't produce sets
19
711
    if ctx.typechecking_context == crate::util::TypecheckingContext::Set {
20
        ctx.record_error(crate::errors::RecoverableParseError::new(
21
            format!(
22
                "Type error: {}\n\tExpected: set\n\tGot: comprehension",
23
                ctx.source_code[node.start_byte()..node.end_byte()].trim()
24
            ),
25
            Some(node.range()),
26
        ));
27
711
    }
28

            
29
    // Comprehensions require a symbol table passed in
30
711
    let symbols_ptr = match ctx.symbols.clone() {
31
711
        Some(s) => s,
32
        None => {
33
            ctx.record_error(RecoverableParseError::new(
34
                "Comprehensions require a symbol table".to_string(),
35
                Some(node.range()),
36
            ));
37
            return Ok(None);
38
        }
39
    };
40

            
41
711
    let mut builder = ComprehensionBuilder::new(symbols_ptr);
42

            
43
    // We need to track the return expression node separately since it appears first in syntax
44
    // but we need to parse generators first (to get variables in scope)
45
711
    let mut return_expr_node: Option<Node> = None;
46

            
47
    // set return expression node and parse generators/conditions
48
1894
    for child in named_children(node) {
49
1894
        match child.kind() {
50
1894
            "arithmetic_expr" | "bool_expr" | "comparison_expr" | "atom" => {
51
                // Store the return expression node to parse later
52
711
                return_expr_node = Some(child);
53
711
            }
54
1183
            "generator" => {
55
                // Parse the generator variable
56
787
                let Some(var_node) = field!(recover, ctx, child, "variable") else {
57
                    return Ok(None);
58
                };
59
787
                let var_name_str = &ctx.source_code[var_node.start_byte()..var_node.end_byte()];
60
787
                let var_name = Name::user(var_name_str);
61

            
62
                // Parse the domain
63
787
                let Some(domain_node) = field!(recover, ctx, child, "domain") else {
64
                    return Ok(None);
65
                };
66

            
67
                // Parse with a new context using the generator symbol table
68
787
                let mut domain_ctx = ctx.with_new_symbols(Some(builder.generator_symboltable()));
69
787
                let Some(var_domain) = parse_domain(&mut domain_ctx, domain_node)? else {
70
                    return Ok(None);
71
                };
72

            
73
                // Add generator using the builder
74
787
                let decl = DeclarationPtr::new_find(var_name, var_domain);
75
787
                builder = builder.generator(decl);
76
            }
77
396
            "condition" => {
78
                // Parse the condition expression
79
396
                let Some(expr_node) = field!(recover, ctx, child, "expression") else {
80
                    return Ok(None);
81
                };
82
396
                let generator_symboltable = builder.generator_symboltable();
83

            
84
                // Parse with a new context using the generator symbol table
85
396
                let mut guard_ctx = ctx.with_new_symbols(Some(generator_symboltable));
86
396
                let Some(guard_expr) = parse_expression(&mut guard_ctx, expr_node)? else {
87
                    return Ok(None);
88
                };
89

            
90
                // Add the condition as a guard
91
396
                builder = builder.guard(guard_expr);
92
            }
93
            _ => {
94
                // Skip other nodes (like punctuation)
95
            }
96
        }
97
    }
98

            
99
    // parse the return expression
100
711
    let return_expr_node = match return_expr_node {
101
711
        Some(node) => node,
102
        None => {
103
            ctx.record_error(RecoverableParseError::new(
104
                "Comprehension missing return expression".to_string(),
105
                Some(node.range()),
106
            ));
107
            return Ok(None);
108
        }
109
    };
110

            
111
    // Use the return expression symbol table which already has quantified variables (as Given) and parent as parent
112
711
    let mut return_ctx = ctx.with_new_symbols(Some(builder.return_expr_symboltable()));
113
711
    let Some(return_expr) = parse_expression(&mut return_ctx, return_expr_node)? else {
114
        return Ok(None);
115
    };
116

            
117
    // Build the comprehension with the return expression and default ACOperatorKind::And
118
711
    let comprehension = builder.with_return_value(return_expr, Some(ACOperatorKind::And));
119

            
120
711
    Ok(Some(Expression::Comprehension(
121
711
        Metadata::new(),
122
711
        Moo::new(comprehension),
123
711
    )))
124
711
}
125

            
126
/// Parse comprehension-style expressions
127
/// - `forAll vars : domain . expr` → `And(Comprehension(...))`
128
/// - `sum vars : domain . expr` → `Sum(Comprehension(...))`
129
2080
pub fn parse_quantifier_or_aggregate_expr(
130
2080
    ctx: &mut ParseContext,
131
2080
    node: &Node,
132
2080
) -> Result<Option<Expression>, FatalParseError> {
133
    // Quantifier and aggregate expressions require a symbol table
134
2080
    let symbols_ptr = match ctx.symbols.clone() {
135
2080
        Some(s) => s,
136
        None => {
137
            ctx.record_error(RecoverableParseError::new(
138
                "Quantifier and aggregate expressions require a symbol table".to_string(),
139
                Some(node.range()),
140
            ));
141
            return Ok(None);
142
        }
143
    };
144

            
145
    // Create the comprehension builder
146
2080
    let mut builder = ComprehensionBuilder::new(symbols_ptr);
147

            
148
    // First pass: collect domain/collection, variables
149
2080
    let mut domain = None;
150
2080
    let mut collection_node = None;
151
2080
    let mut variables = vec![];
152

            
153
6422
    for child in named_children(node) {
154
6422
        match child.kind() {
155
6422
            "identifier" => {
156
2262
                let var_name_str = &ctx.source_code[child.start_byte()..child.end_byte()];
157
2262
                let var_name = Name::user(var_name_str);
158
2262
                variables.push(var_name);
159
2262
            }
160
4160
            "domain" => {
161
                // Parse with the current symbol table (no need for a new context)
162
2080
                let Some(parsed_domain) = parse_domain(ctx, child)? else {
163
                    return Ok(None);
164
                };
165
2080
                domain = Some(parsed_domain);
166
            }
167
2080
            "set_literal" | "matrix" | "tuple" | "record" => {
168
                // Store the collection node to parse later
169
                collection_node = Some(child);
170
            }
171
2080
            _ => continue,
172
        }
173
    }
174

            
175
    // We need either a domain or a collection
176
2080
    if domain.is_none() && collection_node.is_none() {
177
        ctx.record_error(RecoverableParseError::new(
178
            "Quantifier and aggregate expressions require a domain or collection".to_string(),
179
            Some(node.range()),
180
        ));
181
        return Ok(None);
182
2080
    }
183

            
184
2080
    if variables.is_empty() {
185
        ctx.record_error(RecoverableParseError::new(
186
            "Quantifier and aggregate expressions require variables".to_string(),
187
            Some(node.range()),
188
        ));
189
        return Ok(None);
190
2080
    }
191

            
192
    // Get the operator type
193
2080
    let Some(operator_node) = field!(recover, ctx, node, "operator") else {
194
        return Ok(None);
195
    };
196
2080
    let operator_str = &ctx.source_code[operator_node.start_byte()..operator_node.end_byte()];
197

            
198
2080
    let (ac_operator_kind, wrapper) = match operator_str {
199
2080
        "forAll" => (ACOperatorKind::And, "And"),
200
988
        "exists" => (ACOperatorKind::Or, "Or"),
201
234
        "sum" => (ACOperatorKind::Sum, "Sum"),
202
        "min" => (ACOperatorKind::Sum, "Min"), // AC operator doesn't matter for non-boolean aggregates
203
        "max" => (ACOperatorKind::Sum, "Max"),
204
        _ => {
205
            ctx.record_error(RecoverableParseError::new(
206
                format!("Unknown operator: {}", operator_str),
207
                Some(operator_node.range()),
208
            ));
209
            return Ok(None);
210
        }
211
    };
212

            
213
    // Add variables as generators
214
2080
    if let Some(dom) = domain {
215
2262
        for var_name in variables {
216
2262
            let decl = DeclarationPtr::new_find(var_name, dom.clone());
217
2262
            builder = builder.generator(decl);
218
2262
        }
219
    } else if let Some(_coll_node) = collection_node {
220
        // TODO: support collection domains
221
        ctx.record_error(RecoverableParseError::new(
222
            "Collection domains in quantifier and aggregate expressions".to_string(),
223
            Some(_coll_node.range()),
224
        ));
225
        return Ok(None);
226
    }
227

            
228
    // Parse the expression (after variables are in the symbol table)
229
2080
    let Some(expression_node) = field!(recover, ctx, node, "expression") else {
230
        return Ok(None);
231
    };
232

            
233
    // Parse with a new context using the return expression symbol table
234
2080
    let mut expr_ctx = ctx.with_new_symbols(Some(builder.return_expr_symboltable()));
235
2080
    let Some(expression) = parse_expression(&mut expr_ctx, expression_node)? else {
236
        return Ok(None);
237
    };
238

            
239
    // Build the comprehension
240
2080
    let comprehension = builder.with_return_value(expression, Some(ac_operator_kind));
241
2080
    let wrapped_comprehension = Expression::Comprehension(Metadata::new(), Moo::new(comprehension));
242

            
243
    // Wrap in the appropriate expression type
244
2080
    match wrapper {
245
2080
        "And" => Ok(Some(Expression::And(
246
1092
            Metadata::new(),
247
1092
            Moo::new(wrapped_comprehension),
248
1092
        ))),
249
988
        "Or" => Ok(Some(Expression::Or(
250
754
            Metadata::new(),
251
754
            Moo::new(wrapped_comprehension),
252
754
        ))),
253
234
        "Sum" => Ok(Some(Expression::Sum(
254
234
            Metadata::new(),
255
234
            Moo::new(wrapped_comprehension),
256
234
        ))),
257
        "Min" => Ok(Some(Expression::Min(
258
            Metadata::new(),
259
            Moo::new(wrapped_comprehension),
260
        ))),
261
        "Max" => Ok(Some(Expression::Max(
262
            Metadata::new(),
263
            Moo::new(wrapped_comprehension),
264
        ))),
265
        _ => unreachable!(),
266
    }
267
2080
}