1
use crate::errors::FatalParseError;
2
use crate::expression::parse_expression;
3
use crate::field;
4
use crate::parser::ParseContext;
5
use crate::parser::domain::parse_domain;
6
use crate::util::named_children;
7
use conjure_cp_core::ast::ac_operators::ACOperatorKind;
8
use conjure_cp_core::ast::comprehension::ComprehensionBuilder;
9
use conjure_cp_core::ast::{DeclarationPtr, Expression, Metadata, Moo, Name};
10
use std::vec;
11
use tree_sitter::Node;
12

            
13
pub fn parse_comprehension(
14
    ctx: &mut ParseContext,
15
    node: &Node,
16
) -> Result<Option<Expression>, FatalParseError> {
17
    // Comprehensions require a symbol table passed in
18
    let symbols_ptr = ctx.symbols.clone().ok_or_else(|| {
19
        FatalParseError::internal_error(
20
            "Comprehensions require a symbol table".to_string(),
21
            Some(node.range()),
22
        )
23
    })?;
24

            
25
    let mut builder = ComprehensionBuilder::new(symbols_ptr);
26

            
27
    // We need to track the return expression node separately since it appears first in syntax
28
    // but we need to parse generators first (to get variables in scope)
29
    let mut return_expr_node: Option<Node> = None;
30

            
31
    // set return expression node and parse generators/conditions
32
    for child in named_children(node) {
33
        match child.kind() {
34
            "arithmetic_expr" | "bool_expr" | "comparison_expr" | "atom" => {
35
                // Store the return expression node to parse later
36
                return_expr_node = Some(child);
37
            }
38
            "generator" => {
39
                // Parse the generator variable
40
                let var_node = field!(child, "variable");
41
                let var_name_str = &ctx.source_code[var_node.start_byte()..var_node.end_byte()];
42
                let var_name = Name::user(var_name_str);
43

            
44
                // Parse the domain
45
                let domain_node = field!(child, "domain");
46

            
47
                // Parse with a new context using the generator symbol table
48
                let mut domain_ctx = ctx.with_new_symbols(Some(builder.generator_symboltable()));
49
                let Some(var_domain) = parse_domain(&mut domain_ctx, domain_node)? else {
50
                    return Ok(None);
51
                };
52

            
53
                // Add generator using the builder
54
                let decl = DeclarationPtr::new_find(var_name, var_domain);
55
                builder = builder.generator(decl);
56
            }
57
            "condition" => {
58
                // Parse the condition expression
59
                let expr_node = field!(child, "expression");
60
                let generator_symboltable = builder.generator_symboltable();
61

            
62
                // Parse with a new context using the generator symbol table
63
                let mut guard_ctx = ctx.with_new_symbols(Some(generator_symboltable));
64
                let Some(guard_expr) = parse_expression(&mut guard_ctx, expr_node)? else {
65
                    return Ok(None);
66
                };
67

            
68
                // Add the condition as a guard
69
                builder = builder.guard(guard_expr);
70
            }
71
            _ => {
72
                // Skip other nodes (like punctuation)
73
            }
74
        }
75
    }
76

            
77
    // parse the return expression
78
    let return_expr_node = return_expr_node.ok_or_else(|| {
79
        FatalParseError::internal_error(
80
            "Comprehension missing return expression".to_string(),
81
            Some(node.range()),
82
        )
83
    })?;
84

            
85
    // Use the return expression symbol table which already has quantified variables (as Given) and parent as parent
86
    let mut return_ctx = ctx.with_new_symbols(Some(builder.return_expr_symboltable()));
87
    let Some(return_expr) = parse_expression(&mut return_ctx, return_expr_node)? else {
88
        return Ok(None);
89
    };
90

            
91
    // Build the comprehension with the return expression and default ACOperatorKind::And
92
    let comprehension = builder.with_return_value(return_expr, Some(ACOperatorKind::And));
93

            
94
    Ok(Some(Expression::Comprehension(
95
        Metadata::new(),
96
        Moo::new(comprehension),
97
    )))
98
}
99

            
100
/// Parse comprehension-style expressions
101
/// - `forAll vars : domain . expr` → `And(Comprehension(...))`
102
/// - `sum vars : domain . expr` → `Sum(Comprehension(...))`
103
66
pub fn parse_quantifier_or_aggregate_expr(
104
66
    ctx: &mut ParseContext,
105
66
    node: &Node,
106
66
) -> Result<Option<Expression>, FatalParseError> {
107
    // Quantifier and aggregate expressions require a symbol table
108
66
    let symbols_ptr = ctx.symbols.clone().ok_or_else(|| {
109
        FatalParseError::internal_error(
110
            "Quantifier and aggregate expressions require a symbol table".to_string(),
111
            Some(node.range()),
112
        )
113
    })?;
114

            
115
    // Create the comprehension builder
116
66
    let mut builder = ComprehensionBuilder::new(symbols_ptr);
117

            
118
    // First pass: collect domain/collection, variables
119
66
    let mut domain = None;
120
66
    let mut collection_node = None;
121
66
    let mut variables = vec![];
122

            
123
198
    for child in named_children(node) {
124
198
        match child.kind() {
125
198
            "identifier" => {
126
66
                let var_name_str = &ctx.source_code[child.start_byte()..child.end_byte()];
127
66
                let var_name = Name::user(var_name_str);
128
66
                variables.push(var_name);
129
66
            }
130
132
            "domain" => {
131
                // Parse with the current symbol table (no need for a new context)
132
66
                let Some(parsed_domain) = parse_domain(ctx, child)? else {
133
                    return Ok(None);
134
                };
135
66
                domain = Some(parsed_domain);
136
            }
137
66
            "set_literal" | "matrix" | "tuple" | "record" => {
138
                // Store the collection node to parse later
139
                collection_node = Some(child);
140
            }
141
66
            _ => continue,
142
        }
143
    }
144

            
145
    // We need either a domain or a collection
146
66
    if domain.is_none() && collection_node.is_none() {
147
        return Err(FatalParseError::internal_error(
148
            "Quantifier and aggregate expressions require a domain or collection".to_string(),
149
            Some(node.range()),
150
        ));
151
66
    }
152

            
153
66
    if variables.is_empty() {
154
        return Err(FatalParseError::internal_error(
155
            "Quantifier and aggregate expressions require variables".to_string(),
156
            Some(node.range()),
157
        ));
158
66
    }
159

            
160
    // Get the operator type
161
66
    let operator_node = field!(node, "operator");
162
66
    let operator_str = &ctx.source_code[operator_node.start_byte()..operator_node.end_byte()];
163

            
164
66
    let (ac_operator_kind, wrapper) = match operator_str {
165
66
        "forAll" => (ACOperatorKind::And, "And"),
166
        "exists" => (ACOperatorKind::Or, "Or"),
167
        "sum" => (ACOperatorKind::Sum, "Sum"),
168
        "min" => (ACOperatorKind::Sum, "Min"), // AC operator doesn't matter for non-boolean aggregates
169
        "max" => (ACOperatorKind::Sum, "Max"),
170
        _ => {
171
            return Err(FatalParseError::internal_error(
172
                format!("Unknown operator: {}", operator_str),
173
                Some(operator_node.range()),
174
            ));
175
        }
176
    };
177

            
178
    // Add variables as generators
179
66
    if let Some(dom) = domain {
180
66
        for var_name in variables {
181
66
            let decl = DeclarationPtr::new_find(var_name, dom.clone());
182
66
            builder = builder.generator(decl);
183
66
        }
184
    } else if let Some(_coll_node) = collection_node {
185
        // TODO: support collection domains
186
        return Err(FatalParseError::NotImplemented(
187
            "Collection domains in quantifier and aggregate expressions".to_string(),
188
        ));
189
    }
190

            
191
    // Parse the expression (after variables are in the symbol table)
192
66
    let expression_node = field!(node, "expression");
193

            
194
    // Parse with a new context using the return expression symbol table
195
66
    let mut expr_ctx = ctx.with_new_symbols(Some(builder.return_expr_symboltable()));
196
66
    let Some(expression) = parse_expression(&mut expr_ctx, expression_node)? else {
197
        return Ok(None);
198
    };
199

            
200
    // Build the comprehension
201
66
    let comprehension = builder.with_return_value(expression, Some(ac_operator_kind));
202
66
    let wrapped_comprehension = Expression::Comprehension(Metadata::new(), Moo::new(comprehension));
203

            
204
    // Wrap in the appropriate expression type
205
66
    match wrapper {
206
66
        "And" => Ok(Some(Expression::And(
207
66
            Metadata::new(),
208
66
            Moo::new(wrapped_comprehension),
209
66
        ))),
210
        "Or" => Ok(Some(Expression::Or(
211
            Metadata::new(),
212
            Moo::new(wrapped_comprehension),
213
        ))),
214
        "Sum" => Ok(Some(Expression::Sum(
215
            Metadata::new(),
216
            Moo::new(wrapped_comprehension),
217
        ))),
218
        "Min" => Ok(Some(Expression::Min(
219
            Metadata::new(),
220
            Moo::new(wrapped_comprehension),
221
        ))),
222
        "Max" => Ok(Some(Expression::Max(
223
            Metadata::new(),
224
            Moo::new(wrapped_comprehension),
225
        ))),
226
        _ => unreachable!(),
227
    }
228
66
}