1
use crate::errors::{FatalParseError, RecoverableParseError};
2
use crate::expression::parse_expression;
3
use crate::field;
4
use crate::parser::domain::parse_domain;
5
use crate::util::named_children;
6
use conjure_cp_core::ast::ac_operators::ACOperatorKind;
7
use conjure_cp_core::ast::comprehension::ComprehensionBuilder;
8
use conjure_cp_core::ast::{DeclarationPtr, Expression, Metadata, Moo, Name, SymbolTablePtr};
9
use std::vec;
10
use tree_sitter::Node;
11

            
12
pub fn parse_comprehension(
13
    node: &Node,
14
    source_code: &str,
15
    root: &Node,
16
    symbols_ptr: Option<SymbolTablePtr>,
17
    errors: &mut Vec<RecoverableParseError>,
18
) -> Result<Expression, FatalParseError> {
19
    // Comprehensions require a symbol table passed in
20
    let symbols_ptr = symbols_ptr.ok_or_else(|| {
21
        FatalParseError::syntax_error(
22
            "Comprehensions require a symbol table".to_string(),
23
            Some(node.range()),
24
        )
25
    })?;
26

            
27
    let mut builder = ComprehensionBuilder::new(symbols_ptr.clone());
28

            
29
    // We need to track the return expression node separately since it appears first in syntax
30
    // but we need to parse generators first (to get variables in scope)
31
    let mut return_expr_node: Option<Node> = None;
32

            
33
    // set return expression node and parse generators/conditions
34
    for child in named_children(node) {
35
        match child.kind() {
36
            "arithmetic_expr" | "bool_expr" | "comparison_expr" | "atom" => {
37
                // Store the return expression node to parse later
38
                return_expr_node = Some(child);
39
            }
40
            "generator" => {
41
                // Parse the generator variable
42
                let var_node = field!(child, "variable");
43
                let var_name_str = &source_code[var_node.start_byte()..var_node.end_byte()];
44
                let var_name = Name::user(var_name_str);
45

            
46
                // Parse the domain
47
                let domain_node = field!(child, "domain");
48
                let var_domain =
49
                    parse_domain(domain_node, source_code, Some(symbols_ptr.clone()), errors)?;
50

            
51
                // Add generator using the builder
52
                let decl = DeclarationPtr::new_find(var_name, var_domain);
53
                builder = builder.generator(decl);
54
            }
55
            "condition" => {
56
                // Parse the condition expression
57
                let expr_node = field!(child, "expression");
58
                let generator_symboltable = builder.generator_symboltable();
59

            
60
                let guard_expr = parse_expression(
61
                    expr_node,
62
                    source_code,
63
                    root,
64
                    Some(generator_symboltable),
65
                    errors,
66
                )?;
67

            
68
                // Add the condition as a guard
69
                builder = builder.guard(guard_expr);
70
            }
71
            _ => {
72
                // Skip other nodes (like punctuation)
73
            }
74
        }
75
    }
76

            
77
    // parse the return expression
78
    let return_expr_node = return_expr_node.ok_or_else(|| {
79
        FatalParseError::syntax_error(
80
            "Comprehension missing return expression".to_string(),
81
            Some(node.range()),
82
        )
83
    })?;
84

            
85
    // Use the return expression symbol table which already has quantified variables (as Given) and parent as parent
86
    let return_expr = parse_expression(
87
        return_expr_node,
88
        source_code,
89
        root,
90
        Some(builder.return_expr_symboltable()),
91
        errors,
92
    )?;
93

            
94
    // Build the comprehension with the return expression and default ACOperatorKind::And
95
    let comprehension = builder.with_return_value(return_expr, Some(ACOperatorKind::And));
96

            
97
    Ok(Expression::Comprehension(
98
        Metadata::new(),
99
        Moo::new(comprehension),
100
    ))
101
}
102

            
103
/// Parse comprehension-style expressions
104
/// - `forAll vars : domain . expr` → `And(Comprehension(...))`
105
/// - `sum vars : domain . expr` → `Sum(Comprehension(...))`
106
22
pub fn parse_quantifier_or_aggregate_expr(
107
    node: &Node,
108
22
    source_code: &str,
109
    root: &Node,
110
    symbols_ptr: Option<SymbolTablePtr>,
111
    errors: &mut Vec<RecoverableParseError>,
112
) -> Result<Expression, FatalParseError> {
113
    // Quantifier and aggregate expressions require a symbol table
114
    let symbols_ptr = symbols_ptr.ok_or_else(|| {
115
        FatalParseError::syntax_error(
116
22
            "Quantifier and aggregate expressions require a symbol table".to_string(),
117
            Some(node.range()),
118
        )
119
22
    })?;
120
22

            
121
    // Create the comprehension builder
122
    let mut builder = ComprehensionBuilder::new(symbols_ptr.clone());
123
66

            
124
    // First pass: collect domain/collection, variables
125
66
    let mut domain = None;
126
22
    let mut collection_node = None;
127
22
    let mut variables = vec![];
128
22

            
129
22
    for child in named_children(node) {
130
44
        match child.kind() {
131
            "identifier" => {
132
22
                let var_name_str = &source_code[child.start_byte()..child.end_byte()];
133
                let var_name = Name::user(var_name_str);
134
                variables.push(var_name);
135
22
            }
136
            "domain" => {
137
22
                domain = Some(parse_domain(
138
                    child,
139
                    source_code,
140
                    Some(symbols_ptr.clone()),
141
22
                    errors,
142
                )?);
143
            }
144
            "set_literal" | "matrix" | "tuple" | "record" => {
145
                // Store the collection node to parse later
146
22
                collection_node = Some(child);
147
            }
148
            _ => continue,
149
        }
150
    }
151
22

            
152
    // We need either a domain or a collection
153
22
    if domain.is_none() && collection_node.is_none() {
154
        return Err(FatalParseError::syntax_error(
155
            "Quantifier and aggregate expressions require a domain or collection".to_string(),
156
            Some(node.range()),
157
        ));
158
22
    }
159

            
160
    if variables.is_empty() {
161
22
        return Err(FatalParseError::syntax_error(
162
22
            "Quantifier and aggregate expressions require variables".to_string(),
163
            Some(node.range()),
164
22
        ));
165
22
    }
166

            
167
    // Get the operator type
168
    let operator_node = field!(node, "operator");
169
    let operator_str = &source_code[operator_node.start_byte()..operator_node.end_byte()];
170

            
171
    let (ac_operator_kind, wrapper) = match operator_str {
172
        "forAll" => (ACOperatorKind::And, "And"),
173
        "exists" => (ACOperatorKind::Or, "Or"),
174
        "sum" => (ACOperatorKind::Sum, "Sum"),
175
        "min" => (ACOperatorKind::Sum, "Min"), // AC operator doesn't matter for non-boolean aggregates
176
        "max" => (ACOperatorKind::Sum, "Max"),
177
        _ => {
178
            return Err(FatalParseError::syntax_error(
179
22
                format!("Unknown operator: {}", operator_str),
180
22
                Some(operator_node.range()),
181
22
            ));
182
22
        }
183
22
    };
184

            
185
    // Add variables as generators
186
    if let Some(dom) = domain {
187
        for var_name in variables {
188
            let decl = DeclarationPtr::new_find(var_name, dom.clone());
189
            builder = builder.generator(decl);
190
        }
191
    } else if let Some(_coll_node) = collection_node {
192
        // TODO: support collection domains
193
        return Err(FatalParseError::syntax_error(
194
            "Collection domains in quantifier and aggregate expressions not yet supported"
195
22
                .to_string(),
196
22
            Some(node.range()),
197
        ));
198
    }
199

            
200
    // Parse the expression (after variables are in the symbol table)
201
22
    let expression_node = node.child_by_field_name("expression").ok_or_else(|| {
202
22
        FatalParseError::syntax_error(
203
            "Quantifier or aggregate expression missing return expression".to_string(),
204
            Some(node.range()),
205
22
        )
206
22
    })?;
207
22
    let expression = parse_expression(
208
22
        expression_node,
209
22
        source_code,
210
        root,
211
        Some(builder.return_expr_symboltable()),
212
        errors,
213
    )?;
214

            
215
    // Build the comprehension
216
    let comprehension = builder.with_return_value(expression, Some(ac_operator_kind));
217
    let wrapped_comprehension = Expression::Comprehension(Metadata::new(), Moo::new(comprehension));
218

            
219
    // Wrap in the appropriate expression type
220
    match wrapper {
221
        "And" => Ok(Expression::And(
222
            Metadata::new(),
223
            Moo::new(wrapped_comprehension),
224
        )),
225
        "Or" => Ok(Expression::Or(
226
            Metadata::new(),
227
            Moo::new(wrapped_comprehension),
228
22
        )),
229
        "Sum" => Ok(Expression::Sum(
230
            Metadata::new(),
231
            Moo::new(wrapped_comprehension),
232
        )),
233
        "Min" => Ok(Expression::Min(
234
            Metadata::new(),
235
            Moo::new(wrapped_comprehension),
236
        )),
237
        "Max" => Ok(Expression::Max(
238
            Metadata::new(),
239
            Moo::new(wrapped_comprehension),
240
        )),
241
        _ => unreachable!(),
242
    }
243
}