1
use crate::diagnostics::diagnostics_api::SymbolKind;
2
use crate::diagnostics::source_map::{HoverInfo, span_with_hover};
3
use crate::errors::{FatalParseError, RecoverableParseError};
4
use crate::expression::{parse_binary_expression, parse_expression, parse_pareto_expression};
5
use crate::parser::ParseContext;
6
use crate::parser::abstract_literal::parse_abstract;
7
use crate::parser::comprehension::parse_comprehension;
8
use crate::util::{TypecheckingContext, named_children};
9
use crate::{field, named_child};
10
use conjure_cp_core::ast::{
11
    Atom, DeclarationPtr, Expression, GroundDomain, Literal, Metadata, Moo, Name,
12
};
13
use tree_sitter::Node;
14
use ustr::Ustr;
15

            
16
369522
pub fn parse_atom(
17
369522
    ctx: &mut ParseContext,
18
369522
    node: &Node,
19
369522
) -> Result<Option<Expression>, FatalParseError> {
20
369522
    match node.kind() {
21
369522
        "atom" | "sub_atom_expr" => {
22
182809
            let Some(inner) = named_child!(recover, ctx, node) else {
23
                return Ok(None);
24
            };
25
182809
            parse_atom(ctx, &inner)
26
        }
27
186713
        "metavar" => {
28
403
            let Some(ident) = field!(recover, ctx, node, "identifier") else {
29
                return Ok(None);
30
            };
31
403
            let name_str = &ctx.source_code[ident.start_byte()..ident.end_byte()];
32
403
            Ok(Some(Expression::Metavar(
33
403
                Metadata::new(),
34
403
                Ustr::from(name_str),
35
403
            )))
36
        }
37
186310
        "identifier" => {
38
111347
            let Some(var) = parse_variable(ctx, node)? else {
39
300
                return Ok(None);
40
            };
41
111047
            Ok(Some(Expression::Atomic(Metadata::new(), var)))
42
        }
43
74963
        "from_solution" => {
44
34632
            if ctx.root.kind() != "dominance_relation" {
45
                ctx.record_error(RecoverableParseError::new(
46
                    "fromSolution only allowed inside dominance relations".to_string(),
47
                    Some(node.range()),
48
                ));
49
                return Ok(None);
50
34632
            }
51

            
52
34632
            let Some(var_node) = field!(recover, ctx, node, "variable") else {
53
                return Ok(None);
54
            };
55
34632
            let Some(inner) = parse_variable(ctx, &var_node)? else {
56
                return Ok(None);
57
            };
58

            
59
34632
            Ok(Some(Expression::FromSolution(
60
34632
                Metadata::new(),
61
34632
                Moo::new(inner),
62
34632
            )))
63
        }
64
40331
        "pareto_expression" => parse_pareto_expression(ctx, node),
65
39236
        "constant" => {
66
27724
            let Some(lit) = parse_constant(ctx, node)? else {
67
91
                return Ok(None);
68
            };
69
27633
            Ok(Some(Expression::Atomic(
70
27633
                Metadata::new(),
71
27633
                Atom::Literal(lit),
72
27633
            )))
73
        }
74
11512
        "matrix" | "record" | "tuple" | "set_literal" => {
75
4979
            let Some(abs) = parse_abstract(ctx, node)? else {
76
78
                return Ok(None);
77
            };
78
4901
            Ok(Some(Expression::AbstractLiteral(Metadata::new(), abs)))
79
        }
80
6533
        "flatten" => parse_flatten(ctx, node),
81
6273
        "table" | "negative_table" => parse_table(ctx, node),
82
6169
        "index_or_slice" => parse_index_or_slice(ctx, node),
83
        // for now, assume is binary since powerset isn't implemented
84
        // TODO: add powerset support under "set_operation"
85
867
        "set_operation" => parse_binary_expression(ctx, node),
86
711
        "comprehension" => parse_comprehension(ctx, node),
87
        _ => {
88
            ctx.record_error(RecoverableParseError::new(
89
                format!("Expected atom, got: {}", node.kind()),
90
                Some(node.range()),
91
            ));
92
            Ok(None)
93
        }
94
    }
95
369522
}
96

            
97
260
fn parse_flatten(
98
260
    ctx: &mut ParseContext,
99
260
    node: &Node,
100
260
) -> Result<Option<Expression>, FatalParseError> {
101
    // add error and return early if we're in a set context, since flatten doesn't produce sets
102
260
    if ctx.typechecking_context == TypecheckingContext::Set {
103
        ctx.record_error(RecoverableParseError::new(
104
            format!(
105
                "Type error: {}\n\tExpected: set\n\tGot: flatten",
106
                ctx.source_code[node.start_byte()..node.end_byte()].trim()
107
            ),
108
            Some(node.range()),
109
        ));
110
        return Ok(None);
111
260
    }
112

            
113
260
    let Some(expr_node) = field!(recover, ctx, node, "expression") else {
114
        return Ok(None);
115
    };
116
260
    let Some(expr) = parse_atom(ctx, &expr_node)? else {
117
        return Ok(None);
118
    };
119

            
120
260
    if let Some(depth_node) = node.child_by_field_name("depth") {
121
        let Some(depth) = parse_int(ctx, &depth_node) else {
122
            return Ok(None);
123
        };
124
        let depth_expression =
125
            Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Int(depth)));
