1
use crate::diagnostics::diagnostics_api::SymbolKind;
2
use crate::diagnostics::source_map::{HoverInfo, span_with_hover};
3
use crate::errors::{FatalParseError, RecoverableParseError};
4
use crate::expression::{parse_binary_expression, parse_expression};
5
use crate::parser::ParseContext;
6
use crate::parser::abstract_literal::parse_abstract;
7
use crate::parser::comprehension::parse_comprehension;
8
use crate::parser::dominance::parse_pareto_expression;
9
use crate::util::{TypecheckingContext, named_children};
10
use crate::{field, named_child};
11
use conjure_cp_core::ast::{
12
    Atom, DeclarationKind, DeclarationPtr, Expression, GroundDomain, Literal, Metadata, Moo, Name,
13
};
14
use tree_sitter::Node;
15
use ustr::Ustr;
16

            
17
371680
pub fn parse_atom(
18
371680
    ctx: &mut ParseContext,
19
371680
    node: &Node,
20
371680
) -> Result<Option<Expression>, FatalParseError> {
21
371680
    match node.kind() {
22
371680
        "atom" | "sub_atom_expr" => {
23
183888
            let Some(inner) = named_child!(recover, ctx, node) else {
24
                return Ok(None);
25
            };
26
183888
            parse_atom(ctx, &inner)
27
        }
28
187792
        "metavar" => {
29
403
            let Some(ident) = field!(recover, ctx, node, "identifier") else {
30
                return Ok(None);
31
            };
32
403
            let name_str = &ctx.source_code[ident.start_byte()..ident.end_byte()];
33
403
            Ok(Some(Expression::Metavar(
34
403
                Metadata::new(),
35
403
                Ustr::from(name_str),
36
403
            )))
37
        }
38
187389
        "identifier" => {
39
112088
            let Some(var) = parse_variable(ctx, node)? else {
40
352
                return Ok(None);
41
            };
42
111736
            Ok(Some(Expression::Atomic(Metadata::new(), var)))
43
        }
44
75301
        "from_solution" => {
45
34632
            if ctx.root.kind() != "dominance_relation" {
46
                ctx.record_error(RecoverableParseError::new(
47
                    "fromSolution only allowed inside dominance relations".to_string(),
48
                    Some(node.range()),
49
                ));
50
                return Ok(None);
51
34632
            }
52

            
53
34632
            let Some(var_node) = field!(recover, ctx, node, "variable") else {
54
                return Ok(None);
55
            };
56
34632
            let Some(inner) = parse_variable(ctx, &var_node)? else {
57
                return Ok(None);
58
            };
59

            
60
34632
            Ok(Some(Expression::FromSolution(
61
34632
                Metadata::new(),
62
34632
                Moo::new(inner),
63
34632
            )))
64
        }
65
40669
        "pareto_expression" => parse_pareto_expression(ctx, node),
66
39574
        "constant" => {
67
28023
            let Some(lit) = parse_constant(ctx, node)? else {
68
104
                return Ok(None);
69
            };
70
27919
            Ok(Some(Expression::Atomic(
71
27919
                Metadata::new(),
72
27919
                Atom::Literal(lit),
73
27919
            )))
74
        }
75
11551
        "matrix" | "record" | "tuple" | "set_literal" => {
76
5018
            let Some(abs) = parse_abstract(ctx, node)? else {
77
130
                return Ok(None);
78
            };
79
4888
            Ok(Some(Expression::AbstractLiteral(Metadata::new(), abs)))
80
        }
81
6533
        "flatten" => parse_flatten(ctx, node),
82
6273
        "table" | "negative_table" => parse_table(ctx, node),
83
6169
        "index_or_slice" => parse_index_or_slice(ctx, node),
84
        // for now, assume is binary since powerset isn't implemented
85
        // TODO: add powerset support under "set_operation"
86
867
        "set_operation" => parse_binary_expression(ctx, node),
87
711
        "comprehension" => parse_comprehension(ctx, node),
88
        _ => {
89
            ctx.record_error(RecoverableParseError::new(
90
                format!("Expected atom, got: {}", node.kind()),
91
                Some(node.range()),
92
            ));
93
            Ok(None)
94
        }
95
    }
96
371680
}
97

            
98
260
fn parse_flatten(
99
260
    ctx: &mut ParseContext,
100
260
    node: &Node,
101
260
) -> Result<Option<Expression>, FatalParseError> {
102
    // add error and return early if we're in a set context, since flatten doesn't produce sets
103
260
    if ctx.typechecking_context == TypecheckingContext::Set {
104
        ctx.record_error(RecoverableParseError::new(
105
            format!(
106
                "Type error: {}\n\tExpected: set\n\tGot: flatten",
107
                ctx.source_code[node.start_byte()..node.end_byte()].trim()
108
            ),
109
            Some(node.range()),
110
        ));
111
        return Ok(None);
112
260
    }
113

            
114
260
    let Some(expr_node) = field!(recover, ctx, node, "expression") else {
115
        return Ok(None);
116
    };
117
    // TODO: verify the atom is a matrix
118
260
    let Some(expr) = parse_atom(ctx, &expr_node)? else {
119
        return Ok(None);
120
    };
121

            
122
260
    if let Some(depth_node) = node.child_by_field_name("depth") {
123
        let Some(depth) = parse_int(ctx, &depth_node) else {
124
            return Ok(None);
125
        };
126
        let depth_expression =
127
            Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Int(depth)));
