1
use super::util::named_children;
2
use crate::errors::{FatalParseError, RecoverableParseError};
3
use conjure_cp_core::ast::{
4
    DeclarationPtr, Domain, DomainPtr, IntVal, Name, Range, RecordEntry, Reference, SetAttr,
5
    SymbolTablePtr,
6
};
7
use core::panic;
8
use std::str::FromStr;
9
use tree_sitter::Node;
10

            
11
/// Parse an Essence variable domain into its Conjure AST representation.
12
886
pub fn parse_domain(
13
886
    domain: Node,
14
1784
    source_code: &str,
15
1784
    symbols: Option<SymbolTablePtr>,
16
1784
    errors: &mut Vec<RecoverableParseError>,
17
1784
) -> Result<DomainPtr, FatalParseError> {
18
1784
    match domain.kind() {
19
1784
        "domain" => parse_domain(
20
892
            domain.child(0).expect("No domain found"),
21
589
            source_code,
22
589
            symbols,
23
589
            errors,
24
148
        ),
25
593
        "bool_domain" => Ok(Domain::bool()),
26
440
        "int_domain" => Ok(parse_int_domain(domain, source_code, &symbols, errors)),
27
148
        "identifier" => {
28
148
            let decl = get_declaration_ptr_from_identifier(domain, source_code, &symbols, errors)?;
29
            let dom = Domain::reference(decl).ok_or(FatalParseError::syntax_error(
30
303
                format!(
31
292
                    "'{}' is not a valid domain declaration",
32
292
                    &source_code[domain.start_byte()..domain.end_byte()]
33
292
                ),
34
292
                Some(domain.range()),
35
292
            ))?;
36
292
            Ok(dom)
37
292
        }
38
292
        "tuple_domain" => parse_tuple_domain(domain, source_code, symbols, errors),
39
        "matrix_domain" => parse_matrix_domain(domain, source_code, symbols, errors),
40
11
        "record_domain" => parse_record_domain(domain, source_code, symbols, errors),
41
11
        "set_domain" => parse_set_domain(domain, source_code, symbols, errors),
42
        _ => panic!("{} is not a supported domain type", domain.kind()),
43
    }
44
897
}
45

            
46
88
fn get_declaration_ptr_from_identifier(
47
88
    identifier: Node,
48
88
    source_code: &str,
49
88
    symbols_ptr: &Option<SymbolTablePtr>,
50
88
    _errors: &mut Vec<RecoverableParseError>,
51
88
) -> Result<DeclarationPtr, FatalParseError> {
52
88
    let name = Name::user(&source_code[identifier.start_byte()..identifier.end_byte()]);
53
88
    let decl = symbols_ptr
54
99
        .as_ref()
55
99
        .ok_or(FatalParseError::syntax_error(
56
99
            "context needed to resolve identifier".to_string(),
57
99
            Some(identifier.range()),
58
11
        ))?
59
99
        .read()
60
99
        .lookup(&name)
61
99
        .ok_or(FatalParseError::syntax_error(
62
99
            format!("'{name}' is not defined"),
63
88
            Some(identifier.range()),
64
        ))?;
65
88
    Ok(decl)
66
88
}
67

            
68
/// Parse an integer domain. Can be a single integer or a range.
69
292
fn parse_int_domain(
70
292
    int_domain: Node,
71
292
    source_code: &str,
72
292
    symbols_ptr: &Option<SymbolTablePtr>,
73
1190
    errors: &mut Vec<RecoverableParseError>,
74
292
) -> DomainPtr {
75
391
    if int_domain.child_count() == 1 {
76
99
        return Domain::int(vec![Range::Bounded(i32::MIN, i32::MAX)]);
77
391
    }
78
391
    let mut ranges: Vec<Range<i32>> = Vec::new();
79
391
    let mut ranges_unresolved: Vec<Range<IntVal>> = Vec::new();
80
391
    let range_list = int_domain
81
391
        .child_by_field_name("ranges")
82
391
        .expect("No range list found for int domain");
83
463
    for domain_component in named_children(&range_list) {
84
463
        match domain_component.kind() {
85
463
            "atom" => {
86
9
                let text = &source_code[domain_component.start_byte()..domain_component.end_byte()];
87
                // Try parsing as a literal integer first
88
108
                if let Ok(integer) = text.parse::<i32>() {
89
108
                    ranges.push(Range::Single(integer));
90
108
                    continue;
91
                }
92
                // Otherwise, treat as a reference
93
                let decl = get_declaration_ptr_from_identifier(
94
                    domain_component,
95
                    source_code,
96
                    symbols_ptr,
97
                    errors,
98
                );
99
99
                if let Ok(decl) = decl {
100
                    ranges_unresolved.push(Range::Single(IntVal::Reference(Reference::new(decl))));
101
                } else {
102
292
                    panic!("'{}' is not a valid integer", text);
103
292
                }
104
292
            }
105
647
            "int_range" => {
106
647
                let lower_bound: Option<Result<i32, DeclarationPtr>> =
107
355
                    match domain_component.child_by_field_name("lower") {
108
355
                        Some(lower_node) => {
109
                            // Try parsing as a literal integer first
110
355
                            let text = &source_code[lower_node.start_byte()..lower_node.end_byte()];
111
647
                            if let Ok(integer) = text.parse::<i32>() {
112
581
                                Some(Ok(integer))
113
292
                            } else {
114
66
                                let decl = get_declaration_ptr_from_identifier(
115
430
                                    lower_node,
116
430
                                    source_code,
117
430
                                    symbols_ptr,
118
75
                                    errors,
119
                                );
120
66
                                if let Ok(decl) = decl {
121
66
                                    Some(Err(decl))
122
9
                                } else {
123
                                    panic!("'{}' is not a valid integer", text);
124
9
                                }
125
9
                            }
126
                        }
127
355
                        None => None,
128
355
                    };
129
710
                let upper_bound: Option<Result<i32, DeclarationPtr>> =
130
710
                    match domain_component.child_by_field_name("upper") {
131
710
                        Some(upper_node) => {
132
                            // Try parsing as a literal integer first
133
355
                            let text = &source_code[upper_node.start_byte()..upper_node.end_byte()];
134
355
                            if let Ok(integer) = text.parse::<i32>() {
135
333
                                Some(Ok(integer))
136
                            } else {
137
377
                                let decl = get_declaration_ptr_from_identifier(
138
377
                                    upper_node,
139
377
                                    source_code,
140
377
                                    symbols_ptr,
141
22
                                    errors,
142
                                );
143
22
                                if let Ok(decl) = decl {
144
22
                                    Some(Err(decl))
145
                                } else {
146
                                    panic!("'{}' is not a valid integer", text);
147
355
                                }
148
355
                            }
149
                        }
150
355
                        None => None,
151
234
                    };
152

            
153
355
                match (lower_bound, upper_bound) {
154
267
                    (Some(Ok(lower)), Some(Ok(upper))) => ranges.push(Range::Bounded(lower, upper)),
155
22
                    (Some(Ok(lower)), Some(Err(decl))) => {
156
22
                        ranges_unresolved.push(Range::Bounded(
157
22
                            IntVal::Const(lower),
158
22
                            IntVal::Reference(Reference::new(decl)),
159
256
                        ));
160
143
                    }
161
187
                    (Some(Err(decl)), Some(Ok(upper))) => {
162
187
                        ranges_unresolved.push(Range::Bounded(
163
421
                            IntVal::Reference(Reference::new(decl)),
164
66
                            IntVal::Const(upper),
165
66
                        ));
166
66
                    }
167
                    (Some(Err(decl_lower)), Some(Err(decl_upper))) => {
168
                        ranges_unresolved.push(Range::Bounded(
169
                            IntVal::Reference(Reference::new(decl_lower)),
170
                            IntVal::Reference(Reference::new(decl_upper)),
171
                        ));
172
                    }
173
                    (Some(Ok(lower)), None) => {
174
                        ranges.push(Range::UnboundedR(lower));
175
                    }
176
                    (Some(Err(decl)), None) => {
177
                        ranges_unresolved
178
                            .push(Range::UnboundedR(IntVal::Reference(Reference::new(decl))));
179
                    }
180
                    (None, Some(Ok(upper))) => {
181
                        ranges.push(Range::UnboundedL(upper));
182
                    }
183
                    (None, Some(Err(decl))) => {
184
                        ranges_unresolved
185
                            .push(Range::UnboundedL(IntVal::Reference(Reference::new(decl))));
186
                    }
187
                    (None, None) => {
188
                        ranges.push(Range::Unbounded);
189
                    }
190
                }
191
            }
192
            _ => panic!("unsupported int range type"),
193
        }
194
    }
195

            
196
292
    if !ranges_unresolved.is_empty() {
197
66
        for range in ranges {
198
            match range {
199
292
                Range::Single(i) => ranges_unresolved.push(Range::Single(IntVal::Const(i))),
200
193
                Range::Bounded(l, u) => {
201
193
                    ranges_unresolved.push(Range::Bounded(IntVal::Const(l), IntVal::Const(u)))
202
196
                }
203
9
                Range::UnboundedL(l) => ranges_unresolved.push(Range::UnboundedL(IntVal::Const(l))),
204
234
                Range::UnboundedR(u) => ranges_unresolved.push(Range::UnboundedR(IntVal::Const(u))),
205
                Range::Unbounded => ranges_unresolved.push(Range::Unbounded),
206
            }
207
        }
208
66
        return Domain::int(ranges_unresolved);
209
469
    }
210
193

            
211
419
    Domain::int(ranges)
212
292
}
213

            
214
99
fn parse_tuple_domain(
215
    tuple_domain: Node,
216
292
    source_code: &str,
217
    symbols: Option<SymbolTablePtr>,
218
    errors: &mut Vec<RecoverableParseError>,
219
) -> Result<DomainPtr, FatalParseError> {
220
719
    let mut domains: Vec<DomainPtr> = Vec::new();
221
    for domain in named_children(&tuple_domain) {
222
719
        domains.push(parse_domain(domain, source_code, symbols.clone(), errors)?);
223
675
    }
224
675
    Ok(Domain::tuple(domains))
225
587
}
226
88

            
227
fn parse_matrix_domain(
228
88
    matrix_domain: Node,
229
    source_code: &str,
230
    symbols: Option<SymbolTablePtr>,
231
    errors: &mut Vec<RecoverableParseError>,
232
88
) -> Result<DomainPtr, FatalParseError> {
233
44
    let mut domains: Vec<DomainPtr> = Vec::new();
234
    let index_domain_list = matrix_domain
235
        .child_by_field_name("index_domain_list")
236
44
        .expect("No index domains found for matrix domain");
237
    for domain in named_children(&index_domain_list) {
238
        domains.push(parse_domain(domain, source_code, symbols.clone(), errors)?);
239
44
    }
240
719
    let value_domain = parse_domain(
241
        matrix_domain
242
            .child_by_field_name("value_domain")
243
            .ok_or(FatalParseError::syntax_error(
244
                "Expected a value domain".to_string(),
245
                Some(matrix_domain.range()),
246
            ))?,
247
        source_code,
248
        symbols,
249
        errors,
250
    )?;
251
    Ok(Domain::matrix(value_domain, domains))
252
}
253

            
254
fn parse_record_domain(
255
    record_domain: Node,
256
    source_code: &str,
257
    symbols: Option<SymbolTablePtr>,
258
    errors: &mut Vec<RecoverableParseError>,
259
) -> Result<DomainPtr, FatalParseError> {
260
    let mut record_entries: Vec<RecordEntry> = Vec::new();
261
    for record_entry in named_children(&record_domain) {
262
        let name_node = record_entry
263
            .child_by_field_name("name")
264
            .expect("No name found for record entry");
265
        let name = Name::user(&source_code[name_node.start_byte()..name_node.end_byte()]);
266
        let domain_node = record_entry
267
            .child_by_field_name("domain")
268
            .expect("No domain found for record entry");
269
        let domain = parse_domain(domain_node, source_code, symbols.clone(), errors)?;
270
        record_entries.push(RecordEntry { name, domain });
271
    }
272
    Ok(Domain::record(record_entries))
273
}
274

            
275
pub fn parse_set_domain(
276
    set_domain: Node,
277
    source_code: &str,
278
    symbols: Option<SymbolTablePtr>,
279
    errors: &mut Vec<RecoverableParseError>,
280
) -> Result<DomainPtr, FatalParseError> {
281
    let mut set_attribute: Option<SetAttr> = None;
282
    let mut value_domain: Option<DomainPtr> = None;
283

            
284
    for child in named_children(&set_domain) {
285
        match child.kind() {
286
            "set_attributes" => {
287
                // Check if we have both minSize and maxSize (minMax case)
288
                let min_value_node = child.child_by_field_name("min_value");
289
                let max_value_node = child.child_by_field_name("max_value");
290
                let size_value_node = child.child_by_field_name("size_value");
291

            
292
                if let (Some(min_node), Some(max_node)) = (min_value_node, max_value_node) {
293
                    // MinMax case
294
                    let min_str = &source_code[min_node.start_byte()..min_node.end_byte()];
295
                    let max_str = &source_code[max_node.start_byte()..max_node.end_byte()];
296

            
297
                    let min_val = i32::from_str(min_str).map_err(|_| {
298
                        FatalParseError::syntax_error(
299
                            format!("Invalid integer value for minSize: {}", min_str),
300
                            Some(min_node.range()),
301
                        )
302
                    })?;
303

            
304
                    let max_val = i32::from_str(max_str).map_err(|_| {
305
                        FatalParseError::syntax_error(
306
                            format!("Invalid integer value for maxSize: {}", max_str),
307
                            Some(max_node.range()),
308
                        )
309
                    })?;
310

            
311
                    set_attribute = Some(SetAttr::new_min_max_size(min_val, max_val));
312
                } else if let Some(size_node) = size_value_node {
313
                    // Size case
314
                    let size_str = &source_code[size_node.start_byte()..size_node.end_byte()];
315
                    let size_val = i32::from_str(size_str).map_err(|_| {
316
                        FatalParseError::syntax_error(
317
                            format!("Invalid integer value for size: {}", size_str),
318
                            Some(size_node.range()),
319
                        )
320
                    })?;
321
                    set_attribute = Some(SetAttr::new_size(size_val));
322
                } else if let Some(min_node) = min_value_node {
323
                    // MinSize only case
324
                    let min_str = &source_code[min_node.start_byte()..min_node.end_byte()];
325
                    let min_val = i32::from_str(min_str).map_err(|_| {
326
                        FatalParseError::syntax_error(
327
                            format!("Invalid integer value for minSize: {}", min_str),
328
                            Some(min_node.range()),
329
                        )
330
                    })?;
331
                    set_attribute = Some(SetAttr::new_min_size(min_val));
332
                } else if let Some(max_node) = max_value_node {
333
                    // MaxSize only case
334
                    let max_str = &source_code[max_node.start_byte()..max_node.end_byte()];
335
                    let max_val = i32::from_str(max_str).map_err(|_| {
336
                        FatalParseError::syntax_error(
337
                            format!("Invalid integer value for maxSize: {}", max_str),
338
                            Some(max_node.range()),
339
                        )
340
                    })?;
341
                    set_attribute = Some(SetAttr::new_max_size(max_val));
342
                }
343
            }
344
            "domain" => {
345
                value_domain = Some(parse_domain(child, source_code, symbols.clone(), errors)?);
346
            }
347
            _ => {
348
                return Err(FatalParseError::syntax_error(
349
                    format!("Unrecognized set domain child kind: {}", child.kind()),
350
                    Some(child.range()),
351
                ));
352
            }
353
        }
354
    }
355

            
356
    if let Some(domain) = value_domain {
357
        Ok(Domain::set(set_attribute.unwrap_or_default(), domain))
358
    } else {
359
        Err(FatalParseError::syntax_error(
360
            "Set domain must have a value domain".to_string(),
361
            Some(set_domain.range()),
362
        ))
363
    }
364
}