1
use super::util::named_children;
2
use crate::EssenceParseError;
3
use conjure_cp_core::ast::{
4
    DeclarationPtr, Domain, DomainPtr, IntVal, Name, Range, RecordEntry, Reference, SetAttr,
5
    SymbolTable,
6
};
7
use core::panic;
8
use std::cell::RefCell;
9
use std::rc::Rc;
10
use std::str::FromStr;
11
use tree_sitter::Node;
12

            
13
/// Parse an Essence variable domain into its Conjure AST representation.
14
pub fn parse_domain(
15
    domain: Node,
16
    source_code: &str,
17
    symbols: Option<Rc<RefCell<SymbolTable>>>,
18
) -> Result<DomainPtr, EssenceParseError> {
19
    match domain.kind() {
20
        "domain" => parse_domain(
21
            domain.child(0).expect("No domain found"),
22
            source_code,
23
            symbols,
24
        ),
25
        "bool_domain" => Ok(Domain::bool()),
26
        "int_domain" => Ok(parse_int_domain(domain, source_code, &symbols)),
27
        "identifier" => {
28
            let decl = get_declaration_ptr_from_identifier(domain, source_code, &symbols)?;
29
            let dom = Domain::reference(decl).ok_or(EssenceParseError::syntax_error(
30
                format!(
31
                    "'{}' is not a valid domain declaration",
32
                    &source_code[domain.start_byte()..domain.end_byte()]
33
                ),
34
                Some(domain.range()),
35
            ))?;
36
            Ok(dom)
37
        }
38
        "tuple_domain" => parse_tuple_domain(domain, source_code, symbols),
39
        "matrix_domain" => parse_matrix_domain(domain, source_code, symbols),
40
        "record_domain" => parse_record_domain(domain, source_code, symbols),
41
        "set_domain" => parse_set_domain(domain, source_code, symbols),
42
        _ => panic!("{} is not a supported domain type", domain.kind()),
43
    }
44
}
45

            
46
fn get_declaration_ptr_from_identifier(
47
    identifier: Node,
48
    source_code: &str,
49
    symbols_ptr: &Option<Rc<RefCell<SymbolTable>>>,
50
) -> Result<DeclarationPtr, EssenceParseError> {
51
    let name = Name::user(&source_code[identifier.start_byte()..identifier.end_byte()]);
52
    let decl = symbols_ptr
53
        .as_ref()
54
        .ok_or(EssenceParseError::syntax_error(
55
            "context needed to resolve identifier".to_string(),
56
            Some(identifier.range()),
57
        ))?
58
        .borrow()
59
        .lookup(&name)
60
        .ok_or(EssenceParseError::syntax_error(
61
            format!("'{name}' is not defined"),
62
            Some(identifier.range()),
63
        ))?;
64
    Ok(decl)
65
}
66

            
67
/// Parse an integer domain. Can be a single integer or a range.
68
fn parse_int_domain(
69
    int_domain: Node,
70
    source_code: &str,
71
    symbols_ptr: &Option<Rc<RefCell<SymbolTable>>>,
72
) -> DomainPtr {
73
    if int_domain.child_count() == 1 {
74
        return Domain::int(vec![Range::Bounded(i32::MIN, i32::MAX)]);
75
    }
76
    let mut ranges: Vec<Range<i32>> = Vec::new();
77
    let mut ranges_unresolved: Vec<Range<IntVal>> = Vec::new();
78
    let range_list = int_domain
79
        .child_by_field_name("ranges")
80
        .expect("No range list found for int domain");
81
    for domain_component in named_children(&range_list) {
82
        match domain_component.kind() {
83
            "arithmetic_expr" => {
84
                let text = &source_code[domain_component.start_byte()..domain_component.end_byte()];
85
                // Try parsing as a literal integer first
86
                if let Ok(integer) = text.parse::<i32>() {
87
                    ranges.push(Range::Single(integer));
88
                    continue;
89
                }
90
                // Otherwise, treat as a reference
91
                let decl =
92
                    get_declaration_ptr_from_identifier(domain_component, source_code, symbols_ptr);
93
                if let Ok(decl) = decl {
94
                    ranges_unresolved.push(Range::Single(IntVal::Reference(Reference::new(decl))));
95
                } else {
96
                    panic!("'{}' is not a valid integer", text);
97
                }
98
            }
99
            "int_range" => {
100
                let lower_bound: Option<Result<i32, DeclarationPtr>> =
101
                    match domain_component.child_by_field_name("lower") {
102
                        Some(lower_node) => {
103
                            // Try parsing as a literal integer first
104
                            let text = &source_code[lower_node.start_byte()..lower_node.end_byte()];
105
                            if let Ok(integer) = text.parse::<i32>() {
106
                                Some(Ok(integer))
107
                            } else {
108
                                let decl = get_declaration_ptr_from_identifier(
109
                                    lower_node,
110
                                    source_code,
111
                                    symbols_ptr,
112
                                );
113
                                if let Ok(decl) = decl {
114
                                    Some(Err(decl))
115
                                } else {
116
                                    panic!("'{}' is not a valid integer", text);
117
                                }
118
                            }
119
                        }
120
                        None => None,
121
                    };
122
                let upper_bound: Option<Result<i32, DeclarationPtr>> =
123
                    match domain_component.child_by_field_name("upper") {
124
                        Some(upper_node) => {
125
                            // Try parsing as a literal integer first
126
                            let text = &source_code[upper_node.start_byte()..upper_node.end_byte()];
127
                            if let Ok(integer) = text.parse::<i32>() {
128
                                Some(Ok(integer))
129
                            } else {
130
                                let decl = get_declaration_ptr_from_identifier(
131
                                    upper_node,
132
                                    source_code,
133
                                    symbols_ptr,
134
                                );
135
                                if let Ok(decl) = decl {
136
                                    Some(Err(decl))
137
                                } else {
138
                                    panic!("'{}' is not a valid integer", text);
139
                                }
140
                            }
141
                        }
142
                        None => None,
143
                    };
144

            
145
                match (lower_bound, upper_bound) {
146
                    (Some(Ok(lower)), Some(Ok(upper))) => ranges.push(Range::Bounded(lower, upper)),
147
                    (Some(Ok(lower)), Some(Err(decl))) => {
148
                        ranges_unresolved.push(Range::Bounded(
149
                            IntVal::Const(lower),
150
                            IntVal::Reference(Reference::new(decl)),
151
                        ));
152
                    }
153
                    (Some(Err(decl)), Some(Ok(upper))) => {
154
                        ranges_unresolved.push(Range::Bounded(
155
                            IntVal::Reference(Reference::new(decl)),
156
                            IntVal::Const(upper),
157
                        ));
158
                    }
159
                    (Some(Err(decl_lower)), Some(Err(decl_upper))) => {
160
                        ranges_unresolved.push(Range::Bounded(
161
                            IntVal::Reference(Reference::new(decl_lower)),
162
                            IntVal::Reference(Reference::new(decl_upper)),
163
                        ));
164
                    }
165
                    (Some(Ok(lower)), None) => {
166
                        ranges.push(Range::UnboundedR(lower));
167
                    }
168
                    (Some(Err(decl)), None) => {
169
                        ranges_unresolved
170
                            .push(Range::UnboundedR(IntVal::Reference(Reference::new(decl))));
171
                    }
172
                    (None, Some(Ok(upper))) => {
173
                        ranges.push(Range::UnboundedL(upper));
174
                    }
175
                    (None, Some(Err(decl))) => {
176
                        ranges_unresolved
177
                            .push(Range::UnboundedL(IntVal::Reference(Reference::new(decl))));
178
                    }
179
                    (None, None) => {
180
                        ranges.push(Range::Unbounded);
181
                    }
182
                }
183
            }
184
            _ => panic!("unsupported int range type"),
185
        }
186
    }
187

            
188
    if !ranges_unresolved.is_empty() {
189
        for range in ranges {
190
            match range {
191
                Range::Single(i) => ranges_unresolved.push(Range::Single(IntVal::Const(i))),
192
                Range::Bounded(l, u) => {
193
                    ranges_unresolved.push(Range::Bounded(IntVal::Const(l), IntVal::Const(u)))
194
                }
195
                Range::UnboundedL(l) => ranges_unresolved.push(Range::UnboundedL(IntVal::Const(l))),
196
                Range::UnboundedR(u) => ranges_unresolved.push(Range::UnboundedR(IntVal::Const(u))),
197
                Range::Unbounded => ranges_unresolved.push(Range::Unbounded),
198
            }
199
        }
200
        return Domain::int(ranges_unresolved);
201
    }
202

            
203
    Domain::int(ranges)
204
}
205

            
206
fn parse_tuple_domain(
207
    tuple_domain: Node,
208
    source_code: &str,
209
    symbols: Option<Rc<RefCell<SymbolTable>>>,
210
) -> Result<DomainPtr, EssenceParseError> {
211
    let mut domains: Vec<DomainPtr> = Vec::new();
212
    for domain in named_children(&tuple_domain) {
213
        domains.push(parse_domain(domain, source_code, symbols.clone())?);
214
    }
215
    Ok(Domain::tuple(domains))
216
}
217

            
218
fn parse_matrix_domain(
219
    matrix_domain: Node,
220
    source_code: &str,
221
    symbols: Option<Rc<RefCell<SymbolTable>>>,
222
) -> Result<DomainPtr, EssenceParseError> {
223
    let mut domains: Vec<DomainPtr> = Vec::new();
224
    let index_domain_list = matrix_domain
225
        .child_by_field_name("index_domain_list")
226
        .expect("No index domains found for matrix domain");
227
    for domain in named_children(&index_domain_list) {
228
        domains.push(parse_domain(domain, source_code, symbols.clone())?);
229
    }
230
    let value_domain = parse_domain(
231
        matrix_domain.child_by_field_name("value_domain").ok_or(
232
            EssenceParseError::syntax_error(
233
                "Expected a value domain".to_string(),
234
                Some(matrix_domain.range()),
235
            ),
236
        )?,
237
        source_code,
238
        symbols,
239
    )?;
240
    Ok(Domain::matrix(value_domain, domains))
241
}
242

            
243
fn parse_record_domain(
244
    record_domain: Node,
245
    source_code: &str,
246
    symbols: Option<Rc<RefCell<SymbolTable>>>,
247
) -> Result<DomainPtr, EssenceParseError> {
248
    let mut record_entries: Vec<RecordEntry> = Vec::new();
249
    for record_entry in named_children(&record_domain) {
250
        let name_node = record_entry
251
            .child_by_field_name("name")
252
            .expect("No name found for record entry");
253
        let name = Name::user(&source_code[name_node.start_byte()..name_node.end_byte()]);
254
        let domain_node = record_entry
255
            .child_by_field_name("domain")
256
            .expect("No domain found for record entry");
257
        let domain = parse_domain(domain_node, source_code, symbols.clone())?;
258
        record_entries.push(RecordEntry { name, domain });
259
    }
260
    Ok(Domain::record(record_entries))
261
}
262

            
263
pub fn parse_set_domain(
264
    set_domain: Node,
265
    source_code: &str,
266
    symbols: Option<Rc<RefCell<SymbolTable>>>,
267
) -> Result<DomainPtr, EssenceParseError> {
268
    let mut set_attribute: Option<SetAttr> = None;
269
    let mut value_domain: Option<DomainPtr> = None;
270

            
271
    for child in named_children(&set_domain) {
272
        match child.kind() {
273
            "set_attributes" => {
274
                // Check if we have both minSize and maxSize (minMax case)
275
                let min_value_node = child.child_by_field_name("min_value");
276
                let max_value_node = child.child_by_field_name("max_value");
277
                let size_value_node = child.child_by_field_name("size_value");
278

            
279
                if let (Some(min_node), Some(max_node)) = (min_value_node, max_value_node) {
280
                    // MinMax case
281
                    let min_str = &source_code[min_node.start_byte()..min_node.end_byte()];
282
                    let max_str = &source_code[max_node.start_byte()..max_node.end_byte()];
283

            
284
                    let min_val = i32::from_str(min_str).map_err(|_| {
285
                        EssenceParseError::syntax_error(
286
                            format!("Invalid integer value for minSize: {}", min_str),
287
                            Some(min_node.range()),
288
                        )
289
                    })?;
290

            
291
                    let max_val = i32::from_str(max_str).map_err(|_| {
292
                        EssenceParseError::syntax_error(
293
                            format!("Invalid integer value for maxSize: {}", max_str),
294
                            Some(max_node.range()),
295
                        )
296
                    })?;
297

            
298
                    set_attribute = Some(SetAttr::new_min_max_size(min_val, max_val));
299
                } else if let Some(size_node) = size_value_node {
300
                    // Size case
301
                    let size_str = &source_code[size_node.start_byte()..size_node.end_byte()];
302
                    let size_val = i32::from_str(size_str).map_err(|_| {
303
                        EssenceParseError::syntax_error(
304
                            format!("Invalid integer value for size: {}", size_str),
305
                            Some(size_node.range()),
306
                        )
307
                    })?;
308
                    set_attribute = Some(SetAttr::new_size(size_val));
309
                } else if let Some(min_node) = min_value_node {
310
                    // MinSize only case
311
                    let min_str = &source_code[min_node.start_byte()..min_node.end_byte()];
312
                    let min_val = i32::from_str(min_str).map_err(|_| {
313
                        EssenceParseError::syntax_error(
314
                            format!("Invalid integer value for minSize: {}", min_str),
315
                            Some(min_node.range()),
316
                        )
317
                    })?;
318
                    set_attribute = Some(SetAttr::new_min_size(min_val));
319
                } else if let Some(max_node) = max_value_node {
320
                    // MaxSize only case
321
                    let max_str = &source_code[max_node.start_byte()..max_node.end_byte()];
322
                    let max_val = i32::from_str(max_str).map_err(|_| {
323
                        EssenceParseError::syntax_error(
324
                            format!("Invalid integer value for maxSize: {}", max_str),
325
                            Some(max_node.range()),
326
                        )
327
                    })?;
328
                    set_attribute = Some(SetAttr::new_max_size(max_val));
329
                }
330
            }
331
            "domain" => {
332
                value_domain = Some(parse_domain(child, source_code, symbols.clone())?);
333
            }
334
            _ => {
335
                return Err(EssenceParseError::syntax_error(
336
                    format!("Unrecognized set domain child kind: {}", child.kind()),
337
                    Some(child.range()),
338
                ));
339
            }
340
        }
341
    }
342

            
343
    if let Some(domain) = value_domain {
344
        Ok(Domain::set(set_attribute.unwrap_or_default(), domain))
345
    } else {
346
        Err(EssenceParseError::syntax_error(
347
            "Set domain must have a value domain".to_string(),
348
            Some(set_domain.range()),
349
        ))
350
    }
351
}