1
use super::atom::parse_int;
2
use super::util::named_children;
3
use crate::diagnostics::source_map::{HoverInfo, span_with_hover};
4
use crate::errors::FatalParseError;
5
use crate::expression::parse_expression;
6
use crate::parser::ParseContext;
7
use crate::{child, field};
8
use conjure_cp_core::ast::{
9
    DeclarationPtr, Domain, DomainPtr, IntVal, Moo, Name, Range, RecordEntry, Reference, SetAttr,
10
};
11
use tree_sitter::Node;
12

            
13
/// Parse an Essence variable domain into its Conjure AST representation.
14
3450
pub fn parse_domain(
15
3450
    ctx: &mut ParseContext,
16
3450
    domain: Node,
17
3450
) -> Result<Option<DomainPtr>, FatalParseError> {
18
3450
    match domain.kind() {
19
3450
        "domain" => parse_domain(ctx, child!(domain, 0, "domain")),
20
1731
        "bool_domain" => {
21
558
            let hover = HoverInfo {
22
558
                description: "Boolean domain".to_string(),
23
558
                kind: Some(crate::diagnostics::diagnostics_api::SymbolKind::Domain),
24
558
                ty: None,
25
558
                decl_span: None,
26
558
            };
27
558
            span_with_hover(&domain, ctx.source_code, ctx.source_map, hover);
28
558
            Ok(Some(Domain::bool()))
29
        }
30
1173
        "int_domain" => {
31
1107
            let hover = HoverInfo {
32
1107
                description: "Integer domain".to_string(),
33
1107
                kind: Some(crate::diagnostics::diagnostics_api::SymbolKind::Domain),
34
1107
                ty: None,
35
1107
                decl_span: None,
36
1107
            };
37
1107
            span_with_hover(&domain, ctx.source_code, ctx.source_map, hover);
38
1107
            parse_int_domain(ctx, domain)
39
        }
40
66
        "identifier" => {
41
44
            let Some(decl) = get_declaration_ptr_from_identifier(ctx, domain)? else {
42
                return Ok(None);
43
            };
44
44
            let Some(dom) = Domain::reference(decl) else {
45
11
                ctx.record_error(crate::errors::RecoverableParseError::new(
46
11
                    format!(
47
                        "The identifier '{}' is not a valid domain",
48
11
                        &ctx.source_code[domain.start_byte()..domain.end_byte()]
49
                    ),
50
11
                    Some(domain.range()),
51
                ));
52
11
                return Ok(None);
53
            };
54
33
            let name = &ctx.source_code[domain.start_byte()..domain.end_byte()];
55
33
            let hover = HoverInfo {
56
33
                description: format!("Domain reference: {name}"),
57
33
                kind: None,
58
33
                ty: None,
59
33
                decl_span: None,
60
33
            };
61
33
            span_with_hover(&domain, ctx.source_code, ctx.source_map, hover);
62
33
            Ok(Some(dom))
63
        }
64
22
        "tuple_domain" => parse_tuple_domain(ctx, domain),
65
22
        "matrix_domain" => parse_matrix_domain(ctx, domain),
66
        "record_domain" => parse_record_domain(ctx, domain),
67
        "set_domain" => parse_set_domain(ctx, domain),
68
        _ => Err(FatalParseError::internal_error(
69
            format!("{} is not a supported domain type", domain.kind()),
70
            Some(domain.range()),
71
        )),
72
    }
73
3450
}
74

            
75
330
fn get_declaration_ptr_from_identifier(
76
330
    ctx: &mut ParseContext,
77
330
    identifier: Node,
78
330
) -> Result<Option<DeclarationPtr>, FatalParseError> {
79
330
    let name = Name::user(&ctx.source_code[identifier.start_byte()..identifier.end_byte()]);
80
330
    let decl = ctx
81
330
        .symbols
82
330
        .as_ref()
83
330
        .ok_or(FatalParseError::internal_error(
84
330
            "context needed to resolve identifier".to_string(),
85
330
            Some(identifier.range()),
86
        ))?
87
330
        .read()
88
330
        .lookup(&name);
89
330
    match decl {
90
308
        Some(decl) => Ok(Some(decl)),
91
        None => {
92
22
            ctx.record_error(crate::errors::RecoverableParseError::new(
93
22
                format!("The identifier '{}' is not defined", name),
94
22
                Some(identifier.range()),
95
            ));
96
22
            Ok(None)
97
        }
98
    }
99
330
}
100

            
101
/// Parse an integer domain. Can be a single integer or a range.
102
1107
fn parse_int_domain(
103
1107
    ctx: &mut ParseContext,
104
1107
    int_domain: Node,
105
1107
) -> Result<Option<DomainPtr>, FatalParseError> {
106
1107
    if int_domain.child_count() == 1 {
107
        // for domains of just 'int' with no range
108
44
        return Ok(Some(Domain::int(vec![Range::Bounded(i32::MIN, i32::MAX)])));
109
1063
    }
110

            
111
1063
    let range_list = field!(int_domain, "ranges");
112
1063
    let mut ranges_unresolved: Vec<Range<IntVal>> = Vec::new();
113
1063
    let mut all_resolved = true;
114

            
115
1290
    for domain_component in named_children(&range_list) {
116
1290
        match domain_component.kind() {
117
1290
            "atom" | "arithmetic_expr" => {
118
27
                let Some(int_val) = parse_int_val(ctx, domain_component)? else {
119
                    return Ok(None);
120
                };
121

            
122
27
                if !matches!(int_val, IntVal::Const(_)) {
123
                    all_resolved = false;
124
27
                }
125
27
                ranges_unresolved.push(Range::Single(int_val));
126
            }
127
1263
            "int_range" => {
128
1263
                let lower_bound = match domain_component.child_by_field_name("lower") {
129
1263
                    Some(node) => {
130
1263
                        match parse_int_val(ctx, node)? {
131
1263
                            Some(val) => Some(val),
132
                            None => return Ok(None), // semantic error occurred
133
                        }
134
                    }
135
                    None => None,
136
                };
137
1263
                let upper_bound = match domain_component.child_by_field_name("upper") {
138
1263
                    Some(node) => {
139
1263
                        match parse_int_val(ctx, node)? {
140
1241
                            Some(val) => Some(val),
141
22
                            None => return Ok(None), // semantic error occurred
142
                        }
143
                    }
144
                    None => None,
145
                };
146

            
147
1241
                match (lower_bound, upper_bound) {
148
1241
                    (Some(lower), Some(upper)) => {
149
                        // Check if both bounds are constants and validate lower <= upper
150
1241
                        if let (IntVal::Const(l), IntVal::Const(u)) = (&lower, &upper) {
151
878
                            if l > u {
152
33
                                ctx.record_error(crate::errors::RecoverableParseError::new(
153
33
                                    format!(
154
33
                                        "Invalid integer range: lower bound {} is greater than upper bound {}",
155
33
                                        l, u
156
33
                                    ),
157
33
                                    Some(domain_component.range()),
158
33
                                ));
159
845
                            }
160
363
                        } else {
161
363
                            all_resolved = false;
162
363
                        }
163
1241
                        ranges_unresolved.push(Range::Bounded(lower, upper));
164
                    }
165
                    (Some(lower), None) => {
166
                        if !matches!(lower, IntVal::Const(_)) {
167
                            all_resolved = false;
168
                        }
169
                        ranges_unresolved.push(Range::UnboundedR(lower));
170
                    }
171
                    (None, Some(upper)) => {
172
                        if !matches!(upper, IntVal::Const(_)) {
173
                            all_resolved = false;
174
                        }
175
                        ranges_unresolved.push(Range::UnboundedL(upper));
176
                    }
177
                    _ => {
178
                        return Err(FatalParseError::internal_error(
179
                            "Invalid int range: must have at least a lower or upper bound"
180
                                .to_string(),
181
                            Some(domain_component.range()),
182
                        ));
183
                    }
184
                }
185
            }
186
            _ => {
187
                return Err(FatalParseError::internal_error(
188
                    format!(
189
                        "Unexpected int domain component: {}",
190
                        domain_component.kind()
191
                    ),
192
                    Some(domain_component.range()),
193
                ));
194
            }
195
        }
196
    }
197

            
198
    // If all values are resolved constants, convert IntVals to raw integers
199
1041
    if all_resolved {
200
744
        let ranges: Vec<Range<i32>> = ranges_unresolved
201
744
            .into_iter()
202
753
            .map(|r| match r {
203
27
                Range::Single(IntVal::Const(v)) => Range::Single(v),
204
878
                Range::Bounded(IntVal::Const(l), IntVal::Const(u)) => Range::Bounded(l, u),
205
                Range::UnboundedR(IntVal::Const(l)) => Range::UnboundedR(l),
206
                Range::UnboundedL(IntVal::Const(u)) => Range::UnboundedL(u),
207
                Range::Unbounded => Range::Unbounded,
208
                _ => unreachable!("all_resolved should be true only if all are Const"),
209
905
            })
210
744
            .collect();
211
744
        Ok(Some(Domain::int(ranges)))
212
    } else {
213
        // Otherwise, keep as an expression-based domain
214
297
        Ok(Some(Domain::int(ranges_unresolved)))
215
    }
216
1107
}
217

            
218
// Helper function to parse a node into an IntVal
219
// Handles constants, references, and arbitrary expressions
220
2553
fn parse_int_val(ctx: &mut ParseContext, node: Node) -> Result<Option<IntVal>, FatalParseError> {
221
    // For atoms, try to parse as a constant integer first
222
2553
    if node.kind() == "atom" {
223
2421
        let text = &ctx.source_code[node.start_byte()..node.end_byte()];
224
2421
        if let Ok(integer) = text.parse::<i32>() {
225
2135
            return Ok(Some(IntVal::Const(integer)));
226
286
        }
227
        // Otherwise, check if it's an identifier reference
228
286
        let Some(decl) = get_declaration_ptr_from_identifier(ctx, node)? else {
229
            // If identifier isn't defined, its a semantic error
230
22
            return Ok(None);
231
        };
232
264
        return Ok(Some(IntVal::Reference(Reference::new(decl))));
233
132
    }
234

            
235
    // For anything else, parse as an expression
236
132
    let Some(expr) = parse_expression(ctx, node)? else {
237
        return Ok(None);
238
    };
239
132
    Ok(Some(IntVal::Expr(Moo::new(expr))))
240
2553
}
241

            
242
fn parse_tuple_domain(
243
    ctx: &mut ParseContext,
244
    tuple_domain: Node,
245
) -> Result<Option<DomainPtr>, FatalParseError> {
246
    let mut domains: Vec<DomainPtr> = Vec::new();
247
    for domain in named_children(&tuple_domain) {
248
        let Some(parsed_domain) = parse_domain(ctx, domain)? else {
249
            return Ok(None);
250
        };
251
        domains.push(parsed_domain);
252
    }
253
    Ok(Some(Domain::tuple(domains)))
254
}
255

            
256
22
fn parse_matrix_domain(
257
22
    ctx: &mut ParseContext,
258
22
    matrix_domain: Node,
259
22
) -> Result<Option<DomainPtr>, FatalParseError> {
260
22
    let mut domains: Vec<DomainPtr> = Vec::new();
261
22
    let index_domain_list = field!(matrix_domain, "index_domain_list");
262
22
    for domain in named_children(&index_domain_list) {
263
22
        let Some(parsed_domain) = parse_domain(ctx, domain)? else {
264
            return Ok(None);
265
        };
266
22
        domains.push(parsed_domain);
267
    }
268
22
    let Some(value_domain) = parse_domain(ctx, field!(matrix_domain, "value_domain"))? else {
269
        return Ok(None);
270
    };
271
22
    Ok(Some(Domain::matrix(value_domain, domains)))
272
22
}
273

            
274
fn parse_record_domain(
275
    ctx: &mut ParseContext,
276
    record_domain: Node,
277
) -> Result<Option<DomainPtr>, FatalParseError> {
278
    let mut record_entries: Vec<RecordEntry> = Vec::new();
279
    for record_entry in named_children(&record_domain) {
280
        let name_node = field!(record_entry, "name");
281
        let name = Name::user(&ctx.source_code[name_node.start_byte()..name_node.end_byte()]);
282
        let domain_node = field!(record_entry, "domain");
283
        let Some(domain) = parse_domain(ctx, domain_node)? else {
284
            return Ok(None);
285
        };
286
        record_entries.push(RecordEntry { name, domain });
287
    }
288
    Ok(Some(Domain::record(record_entries)))
289
}
290

            
291
pub fn parse_set_domain(
292
    ctx: &mut ParseContext,
293
    set_domain: Node,
294
) -> Result<Option<DomainPtr>, FatalParseError> {
295
    let mut set_attribute: Option<SetAttr> = None;
296
    let mut value_domain: Option<DomainPtr> = None;
297

            
298
    for child in named_children(&set_domain) {
299
        match child.kind() {
300
            "set_attributes" => {
301
                // Check if we have both minSize and maxSize (minMax case)
302
                let min_value_node = child.child_by_field_name("min_value");
303
                let max_value_node = child.child_by_field_name("max_value");
304
                let size_value_node = child.child_by_field_name("size_value");
305

            
306
                if let (Some(min_node), Some(max_node)) = (min_value_node, max_value_node) {
307
                    // MinMax case
308
                    let min_val = parse_int(ctx, &min_node)?;
309
                    let max_val = parse_int(ctx, &max_node)?;
310

            
311
                    set_attribute = Some(SetAttr::new_min_max_size(min_val, max_val));
312
                } else if let Some(size_node) = size_value_node {
313
                    // Size case
314
                    let size_val = parse_int(ctx, &size_node)?;
315
                    set_attribute = Some(SetAttr::new_size(size_val));
316
                } else if let Some(min_node) = min_value_node {
317
                    // MinSize only case
318
                    let min_val = parse_int(ctx, &min_node)?;
319
                    set_attribute = Some(SetAttr::new_min_size(min_val));
320
                } else if let Some(max_node) = max_value_node {
321
                    // MaxSize only case
322
                    let max_val = parse_int(ctx, &max_node)?;
323
                    set_attribute = Some(SetAttr::new_max_size(max_val));
324
                }
325
            }
326
            "domain" => {
327
                let Some(parsed_domain) = parse_domain(ctx, child)? else {
328
                    return Ok(None);
329
                };
330
                value_domain = Some(parsed_domain);
331
            }
332
            _ => {
333
                return Err(FatalParseError::internal_error(
334
                    format!("Unrecognized set domain child kind: {}", child.kind()),
335
                    Some(child.range()),
336
                ));
337
            }
338
        }
339
    }
340

            
341
    if let Some(domain) = value_domain {
342
        Ok(Some(Domain::set(set_attribute.unwrap_or_default(), domain)))
343
    } else {
344
        Err(FatalParseError::internal_error(
345
            "Set domain must have a value domain".to_string(),
346
            Some(set_domain.range()),
347
        ))
348
    }
349
}