1
use crate::expression::parse_expression;
2
use crate::parser::domain::parse_domain;
3
use crate::util::named_children;
4
use crate::{EssenceParseError, field};
5
use conjure_cp_core::ast::ac_operators::ACOperatorKind;
6
use conjure_cp_core::ast::comprehension::ComprehensionBuilder;
7
use conjure_cp_core::ast::{DeclarationPtr, Expression, Metadata, Moo, Name, SymbolTable};
8
use std::cell::RefCell;
9
use std::rc::Rc;
10
use std::vec;
11
use tree_sitter::Node;
12

            
13
pub fn parse_comprehension(
14
    node: &Node,
15
    source_code: &str,
16
    root: &Node,
17
    symbols_ptr: Option<Rc<RefCell<SymbolTable>>>,
18
) -> Result<Expression, EssenceParseError> {
19
    // Comprehensions require a symbol table passed in
20
    let symbols_ptr = symbols_ptr.ok_or_else(|| {
21
        EssenceParseError::syntax_error(
22
            "Comprehensions require a symbol table".to_string(),
23
            Some(node.range()),
24
        )
25
    })?;
26

            
27
    let mut builder = ComprehensionBuilder::new(symbols_ptr.clone());
28

            
29
    // We need to track the return expression node separately since it appears first in syntax
30
    // but we need to parse generators first (to get variables in scope)
31
    let mut return_expr_node: Option<Node> = None;
32

            
33
    // set return expression node and parse generators/conditions
34
    for child in named_children(node) {
35
        match child.kind() {
36
            "arithmetic_expr" | "bool_expr" | "comparison_expr" => {
37
                // Store the return expression node to parse later
38
                return_expr_node = Some(child);
39
            }
40
            "generator" => {
41
                // Parse the generator variable
42
                let var_node = field!(child, "variable");
43
                let var_name_str = &source_code[var_node.start_byte()..var_node.end_byte()];
44
                let var_name = Name::user(var_name_str);
45

            
46
                // Parse the domain
47
                let domain_node = field!(child, "domain");
48
                let var_domain = parse_domain(domain_node, source_code, Some(symbols_ptr.clone()))?;
49

            
50
                // Add generator using the builder
51
                let decl = DeclarationPtr::new_var(var_name, var_domain);
52
                builder = builder.generator(decl);
53
            }
54
            "condition" => {
55
                // Parse the condition expression
56
                let expr_node = field!(child, "expression");
57
                let generator_symboltable = builder.generator_symboltable();
58

            
59
                let guard_expr =
60
                    parse_expression(expr_node, source_code, root, Some(generator_symboltable))?;
61

            
62
                // Add the condition as a guard
63
                builder = builder.guard(guard_expr);
64
            }
65
            _ => {
66
                // Skip other nodes (like punctuation)
67
            }
68
        }
69
    }
70

            
71
    // parse the return expression
72
    let return_expr_node = return_expr_node.ok_or_else(|| {
73
        EssenceParseError::syntax_error(
74
            "Comprehension missing return expression".to_string(),
75
            Some(node.range()),
76
        )
77
    })?;
78

            
79
    // Use the return expression symbol table which already has induction variables (as Given) and parent as parent
80
    let return_expr = parse_expression(
81
        return_expr_node,
82
        source_code,
83
        root,
84
        Some(builder.return_expr_symboltable()),
85
    )?;
86

            
87
    // Build the comprehension with the return expression and default ACOperatorKind::And
88
    let comprehension = builder.with_return_value(return_expr, Some(ACOperatorKind::And));
89

            
90
    Ok(Expression::Comprehension(
91
        Metadata::new(),
92
        Moo::new(comprehension),
93
    ))
94
}
95

            
96
/// Parse comprehension-style expressions
97
/// - `forAll vars : domain . expr` → `And(Comprehension(...))`
98
/// - `sum vars : domain . expr` → `Sum(Comprehension(...))`
99
pub fn parse_quantifier_or_aggregate_expr(
100
    node: &Node,
101
    source_code: &str,
102
    root: &Node,
103
    symbols_ptr: Option<Rc<RefCell<SymbolTable>>>,
104
) -> Result<Expression, EssenceParseError> {
105
    // Quantifier and aggregate expressions require a symbol table
106
    let symbols_ptr = symbols_ptr.ok_or_else(|| {
107
        EssenceParseError::syntax_error(
108
            "Quantifier and aggregate expressions require a symbol table".to_string(),
109
            Some(node.range()),
110
        )
111
    })?;
112

            
113
    // Create the comprehension builder
114
    let mut builder = ComprehensionBuilder::new(symbols_ptr.clone());
115

            
116
    // First pass: collect domain/collection, variables
117
    let mut domain = None;
118
    let mut collection_node = None;
119
    let mut variables = vec![];
120

            
121
    for child in named_children(node) {
122
        match child.kind() {
123
            "identifier" => {
124
                let var_name_str = &source_code[child.start_byte()..child.end_byte()];
125
                let var_name = Name::user(var_name_str);
126
                variables.push(var_name);
127
            }
128
            "domain" => {
129
                domain = Some(parse_domain(child, source_code, Some(symbols_ptr.clone()))?);
130
            }
131
            "set_literal" | "matrix" | "tuple" | "record" => {
132
                // Store the collection node to parse later
133
                collection_node = Some(child);
134
            }
135
            _ => continue,
136
        }
137
    }
138

            
139
    // We need either a domain or a collection
140
    if domain.is_none() && collection_node.is_none() {
141
        return Err(EssenceParseError::syntax_error(
142
            "Quantifier and aggregate expressions require a domain or collection".to_string(),
143
            Some(node.range()),
144
        ));
145
    }
146

            
147
    if variables.is_empty() {
148
        return Err(EssenceParseError::syntax_error(
149
            "Quantifier and aggregate expressions require variables".to_string(),
150
            Some(node.range()),
151
        ));
152
    }
153

            
154
    // Get the operator type
155
    let operator_node = field!(node, "operator");
156
    let operator_str = &source_code[operator_node.start_byte()..operator_node.end_byte()];
157

            
158
    let (ac_operator_kind, wrapper) = match operator_str {
159
        "forAll" => (ACOperatorKind::And, "And"),
160
        "exists" => (ACOperatorKind::Or, "Or"),
161
        "sum" => (ACOperatorKind::Sum, "Sum"),
162
        "min" => (ACOperatorKind::Sum, "Min"), // AC operator doesn't matter for non-boolean aggregates
163
        "max" => (ACOperatorKind::Sum, "Max"),
164
        _ => {
165
            return Err(EssenceParseError::syntax_error(
166
                format!("Unknown operator: {}", operator_str),
167
                Some(operator_node.range()),
168
            ));
169
        }
170
    };
171

            
172
    // Add variables as generators
173
    if let Some(dom) = domain {
174
        for var_name in variables {
175
            let decl = DeclarationPtr::new_var(var_name, dom.clone());
176
            builder = builder.generator(decl);
177
        }
178
    } else if let Some(_coll_node) = collection_node {
179
        // TODO: support collection domains
180
        return Err(EssenceParseError::syntax_error(
181
            "Collection domains in quantifier and aggregate expressions not yet supported"
182
                .to_string(),
183
            Some(node.range()),
184
        ));
185
    }
186

            
187
    // Parse the expression (after variables are in the symbol table)
188
    let expression_node = node.child_by_field_name("expression").ok_or_else(|| {
189
        EssenceParseError::syntax_error(
190
            "Quantifier or aggregate expression missing return expression".to_string(),
191
            Some(node.range()),
192
        )
193
    })?;
194
    let expression = parse_expression(
195
        expression_node,
196
        source_code,
197
        root,
198
        Some(builder.return_expr_symboltable()),
199
    )?;
200

            
201
    // Build the comprehension
202
    let comprehension = builder.with_return_value(expression, Some(ac_operator_kind));
203
    let wrapped_comprehension = Expression::Comprehension(Metadata::new(), Moo::new(comprehension));
204

            
205
    // Wrap in the appropriate expression type
206
    match wrapper {
207
        "And" => Ok(Expression::And(
208
            Metadata::new(),
209
            Moo::new(wrapped_comprehension),
210
        )),
211
        "Or" => Ok(Expression::Or(
212
            Metadata::new(),
213
            Moo::new(wrapped_comprehension),
214
        )),
215
        "Sum" => Ok(Expression::Sum(
216
            Metadata::new(),
217
            Moo::new(wrapped_comprehension),
218
        )),
219
        "Min" => Ok(Expression::Min(
220
            Metadata::new(),
221
            Moo::new(wrapped_comprehension),
222
        )),
223
        "Max" => Ok(Expression::Max(
224
            Metadata::new(),
225
            Moo::new(wrapped_comprehension),
226
        )),
227
        _ => unreachable!(),
228
    }
229
}