1
use crate::expression::parse_expression;
2
use crate::parser::domain::parse_domain;
3
use crate::util::named_children;
4
use crate::{EssenceParseError, field};
5
use conjure_cp_core::ast::ac_operators::ACOperatorKind;
6
use conjure_cp_core::ast::comprehension::ComprehensionBuilder;
7
use conjure_cp_core::ast::{DeclarationPtr, Expression, Metadata, Moo, Name, SymbolTablePtr};
8
use std::vec;
9
use tree_sitter::Node;
10

            
11
pub fn parse_comprehension(
12
    node: &Node,
13
    source_code: &str,
14
    root: &Node,
15
    symbols_ptr: Option<SymbolTablePtr>,
16
) -> Result<Expression, EssenceParseError> {
17
    // Comprehensions require a symbol table passed in
18
    let symbols_ptr = symbols_ptr.ok_or_else(|| {
19
        EssenceParseError::syntax_error(
20
            "Comprehensions require a symbol table".to_string(),
21
            Some(node.range()),
22
        )
23
    })?;
24

            
25
    let mut builder = ComprehensionBuilder::new(symbols_ptr.clone());
26

            
27
    // We need to track the return expression node separately since it appears first in syntax
28
    // but we need to parse generators first (to get variables in scope)
29
    let mut return_expr_node: Option<Node> = None;
30

            
31
    // set return expression node and parse generators/conditions
32
    for child in named_children(node) {
33
        match child.kind() {
34
            "arithmetic_expr" | "bool_expr" | "comparison_expr" | "atom" => {
35
                // Store the return expression node to parse later
36
                return_expr_node = Some(child);
37
            }
38
            "generator" => {
39
                // Parse the generator variable
40
                let var_node = field!(child, "variable");
41
                let var_name_str = &source_code[var_node.start_byte()..var_node.end_byte()];
42
                let var_name = Name::user(var_name_str);
43

            
44
                // Parse the domain
45
                let domain_node = field!(child, "domain");
46
                let var_domain = parse_domain(domain_node, source_code, Some(symbols_ptr.clone()))?;
47

            
48
                // Add generator using the builder
49
                let decl = DeclarationPtr::new_find(var_name, var_domain);
50
                builder = builder.generator(decl);
51
            }
52
            "condition" => {
53
                // Parse the condition expression
54
                let expr_node = field!(child, "expression");
55
                let generator_symboltable = builder.generator_symboltable();
56

            
57
                let guard_expr =
58
                    parse_expression(expr_node, source_code, root, Some(generator_symboltable))?;
59

            
60
                // Add the condition as a guard
61
                builder = builder.guard(guard_expr);
62
            }
63
            _ => {
64
                // Skip other nodes (like punctuation)
65
            }
66
        }
67
    }
68

            
69
    // parse the return expression
70
    let return_expr_node = return_expr_node.ok_or_else(|| {
71
        EssenceParseError::syntax_error(
72
            "Comprehension missing return expression".to_string(),
73
            Some(node.range()),
74
        )
75
    })?;
76

            
77
    // Use the return expression symbol table which already has quantified variables (as Given) and parent as parent
78
    let return_expr = parse_expression(
79
        return_expr_node,
80
        source_code,
81
        root,
82
        Some(builder.return_expr_symboltable()),
83
    )?;
84

            
85
    // Build the comprehension with the return expression and default ACOperatorKind::And
86
    let comprehension = builder.with_return_value(return_expr, Some(ACOperatorKind::And));
87

            
88
    Ok(Expression::Comprehension(
89
        Metadata::new(),
90
        Moo::new(comprehension),
91
    ))
92
}
93

            
94
/// Parse comprehension-style expressions
95
/// - `forAll vars : domain . expr` → `And(Comprehension(...))`
96
/// - `sum vars : domain . expr` → `Sum(Comprehension(...))`
97
pub fn parse_quantifier_or_aggregate_expr(
98
    node: &Node,
99
    source_code: &str,
100
    root: &Node,
101
    symbols_ptr: Option<SymbolTablePtr>,
102
) -> Result<Expression, EssenceParseError> {
103
    // Quantifier and aggregate expressions require a symbol table
104
    let symbols_ptr = symbols_ptr.ok_or_else(|| {
105
        EssenceParseError::syntax_error(
106
            "Quantifier and aggregate expressions require a symbol table".to_string(),
107
            Some(node.range()),
108
        )
109
    })?;
110

            
111
    // Create the comprehension builder
112
    let mut builder = ComprehensionBuilder::new(symbols_ptr.clone());
113

            
114
    // First pass: collect domain/collection, variables
115
    let mut domain = None;
116
    let mut collection_node = None;
117
    let mut variables = vec![];
118

            
119
    for child in named_children(node) {
120
        match child.kind() {
121
            "identifier" => {
122
                let var_name_str = &source_code[child.start_byte()..child.end_byte()];
123
                let var_name = Name::user(var_name_str);
124
                variables.push(var_name);
125
            }
126
            "domain" => {
127
                domain = Some(parse_domain(child, source_code, Some(symbols_ptr.clone()))?);
128
            }
129
            "set_literal" | "matrix" | "tuple" | "record" => {
130
                // Store the collection node to parse later
131
                collection_node = Some(child);
132
            }
133
            _ => continue,
134
        }
135
    }
136

            
137
    // We need either a domain or a collection
138
    if domain.is_none() && collection_node.is_none() {
139
        return Err(EssenceParseError::syntax_error(
140
            "Quantifier and aggregate expressions require a domain or collection".to_string(),
141
            Some(node.range()),
142
        ));
143
    }
144

            
145
    if variables.is_empty() {
146
        return Err(EssenceParseError::syntax_error(
147
            "Quantifier and aggregate expressions require variables".to_string(),
148
            Some(node.range()),
149
        ));
150
    }
151

            
152
    // Get the operator type
153
    let operator_node = field!(node, "operator");
154
    let operator_str = &source_code[operator_node.start_byte()..operator_node.end_byte()];
155

            
156
    let (ac_operator_kind, wrapper) = match operator_str {
157
        "forAll" => (ACOperatorKind::And, "And"),
158
        "exists" => (ACOperatorKind::Or, "Or"),
159
        "sum" => (ACOperatorKind::Sum, "Sum"),
160
        "min" => (ACOperatorKind::Sum, "Min"), // AC operator doesn't matter for non-boolean aggregates
161
        "max" => (ACOperatorKind::Sum, "Max"),
162
        _ => {
163
            return Err(EssenceParseError::syntax_error(
164
                format!("Unknown operator: {}", operator_str),
165
                Some(operator_node.range()),
166
            ));
167
        }
168
    };
169

            
170
    // Add variables as generators
171
    if let Some(dom) = domain {
172
        for var_name in variables {
173
            let decl = DeclarationPtr::new_find(var_name, dom.clone());
174
            builder = builder.generator(decl);
175
        }
176
    } else if let Some(_coll_node) = collection_node {
177
        // TODO: support collection domains
178
        return Err(EssenceParseError::syntax_error(
179
            "Collection domains in quantifier and aggregate expressions not yet supported"
180
                .to_string(),
181
            Some(node.range()),
182
        ));
183
    }
184

            
185
    // Parse the expression (after variables are in the symbol table)
186
    let expression_node = node.child_by_field_name("expression").ok_or_else(|| {
187
        EssenceParseError::syntax_error(
188
            "Quantifier or aggregate expression missing return expression".to_string(),
189
            Some(node.range()),
190
        )
191
    })?;
192
    let expression = parse_expression(
193
        expression_node,
194
        source_code,
195
        root,
196
        Some(builder.return_expr_symboltable()),
197
    )?;
198

            
199
    // Build the comprehension
200
    let comprehension = builder.with_return_value(expression, Some(ac_operator_kind));
201
    let wrapped_comprehension = Expression::Comprehension(Metadata::new(), Moo::new(comprehension));
202

            
203
    // Wrap in the appropriate expression type
204
    match wrapper {
205
        "And" => Ok(Expression::And(
206
            Metadata::new(),
207
            Moo::new(wrapped_comprehension),
208
        )),
209
        "Or" => Ok(Expression::Or(
210
            Metadata::new(),
211
            Moo::new(wrapped_comprehension),
212
        )),
213
        "Sum" => Ok(Expression::Sum(
214
            Metadata::new(),
215
            Moo::new(wrapped_comprehension),
216
        )),
217
        "Min" => Ok(Expression::Min(
218
            Metadata::new(),
219
            Moo::new(wrapped_comprehension),
220
        )),
221
        "Max" => Ok(Expression::Max(
222
            Metadata::new(),
223
            Moo::new(wrapped_comprehension),
224
        )),
225
        _ => unreachable!(),
226
    }
227
}