126
        Ok(Some(Expression::Flatten(
127
            Metadata::new(),
128
            Some(Moo::new(depth_expression)),
129
            Moo::new(expr),
130
        )))
131
    } else {
132
260
        Ok(Some(Expression::Flatten(
133
260
            Metadata::new(),
134
260
            None,
135
260
            Moo::new(expr),
136
260
        )))
137
    }
138
260
}
139

            
140
104
fn parse_table(ctx: &mut ParseContext, node: &Node) -> Result<Option<Expression>, FatalParseError> {
141
    // add error and return early if we're in a set context, since tables aren't allowed there
142
104
    if ctx.typechecking_context == TypecheckingContext::Set {
143
        ctx.record_error(RecoverableParseError::new(
144
            format!(
145
                "Type error: {}\n\tExpected: set\n\tGot: table",
146
                ctx.source_code[node.start_byte()..node.end_byte()].trim()
147
            ),
148
            Some(node.range()),
149
        ));
150
        return Ok(None);
151
104
    }
152

            
153
    // the variables and rows can contain arbitrary expressions, so we temporarily set the context to Unknown to avoid typechecking errors
154
104
    let saved_context = ctx.typechecking_context;
155
104
    ctx.typechecking_context = TypecheckingContext::Unknown;
156

            
157
104
    let Some(variables_node) = field!(recover, ctx, node, "variables") else {
158
        return Ok(None);
159
    };
160
104
    let Some(variables) = parse_atom(ctx, &variables_node)? else {
161
        return Ok(None);
162
    };
163

            
164
104
    let Some(rows_node) = field!(recover, ctx, node, "rows") else {
165
        return Ok(None);
166
    };
167
104
    let Some(rows) = parse_atom(ctx, &rows_node)? else {
168
        return Ok(None);
169
    };
170

            
171
104
    ctx.typechecking_context = saved_context;
172

            
173
104
    match node.kind() {
174
104
        "table" => Ok(Some(Expression::Table(
175
78
            Metadata::new(),
176
78
            Moo::new(variables),
177
78
            Moo::new(rows),
178
78
        ))),
179
26
        "negative_table" => Ok(Some(Expression::NegativeTable(
180
26
            Metadata::new(),
181
26
            Moo::new(variables),
182
26
            Moo::new(rows),
183
26
        ))),
184
        _ => {
185
            ctx.record_error(RecoverableParseError::new(
186
                format!(
187
                    "Expected 'table' or 'negative_table', got: '{}'",
188
                    node.kind()
189
                ),
190
                Some(node.range()),
191
            ));
192
            Ok(None)
193
        }
194
    }
195
104
}
196

            
197
5302
fn parse_index_or_slice(
198
5302
    ctx: &mut ParseContext,
199
5302
    node: &Node,
200
5302
) -> Result<Option<Expression>, FatalParseError> {
201
    // add error and return early if we're in a set context, since indexing/slicing doesn't produce sets
202
5302
    if ctx.typechecking_context == TypecheckingContext::Set {
203
        ctx.record_error(RecoverableParseError::new(
204
            format!(
205
                "Type error: {}\n\tExpected: set\n\tGot: index or slice",
206
                ctx.source_code[node.start_byte()..node.end_byte()].trim()
207
            ),
208
            Some(node.range()),
209
        ));
210
        return Ok(None);
211
5302
    }
212

            
213
    // Save current context and temporarily set to Unknown for the collection
214
5302
    let saved_context = ctx.typechecking_context;
215
5302
    ctx.typechecking_context = TypecheckingContext::Unknown;
216
5302
    let Some(collection_node) = field!(recover, ctx, node, "collection") else {
217
        return Ok(None);
218
    };
219
5302
    let Some(collection) = parse_atom(ctx, &collection_node)? else {
220
13
        return Ok(None);
221
    };
222
5289
    ctx.typechecking_context = saved_context;
223
5289
    let mut indices = Vec::new();
224
5289
    let Some(indices_node) = field!(recover, ctx, node, "indices") else {
225
        return Ok(None);
226
    };
227
7579
    for idx_node in named_children(&indices_node) {
228
7579
        indices.push(parse_index(ctx, &idx_node)?);
229
    }
230

            
231
7267
    let has_null_idx = indices.iter().any(|idx| idx.is_none());
232
    // TODO: We could check whether the slice/index is safe here
233
5289
    if has_null_idx {
234
        // It's a slice
235
885
        Ok(Some(Expression::UnsafeSlice(
236
885
            Metadata::new(),
237
885
            Moo::new(collection),
238
885
            indices,
239
885
        )))
240
    } else {
241
        // It's an index
242
5861
        let idx_exprs: Vec<Expression> = indices.into_iter().map(|idx| idx.unwrap()).collect();
243
4404
        Ok(Some(Expression::UnsafeIndex(
244
4404
            Metadata::new(),
245
4404
            Moo::new(collection),
246
4404
            idx_exprs,
247
4404
        )))
248
    }
249
5302
}
250

            
251
7579
fn parse_index(ctx: &mut ParseContext, node: &Node) -> Result<Option<Expression>, FatalParseError> {
252
7579
    match node.kind() {
253
7579
        "arithmetic_expr" | "atom" => {
254
6694
            let saved_context = ctx.typechecking_context;
255
6694
            ctx.typechecking_context = TypecheckingContext::Unknown;
256

            
257
            // TODO: add collection-aware index typechecking.
258
            // For tuple/matrix/set-like indexing, indices should be arithmetic.
259
            // For record field access, index atoms should resolve to valid field names.
260
            // This requires checking index expression together with the indexed collection type.
261

            
262
6694
            let Some(expr) = parse_expression(ctx, *node)? else {
263
                return Ok(None);
264
            };
265

            
266
6694
            ctx.typechecking_context = saved_context;
267
6694
            Ok(Some(expr))
268
        }
269
885
        "null_index" => Ok(None),
270
        _ => {
271
            ctx.record_error(RecoverableParseError::new(
272
                format!("Expected an index, got: '{}'", node.kind()),
273
                Some(node.range()),
274
            ));
275
            Ok(None)
276
        }
277
    }
278
7579
}
279

            
280
145979
fn parse_variable(ctx: &mut ParseContext, node: &Node) -> Result<Option<Atom>, FatalParseError> {
281
145979
    let raw_name = &ctx.source_code[node.start_byte()..node.end_byte()];
282

            
283
145979
    let name = Name::user(raw_name.trim());
284
145979
    if let Some(symbols) = &ctx.symbols {
285
145979
        let lookup_result = {
286
145979
            let symbols_read = symbols.read();
287
145979
            symbols_read.lookup(&name)
288
        };
289

            
290
145979
        if let Some(decl) = lookup_result {
291
145835
            let hover = HoverInfo {
292
145835
                description: format!("Variable: {name}"),
293
145835
                kind: Some(SymbolKind::Decimal),
294
145835
                ty: decl.domain().map(|d| d.to_string()),
295
145835
                decl_span: ctx.lookup_decl_span(&name),
296
            };
297
145835
            span_with_hover(node, ctx.source_code, ctx.source_map, hover);
298

            
299
            // Type check the variable against the expected context
300
145835
            if let Some(error_msg) = typecheck_variable(&decl, ctx.typechecking_context, raw_name) {
301
156
                ctx.record_error(RecoverableParseError::new(error_msg, Some(node.range())));
302
156
                return Ok(None);
303
145679
            }
304

            
305
145679
            Ok(Some(Atom::Reference(conjure_cp_core::ast::Reference::new(
306
145679
                decl,
307
145679
            ))))
308
        } else {
309
144
            ctx.record_error(RecoverableParseError::new(
310
144
                format!("The identifier '{}' is not defined", raw_name),
311
144
                Some(node.range()),
312
            ));
313
144
            Ok(None)
314
        }
315
    } else {
316
        ctx.record_error(RecoverableParseError::new(
317
            format!("Symbol table missing when parsing variable '{raw_name}'"),
318
            Some(node.range()),
319
        ));
320
        Ok(None)
321
    }
322
145979
}
323

            
324
/// Type check a variable declaration against the expected expression context.
325
/// Returns an error message if the variable type doesn't match the context.
326
145835
fn typecheck_variable(
327
145835
    decl: &DeclarationPtr,
328
145835
    context: TypecheckingContext,
329
145835
    raw_name: &str,
330
145835
) -> Option<String> {
331
    // Only type check when context is known
332
145835
    if context == TypecheckingContext::Unknown {
333
18494
        return None;
334
127341
    }
335

            
336
    // Get the variable's domain and resolve it
337
127341
    let domain = decl.domain()?;
338
127341
    let ground_domain = domain.resolve()?;
339

            
340
    // Determine what type is expected
341
126935
    let expected = match context {
342
99125
        TypecheckingContext::Boolean => "bool",
343
27433
        TypecheckingContext::Arithmetic => "int",
344
377
        TypecheckingContext::Set => "set",
345
        TypecheckingContext::Unknown => return None, // shouldn't reach here
346
    };
347

            
348
    // Determine what type we actually have
349
126935
    let actual = match ground_domain.as_ref() {
350
99073
        GroundDomain::Bool => "bool",
351
27485
        GroundDomain::Int(_) => "int",
352
26
        GroundDomain::Matrix(_, _) => "matrix",
353
351
        GroundDomain::Set(_, _) => "set",
354
        GroundDomain::MSet(_, _) => "mset",
355
        GroundDomain::Tuple(_) => "tuple",
356
        GroundDomain::Record(_) => "record",
357
        GroundDomain::Function(_, _, _) => "function",
358
        GroundDomain::Empty(_) => "empty",
359
    };
360

            
361
    // If types match, no error
362
126935
    if expected == actual {
363
126779
        return None;
364
156
    }
365

            
366
    // Otherwise, report the type mismatch
367
156
    Some(format!(
368
156
        "Type error: {}\n\tExpected: {}\n\tGot: {}",
369
156
        raw_name, expected, actual
370
156
    ))
371
145835
}
372

            
373
27724
fn parse_constant(ctx: &mut ParseContext, node: &Node) -> Result<Option<Literal>, FatalParseError> {
374
27724
    let Some(inner) = named_child!(recover, ctx, node) else {
375
        return Ok(None);
376
    };
377
27724
    let raw_value = &ctx.source_code[inner.start_byte()..inner.end_byte()];
378
27724
    let lit = match inner.kind() {
379
27724
        "integer" => {
380
27008
            let Some(value) = parse_int(ctx, &inner) else {
381
26
                return Ok(None);
382
            };
383
26982
            Literal::Int(value)
384
        }
385
716
        "TRUE" => {
386
422
            let hover = HoverInfo {
387
422
                description: format!("Boolean constant: {raw_value}"),
388
422
                kind: None,
389
422
                ty: None,
390
422
                decl_span: None,
391
422
            };
392
422
            span_with_hover(&inner, ctx.source_code, ctx.source_map, hover);
393
422
            Literal::Bool(true)
394
        }
395
294
        "FALSE" => {
396
294
            let hover = HoverInfo {
397
294
                description: format!("Boolean constant: {raw_value}"),
398
294
                kind: None,
399
294
                ty: None,
400
294
                decl_span: None,
401
294
            };
402
294
            span_with_hover(&inner, ctx.source_code, ctx.source_map, hover);
403
294
            Literal::Bool(false)
404
        }
405
        _ => {
406
            ctx.record_error(RecoverableParseError::new(
407
                format!(
408
                    "'{}' (kind: '{}') is not a valid constant",
409
                    raw_value,
410
                    inner.kind()
411
                ),
412
                Some(inner.range()),
413
            ));
414
            return Ok(None);
415
        }
416
    };
417

            
418
    // Type check the constant against the expected context
419
27698
    if ctx.typechecking_context != TypecheckingContext::Unknown {
420
11106
        let expected = match ctx.typechecking_context {
421
227
            TypecheckingContext::Boolean => "bool",
422
10866
            TypecheckingContext::Arithmetic => "int",
423
13
            TypecheckingContext::Set => "set",
424
            TypecheckingContext::Unknown => "",
425
        };
426

            
427
11106
        let actual = match &lit {
428
214
            Literal::Bool(_) => "bool",
429
10892
            Literal::Int(_) => "int",
430
            Literal::AbstractLiteral(_) => return Ok(None), // Abstract literals aren't type-checked here
431
        };
432

            
433
11106
        if expected != actual {
434
65
            ctx.record_error(RecoverableParseError::new(
435
65
                format!(
436
                    "Type error: {}\n\tExpected: {}\n\tGot: {}",
437
                    raw_value, expected, actual
438
                ),
439
65
                Some(node.range()),
440
            ));
441
65
            return Ok(None);
442
11041
        }
443
16592
    }
444
27633
    Ok(Some(lit))
445
27724
}
446

            
447
27086
pub(crate) fn parse_int(ctx: &mut ParseContext, node: &Node) -> Option<i32> {
448
27086
    let raw_value = &ctx.source_code[node.start_byte()..node.end_byte()];
449
27086
    if let Ok(v) = raw_value.parse::<i32>() {
450
27060
        Some(v)
451
    } else {
452
26
        ctx.record_error(RecoverableParseError::new(
453
26
            "Expected an integer here".to_string(),
454
26
            Some(node.range()),
455
        ));
456
26
        None
457
    }
458
27086
}