1
use crate::diagnostics::diagnostics_api::SymbolKind;
2
use crate::diagnostics::source_map::{HoverInfo, span_with_hover};
3
use crate::errors::{FatalParseError, RecoverableParseError};
4
use crate::expression::{parse_binary_expression, parse_expression};
5
use crate::parser::ParseContext;
6
use crate::parser::abstract_literal::parse_abstract;
7
use crate::parser::comprehension::parse_comprehension;
8
use crate::util::{TypecheckingContext, named_children};
9
use crate::{field, named_child};
10
use conjure_cp_core::ast::{
11
    Atom, DeclarationPtr, Expression, GroundDomain, Literal, Metadata, Moo, Name,
12
};
13
use tree_sitter::Node;
14
use ustr::Ustr;
15

            
16
61360
pub fn parse_atom(
17
61360
    ctx: &mut ParseContext,
18
61360
    node: &Node,
19
61360
) -> Result<Option<Expression>, FatalParseError> {
20
61360
    match node.kind() {
21
61360
        "atom" | "sub_atom_expr" => parse_atom(ctx, &named_child!(node)),
22
31518
        "metavar" => {
23
403
            let ident = field!(node, "identifier");
24
403
            let name_str = &ctx.source_code[ident.start_byte()..ident.end_byte()];
25
403
            Ok(Some(Expression::Metavar(
26
403
                Metadata::new(),
27
403
                Ustr::from(name_str),
28
403
            )))
29
        }
30
31115
        "identifier" => {
31
16171
            let Some(var) = parse_variable(ctx, node)? else {
32
155
                return Ok(None);
33
            };
34
16016
            Ok(Some(Expression::Atomic(Metadata::new(), var)))
35
        }
36
14944
        "from_solution" => {
37
            if ctx.root.kind() != "dominance_relation" {
38
                return Err(FatalParseError::internal_error(
39
                    "fromSolution only allowed inside dominance relations".to_string(),
40
                    Some(node.range()),
41
                ));
42
            }
43

            
44
            let Some(inner) = parse_variable(ctx, &field!(node, "variable"))? else {
45
                return Ok(None);
46
            };
47

            
48
            Ok(Some(Expression::FromSolution(
49
                Metadata::new(),
50
                Moo::new(inner),
51
            )))
52
        }
53
14944
        "constant" => {
54
10274
            let Some(lit) = parse_constant(ctx, node)? else {
55
22
                return Ok(None);
56
            };
57
10252
            Ok(Some(Expression::Atomic(
58
10252
                Metadata::new(),
59
10252
                Atom::Literal(lit),
60
10252
            )))
61
        }
62
4670
        "matrix" | "record" | "tuple" | "set_literal" => {
63
1940
            let Some(abs) = parse_abstract(ctx, node)? else {
64
22
                return Ok(None);
65
            };
66
1918
            Ok(Some(Expression::AbstractLiteral(Metadata::new(), abs)))
67
        }
68
2730
        "flatten" => parse_flatten(ctx, node),
69
2620
        "table" | "negative_table" => parse_table(ctx, node),
70
2576
        "index_or_slice" => parse_index_or_slice(ctx, node),
71
        // for now, assume is binary since powerset isn't implemented
72
        // TODO: add powerset support under "set_operation"
73
308
        "set_operation" => parse_binary_expression(ctx, node),
74
242
        "comprehension" => parse_comprehension(ctx, node),
75
        _ => Err(FatalParseError::internal_error(
76
            format!("Expected atom, got: {}", node.kind()),
77
            Some(node.range()),
78
        )),
79
    }
80
61360
}
81

            
82
110
fn parse_flatten(
83
110
    ctx: &mut ParseContext,
84
110
    node: &Node,
85
110
) -> Result<Option<Expression>, FatalParseError> {
86
110
    let expr_node = field!(node, "expression");
87
110
    let Some(expr) = parse_atom(ctx, &expr_node)? else {
88
        return Ok(None);
89
    };
90

            
91
110
    if node.child_by_field_name("depth").is_some() {
92
        let depth_node = field!(node, "depth");
93
        let depth = parse_int(ctx, &depth_node)?;
94
        let depth_expression =
95
            Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Int(depth)));
