1
use super::util::named_children;
2
use crate::EssenceParseError;
3
use conjure_cp_core::ast::{
4
    DeclarationPtr, Domain, DomainPtr, IntVal, Name, Range, RecordEntry, Reference, SetAttr,
5
    SymbolTablePtr,
6
};
7
use core::panic;
8
use std::str::FromStr;
9
use tree_sitter::Node;
10

            
11
/// Parse an Essence variable domain into its Conjure AST representation.
12
630
pub fn parse_domain(
13
630
    domain: Node,
14
630
    source_code: &str,
15
630
    symbols: Option<SymbolTablePtr>,
16
630
) -> Result<DomainPtr, EssenceParseError> {
17
630
    match domain.kind() {
18
630
        "domain" => parse_domain(
19
313
            domain.child(0).expect("No domain found"),
20
313
            source_code,
21
313
            symbols,
22
        ),
23
317
        "bool_domain" => Ok(Domain::bool()),
24
164
        "int_domain" => Ok(parse_int_domain(domain, source_code, &symbols)),
25
        "identifier" => {
26
            let decl = get_declaration_ptr_from_identifier(domain, source_code, &symbols)?;
27
            let dom = Domain::reference(decl).ok_or(EssenceParseError::syntax_error(
28
                format!(
29
                    "'{}' is not a valid domain declaration",
30
                    &source_code[domain.start_byte()..domain.end_byte()]
31
                ),
32
                Some(domain.range()),
33
            ))?;
34
            Ok(dom)
35
        }
36
        "tuple_domain" => parse_tuple_domain(domain, source_code, symbols),
37
        "matrix_domain" => parse_matrix_domain(domain, source_code, symbols),
38
        "record_domain" => parse_record_domain(domain, source_code, symbols),
39
        "set_domain" => parse_set_domain(domain, source_code, symbols),
40
        _ => panic!("{} is not a supported domain type", domain.kind()),
41
    }
42
630
}
43

            
44
fn get_declaration_ptr_from_identifier(
45
    identifier: Node,
46
    source_code: &str,
47
    symbols_ptr: &Option<SymbolTablePtr>,
48
) -> Result<DeclarationPtr, EssenceParseError> {
49
    let name = Name::user(&source_code[identifier.start_byte()..identifier.end_byte()]);
50
    let decl = symbols_ptr
51
        .as_ref()
52
        .ok_or(EssenceParseError::syntax_error(
53
            "context needed to resolve identifier".to_string(),
54
            Some(identifier.range()),
55
        ))?
56
        .read()
57
        .lookup(&name)
58
        .ok_or(EssenceParseError::syntax_error(
59
            format!("'{name}' is not defined"),
60
            Some(identifier.range()),
61
        ))?;
62
    Ok(decl)
63
}
64

            
65
/// Parse an integer domain. Can be a single integer or a range.
66
164
fn parse_int_domain(
67
164
    int_domain: Node,
68
164
    source_code: &str,
69
164
    symbols_ptr: &Option<SymbolTablePtr>,
70
164
) -> DomainPtr {
71
164
    if int_domain.child_count() == 1 {
72
        return Domain::int(vec![Range::Bounded(i32::MIN, i32::MAX)]);
73
164
    }
74
164
    let mut ranges: Vec<Range<i32>> = Vec::new();
75
164
    let mut ranges_unresolved: Vec<Range<IntVal>> = Vec::new();
76
164
    let range_list = int_domain
77
164
        .child_by_field_name("ranges")
78
164
        .expect("No range list found for int domain");
79
170
    for domain_component in named_children(&range_list) {
80
170
        match domain_component.kind() {
81
170
            "atom" => {
82
9
                let text = &source_code[domain_component.start_byte()..domain_component.end_byte()];
83
                // Try parsing as a literal integer first
84
9
                if let Ok(integer) = text.parse::<i32>() {
85
9
                    ranges.push(Range::Single(integer));
86
9
                    continue;
87
                }
88
                // Otherwise, treat as a reference
89
                let decl =
90
                    get_declaration_ptr_from_identifier(domain_component, source_code, symbols_ptr);
91
                if let Ok(decl) = decl {
92
                    ranges_unresolved.push(Range::Single(IntVal::Reference(Reference::new(decl))));
93
                } else {
94
                    panic!("'{}' is not a valid integer", text);
95
                }
96
            }
97
161
            "int_range" => {
98
161
                let lower_bound: Option<Result<i32, DeclarationPtr>> =
99
161
                    match domain_component.child_by_field_name("lower") {
100
161
                        Some(lower_node) => {
101
                            // Try parsing as a literal integer first
102
161
                            let text = &source_code[lower_node.start_byte()..lower_node.end_byte()];
103
161
                            if let Ok(integer) = text.parse::<i32>() {
104
161
                                Some(Ok(integer))
105
                            } else {
106
                                let decl = get_declaration_ptr_from_identifier(
107
                                    lower_node,
108
                                    source_code,
109
                                    symbols_ptr,
110
                                );
111
                                if let Ok(decl) = decl {
112
                                    Some(Err(decl))
113
                                } else {
114
                                    panic!("'{}' is not a valid integer", text);
115
                                }
116
                            }
117
                        }
118
                        None => None,
119
                    };
120
161
                let upper_bound: Option<Result<i32, DeclarationPtr>> =
121
161
                    match domain_component.child_by_field_name("upper") {
122
161
                        Some(upper_node) => {
123
                            // Try parsing as a literal integer first
124
161
                            let text = &source_code[upper_node.start_byte()..upper_node.end_byte()];
125
161
                            if let Ok(integer) = text.parse::<i32>() {
126
161
                                Some(Ok(integer))
127
                            } else {
128
                                let decl = get_declaration_ptr_from_identifier(
129
                                    upper_node,
130
                                    source_code,
131
                                    symbols_ptr,
132
                                );
133
                                if let Ok(decl) = decl {
134
                                    Some(Err(decl))
135
                                } else {
136
                                    panic!("'{}' is not a valid integer", text);
137
                                }
138
                            }
139
                        }
140
                        None => None,
141
                    };
142

            
143
161
                match (lower_bound, upper_bound) {
144
161
                    (Some(Ok(lower)), Some(Ok(upper))) => ranges.push(Range::Bounded(lower, upper)),
145
                    (Some(Ok(lower)), Some(Err(decl))) => {
146
                        ranges_unresolved.push(Range::Bounded(
147
                            IntVal::Const(lower),
148
                            IntVal::Reference(Reference::new(decl)),
149
                        ));
150
                    }
151
                    (Some(Err(decl)), Some(Ok(upper))) => {
152
                        ranges_unresolved.push(Range::Bounded(
153
                            IntVal::Reference(Reference::new(decl)),
154
                            IntVal::Const(upper),
155
                        ));
156
                    }
157
                    (Some(Err(decl_lower)), Some(Err(decl_upper))) => {
158
                        ranges_unresolved.push(Range::Bounded(
159
                            IntVal::Reference(Reference::new(decl_lower)),
160
                            IntVal::Reference(Reference::new(decl_upper)),
161
                        ));
162
                    }
163
                    (Some(Ok(lower)), None) => {
164
                        ranges.push(Range::UnboundedR(lower));
165
                    }
166
                    (Some(Err(decl)), None) => {
167
                        ranges_unresolved
168
                            .push(Range::UnboundedR(IntVal::Reference(Reference::new(decl))));
169
                    }
170
                    (None, Some(Ok(upper))) => {
171
                        ranges.push(Range::UnboundedL(upper));
172
                    }
173
                    (None, Some(Err(decl))) => {
174
                        ranges_unresolved
175
                            .push(Range::UnboundedL(IntVal::Reference(Reference::new(decl))));
176
                    }
177
                    (None, None) => {
178
                        ranges.push(Range::Unbounded);
179
                    }
180
                }
181
            }
182
            _ => panic!("unsupported int range type"),
183
        }
184
    }
185

            
186
164
    if !ranges_unresolved.is_empty() {
187
        for range in ranges {
188
            match range {
189
                Range::Single(i) => ranges_unresolved.push(Range::Single(IntVal::Const(i))),
190
                Range::Bounded(l, u) => {
191
                    ranges_unresolved.push(Range::Bounded(IntVal::Const(l), IntVal::Const(u)))
192
                }
193
                Range::UnboundedL(l) => ranges_unresolved.push(Range::UnboundedL(IntVal::Const(l))),
194
                Range::UnboundedR(u) => ranges_unresolved.push(Range::UnboundedR(IntVal::Const(u))),
195
                Range::Unbounded => ranges_unresolved.push(Range::Unbounded),
196
            }
197
        }
198
        return Domain::int(ranges_unresolved);
199
164
    }
200

            
201
164
    Domain::int(ranges)
202
164
}
203

            
204
fn parse_tuple_domain(
205
    tuple_domain: Node,
206
    source_code: &str,
207
    symbols: Option<SymbolTablePtr>,
208
) -> Result<DomainPtr, EssenceParseError> {
209
    let mut domains: Vec<DomainPtr> = Vec::new();
210
    for domain in named_children(&tuple_domain) {
211
        domains.push(parse_domain(domain, source_code, symbols.clone())?);
212
    }
213
    Ok(Domain::tuple(domains))
214
}
215

            
216
fn parse_matrix_domain(
217
    matrix_domain: Node,
218
    source_code: &str,
219
    symbols: Option<SymbolTablePtr>,
220
) -> Result<DomainPtr, EssenceParseError> {
221
    let mut domains: Vec<DomainPtr> = Vec::new();
222
    let index_domain_list = matrix_domain
223
        .child_by_field_name("index_domain_list")
224
        .expect("No index domains found for matrix domain");
225
    for domain in named_children(&index_domain_list) {
226
        domains.push(parse_domain(domain, source_code, symbols.clone())?);
227
    }
228
    let value_domain = parse_domain(
229
        matrix_domain.child_by_field_name("value_domain").ok_or(
230
            EssenceParseError::syntax_error(
231
                "Expected a value domain".to_string(),
232
                Some(matrix_domain.range()),
233
            ),
234
        )?,
235
        source_code,
236
        symbols,
237
    )?;
238
    Ok(Domain::matrix(value_domain, domains))
239
}
240

            
241
fn parse_record_domain(
242
    record_domain: Node,
243
    source_code: &str,
244
    symbols: Option<SymbolTablePtr>,
245
) -> Result<DomainPtr, EssenceParseError> {
246
    let mut record_entries: Vec<RecordEntry> = Vec::new();
247
    for record_entry in named_children(&record_domain) {
248
        let name_node = record_entry
249
            .child_by_field_name("name")
250
            .expect("No name found for record entry");
251
        let name = Name::user(&source_code[name_node.start_byte()..name_node.end_byte()]);
252
        let domain_node = record_entry
253
            .child_by_field_name("domain")
254
            .expect("No domain found for record entry");
255
        let domain = parse_domain(domain_node, source_code, symbols.clone())?;
256
        record_entries.push(RecordEntry { name, domain });
257
    }
258
    Ok(Domain::record(record_entries))
259
}
260

            
261
pub fn parse_set_domain(
262
    set_domain: Node,
263
    source_code: &str,
264
    symbols: Option<SymbolTablePtr>,
265
) -> Result<DomainPtr, EssenceParseError> {
266
    let mut set_attribute: Option<SetAttr> = None;
267
    let mut value_domain: Option<DomainPtr> = None;
268

            
269
    for child in named_children(&set_domain) {
270
        match child.kind() {
271
            "set_attributes" => {
272
                // Check if we have both minSize and maxSize (minMax case)
273
                let min_value_node = child.child_by_field_name("min_value");
274
                let max_value_node = child.child_by_field_name("max_value");
275
                let size_value_node = child.child_by_field_name("size_value");
276

            
277
                if let (Some(min_node), Some(max_node)) = (min_value_node, max_value_node) {
278
                    // MinMax case
279
                    let min_str = &source_code[min_node.start_byte()..min_node.end_byte()];
280
                    let max_str = &source_code[max_node.start_byte()..max_node.end_byte()];
281

            
282
                    let min_val = i32::from_str(min_str).map_err(|_| {
283
                        EssenceParseError::syntax_error(
284
                            format!("Invalid integer value for minSize: {}", min_str),
285
                            Some(min_node.range()),
286
                        )
287
                    })?;
288

            
289
                    let max_val = i32::from_str(max_str).map_err(|_| {
290
                        EssenceParseError::syntax_error(
291
                            format!("Invalid integer value for maxSize: {}", max_str),
292
                            Some(max_node.range()),
293
                        )
294
                    })?;
295

            
296
                    set_attribute = Some(SetAttr::new_min_max_size(min_val, max_val));
297
                } else if let Some(size_node) = size_value_node {
298
                    // Size case
299
                    let size_str = &source_code[size_node.start_byte()..size_node.end_byte()];
300
                    let size_val = i32::from_str(size_str).map_err(|_| {
301
                        EssenceParseError::syntax_error(
302
                            format!("Invalid integer value for size: {}", size_str),
303
                            Some(size_node.range()),
304
                        )
305
                    })?;
306
                    set_attribute = Some(SetAttr::new_size(size_val));
307
                } else if let Some(min_node) = min_value_node {
308
                    // MinSize only case
309
                    let min_str = &source_code[min_node.start_byte()..min_node.end_byte()];
310
                    let min_val = i32::from_str(min_str).map_err(|_| {
311
                        EssenceParseError::syntax_error(
312
                            format!("Invalid integer value for minSize: {}", min_str),
313
                            Some(min_node.range()),
314
                        )
315
                    })?;
316
                    set_attribute = Some(SetAttr::new_min_size(min_val));
317
                } else if let Some(max_node) = max_value_node {
318
                    // MaxSize only case
319
                    let max_str = &source_code[max_node.start_byte()..max_node.end_byte()];
320
                    let max_val = i32::from_str(max_str).map_err(|_| {
321
                        EssenceParseError::syntax_error(
322
                            format!("Invalid integer value for maxSize: {}", max_str),
323
                            Some(max_node.range()),
324
                        )
325
                    })?;
326
                    set_attribute = Some(SetAttr::new_max_size(max_val));
327
                }
328
            }
329
            "domain" => {
330
                value_domain = Some(parse_domain(child, source_code, symbols.clone())?);
331
            }
332
            _ => {
333
                return Err(EssenceParseError::syntax_error(
334
                    format!("Unrecognized set domain child kind: {}", child.kind()),
335
                    Some(child.range()),
336
                ));
337
            }
338
        }
339
    }
340

            
341
    if let Some(domain) = value_domain {
342
        Ok(Domain::set(set_attribute.unwrap_or_default(), domain))
343
    } else {
344
        Err(EssenceParseError::syntax_error(
345
            "Set domain must have a value domain".to_string(),
346
            Some(set_domain.range()),
347
        ))
348
    }
349
}