Skip to main content

conjure_cp_essence_parser/parser/
domain.rs

1use super::util::named_children;
2use crate::errors::{FatalParseError, RecoverableParseError};
3use conjure_cp_core::ast::{
4    DeclarationPtr, Domain, DomainPtr, IntVal, Name, Range, RecordEntry, Reference, SetAttr,
5    SymbolTablePtr,
6};
7use core::panic;
8use std::str::FromStr;
9use tree_sitter::Node;
10
11/// Parse an Essence variable domain into its Conjure AST representation.
12pub fn parse_domain(
13    domain: Node,
14    source_code: &str,
15    symbols: Option<SymbolTablePtr>,
16    errors: &mut Vec<RecoverableParseError>,
17) -> Result<DomainPtr, FatalParseError> {
18    match domain.kind() {
19        "domain" => parse_domain(
20            domain.child(0).expect("No domain found"),
21            source_code,
22            symbols,
23            errors,
24        ),
25        "bool_domain" => Ok(Domain::bool()),
26        "int_domain" => Ok(parse_int_domain(domain, source_code, &symbols, errors)),
27        "identifier" => {
28            let decl = get_declaration_ptr_from_identifier(domain, source_code, &symbols, errors)?;
29            let dom = Domain::reference(decl).ok_or(FatalParseError::syntax_error(
30                format!(
31                    "'{}' is not a valid domain declaration",
32                    &source_code[domain.start_byte()..domain.end_byte()]
33                ),
34                Some(domain.range()),
35            ))?;
36            Ok(dom)
37        }
38        "tuple_domain" => parse_tuple_domain(domain, source_code, symbols, errors),
39        "matrix_domain" => parse_matrix_domain(domain, source_code, symbols, errors),
40        "record_domain" => parse_record_domain(domain, source_code, symbols, errors),
41        "set_domain" => parse_set_domain(domain, source_code, symbols, errors),
42        _ => panic!("{} is not a supported domain type", domain.kind()),
43    }
44}
45
46fn get_declaration_ptr_from_identifier(
47    identifier: Node,
48    source_code: &str,
49    symbols_ptr: &Option<SymbolTablePtr>,
50    _errors: &mut Vec<RecoverableParseError>,
51) -> Result<DeclarationPtr, FatalParseError> {
52    let name = Name::user(&source_code[identifier.start_byte()..identifier.end_byte()]);
53    let decl = symbols_ptr
54        .as_ref()
55        .ok_or(FatalParseError::syntax_error(
56            "context needed to resolve identifier".to_string(),
57            Some(identifier.range()),
58        ))?
59        .read()
60        .lookup(&name)
61        .ok_or(FatalParseError::syntax_error(
62            format!("'{name}' is not defined"),
63            Some(identifier.range()),
64        ))?;
65    Ok(decl)
66}
67
68/// Parse an integer domain. Can be a single integer or a range.
69fn parse_int_domain(
70    int_domain: Node,
71    source_code: &str,
72    symbols_ptr: &Option<SymbolTablePtr>,
73    errors: &mut Vec<RecoverableParseError>,
74) -> DomainPtr {
75    if int_domain.child_count() == 1 {
76        return Domain::int(vec![Range::Bounded(i32::MIN, i32::MAX)]);
77    }
78    let mut ranges: Vec<Range<i32>> = Vec::new();
79    let mut ranges_unresolved: Vec<Range<IntVal>> = Vec::new();
80    let range_list = int_domain
81        .child_by_field_name("ranges")
82        .expect("No range list found for int domain");
83    for domain_component in named_children(&range_list) {
84        match domain_component.kind() {
85            "atom" => {
86                let text = &source_code[domain_component.start_byte()..domain_component.end_byte()];
87                // Try parsing as a literal integer first
88                if let Ok(integer) = text.parse::<i32>() {
89                    ranges.push(Range::Single(integer));
90                    continue;
91                }
92                // Otherwise, treat as a reference
93                let decl = get_declaration_ptr_from_identifier(
94                    domain_component,
95                    source_code,
96                    symbols_ptr,
97                    errors,
98                );
99                if let Ok(decl) = decl {
100                    ranges_unresolved.push(Range::Single(IntVal::Reference(Reference::new(decl))));
101                } else {
102                    panic!("'{}' is not a valid integer", text);
103                }
104            }
105            "int_range" => {
106                let lower_bound: Option<Result<i32, DeclarationPtr>> =
107                    match domain_component.child_by_field_name("lower") {
108                        Some(lower_node) => {
109                            // Try parsing as a literal integer first
110                            let text = &source_code[lower_node.start_byte()..lower_node.end_byte()];
111                            if let Ok(integer) = text.parse::<i32>() {
112                                Some(Ok(integer))
113                            } else {
114                                let decl = get_declaration_ptr_from_identifier(
115                                    lower_node,
116                                    source_code,
117                                    symbols_ptr,
118                                    errors,
119                                );
120                                if let Ok(decl) = decl {
121                                    Some(Err(decl))
122                                } else {
123                                    panic!("'{}' is not a valid integer", text);
124                                }
125                            }
126                        }
127                        None => None,
128                    };
129                let upper_bound: Option<Result<i32, DeclarationPtr>> =
130                    match domain_component.child_by_field_name("upper") {
131                        Some(upper_node) => {
132                            // Try parsing as a literal integer first
133                            let text = &source_code[upper_node.start_byte()..upper_node.end_byte()];
134                            if let Ok(integer) = text.parse::<i32>() {
135                                Some(Ok(integer))
136                            } else {
137                                let decl = get_declaration_ptr_from_identifier(
138                                    upper_node,
139                                    source_code,
140                                    symbols_ptr,
141                                    errors,
142                                );
143                                if let Ok(decl) = decl {
144                                    Some(Err(decl))
145                                } else {
146                                    panic!("'{}' is not a valid integer", text);
147                                }
148                            }
149                        }
150                        None => None,
151                    };
152
153                match (lower_bound, upper_bound) {
154                    (Some(Ok(lower)), Some(Ok(upper))) => ranges.push(Range::Bounded(lower, upper)),
155                    (Some(Ok(lower)), Some(Err(decl))) => {
156                        ranges_unresolved.push(Range::Bounded(
157                            IntVal::Const(lower),
158                            IntVal::Reference(Reference::new(decl)),
159                        ));
160                    }
161                    (Some(Err(decl)), Some(Ok(upper))) => {
162                        ranges_unresolved.push(Range::Bounded(
163                            IntVal::Reference(Reference::new(decl)),
164                            IntVal::Const(upper),
165                        ));
166                    }
167                    (Some(Err(decl_lower)), Some(Err(decl_upper))) => {
168                        ranges_unresolved.push(Range::Bounded(
169                            IntVal::Reference(Reference::new(decl_lower)),
170                            IntVal::Reference(Reference::new(decl_upper)),
171                        ));
172                    }
173                    (Some(Ok(lower)), None) => {
174                        ranges.push(Range::UnboundedR(lower));
175                    }
176                    (Some(Err(decl)), None) => {
177                        ranges_unresolved
178                            .push(Range::UnboundedR(IntVal::Reference(Reference::new(decl))));
179                    }
180                    (None, Some(Ok(upper))) => {
181                        ranges.push(Range::UnboundedL(upper));
182                    }
183                    (None, Some(Err(decl))) => {
184                        ranges_unresolved
185                            .push(Range::UnboundedL(IntVal::Reference(Reference::new(decl))));
186                    }
187                    (None, None) => {
188                        ranges.push(Range::Unbounded);
189                    }
190                }
191            }
192            _ => panic!("unsupported int range type"),
193        }
194    }
195
196    if !ranges_unresolved.is_empty() {
197        for range in ranges {
198            match range {
199                Range::Single(i) => ranges_unresolved.push(Range::Single(IntVal::Const(i))),
200                Range::Bounded(l, u) => {
201                    ranges_unresolved.push(Range::Bounded(IntVal::Const(l), IntVal::Const(u)))
202                }
203                Range::UnboundedL(l) => ranges_unresolved.push(Range::UnboundedL(IntVal::Const(l))),
204                Range::UnboundedR(u) => ranges_unresolved.push(Range::UnboundedR(IntVal::Const(u))),
205                Range::Unbounded => ranges_unresolved.push(Range::Unbounded),
206            }
207        }
208        return Domain::int(ranges_unresolved);
209    }
210
211    Domain::int(ranges)
212}
213
214fn parse_tuple_domain(
215    tuple_domain: Node,
216    source_code: &str,
217    symbols: Option<SymbolTablePtr>,
218    errors: &mut Vec<RecoverableParseError>,
219) -> Result<DomainPtr, FatalParseError> {
220    let mut domains: Vec<DomainPtr> = Vec::new();
221    for domain in named_children(&tuple_domain) {
222        domains.push(parse_domain(domain, source_code, symbols.clone(), errors)?);
223    }
224    Ok(Domain::tuple(domains))
225}
226
227fn parse_matrix_domain(
228    matrix_domain: Node,
229    source_code: &str,
230    symbols: Option<SymbolTablePtr>,
231    errors: &mut Vec<RecoverableParseError>,
232) -> Result<DomainPtr, FatalParseError> {
233    let mut domains: Vec<DomainPtr> = Vec::new();
234    let index_domain_list = matrix_domain
235        .child_by_field_name("index_domain_list")
236        .expect("No index domains found for matrix domain");
237    for domain in named_children(&index_domain_list) {
238        domains.push(parse_domain(domain, source_code, symbols.clone(), errors)?);
239    }
240    let value_domain = parse_domain(
241        matrix_domain
242            .child_by_field_name("value_domain")
243            .ok_or(FatalParseError::syntax_error(
244                "Expected a value domain".to_string(),
245                Some(matrix_domain.range()),
246            ))?,
247        source_code,
248        symbols,
249        errors,
250    )?;
251    Ok(Domain::matrix(value_domain, domains))
252}
253
254fn parse_record_domain(
255    record_domain: Node,
256    source_code: &str,
257    symbols: Option<SymbolTablePtr>,
258    errors: &mut Vec<RecoverableParseError>,
259) -> Result<DomainPtr, FatalParseError> {
260    let mut record_entries: Vec<RecordEntry> = Vec::new();
261    for record_entry in named_children(&record_domain) {
262        let name_node = record_entry
263            .child_by_field_name("name")
264            .expect("No name found for record entry");
265        let name = Name::user(&source_code[name_node.start_byte()..name_node.end_byte()]);
266        let domain_node = record_entry
267            .child_by_field_name("domain")
268            .expect("No domain found for record entry");
269        let domain = parse_domain(domain_node, source_code, symbols.clone(), errors)?;
270        record_entries.push(RecordEntry { name, domain });
271    }
272    Ok(Domain::record(record_entries))
273}
274
275pub fn parse_set_domain(
276    set_domain: Node,
277    source_code: &str,
278    symbols: Option<SymbolTablePtr>,
279    errors: &mut Vec<RecoverableParseError>,
280) -> Result<DomainPtr, FatalParseError> {
281    let mut set_attribute: Option<SetAttr> = None;
282    let mut value_domain: Option<DomainPtr> = None;
283
284    for child in named_children(&set_domain) {
285        match child.kind() {
286            "set_attributes" => {
287                // Check if we have both minSize and maxSize (minMax case)
288                let min_value_node = child.child_by_field_name("min_value");
289                let max_value_node = child.child_by_field_name("max_value");
290                let size_value_node = child.child_by_field_name("size_value");
291
292                if let (Some(min_node), Some(max_node)) = (min_value_node, max_value_node) {
293                    // MinMax case
294                    let min_str = &source_code[min_node.start_byte()..min_node.end_byte()];
295                    let max_str = &source_code[max_node.start_byte()..max_node.end_byte()];
296
297                    let min_val = i32::from_str(min_str).map_err(|_| {
298                        FatalParseError::syntax_error(
299                            format!("Invalid integer value for minSize: {}", min_str),
300                            Some(min_node.range()),
301                        )
302                    })?;
303
304                    let max_val = i32::from_str(max_str).map_err(|_| {
305                        FatalParseError::syntax_error(
306                            format!("Invalid integer value for maxSize: {}", max_str),
307                            Some(max_node.range()),
308                        )
309                    })?;
310
311                    set_attribute = Some(SetAttr::new_min_max_size(min_val, max_val));
312                } else if let Some(size_node) = size_value_node {
313                    // Size case
314                    let size_str = &source_code[size_node.start_byte()..size_node.end_byte()];
315                    let size_val = i32::from_str(size_str).map_err(|_| {
316                        FatalParseError::syntax_error(
317                            format!("Invalid integer value for size: {}", size_str),
318                            Some(size_node.range()),
319                        )
320                    })?;
321                    set_attribute = Some(SetAttr::new_size(size_val));
322                } else if let Some(min_node) = min_value_node {
323                    // MinSize only case
324                    let min_str = &source_code[min_node.start_byte()..min_node.end_byte()];
325                    let min_val = i32::from_str(min_str).map_err(|_| {
326                        FatalParseError::syntax_error(
327                            format!("Invalid integer value for minSize: {}", min_str),
328                            Some(min_node.range()),
329                        )
330                    })?;
331                    set_attribute = Some(SetAttr::new_min_size(min_val));
332                } else if let Some(max_node) = max_value_node {
333                    // MaxSize only case
334                    let max_str = &source_code[max_node.start_byte()..max_node.end_byte()];
335                    let max_val = i32::from_str(max_str).map_err(|_| {
336                        FatalParseError::syntax_error(
337                            format!("Invalid integer value for maxSize: {}", max_str),
338                            Some(max_node.range()),
339                        )
340                    })?;
341                    set_attribute = Some(SetAttr::new_max_size(max_val));
342                }
343            }
344            "domain" => {
345                value_domain = Some(parse_domain(child, source_code, symbols.clone(), errors)?);
346            }
347            _ => {
348                return Err(FatalParseError::syntax_error(
349                    format!("Unrecognized set domain child kind: {}", child.kind()),
350                    Some(child.range()),
351                ));
352            }
353        }
354    }
355
356    if let Some(domain) = value_domain {
357        Ok(Domain::set(set_attribute.unwrap_or_default(), domain))
358    } else {
359        Err(FatalParseError::syntax_error(
360            "Set domain must have a value domain".to_string(),
361            Some(set_domain.range()),
362        ))
363    }
364}