96
        Ok(Some(Expression::Flatten(
97
            Metadata::new(),
98
            Some(Moo::new(depth_expression)),
99
            Moo::new(expr),
100
        )))
101
    } else {
102
110
        Ok(Some(Expression::Flatten(
103
110
            Metadata::new(),
104
110
            None,
105
110
            Moo::new(expr),
106
110
        )))
107
    }
108
110
}
109

            
110
44
fn parse_table(ctx: &mut ParseContext, node: &Node) -> Result<Option<Expression>, FatalParseError> {
111
    // the variables and rows can contain arbitrary expressions, so we temporarily set the context to Unknown to avoid typechecking errors
112
44
    let saved_context = ctx.typechecking_context;
113
44
    ctx.typechecking_context = TypecheckingContext::Unknown;
114

            
115
44
    let variables_node = field!(node, "variables");
116
44
    let Some(variables) = parse_atom(ctx, &variables_node)? else {
117
        return Ok(None);
118
    };
119

            
120
44
    let rows_node = field!(node, "rows");
121
44
    let Some(rows) = parse_atom(ctx, &rows_node)? else {
122
        return Ok(None);
123
    };
124

            
125
44
    ctx.typechecking_context = saved_context;
126

            
127
44
    match node.kind() {
128
44
        "table" => Ok(Some(Expression::Table(
129
33
            Metadata::new(),
130
33
            Moo::new(variables),
131
33
            Moo::new(rows),
132
33
        ))),
133
11
        "negative_table" => Ok(Some(Expression::NegativeTable(
134
11
            Metadata::new(),
135
11
            Moo::new(variables),
136
11
            Moo::new(rows),
137
11
        ))),
138
        _ => Err(FatalParseError::internal_error(
139
            format!(
140
                "Expected 'table' or 'negative_table', got: '{}'",
141
                node.kind()
142
            ),
143
            Some(node.range()),
144
        )),
145
    }
146
44
}
147

            
148
2268
fn parse_index_or_slice(
149
2268
    ctx: &mut ParseContext,
150
2268
    node: &Node,
151
2268
) -> Result<Option<Expression>, FatalParseError> {
152
    // Save current context and temporarily set to Unknown for the collection
153
2268
    let saved_context = ctx.typechecking_context;
154
2268
    ctx.typechecking_context = TypecheckingContext::Unknown;
155
2268
    let Some(collection) = parse_atom(ctx, &field!(node, "collection"))? else {
156
11
        return Ok(None);
157
    };
158
2257
    ctx.typechecking_context = saved_context;
159
2257
    let mut indices = Vec::new();
160
3194
    for idx_node in named_children(&field!(node, "indices")) {
161
3194
        indices.push(parse_index(ctx, &idx_node)?);
162
    }
163

            
164
3062
    let has_null_idx = indices.iter().any(|idx| idx.is_none());
165
    // TODO: We could check whether the slice/index is safe here
166
2257
    if has_null_idx {
167
        // It's a slice
168
397
        Ok(Some(Expression::UnsafeSlice(
169
397
            Metadata::new(),
170
397
            Moo::new(collection),
171
397
            indices,
172
397
        )))
173
    } else {
174
        // It's an index
175
2422
        let idx_exprs: Vec<Expression> = indices.into_iter().map(|idx| idx.unwrap()).collect();
176
1860
        Ok(Some(Expression::UnsafeIndex(
177
1860
            Metadata::new(),
178
1860
            Moo::new(collection),
179
1860
            idx_exprs,
180
1860
        )))
181
    }
182
2268
}
183

            
184
3194
fn parse_index(ctx: &mut ParseContext, node: &Node) -> Result<Option<Expression>, FatalParseError> {
185
3194
    match node.kind() {
186
3194
        "arithmetic_expr" | "atom" => {
187
2797
            let saved_context = ctx.typechecking_context;
188
2797
            ctx.typechecking_context = TypecheckingContext::Unknown;
189

            
190
            // TODO: add collection-aware index typechecking.
191
            // For tuple/matrix/set-like indexing, indices should be arithmetic.
192
            // For record field access, index atoms should resolve to valid field names.
193
            // This requires checking index expression together with the indexed collection type.
194

            
195
2797
            let Some(expr) = parse_expression(ctx, *node)? else {
196
                return Ok(None);
197
            };
198

            
199
2797
            ctx.typechecking_context = saved_context;
200
2797
            Ok(Some(expr))
201
        }
202
397
        "null_index" => Ok(None),
203
        _ => Err(FatalParseError::internal_error(
204
            format!("Expected an index, got: '{}'", node.kind()),
205
            Some(node.range()),
206
        )),
207
    }
208
3194
}
209

            
210
16171
fn parse_variable(ctx: &mut ParseContext, node: &Node) -> Result<Option<Atom>, FatalParseError> {
211
16171
    let raw_name = &ctx.source_code[node.start_byte()..node.end_byte()];
212
16171
    let name = Name::user(raw_name.trim());
213
16171
    if let Some(symbols) = &ctx.symbols {
214
16171
        let lookup_result = {
215
16171
            let symbols_read = symbols.read();
216
16171
            symbols_read.lookup(&name)
217
        };
218

            
219
16171
        if let Some(decl) = lookup_result {
220
16104
            let hover = HoverInfo {
221
16104
                description: format!("Variable: {name}"),
222
16104
                kind: Some(SymbolKind::Decimal),
223
16104
                ty: decl.domain().map(|d| d.to_string()),
224
16104
                decl_span: None,
225
            };
226
16104
            span_with_hover(node, ctx.source_code, ctx.source_map, hover);
227

            
228
            // Type check the variable against the expected context
229
16104
            if let Some(error_msg) = typecheck_variable(&decl, ctx.typechecking_context) {
230
88
                ctx.record_error(RecoverableParseError::new(error_msg, Some(node.range())));
231
88
                return Ok(None);
232
16016
            }
233

            
234
16016
            Ok(Some(Atom::Reference(conjure_cp_core::ast::Reference::new(
235
16016
                decl,
236
16016
            ))))
237
        } else {
238
67
            ctx.record_error(RecoverableParseError::new(
239
67
                format!("The identifier '{}' is not defined", raw_name),
240
67
                Some(node.range()),
241
            ));
242
67
            Ok(None)
243
        }
244
    } else {
245
        Err(FatalParseError::internal_error(
246
            format!("Symbol table missing when parsing variable '{raw_name}'"),
247
            Some(node.range()),
248
        ))
249
    }
250
16171
}
251

            
252
/// Type check a variable declaration against the expected expression context.
253
/// Returns an error message if the variable type doesn't match the context.
254
16104
fn typecheck_variable(decl: &DeclarationPtr, context: TypecheckingContext) -> Option<String> {
255
    // Only type check when context is known
256
16104
    if context == TypecheckingContext::Unknown {
257
6900
        return None;
258
9204
    }
259

            
260
    // Get the variable's domain and resolve it
261
9204
    let domain = decl.domain()?;
262
9204
    let ground_domain = domain.resolve()?;
263

            
264
    // Determine what type is expected
265
9138
    let expected = match context {
266
1661
        TypecheckingContext::Boolean => "bool",
267
7477
        TypecheckingContext::Arithmetic => "int",
268
        TypecheckingContext::Unknown => return None, // shouldn't reach here
269
    };
270

            
271
    // Determine what type we actually have
272
9138
    let actual = match ground_domain.as_ref() {
273
1639
        GroundDomain::Bool => "bool",
274
7488
        GroundDomain::Int(_) => "int",
275
11
        GroundDomain::Matrix(_, _) => "matrix",
276
        GroundDomain::Set(_, _) => "set",
277
        GroundDomain::MSet(_, _) => "mset",
278
        GroundDomain::Tuple(_) => "tuple",
279
        GroundDomain::Record(_) => "record",
280
        GroundDomain::Function(_, _, _) => "function",
281
        GroundDomain::Empty(_) => "empty",
282
    };
283

            
284
    // If types match, no error
285
9138
    if expected == actual {
286
9050
        return None;
287
88
    }
288

            
289
    // Otherwise, report the type mismatch
290
88
    Some(format!(
291
88
        "Type error:\n\tExpected: {}\n\tGot: {}",
292
88
        expected, actual
293
88
    ))
294
16104
}
295

            
296
10274
fn parse_constant(ctx: &mut ParseContext, node: &Node) -> Result<Option<Literal>, FatalParseError> {
297
10274
    let inner = named_child!(node);
298
10274
    let raw_value = &ctx.source_code[inner.start_byte()..inner.end_byte()];
299
10274
    let lit = match inner.kind() {
300
10274
        "integer" => {
301
9934
            let value = parse_int(ctx, &inner)?;
302
9934
            Literal::Int(value)
303
        }
304
340
        "TRUE" => {
305
189
            let hover = HoverInfo {
306
189
                description: format!("Boolean constant: {raw_value}"),
307
189
                kind: None,
308
189
                ty: None,
309
189
                decl_span: None,
310
189
            };
311
189
            span_with_hover(&inner, ctx.source_code, ctx.source_map, hover);
312
189
            Literal::Bool(true)
313
        }
314
151
        "FALSE" => {
315
151
            let hover = HoverInfo {
316
151
                description: format!("Boolean constant: {raw_value}"),
317
151
                kind: None,
318
151
                ty: None,
319
151
                decl_span: None,
320
151
            };
321
151
            span_with_hover(&inner, ctx.source_code, ctx.source_map, hover);
322
151
            Literal::Bool(false)
323
        }
324
        _ => {
325
            return Err(FatalParseError::internal_error(
326
                format!(
327
                    "'{}' (kind: '{}') is not a valid constant",
328
                    raw_value,
329
                    inner.kind()
330
                ),
331
                Some(inner.range()),
332
            ));
333
        }
334
    };
335

            
336
    // Type check the constant against the expected context
337
10274
    if ctx.typechecking_context != TypecheckingContext::Unknown {
338
4737
        let expected = match ctx.typechecking_context {
339
99
            TypecheckingContext::Boolean => "bool",
340
4638
            TypecheckingContext::Arithmetic => "int",
341
            TypecheckingContext::Unknown => "",
342
        };
343

            
344
4737
        let actual = match &lit {
345
99
            Literal::Bool(_) => "bool",
346
4638
            Literal::Int(_) => "int",
347
            Literal::AbstractLiteral(_) => return Ok(None), // Abstract literals aren't type-checked here
348
        };
349

            
350
4737
        if expected != actual {
351
22
            ctx.record_error(RecoverableParseError::new(
352
22
                format!("Type error:\n\tExpected: {}\n\tGot: {}", expected, actual),
353
22
                Some(node.range()),
354
            ));
355
22
            return Ok(None);
356
4715
        }
357
5537
    }
358
10252
    Ok(Some(lit))
359
10274
}
360

            
361
9967
pub(crate) fn parse_int(ctx: &ParseContext, node: &Node) -> Result<i32, FatalParseError> {
362
9967
    let raw_value = &ctx.source_code[node.start_byte()..node.end_byte()];
363
9967
    raw_value.parse::<i32>().map_err(|_e| {
364
        FatalParseError::internal_error("Expected an integer here".to_string(), Some(node.range()))
365
    })
366
9967
}