128
        Ok(Some(Expression::Flatten(
129
            Metadata::new(),
130
            Some(Moo::new(depth_expression)),
131
            Moo::new(expr),
132
        )))
133
    } else {
134
260
        Ok(Some(Expression::Flatten(
135
260
            Metadata::new(),
136
260
            None,
137
260
            Moo::new(expr),
138
260
        )))
139
    }
140
260
}
141

            
142
104
fn parse_table(ctx: &mut ParseContext, node: &Node) -> Result<Option<Expression>, FatalParseError> {
143
    // add error and return early if we're in a set context, since tables aren't allowed there
144
104
    if ctx.typechecking_context == TypecheckingContext::Set {
145
        ctx.record_error(RecoverableParseError::new(
146
            format!(
147
                "Type error: {}\n\tExpected: set\n\tGot: table",
148
                ctx.source_code[node.start_byte()..node.end_byte()].trim()
149
            ),
150
            Some(node.range()),
151
        ));
152
        return Ok(None);
153
104
    }
154

            
155
    // the variables and rows can contain arbitrary expressions, so we temporarily set the context to Unknown to avoid typechecking errors
156
104
    let saved_context = ctx.typechecking_context;
157
104
    ctx.typechecking_context = TypecheckingContext::Unknown;
158

            
159
104
    let Some(variables_node) = field!(recover, ctx, node, "variables") else {
160
        return Ok(None);
161
    };
162
104
    let Some(variables) = parse_atom(ctx, &variables_node)? else {
163
        return Ok(None);
164
    };
165

            
166
104
    let Some(rows_node) = field!(recover, ctx, node, "rows") else {
167
        return Ok(None);
168
    };
169
104
    let Some(rows) = parse_atom(ctx, &rows_node)? else {
170
        return Ok(None);
171
    };
172

            
173
104
    ctx.typechecking_context = saved_context;
174

            
175
104
    match node.kind() {
176
104
        "table" => Ok(Some(Expression::Table(
177
78
            Metadata::new(),
178
78
            Moo::new(variables),
179
78
            Moo::new(rows),
180
78
        ))),
181
26
        "negative_table" => Ok(Some(Expression::NegativeTable(
182
26
            Metadata::new(),
183
26
            Moo::new(variables),
184
26
            Moo::new(rows),
185
26
        ))),
186
        _ => {
187
            ctx.record_error(RecoverableParseError::new(
188
                format!(
189
                    "Expected 'table' or 'negative_table', got: '{}'",
190
                    node.kind()
191
                ),
192
                Some(node.range()),
193
            ));
194
            Ok(None)
195
        }
196
    }
197
104
}
198

            
199
5302
fn parse_index_or_slice(
200
5302
    ctx: &mut ParseContext,
201
5302
    node: &Node,
202
5302
) -> Result<Option<Expression>, FatalParseError> {
203
    // add error and return early if we're in a set context, since indexing/slicing doesn't produce sets
204
5302
    if ctx.typechecking_context == TypecheckingContext::Set {
205
        ctx.record_error(RecoverableParseError::new(
206
            format!(
207
                "Type error: {}\n\tExpected: set\n\tGot: index or slice",
208
                ctx.source_code[node.start_byte()..node.end_byte()].trim()
209
            ),
210
            Some(node.range()),
211
        ));
212
        return Ok(None);
213
5302
    }
214

            
215
    // Save current context and temporarily set to Unknown for the collection
216
5302
    let saved_context = ctx.typechecking_context;
217
5302
    ctx.typechecking_context = TypecheckingContext::Unknown;
218
5302
    let Some(collection_node) = field!(recover, ctx, node, "collection") else {
219
        return Ok(None);
220
    };
221
5302
    let Some(collection) = parse_atom(ctx, &collection_node)? else {
222
13
        return Ok(None);
223
    };
224
5289
    ctx.typechecking_context = saved_context;
225
5289
    let mut indices = Vec::new();
226
5289
    let Some(indices_node) = field!(recover, ctx, node, "indices") else {
227
        return Ok(None);
228
    };
229
7579
    for idx_node in named_children(&indices_node) {
230
7579
        indices.push(parse_index(ctx, &idx_node)?);
231
    }
232

            
233
7267
    let has_null_idx = indices.iter().any(|idx| idx.is_none());
234
    // TODO: We could check whether the slice/index is safe here
235
5289
    if has_null_idx {
236
        // It's a slice
237
885
        Ok(Some(Expression::UnsafeSlice(
238
885
            Metadata::new(),
239
885
            Moo::new(collection),
240
885
            indices,
241
885
        )))
242
    } else {
243
        // It's an index
244
5861
        let idx_exprs: Vec<Expression> = indices.into_iter().map(|idx| idx.unwrap()).collect();
245
4404
        Ok(Some(Expression::UnsafeIndex(
246
4404
            Metadata::new(),
247
4404
            Moo::new(collection),
248
4404
            idx_exprs,
249
4404
        )))
250
    }
251
5302
}
252

            
253
7579
fn parse_index(ctx: &mut ParseContext, node: &Node) -> Result<Option<Expression>, FatalParseError> {
254
7579
    match node.kind() {
255
7579
        "arithmetic_expr" | "atom" => {
256
6694
            let saved_context = ctx.typechecking_context;
257
6694
            ctx.typechecking_context = TypecheckingContext::Unknown;
258

            
259
            // TODO: add collection-aware index typechecking.
260
            // For tuple/matrix/set-like indexing, indices should be arithmetic.
261
            // For record field access, index atoms should resolve to valid field names.
262
            // This requires checking index expression together with the indexed collection type.
263

            
264
6694
            let Some(expr) = parse_expression(ctx, *node)? else {
265
                return Ok(None);
266
            };
267

            
268
6694
            ctx.typechecking_context = saved_context;
269
6694
            Ok(Some(expr))
270
        }
271
885
        "null_index" => Ok(None),
272
        _ => {
273
            ctx.record_error(RecoverableParseError::new(
274
                format!("Expected an index, got: '{}'", node.kind()),
275
                Some(node.range()),
276
            ));
277
            Ok(None)
278
        }
279
    }
280
7579
}
281

            
282
146720
fn parse_variable(ctx: &mut ParseContext, node: &Node) -> Result<Option<Atom>, FatalParseError> {
283
146720
    let raw_name = &ctx.source_code[node.start_byte()..node.end_byte()];
284

            
285
146720
    let name = Name::user(raw_name.trim());
286
146720
    if let Some(symbols) = &ctx.symbols {
287
146720
        let lookup_result = {
288
146720
            let symbols_read = symbols.read();
289
146720
            symbols_read.lookup(&name)
290
        };
291

            
292
146720
        if let Some(decl) = lookup_result {
293
146589
            let symbol_kind = match &decl.kind().clone() as &DeclarationKind {
294
137876
                DeclarationKind::Find(_) => SymbolKind::FindVar,
295
290
                DeclarationKind::Given(_) => SymbolKind::GivenVar,
296
2082
                DeclarationKind::ValueLetting(_, _) => SymbolKind::LettingVar,
297
                DeclarationKind::TemporaryValueLetting(_) => SymbolKind::LettingVar,
298
26
                DeclarationKind::DomainLetting(_) => SymbolKind::LettingVar,
299
6263
                DeclarationKind::Quantified(..) => SymbolKind::FindVar,
300
                DeclarationKind::QuantifiedExpr(..) => SymbolKind::FindVar,
301
52
                DeclarationKind::Field(_) => SymbolKind::Decimal,
302
                &_ => todo!(),
303
            };
304

            
305
146589
            let hover = HoverInfo {
306
146589
                description: format!("Variable: {name}"),
307
146589
                doc_key: None,
308
146589
                kind: Some(symbol_kind),
309
146589
                ty: decl.domain().map(|d| d.to_string()),
310
146589
                decl_span: ctx.lookup_decl_span(&name),
311
            };
312
146589
            span_with_hover(node, ctx.source_code, ctx.source_map, hover);
313

            
314
            // Type check the variable against the expected context
315
146589
            if let Some(error_msg) = typecheck_variable(&decl, ctx.typechecking_context, raw_name) {
316
221
                ctx.record_error(RecoverableParseError::new(error_msg, Some(node.range())));
317
221
                return Ok(None);
318
146368
            }
319

            
320
146368
            Ok(Some(Atom::Reference(conjure_cp_core::ast::Reference::new(
321
146368
                decl,
322
146368
            ))))
323
        } else {
324
131
            ctx.record_error(RecoverableParseError::new(
325
131
                format!("The identifier '{}' is not defined", raw_name),
326
131
                Some(node.range()),
327
            ));
328
131
            Ok(None)
329
        }
330
    } else {
331
        ctx.record_error(RecoverableParseError::new(
332
            format!("Symbol table missing when parsing variable '{raw_name}'"),
333
            Some(node.range()),
334
        ));
335
        Ok(None)
336
    }
337
146720
}
338

            
339
/// Type check a variable declaration against the expected expression context.
340
/// Returns an error message if the variable type doesn't match the context.
341
146589
fn typecheck_variable(
342
146589
    decl: &DeclarationPtr,
343
146589
    context: TypecheckingContext,
344
146589
    raw_name: &str,
345
146589
) -> Option<String> {
346
    // Only type check when context is known
347
146589
    if context == TypecheckingContext::Unknown {
348
16435
        return None;
349
130154
    }
350

            
351
    // Get the variable's domain and resolve it
352
130154
    let domain = decl.domain()?;
353
130154
    let ground_domain = domain.resolve()?;
354

            
355
    // Determine what type is expected
356
129748
    let expected = match context {
357
99398
        TypecheckingContext::Boolean => "bool",
358
29297
        TypecheckingContext::Arithmetic => "int",
359
585
        TypecheckingContext::Set => "set",
360
13
        TypecheckingContext::SetOrMatrix => "set or matrix",
361
        TypecheckingContext::MSet => "mset",
362
377
        TypecheckingContext::Matrix => "matrix",
363
52
        TypecheckingContext::Tuple => "tuple",
364
26
        TypecheckingContext::Record => "record",
365
        TypecheckingContext::Partition => "partition",
366
        TypecheckingContext::Sequence => "sequence",
367
        TypecheckingContext::Unknown => return None, // shouldn't reach here
368
    };
369

            
370
    // Determine what type we actually have
371
129748
    let actual = match ground_domain.as_ref() {
372
99359
        GroundDomain::Bool => "bool",
373
29336
        GroundDomain::Int(_) => "int",
374
429
        GroundDomain::Matrix(_, _) => "matrix",
375
533
        GroundDomain::Set(_, _) => "set",
376
        GroundDomain::MSet(_, _) => "mset",
377
52
        GroundDomain::Tuple(_) => "tuple",
378
39
        GroundDomain::Record(_) => "record",
379
        GroundDomain::Function(_, _, _) => "function",
380
        GroundDomain::Variant(_) => "variant",
381
        GroundDomain::Relation(_, _) => "relation",
382
        GroundDomain::Partition(_, _) => "partition",
383
        GroundDomain::Sequence(_, _) => "sequence",
384
        GroundDomain::Empty(_) => "empty",
385
    };
386

            
387
    // If types match, no error
388
129748
    if expected == actual
389
221
        || (context == TypecheckingContext::SetOrMatrix && matches!(actual, "set" | "matrix"))
390
    {
391
129527
        return None;
392
221
    }
393

            
394
    // Otherwise, report the type mismatch
395
221
    Some(format!(
396
221
        "Type error: {}\n\tExpected: {}\n\tGot: {}",
397
221
        raw_name, expected, actual
398
221
    ))
399
146589
}
400

            
401
28023
fn parse_constant(ctx: &mut ParseContext, node: &Node) -> Result<Option<Literal>, FatalParseError> {
402
28023
    let Some(inner) = named_child!(recover, ctx, node) else {
403
        return Ok(None);
404
    };
405
28023
    let raw_value = &ctx.source_code[inner.start_byte()..inner.end_byte()];
406
28023
    let lit = match inner.kind() {
407
28023
        "integer" => {
408
27320
            let Some(value) = parse_int(ctx, &inner) else {
409
26
                return Ok(None);
410
            };
411
27294
            Literal::Int(value)
412
        }
413
703
        "TRUE" => {
414
409
            let hover = HoverInfo {
415
409
                description: format!("Boolean constant: {raw_value}"),
416
409
                doc_key: None,
417
409
                kind: None,
418
409
                ty: None,
419
409
                decl_span: None,
420
409
            };
421
409
            span_with_hover(&inner, ctx.source_code, ctx.source_map, hover);
422
409
            Literal::Bool(true)
423
        }
424
294
        "FALSE" => {
425
294
            let hover = HoverInfo {
426
294
                description: format!("Boolean constant: {raw_value}"),
427
294
                doc_key: None,
428
294
                kind: None,
429
294
                ty: None,
430
294
                decl_span: None,
431
294
            };
432
294
            span_with_hover(&inner, ctx.source_code, ctx.source_map, hover);
433
294
            Literal::Bool(false)
434
        }
435
        _ => {
436
            ctx.record_error(RecoverableParseError::new(
437
                format!(
438
                    "'{}' (kind: '{}') is not a valid constant",
439
                    raw_value,
440
                    inner.kind()
441
                ),
442
                Some(inner.range()),
443
            ));
444
            return Ok(None);
445
        }
446
    };
447

            
448
    // Type check the constant against the expected context
449
27997
    if ctx.typechecking_context != TypecheckingContext::Unknown {
450
14709
        let expected = match ctx.typechecking_context {
451
293
            TypecheckingContext::Boolean => "bool",
452
14390
            TypecheckingContext::Arithmetic => "int",
453
13
            TypecheckingContext::Set => "set",
454
            TypecheckingContext::SetOrMatrix => "set or matrix",
455
            TypecheckingContext::MSet => "mset",
456
13
            TypecheckingContext::Matrix => "matrix",
457
            TypecheckingContext::Tuple => "tuple",
458
            TypecheckingContext::Record => "record",
459
            TypecheckingContext::Partition => "partition",
460
            TypecheckingContext::Sequence => "sequence",
461
            TypecheckingContext::Unknown => "",
462
        };
463

            
464
14709
        let actual = match &lit {
465
280
            Literal::Bool(_) => "bool",
466
14429
            Literal::Int(_) => "int",
467
            Literal::AbstractLiteral(_) => return Ok(None), // Abstract literals aren't type-checked here
468
        };
469

            
470
14709
        if expected != actual {
471
78
            ctx.record_error(RecoverableParseError::new(
472
78
                format!(
473
                    "Type error: {}\n\tExpected: {}\n\tGot: {}",
474
                    raw_value, expected, actual
475
                ),
476
78
                Some(node.range()),
477
            ));
478
78
            return Ok(None);
479
14631
        }
480
13288
    }
481
27919
    Ok(Some(lit))
482
28023
}
483

            
484
27398
pub(crate) fn parse_int(ctx: &mut ParseContext, node: &Node) -> Option<i32> {
485
27398
    let raw_value = &ctx.source_code[node.start_byte()..node.end_byte()];
486
27398
    if let Ok(v) = raw_value.parse::<i32>() {
487
27372
        Some(v)
488
    } else {
489
26
        ctx.record_error(RecoverableParseError::new(
490
26
            "Expected an integer here".to_string(),
491
26
            Some(node.range()),
492
        ));
493
26
        None
494
    }
495
27398
}