1use super::util::named_children;
2use crate::errors::{FatalParseError, RecoverableParseError};
3use conjure_cp_core::ast::{
4 DeclarationPtr, Domain, DomainPtr, IntVal, Name, Range, RecordEntry, Reference, SetAttr,
5 SymbolTablePtr,
6};
7use core::panic;
8use std::str::FromStr;
9use tree_sitter::Node;
10
11pub fn parse_domain(
13 domain: Node,
14 source_code: &str,
15 symbols: Option<SymbolTablePtr>,
16 errors: &mut Vec<RecoverableParseError>,
17) -> Result<DomainPtr, FatalParseError> {
18 match domain.kind() {
19 "domain" => parse_domain(
20 domain.child(0).expect("No domain found"),
21 source_code,
22 symbols,
23 errors,
24 ),
25 "bool_domain" => Ok(Domain::bool()),
26 "int_domain" => Ok(parse_int_domain(domain, source_code, &symbols, errors)),
27 "identifier" => {
28 let decl = get_declaration_ptr_from_identifier(domain, source_code, &symbols, errors)?;
29 let dom = Domain::reference(decl).ok_or(FatalParseError::syntax_error(
30 format!(
31 "'{}' is not a valid domain declaration",
32 &source_code[domain.start_byte()..domain.end_byte()]
33 ),
34 Some(domain.range()),
35 ))?;
36 Ok(dom)
37 }
38 "tuple_domain" => parse_tuple_domain(domain, source_code, symbols, errors),
39 "matrix_domain" => parse_matrix_domain(domain, source_code, symbols, errors),
40 "record_domain" => parse_record_domain(domain, source_code, symbols, errors),
41 "set_domain" => parse_set_domain(domain, source_code, symbols, errors),
42 _ => panic!("{} is not a supported domain type", domain.kind()),
43 }
44}
45
46fn get_declaration_ptr_from_identifier(
47 identifier: Node,
48 source_code: &str,
49 symbols_ptr: &Option<SymbolTablePtr>,
50 _errors: &mut Vec<RecoverableParseError>,
51) -> Result<DeclarationPtr, FatalParseError> {
52 let name = Name::user(&source_code[identifier.start_byte()..identifier.end_byte()]);
53 let decl = symbols_ptr
54 .as_ref()
55 .ok_or(FatalParseError::syntax_error(
56 "context needed to resolve identifier".to_string(),
57 Some(identifier.range()),
58 ))?
59 .read()
60 .lookup(&name)
61 .ok_or(FatalParseError::syntax_error(
62 format!("'{name}' is not defined"),
63 Some(identifier.range()),
64 ))?;
65 Ok(decl)
66}
67
68fn parse_int_domain(
70 int_domain: Node,
71 source_code: &str,
72 symbols_ptr: &Option<SymbolTablePtr>,
73 errors: &mut Vec<RecoverableParseError>,
74) -> DomainPtr {
75 if int_domain.child_count() == 1 {
76 return Domain::int(vec![Range::Bounded(i32::MIN, i32::MAX)]);
77 }
78 let mut ranges: Vec<Range<i32>> = Vec::new();
79 let mut ranges_unresolved: Vec<Range<IntVal>> = Vec::new();
80 let range_list = int_domain
81 .child_by_field_name("ranges")
82 .expect("No range list found for int domain");
83 for domain_component in named_children(&range_list) {
84 match domain_component.kind() {
85 "atom" => {
86 let text = &source_code[domain_component.start_byte()..domain_component.end_byte()];
87 if let Ok(integer) = text.parse::<i32>() {
89 ranges.push(Range::Single(integer));
90 continue;
91 }
92 let decl = get_declaration_ptr_from_identifier(
94 domain_component,
95 source_code,
96 symbols_ptr,
97 errors,
98 );
99 if let Ok(decl) = decl {
100 ranges_unresolved.push(Range::Single(IntVal::Reference(Reference::new(decl))));
101 } else {
102 panic!("'{}' is not a valid integer", text);
103 }
104 }
105 "int_range" => {
106 let lower_bound: Option<Result<i32, DeclarationPtr>> =
107 match domain_component.child_by_field_name("lower") {
108 Some(lower_node) => {
109 let text = &source_code[lower_node.start_byte()..lower_node.end_byte()];
111 if let Ok(integer) = text.parse::<i32>() {
112 Some(Ok(integer))
113 } else {
114 let decl = get_declaration_ptr_from_identifier(
115 lower_node,
116 source_code,
117 symbols_ptr,
118 errors,
119 );
120 if let Ok(decl) = decl {
121 Some(Err(decl))
122 } else {
123 panic!("'{}' is not a valid integer", text);
124 }
125 }
126 }
127 None => None,
128 };
129 let upper_bound: Option<Result<i32, DeclarationPtr>> =
130 match domain_component.child_by_field_name("upper") {
131 Some(upper_node) => {
132 let text = &source_code[upper_node.start_byte()..upper_node.end_byte()];
134 if let Ok(integer) = text.parse::<i32>() {
135 Some(Ok(integer))
136 } else {
137 let decl = get_declaration_ptr_from_identifier(
138 upper_node,
139 source_code,
140 symbols_ptr,
141 errors,
142 );
143 if let Ok(decl) = decl {
144 Some(Err(decl))
145 } else {
146 panic!("'{}' is not a valid integer", text);
147 }
148 }
149 }
150 None => None,
151 };
152
153 match (lower_bound, upper_bound) {
154 (Some(Ok(lower)), Some(Ok(upper))) => ranges.push(Range::Bounded(lower, upper)),
155 (Some(Ok(lower)), Some(Err(decl))) => {
156 ranges_unresolved.push(Range::Bounded(
157 IntVal::Const(lower),
158 IntVal::Reference(Reference::new(decl)),
159 ));
160 }
161 (Some(Err(decl)), Some(Ok(upper))) => {
162 ranges_unresolved.push(Range::Bounded(
163 IntVal::Reference(Reference::new(decl)),
164 IntVal::Const(upper),
165 ));
166 }
167 (Some(Err(decl_lower)), Some(Err(decl_upper))) => {
168 ranges_unresolved.push(Range::Bounded(
169 IntVal::Reference(Reference::new(decl_lower)),
170 IntVal::Reference(Reference::new(decl_upper)),
171 ));
172 }
173 (Some(Ok(lower)), None) => {
174 ranges.push(Range::UnboundedR(lower));
175 }
176 (Some(Err(decl)), None) => {
177 ranges_unresolved
178 .push(Range::UnboundedR(IntVal::Reference(Reference::new(decl))));
179 }
180 (None, Some(Ok(upper))) => {
181 ranges.push(Range::UnboundedL(upper));
182 }
183 (None, Some(Err(decl))) => {
184 ranges_unresolved
185 .push(Range::UnboundedL(IntVal::Reference(Reference::new(decl))));
186 }
187 (None, None) => {
188 ranges.push(Range::Unbounded);
189 }
190 }
191 }
192 _ => panic!("unsupported int range type"),
193 }
194 }
195
196 if !ranges_unresolved.is_empty() {
197 for range in ranges {
198 match range {
199 Range::Single(i) => ranges_unresolved.push(Range::Single(IntVal::Const(i))),
200 Range::Bounded(l, u) => {
201 ranges_unresolved.push(Range::Bounded(IntVal::Const(l), IntVal::Const(u)))
202 }
203 Range::UnboundedL(l) => ranges_unresolved.push(Range::UnboundedL(IntVal::Const(l))),
204 Range::UnboundedR(u) => ranges_unresolved.push(Range::UnboundedR(IntVal::Const(u))),
205 Range::Unbounded => ranges_unresolved.push(Range::Unbounded),
206 }
207 }
208 return Domain::int(ranges_unresolved);
209 }
210
211 Domain::int(ranges)
212}
213
214fn parse_tuple_domain(
215 tuple_domain: Node,
216 source_code: &str,
217 symbols: Option<SymbolTablePtr>,
218 errors: &mut Vec<RecoverableParseError>,
219) -> Result<DomainPtr, FatalParseError> {
220 let mut domains: Vec<DomainPtr> = Vec::new();
221 for domain in named_children(&tuple_domain) {
222 domains.push(parse_domain(domain, source_code, symbols.clone(), errors)?);
223 }
224 Ok(Domain::tuple(domains))
225}
226
227fn parse_matrix_domain(
228 matrix_domain: Node,
229 source_code: &str,
230 symbols: Option<SymbolTablePtr>,
231 errors: &mut Vec<RecoverableParseError>,
232) -> Result<DomainPtr, FatalParseError> {
233 let mut domains: Vec<DomainPtr> = Vec::new();
234 let index_domain_list = matrix_domain
235 .child_by_field_name("index_domain_list")
236 .expect("No index domains found for matrix domain");
237 for domain in named_children(&index_domain_list) {
238 domains.push(parse_domain(domain, source_code, symbols.clone(), errors)?);
239 }
240 let value_domain = parse_domain(
241 matrix_domain
242 .child_by_field_name("value_domain")
243 .ok_or(FatalParseError::syntax_error(
244 "Expected a value domain".to_string(),
245 Some(matrix_domain.range()),
246 ))?,
247 source_code,
248 symbols,
249 errors,
250 )?;
251 Ok(Domain::matrix(value_domain, domains))
252}
253
254fn parse_record_domain(
255 record_domain: Node,
256 source_code: &str,
257 symbols: Option<SymbolTablePtr>,
258 errors: &mut Vec<RecoverableParseError>,
259) -> Result<DomainPtr, FatalParseError> {
260 let mut record_entries: Vec<RecordEntry> = Vec::new();
261 for record_entry in named_children(&record_domain) {
262 let name_node = record_entry
263 .child_by_field_name("name")
264 .expect("No name found for record entry");
265 let name = Name::user(&source_code[name_node.start_byte()..name_node.end_byte()]);
266 let domain_node = record_entry
267 .child_by_field_name("domain")
268 .expect("No domain found for record entry");
269 let domain = parse_domain(domain_node, source_code, symbols.clone(), errors)?;
270 record_entries.push(RecordEntry { name, domain });
271 }
272 Ok(Domain::record(record_entries))
273}
274
275pub fn parse_set_domain(
276 set_domain: Node,
277 source_code: &str,
278 symbols: Option<SymbolTablePtr>,
279 errors: &mut Vec<RecoverableParseError>,
280) -> Result<DomainPtr, FatalParseError> {
281 let mut set_attribute: Option<SetAttr> = None;
282 let mut value_domain: Option<DomainPtr> = None;
283
284 for child in named_children(&set_domain) {
285 match child.kind() {
286 "set_attributes" => {
287 let min_value_node = child.child_by_field_name("min_value");
289 let max_value_node = child.child_by_field_name("max_value");
290 let size_value_node = child.child_by_field_name("size_value");
291
292 if let (Some(min_node), Some(max_node)) = (min_value_node, max_value_node) {
293 let min_str = &source_code[min_node.start_byte()..min_node.end_byte()];
295 let max_str = &source_code[max_node.start_byte()..max_node.end_byte()];
296
297 let min_val = i32::from_str(min_str).map_err(|_| {
298 FatalParseError::syntax_error(
299 format!("Invalid integer value for minSize: {}", min_str),
300 Some(min_node.range()),
301 )
302 })?;
303
304 let max_val = i32::from_str(max_str).map_err(|_| {
305 FatalParseError::syntax_error(
306 format!("Invalid integer value for maxSize: {}", max_str),
307 Some(max_node.range()),
308 )
309 })?;
310
311 set_attribute = Some(SetAttr::new_min_max_size(min_val, max_val));
312 } else if let Some(size_node) = size_value_node {
313 let size_str = &source_code[size_node.start_byte()..size_node.end_byte()];
315 let size_val = i32::from_str(size_str).map_err(|_| {
316 FatalParseError::syntax_error(
317 format!("Invalid integer value for size: {}", size_str),
318 Some(size_node.range()),
319 )
320 })?;
321 set_attribute = Some(SetAttr::new_size(size_val));
322 } else if let Some(min_node) = min_value_node {
323 let min_str = &source_code[min_node.start_byte()..min_node.end_byte()];
325 let min_val = i32::from_str(min_str).map_err(|_| {
326 FatalParseError::syntax_error(
327 format!("Invalid integer value for minSize: {}", min_str),
328 Some(min_node.range()),
329 )
330 })?;
331 set_attribute = Some(SetAttr::new_min_size(min_val));
332 } else if let Some(max_node) = max_value_node {
333 let max_str = &source_code[max_node.start_byte()..max_node.end_byte()];
335 let max_val = i32::from_str(max_str).map_err(|_| {
336 FatalParseError::syntax_error(
337 format!("Invalid integer value for maxSize: {}", max_str),
338 Some(max_node.range()),
339 )
340 })?;
341 set_attribute = Some(SetAttr::new_max_size(max_val));
342 }
343 }
344 "domain" => {
345 value_domain = Some(parse_domain(child, source_code, symbols.clone(), errors)?);
346 }
347 _ => {
348 return Err(FatalParseError::syntax_error(
349 format!("Unrecognized set domain child kind: {}", child.kind()),
350 Some(child.range()),
351 ));
352 }
353 }
354 }
355
356 if let Some(domain) = value_domain {
357 Ok(Domain::set(set_attribute.unwrap_or_default(), domain))
358 } else {
359 Err(FatalParseError::syntax_error(
360 "Set domain must have a value domain".to_string(),
361 Some(set_domain.range()),
362 ))
363 }
364}