conjure_cp_essence_parser/parser/
domain.rs

1use super::util::named_children;
2use crate::EssenceParseError;
3use conjure_cp_core::ast::{
4    DeclarationPtr, Domain, DomainPtr, IntVal, Name, Range, RecordEntry, Reference, SetAttr,
5    SymbolTable,
6};
7use core::panic;
8use std::cell::RefCell;
9use std::rc::Rc;
10use std::str::FromStr;
11use tree_sitter::Node;
12
13/// Parse an Essence variable domain into its Conjure AST representation.
14pub fn parse_domain(
15    domain: Node,
16    source_code: &str,
17    symbols: Option<Rc<RefCell<SymbolTable>>>,
18) -> Result<DomainPtr, EssenceParseError> {
19    match domain.kind() {
20        "domain" => parse_domain(
21            domain.child(0).expect("No domain found"),
22            source_code,
23            symbols,
24        ),
25        "bool_domain" => Ok(Domain::bool()),
26        "int_domain" => Ok(parse_int_domain(domain, source_code, &symbols)),
27        "identifier" => {
28            let decl = get_declaration_ptr_from_identifier(domain, source_code, &symbols)?;
29            let dom = Domain::reference(decl).ok_or(EssenceParseError::syntax_error(
30                format!(
31                    "'{}' is not a valid domain declaration",
32                    &source_code[domain.start_byte()..domain.end_byte()]
33                ),
34                Some(domain.range()),
35            ))?;
36            Ok(dom)
37        }
38        "tuple_domain" => parse_tuple_domain(domain, source_code, symbols),
39        "matrix_domain" => parse_matrix_domain(domain, source_code, symbols),
40        "record_domain" => parse_record_domain(domain, source_code, symbols),
41        "set_domain" => parse_set_domain(domain, source_code, symbols),
42        _ => panic!("{} is not a supported domain type", domain.kind()),
43    }
44}
45
46fn get_declaration_ptr_from_identifier(
47    identifier: Node,
48    source_code: &str,
49    symbols_ptr: &Option<Rc<RefCell<SymbolTable>>>,
50) -> Result<DeclarationPtr, EssenceParseError> {
51    let name = Name::user(&source_code[identifier.start_byte()..identifier.end_byte()]);
52    let decl = symbols_ptr
53        .as_ref()
54        .ok_or(EssenceParseError::syntax_error(
55            "context needed to resolve identifier".to_string(),
56            Some(identifier.range()),
57        ))?
58        .borrow()
59        .lookup(&name)
60        .ok_or(EssenceParseError::syntax_error(
61            format!("'{name}' is not defined"),
62            Some(identifier.range()),
63        ))?;
64    Ok(decl)
65}
66
67/// Parse an integer domain. Can be a single integer or a range.
68fn parse_int_domain(
69    int_domain: Node,
70    source_code: &str,
71    symbols_ptr: &Option<Rc<RefCell<SymbolTable>>>,
72) -> DomainPtr {
73    if int_domain.child_count() == 1 {
74        return Domain::int(vec![Range::Bounded(i32::MIN, i32::MAX)]);
75    }
76    let mut ranges: Vec<Range<i32>> = Vec::new();
77    let mut ranges_unresolved: Vec<Range<IntVal>> = Vec::new();
78    let range_list = int_domain
79        .child_by_field_name("ranges")
80        .expect("No range list found for int domain");
81    for domain_component in named_children(&range_list) {
82        match domain_component.kind() {
83            "arithmetic_expr" => {
84                let text = &source_code[domain_component.start_byte()..domain_component.end_byte()];
85                // Try parsing as a literal integer first
86                if let Ok(integer) = text.parse::<i32>() {
87                    ranges.push(Range::Single(integer));
88                    continue;
89                }
90                // Otherwise, treat as a reference
91                let decl =
92                    get_declaration_ptr_from_identifier(domain_component, source_code, symbols_ptr);
93                if let Ok(decl) = decl {
94                    ranges_unresolved.push(Range::Single(IntVal::Reference(Reference::new(decl))));
95                } else {
96                    panic!("'{}' is not a valid integer", text);
97                }
98            }
99            "int_range" => {
100                let lower_bound: Option<Result<i32, DeclarationPtr>> =
101                    match domain_component.child_by_field_name("lower") {
102                        Some(lower_node) => {
103                            // Try parsing as a literal integer first
104                            let text = &source_code[lower_node.start_byte()..lower_node.end_byte()];
105                            if let Ok(integer) = text.parse::<i32>() {
106                                Some(Ok(integer))
107                            } else {
108                                let decl = get_declaration_ptr_from_identifier(
109                                    lower_node,
110                                    source_code,
111                                    symbols_ptr,
112                                );
113                                if let Ok(decl) = decl {
114                                    Some(Err(decl))
115                                } else {
116                                    panic!("'{}' is not a valid integer", text);
117                                }
118                            }
119                        }
120                        None => None,
121                    };
122                let upper_bound: Option<Result<i32, DeclarationPtr>> =
123                    match domain_component.child_by_field_name("upper") {
124                        Some(upper_node) => {
125                            // Try parsing as a literal integer first
126                            let text = &source_code[upper_node.start_byte()..upper_node.end_byte()];
127                            if let Ok(integer) = text.parse::<i32>() {
128                                Some(Ok(integer))
129                            } else {
130                                let decl = get_declaration_ptr_from_identifier(
131                                    upper_node,
132                                    source_code,
133                                    symbols_ptr,
134                                );
135                                if let Ok(decl) = decl {
136                                    Some(Err(decl))
137                                } else {
138                                    panic!("'{}' is not a valid integer", text);
139                                }
140                            }
141                        }
142                        None => None,
143                    };
144
145                match (lower_bound, upper_bound) {
146                    (Some(Ok(lower)), Some(Ok(upper))) => ranges.push(Range::Bounded(lower, upper)),
147                    (Some(Ok(lower)), Some(Err(decl))) => {
148                        ranges_unresolved.push(Range::Bounded(
149                            IntVal::Const(lower),
150                            IntVal::Reference(Reference::new(decl)),
151                        ));
152                    }
153                    (Some(Err(decl)), Some(Ok(upper))) => {
154                        ranges_unresolved.push(Range::Bounded(
155                            IntVal::Reference(Reference::new(decl)),
156                            IntVal::Const(upper),
157                        ));
158                    }
159                    (Some(Err(decl_lower)), Some(Err(decl_upper))) => {
160                        ranges_unresolved.push(Range::Bounded(
161                            IntVal::Reference(Reference::new(decl_lower)),
162                            IntVal::Reference(Reference::new(decl_upper)),
163                        ));
164                    }
165                    (Some(Ok(lower)), None) => {
166                        ranges.push(Range::UnboundedR(lower));
167                    }
168                    (Some(Err(decl)), None) => {
169                        ranges_unresolved
170                            .push(Range::UnboundedR(IntVal::Reference(Reference::new(decl))));
171                    }
172                    (None, Some(Ok(upper))) => {
173                        ranges.push(Range::UnboundedL(upper));
174                    }
175                    (None, Some(Err(decl))) => {
176                        ranges_unresolved
177                            .push(Range::UnboundedL(IntVal::Reference(Reference::new(decl))));
178                    }
179                    (None, None) => {
180                        ranges.push(Range::Unbounded);
181                    }
182                }
183            }
184            _ => panic!("unsupported int range type"),
185        }
186    }
187
188    if !ranges_unresolved.is_empty() {
189        for range in ranges {
190            match range {
191                Range::Single(i) => ranges_unresolved.push(Range::Single(IntVal::Const(i))),
192                Range::Bounded(l, u) => {
193                    ranges_unresolved.push(Range::Bounded(IntVal::Const(l), IntVal::Const(u)))
194                }
195                Range::UnboundedL(l) => ranges_unresolved.push(Range::UnboundedL(IntVal::Const(l))),
196                Range::UnboundedR(u) => ranges_unresolved.push(Range::UnboundedR(IntVal::Const(u))),
197                Range::Unbounded => ranges_unresolved.push(Range::Unbounded),
198            }
199        }
200        return Domain::int(ranges_unresolved);
201    }
202
203    Domain::int(ranges)
204}
205
206fn parse_tuple_domain(
207    tuple_domain: Node,
208    source_code: &str,
209    symbols: Option<Rc<RefCell<SymbolTable>>>,
210) -> Result<DomainPtr, EssenceParseError> {
211    let mut domains: Vec<DomainPtr> = Vec::new();
212    for domain in named_children(&tuple_domain) {
213        domains.push(parse_domain(domain, source_code, symbols.clone())?);
214    }
215    Ok(Domain::tuple(domains))
216}
217
218fn parse_matrix_domain(
219    matrix_domain: Node,
220    source_code: &str,
221    symbols: Option<Rc<RefCell<SymbolTable>>>,
222) -> Result<DomainPtr, EssenceParseError> {
223    let mut domains: Vec<DomainPtr> = Vec::new();
224    let index_domain_list = matrix_domain
225        .child_by_field_name("index_domain_list")
226        .expect("No index domains found for matrix domain");
227    for domain in named_children(&index_domain_list) {
228        domains.push(parse_domain(domain, source_code, symbols.clone())?);
229    }
230    let value_domain = parse_domain(
231        matrix_domain.child_by_field_name("value_domain").ok_or(
232            EssenceParseError::syntax_error(
233                "Expected a value domain".to_string(),
234                Some(matrix_domain.range()),
235            ),
236        )?,
237        source_code,
238        symbols,
239    )?;
240    Ok(Domain::matrix(value_domain, domains))
241}
242
243fn parse_record_domain(
244    record_domain: Node,
245    source_code: &str,
246    symbols: Option<Rc<RefCell<SymbolTable>>>,
247) -> Result<DomainPtr, EssenceParseError> {
248    let mut record_entries: Vec<RecordEntry> = Vec::new();
249    for record_entry in named_children(&record_domain) {
250        let name_node = record_entry
251            .child_by_field_name("name")
252            .expect("No name found for record entry");
253        let name = Name::user(&source_code[name_node.start_byte()..name_node.end_byte()]);
254        let domain_node = record_entry
255            .child_by_field_name("domain")
256            .expect("No domain found for record entry");
257        let domain = parse_domain(domain_node, source_code, symbols.clone())?;
258        record_entries.push(RecordEntry { name, domain });
259    }
260    Ok(Domain::record(record_entries))
261}
262
263pub fn parse_set_domain(
264    set_domain: Node,
265    source_code: &str,
266    symbols: Option<Rc<RefCell<SymbolTable>>>,
267) -> Result<DomainPtr, EssenceParseError> {
268    let mut set_attribute: Option<SetAttr> = None;
269    let mut value_domain: Option<DomainPtr> = None;
270
271    for child in named_children(&set_domain) {
272        match child.kind() {
273            "set_attributes" => {
274                // Check if we have both minSize and maxSize (minMax case)
275                let min_value_node = child.child_by_field_name("min_value");
276                let max_value_node = child.child_by_field_name("max_value");
277                let size_value_node = child.child_by_field_name("size_value");
278
279                if let (Some(min_node), Some(max_node)) = (min_value_node, max_value_node) {
280                    // MinMax case
281                    let min_str = &source_code[min_node.start_byte()..min_node.end_byte()];
282                    let max_str = &source_code[max_node.start_byte()..max_node.end_byte()];
283
284                    let min_val = i32::from_str(min_str).map_err(|_| {
285                        EssenceParseError::syntax_error(
286                            format!("Invalid integer value for minSize: {}", min_str),
287                            Some(min_node.range()),
288                        )
289                    })?;
290
291                    let max_val = i32::from_str(max_str).map_err(|_| {
292                        EssenceParseError::syntax_error(
293                            format!("Invalid integer value for maxSize: {}", max_str),
294                            Some(max_node.range()),
295                        )
296                    })?;
297
298                    set_attribute = Some(SetAttr::new_min_max_size(min_val, max_val));
299                } else if let Some(size_node) = size_value_node {
300                    // Size case
301                    let size_str = &source_code[size_node.start_byte()..size_node.end_byte()];
302                    let size_val = i32::from_str(size_str).map_err(|_| {
303                        EssenceParseError::syntax_error(
304                            format!("Invalid integer value for size: {}", size_str),
305                            Some(size_node.range()),
306                        )
307                    })?;
308                    set_attribute = Some(SetAttr::new_size(size_val));
309                } else if let Some(min_node) = min_value_node {
310                    // MinSize only case
311                    let min_str = &source_code[min_node.start_byte()..min_node.end_byte()];
312                    let min_val = i32::from_str(min_str).map_err(|_| {
313                        EssenceParseError::syntax_error(
314                            format!("Invalid integer value for minSize: {}", min_str),
315                            Some(min_node.range()),
316                        )
317                    })?;
318                    set_attribute = Some(SetAttr::new_min_size(min_val));
319                } else if let Some(max_node) = max_value_node {
320                    // MaxSize only case
321                    let max_str = &source_code[max_node.start_byte()..max_node.end_byte()];
322                    let max_val = i32::from_str(max_str).map_err(|_| {
323                        EssenceParseError::syntax_error(
324                            format!("Invalid integer value for maxSize: {}", max_str),
325                            Some(max_node.range()),
326                        )
327                    })?;
328                    set_attribute = Some(SetAttr::new_max_size(max_val));
329                }
330            }
331            "domain" => {
332                value_domain = Some(parse_domain(child, source_code, symbols.clone())?);
333            }
334            _ => {
335                return Err(EssenceParseError::syntax_error(
336                    format!("Unrecognized set domain child kind: {}", child.kind()),
337                    Some(child.range()),
338                ));
339            }
340        }
341    }
342
343    if let Some(domain) = value_domain {
344        Ok(Domain::set(set_attribute.unwrap_or_default(), domain))
345    } else {
346        Err(EssenceParseError::syntax_error(
347            "Set domain must have a value domain".to_string(),
348            Some(set_domain.range()),
349        ))
350    }
351}