Skip to main content

conjure_cp_essence_parser/parser/
domain.rs

1use super::atom::parse_int;
2use super::util::named_children;
3use crate::diagnostics::source_map::{HoverInfo, span_with_hover};
4use crate::errors::FatalParseError;
5use crate::expression::parse_expression;
6use crate::parser::ParseContext;
7use crate::{child, field};
8use conjure_cp_core::ast::{
9    DeclarationPtr, Domain, DomainPtr, IntVal, Moo, Name, Range, RecordEntry, Reference, SetAttr,
10};
11use tree_sitter::Node;
12
13/// Parse an Essence variable domain into its Conjure AST representation.
14pub fn parse_domain(
15    ctx: &mut ParseContext,
16    domain: Node,
17) -> Result<Option<DomainPtr>, FatalParseError> {
18    match domain.kind() {
19        "domain" => parse_domain(ctx, child!(domain, 0, "domain")),
20        "bool_domain" => {
21            let hover = HoverInfo {
22                description: "Boolean domain".to_string(),
23                kind: Some(crate::diagnostics::diagnostics_api::SymbolKind::Domain),
24                ty: None,
25                decl_span: None,
26            };
27            span_with_hover(&domain, ctx.source_code, ctx.source_map, hover);
28            Ok(Some(Domain::bool()))
29        }
30        "int_domain" => {
31            let hover = HoverInfo {
32                description: "Integer domain".to_string(),
33                kind: Some(crate::diagnostics::diagnostics_api::SymbolKind::Domain),
34                ty: None,
35                decl_span: None,
36            };
37            span_with_hover(&domain, ctx.source_code, ctx.source_map, hover);
38            parse_int_domain(ctx, domain)
39        }
40        "identifier" => {
41            let Some(decl) = get_declaration_ptr_from_identifier(ctx, domain)? else {
42                return Ok(None);
43            };
44            let Some(dom) = Domain::reference(decl) else {
45                ctx.record_error(crate::errors::RecoverableParseError::new(
46                    format!(
47                        "The identifier '{}' is not a valid domain",
48                        &ctx.source_code[domain.start_byte()..domain.end_byte()]
49                    ),
50                    Some(domain.range()),
51                ));
52                return Ok(None);
53            };
54            let name = &ctx.source_code[domain.start_byte()..domain.end_byte()];
55            let hover = HoverInfo {
56                description: format!("Domain reference: {name}"),
57                kind: None,
58                ty: None,
59                decl_span: None,
60            };
61            span_with_hover(&domain, ctx.source_code, ctx.source_map, hover);
62            Ok(Some(dom))
63        }
64        "tuple_domain" => parse_tuple_domain(ctx, domain),
65        "matrix_domain" => parse_matrix_domain(ctx, domain),
66        "record_domain" => parse_record_domain(ctx, domain),
67        "set_domain" => parse_set_domain(ctx, domain),
68        _ => Err(FatalParseError::internal_error(
69            format!("{} is not a supported domain type", domain.kind()),
70            Some(domain.range()),
71        )),
72    }
73}
74
75fn get_declaration_ptr_from_identifier(
76    ctx: &mut ParseContext,
77    identifier: Node,
78) -> Result<Option<DeclarationPtr>, FatalParseError> {
79    let name = Name::user(&ctx.source_code[identifier.start_byte()..identifier.end_byte()]);
80    let decl = ctx
81        .symbols
82        .as_ref()
83        .ok_or(FatalParseError::internal_error(
84            "context needed to resolve identifier".to_string(),
85            Some(identifier.range()),
86        ))?
87        .read()
88        .lookup(&name);
89    match decl {
90        Some(decl) => Ok(Some(decl)),
91        None => {
92            ctx.record_error(crate::errors::RecoverableParseError::new(
93                format!("The identifier '{}' is not defined", name),
94                Some(identifier.range()),
95            ));
96            Ok(None)
97        }
98    }
99}
100
101/// Parse an integer domain. Can be a single integer or a range.
102fn parse_int_domain(
103    ctx: &mut ParseContext,
104    int_domain: Node,
105) -> Result<Option<DomainPtr>, FatalParseError> {
106    if int_domain.child_count() == 1 {
107        // for domains of just 'int' with no range
108        return Ok(Some(Domain::int(vec![Range::Bounded(i32::MIN, i32::MAX)])));
109    }
110
111    let range_list = field!(int_domain, "ranges");
112    let mut ranges_unresolved: Vec<Range<IntVal>> = Vec::new();
113    let mut all_resolved = true;
114
115    for domain_component in named_children(&range_list) {
116        match domain_component.kind() {
117            "atom" | "arithmetic_expr" => {
118                let Some(int_val) = parse_int_val(ctx, domain_component)? else {
119                    return Ok(None);
120                };
121
122                if !matches!(int_val, IntVal::Const(_)) {
123                    all_resolved = false;
124                }
125                ranges_unresolved.push(Range::Single(int_val));
126            }
127            "int_range" => {
128                let lower_bound = match domain_component.child_by_field_name("lower") {
129                    Some(node) => {
130                        match parse_int_val(ctx, node)? {
131                            Some(val) => Some(val),
132                            None => return Ok(None), // semantic error occurred
133                        }
134                    }
135                    None => None,
136                };
137                let upper_bound = match domain_component.child_by_field_name("upper") {
138                    Some(node) => {
139                        match parse_int_val(ctx, node)? {
140                            Some(val) => Some(val),
141                            None => return Ok(None), // semantic error occurred
142                        }
143                    }
144                    None => None,
145                };
146
147                match (lower_bound, upper_bound) {
148                    (Some(lower), Some(upper)) => {
149                        // Check if both bounds are constants and validate lower <= upper
150                        if let (IntVal::Const(l), IntVal::Const(u)) = (&lower, &upper) {
151                            if l > u {
152                                ctx.record_error(crate::errors::RecoverableParseError::new(
153                                    format!(
154                                        "Invalid integer range: lower bound {} is greater than upper bound {}",
155                                        l, u
156                                    ),
157                                    Some(domain_component.range()),
158                                ));
159                            }
160                        } else {
161                            all_resolved = false;
162                        }
163                        ranges_unresolved.push(Range::Bounded(lower, upper));
164                    }
165                    (Some(lower), None) => {
166                        if !matches!(lower, IntVal::Const(_)) {
167                            all_resolved = false;
168                        }
169                        ranges_unresolved.push(Range::UnboundedR(lower));
170                    }
171                    (None, Some(upper)) => {
172                        if !matches!(upper, IntVal::Const(_)) {
173                            all_resolved = false;
174                        }
175                        ranges_unresolved.push(Range::UnboundedL(upper));
176                    }
177                    _ => {
178                        return Err(FatalParseError::internal_error(
179                            "Invalid int range: must have at least a lower or upper bound"
180                                .to_string(),
181                            Some(domain_component.range()),
182                        ));
183                    }
184                }
185            }
186            _ => {
187                return Err(FatalParseError::internal_error(
188                    format!(
189                        "Unexpected int domain component: {}",
190                        domain_component.kind()
191                    ),
192                    Some(domain_component.range()),
193                ));
194            }
195        }
196    }
197
198    // If all values are resolved constants, convert IntVals to raw integers
199    if all_resolved {
200        let ranges: Vec<Range<i32>> = ranges_unresolved
201            .into_iter()
202            .map(|r| match r {
203                Range::Single(IntVal::Const(v)) => Range::Single(v),
204                Range::Bounded(IntVal::Const(l), IntVal::Const(u)) => Range::Bounded(l, u),
205                Range::UnboundedR(IntVal::Const(l)) => Range::UnboundedR(l),
206                Range::UnboundedL(IntVal::Const(u)) => Range::UnboundedL(u),
207                Range::Unbounded => Range::Unbounded,
208                _ => unreachable!("all_resolved should be true only if all are Const"),
209            })
210            .collect();
211        Ok(Some(Domain::int(ranges)))
212    } else {
213        // Otherwise, keep as an expression-based domain
214        Ok(Some(Domain::int(ranges_unresolved)))
215    }
216}
217
218// Helper function to parse a node into an IntVal
219// Handles constants, references, and arbitrary expressions
220fn parse_int_val(ctx: &mut ParseContext, node: Node) -> Result<Option<IntVal>, FatalParseError> {
221    // For atoms, try to parse as a constant integer first
222    if node.kind() == "atom" {
223        let text = &ctx.source_code[node.start_byte()..node.end_byte()];
224        if let Ok(integer) = text.parse::<i32>() {
225            return Ok(Some(IntVal::Const(integer)));
226        }
227        // Otherwise, check if it's an identifier reference
228        let Some(decl) = get_declaration_ptr_from_identifier(ctx, node)? else {
229            // If identifier isn't defined, its a semantic error
230            return Ok(None);
231        };
232        return Ok(Some(IntVal::Reference(Reference::new(decl))));
233    }
234
235    // For anything else, parse as an expression
236    let Some(expr) = parse_expression(ctx, node)? else {
237        return Ok(None);
238    };
239    Ok(Some(IntVal::Expr(Moo::new(expr))))
240}
241
242fn parse_tuple_domain(
243    ctx: &mut ParseContext,
244    tuple_domain: Node,
245) -> Result<Option<DomainPtr>, FatalParseError> {
246    let mut domains: Vec<DomainPtr> = Vec::new();
247    for domain in named_children(&tuple_domain) {
248        let Some(parsed_domain) = parse_domain(ctx, domain)? else {
249            return Ok(None);
250        };
251        domains.push(parsed_domain);
252    }
253    Ok(Some(Domain::tuple(domains)))
254}
255
256fn parse_matrix_domain(
257    ctx: &mut ParseContext,
258    matrix_domain: Node,
259) -> Result<Option<DomainPtr>, FatalParseError> {
260    let mut domains: Vec<DomainPtr> = Vec::new();
261    let index_domain_list = field!(matrix_domain, "index_domain_list");
262    for domain in named_children(&index_domain_list) {
263        let Some(parsed_domain) = parse_domain(ctx, domain)? else {
264            return Ok(None);
265        };
266        domains.push(parsed_domain);
267    }
268    let Some(value_domain) = parse_domain(ctx, field!(matrix_domain, "value_domain"))? else {
269        return Ok(None);
270    };
271    Ok(Some(Domain::matrix(value_domain, domains)))
272}
273
274fn parse_record_domain(
275    ctx: &mut ParseContext,
276    record_domain: Node,
277) -> Result<Option<DomainPtr>, FatalParseError> {
278    let mut record_entries: Vec<RecordEntry> = Vec::new();
279    for record_entry in named_children(&record_domain) {
280        let name_node = field!(record_entry, "name");
281        let name = Name::user(&ctx.source_code[name_node.start_byte()..name_node.end_byte()]);
282        let domain_node = field!(record_entry, "domain");
283        let Some(domain) = parse_domain(ctx, domain_node)? else {
284            return Ok(None);
285        };
286        record_entries.push(RecordEntry { name, domain });
287    }
288    Ok(Some(Domain::record(record_entries)))
289}
290
291pub fn parse_set_domain(
292    ctx: &mut ParseContext,
293    set_domain: Node,
294) -> Result<Option<DomainPtr>, FatalParseError> {
295    let mut set_attribute: Option<SetAttr> = None;
296    let mut value_domain: Option<DomainPtr> = None;
297
298    for child in named_children(&set_domain) {
299        match child.kind() {
300            "set_attributes" => {
301                // Check if we have both minSize and maxSize (minMax case)
302                let min_value_node = child.child_by_field_name("min_value");
303                let max_value_node = child.child_by_field_name("max_value");
304                let size_value_node = child.child_by_field_name("size_value");
305
306                if let (Some(min_node), Some(max_node)) = (min_value_node, max_value_node) {
307                    // MinMax case
308                    let min_val = parse_int(ctx, &min_node)?;
309                    let max_val = parse_int(ctx, &max_node)?;
310
311                    set_attribute = Some(SetAttr::new_min_max_size(min_val, max_val));
312                } else if let Some(size_node) = size_value_node {
313                    // Size case
314                    let size_val = parse_int(ctx, &size_node)?;
315                    set_attribute = Some(SetAttr::new_size(size_val));
316                } else if let Some(min_node) = min_value_node {
317                    // MinSize only case
318                    let min_val = parse_int(ctx, &min_node)?;
319                    set_attribute = Some(SetAttr::new_min_size(min_val));
320                } else if let Some(max_node) = max_value_node {
321                    // MaxSize only case
322                    let max_val = parse_int(ctx, &max_node)?;
323                    set_attribute = Some(SetAttr::new_max_size(max_val));
324                }
325            }
326            "domain" => {
327                let Some(parsed_domain) = parse_domain(ctx, child)? else {
328                    return Ok(None);
329                };
330                value_domain = Some(parsed_domain);
331            }
332            _ => {
333                return Err(FatalParseError::internal_error(
334                    format!("Unrecognized set domain child kind: {}", child.kind()),
335                    Some(child.range()),
336                ));
337            }
338        }
339    }
340
341    if let Some(domain) = value_domain {
342        Ok(Some(Domain::set(set_attribute.unwrap_or_default(), domain)))
343    } else {
344        Err(FatalParseError::internal_error(
345            "Set domain must have a value domain".to_string(),
346            Some(set_domain.range()),
347        ))
348    }
349}