1
use crate::diagnostics::diagnostics_api::SymbolKind;
2
use crate::diagnostics::source_map::{HoverInfo, span_with_hover};
3
use crate::errors::{FatalParseError, RecoverableParseError};
4
use crate::expression::{parse_binary_expression, parse_expression};
5
use crate::parser::ParseContext;
6
use crate::parser::abstract_literal::parse_abstract;
7
use crate::parser::comprehension::parse_comprehension;
8
use crate::util::{TypecheckingContext, named_children};
9
use crate::{field, named_child};
10
use conjure_cp_core::ast::{
11
    Atom, DeclarationPtr, Expression, GroundDomain, Literal, Metadata, Moo, Name,
12
};
13
use tree_sitter::Node;
14
use ustr::Ustr;
15

            
16
30512
pub fn parse_atom(
17
30512
    ctx: &mut ParseContext,
18
30512
    node: &Node,
19
30512
) -> Result<Option<Expression>, FatalParseError> {
20
30512
    match node.kind() {
21
30512
        "atom" | "sub_atom_expr" => parse_atom(ctx, &named_child!(node)),
22
15418
        "metavar" => {
23
885
            let ident = field!(node, "identifier");
24
885
            let name_str = &ctx.source_code[ident.start_byte()..ident.end_byte()];
25
885
            Ok(Some(Expression::Metavar(
26
885
                Metadata::new(),
27
885
                Ustr::from(name_str),
28
885
            )))
29
        }
30
14533
        "identifier" => {
31
4070
            let Some(var) = parse_variable(ctx, node)? else {
32
535
                return Ok(None);
33
            };
34
3535
            Ok(Some(Expression::Atomic(Metadata::new(), var)))
35
        }
36
10463
        "from_solution" => {
37
432
            if ctx.root.kind() != "dominance_relation" {
38
                return Err(FatalParseError::internal_error(
39
                    "fromSolution only allowed inside dominance relations".to_string(),
40
                    Some(node.range()),
41
                ));
42
432
            }
43

            
44
432
            let Some(inner) = parse_variable(ctx, &field!(node, "variable"))? else {
45
                return Ok(None);
46
            };
47

            
48
432
            Ok(Some(Expression::FromSolution(
49
432
                Metadata::new(),
50
432
                Moo::new(inner),
51
432
            )))
52
        }
53
10031
        "constant" => {
54
6999
            let Some(lit) = parse_constant(ctx, node)? else {
55
68
                return Ok(None);
56
            };
57
6931
            Ok(Some(Expression::Atomic(
58
6931
                Metadata::new(),
59
6931
                Atom::Literal(lit),
60
6931
            )))
61
        }
62
3032
        "matrix" | "record" | "tuple" | "set_literal" => {
63
2596
            let Some(abs) = parse_abstract(ctx, node)? else {
64
68
                return Ok(None);
65
            };
66
2528
            Ok(Some(Expression::AbstractLiteral(Metadata::new(), abs)))
67
        }
68
436
        "flatten" => parse_flatten(ctx, node),
69
436
        "table" | "negative_table" => parse_table(ctx, node),
70
300
        "index_or_slice" => parse_index_or_slice(ctx, node),
71
        // for now, assume is binary since powerset isn't implemented
72
        // TODO: add powerset support under "set_operation"
73
204
        "set_operation" => parse_binary_expression(ctx, node),
74
        "comprehension" => parse_comprehension(ctx, node),
75
        _ => Err(FatalParseError::internal_error(
76
            format!("Expected atom, got: {}", node.kind()),
77
            Some(node.range()),
78
        )),
79
    }
80
30512
}
81

            
82
fn parse_flatten(
83
    ctx: &mut ParseContext,
84
    node: &Node,
85
) -> Result<Option<Expression>, FatalParseError> {
86
    let expr_node = field!(node, "expression");
87
    let Some(expr) = parse_atom(ctx, &expr_node)? else {
88
        return Ok(None);
89
    };
90

            
91
    if node.child_by_field_name("depth").is_some() {
92
        let depth_node = field!(node, "depth");
93
        let depth = parse_int(ctx, &depth_node)?;
94
        let depth_expression =
95
            Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Int(depth)));
96
        Ok(Some(Expression::Flatten(
97
            Metadata::new(),
98
            Some(Moo::new(depth_expression)),
99
            Moo::new(expr),
100
        )))
101
    } else {
102
        Ok(Some(Expression::Flatten(
103
            Metadata::new(),
104
            None,
105
            Moo::new(expr),
106
        )))
107
    }
108
}
109

            
110
136
fn parse_table(ctx: &mut ParseContext, node: &Node) -> Result<Option<Expression>, FatalParseError> {
111
    // the variables and rows can contain arbitrary expressions, so we temporarily set the context to Unknown to avoid typechecking errors
112
136
    let saved_context = ctx.typechecking_context;
113
136
    ctx.typechecking_context = TypecheckingContext::Unknown;
114

            
115
136
    let variables_node = field!(node, "variables");
116
136
    let Some(variables) = parse_atom(ctx, &variables_node)? else {
117
        return Ok(None);
118
    };
119

            
120
136
    let rows_node = field!(node, "rows");
121
136
    let Some(rows) = parse_atom(ctx, &rows_node)? else {
122
        return Ok(None);
123
    };
124

            
125
136
    ctx.typechecking_context = saved_context;
126

            
127
136
    match node.kind() {
128
136
        "table" => Ok(Some(Expression::Table(
129
102
            Metadata::new(),
130
102
            Moo::new(variables),
131
102
            Moo::new(rows),
132
102
        ))),
133
34
        "negative_table" => Ok(Some(Expression::NegativeTable(
134
34
            Metadata::new(),
135
34
            Moo::new(variables),
136
34
            Moo::new(rows),
137
34
        ))),
138
        _ => Err(FatalParseError::internal_error(
139
            format!(
140
                "Expected 'table' or 'negative_table', got: '{}'",
141
                node.kind()
142
            ),
143
            Some(node.range()),
144
        )),
145
    }
146
136
}
147

            
148
96
fn parse_index_or_slice(
149
96
    ctx: &mut ParseContext,
150
96
    node: &Node,
151
96
) -> Result<Option<Expression>, FatalParseError> {
152
    // Save current context and temporarily set to Unknown for the collection
153
96
    let saved_context = ctx.typechecking_context;
154
96
    ctx.typechecking_context = TypecheckingContext::Unknown;
155
96
    let Some(collection) = parse_atom(ctx, &field!(node, "collection"))? else {
156
22
        return Ok(None);
157
    };
158
74
    ctx.typechecking_context = saved_context;
159
74
    let mut indices = Vec::new();
160
114
    for idx_node in named_children(&field!(node, "indices")) {
161
114
        indices.push(parse_index(ctx, &idx_node)?);
162
    }
163

            
164
80
    let has_null_idx = indices.iter().any(|idx| idx.is_none());
165
    // TODO: We could check whether the slice/index is safe here
166
74
    if has_null_idx {
167
        // It's a slice
168
71
        Ok(Some(Expression::UnsafeSlice(
169
71
            Metadata::new(),
170
71
            Moo::new(collection),
171
71
            indices,
172
71
        )))
173
    } else {
174
        // It's an index
175
6
        let idx_exprs: Vec<Expression> = indices.into_iter().map(|idx| idx.unwrap()).collect();
176
3
        Ok(Some(Expression::UnsafeIndex(
177
3
            Metadata::new(),
178
3
            Moo::new(collection),
179
3
            idx_exprs,
180
3
        )))
181
    }
182
96
}
183

            
184
114
fn parse_index(ctx: &mut ParseContext, node: &Node) -> Result<Option<Expression>, FatalParseError> {
185
114
    match node.kind() {
186
114
        "arithmetic_expr" | "atom" => {
187
111
            ctx.typechecking_context = TypecheckingContext::Arithmetic;
188
111
            let Some(expr) = parse_expression(ctx, *node)? else {
189
68
                return Ok(None);
190
            };
191
43
            Ok(Some(expr))
192
        }
193
3
        "null_index" => Ok(None),
194
        _ => Err(FatalParseError::internal_error(
195
            format!("Expected an index, got: '{}'", node.kind()),
196
            Some(node.range()),
197
        )),
198
    }
199
114
}
200

            
201
4502
fn parse_variable(ctx: &mut ParseContext, node: &Node) -> Result<Option<Atom>, FatalParseError> {
202
4502
    let raw_name = &ctx.source_code[node.start_byte()..node.end_byte()];
203
4502
    let name = Name::user(raw_name.trim());
204
4502
    if let Some(symbols) = &ctx.symbols {
205
4502
        let lookup_result = {
206
4502
            let symbols_read = symbols.read();
207
4502
            symbols_read.lookup(&name)
208
        };
209

            
210
4502
        if let Some(decl) = lookup_result {
211
4307
            let hover = HoverInfo {
212
4307
                description: format!("Variable: {name}"),
213
4307
                kind: Some(SymbolKind::Decimal),
214
4307
                ty: decl.domain().map(|d| d.to_string()),
215
4307
                decl_span: None,
216
            };
217
4307
            span_with_hover(node, ctx.source_code, ctx.source_map, hover);
218

            
219
            // Type check the variable against the expected context
220
4307
            if let Some(error_msg) = typecheck_variable(&decl, ctx.typechecking_context) {
221
340
                ctx.record_error(RecoverableParseError::new(error_msg, Some(node.range())));
222
340
                return Ok(None);
223
3967
            }
224

            
225
3967
            Ok(Some(Atom::Reference(conjure_cp_core::ast::Reference::new(
226
3967
                decl,
227
3967
            ))))
228
        } else {
229
195
            ctx.record_error(RecoverableParseError::new(
230
195
                format!("The identifier '{}' is not defined", raw_name),
231
195
                Some(node.range()),
232
            ));
233
195
            Ok(None)
234
        }
235
    } else {
236
        Err(FatalParseError::internal_error(
237
            format!("Symbol table missing when parsing variable '{raw_name}'"),
238
            Some(node.range()),
239
        ))
240
    }
241
4502
}
242

            
243
/// Type check a variable declaration against the expected expression context.
244
/// Returns an error message if the variable type doesn't match the context.
245
4307
fn typecheck_variable(decl: &DeclarationPtr, context: TypecheckingContext) -> Option<String> {
246
    // Only type check when context is known
247
4307
    if context == TypecheckingContext::Unknown {
248
2015
        return None;
249
2292
    }
250

            
251
    // Get the variable's domain and resolve it
252
2292
    let domain = decl.domain()?;
253
2292
    let ground_domain = domain.resolve()?;
254

            
255
    // Determine what type is expected
256
2292
    let expected = match context {
257
1298
        TypecheckingContext::Boolean => "bool",
258
994
        TypecheckingContext::Arithmetic => "int",
259
        TypecheckingContext::Unknown => return None, // shouldn't reach here
260
    };
261

            
262
    // Determine what type we actually have
263
2292
    let actual = match ground_domain.as_ref() {
264
1298
        GroundDomain::Bool => "bool",
265
960
        GroundDomain::Int(_) => "int",
266
34
        GroundDomain::Matrix(_, _) => "matrix",
267
        GroundDomain::Set(_, _) => "set",
268
        GroundDomain::MSet(_, _) => "mset",
269
        GroundDomain::Tuple(_) => "tuple",
270
        GroundDomain::Record(_) => "record",
271
        GroundDomain::Function(_, _, _) => "function",
272
        GroundDomain::Empty(_) => "empty",
273
    };
274

            
275
    // If types match, no error
276
2292
    if expected == actual {
277
1952
        return None;
278
340
    }
279

            
280
    // Otherwise, report the type mismatch
281
340
    Some(format!(
282
340
        "Type error:\n\tExpected: {}\n\tGot: {}",
283
340
        expected, actual
284
340
    ))
285
4307
}
286

            
287
6999
fn parse_constant(ctx: &mut ParseContext, node: &Node) -> Result<Option<Literal>, FatalParseError> {
288
6999
    let inner = named_child!(node);
289
6999
    let raw_value = &ctx.source_code[inner.start_byte()..inner.end_byte()];
290
6999
    let lit = match inner.kind() {
291
6999
        "integer" => {
292
6782
            let value = parse_int(ctx, &inner)?;
293
6782
            Literal::Int(value)
294
        }
295
217
        "TRUE" => {
296
63
            let hover = HoverInfo {
297
63
                description: format!("Boolean constant: {raw_value}"),
298
63
                kind: None,
299
63
                ty: None,
300
63
                decl_span: None,
301
63
            };
302
63
            span_with_hover(&inner, ctx.source_code, ctx.source_map, hover);
303
63
            Literal::Bool(true)
304
        }
305
154
        "FALSE" => {
306
154
            let hover = HoverInfo {
307
154
                description: format!("Boolean constant: {raw_value}"),
308
154
                kind: None,
309
154
                ty: None,
310
154
                decl_span: None,
311
154
            };
312
154
            span_with_hover(&inner, ctx.source_code, ctx.source_map, hover);
313
154
            Literal::Bool(false)
314
        }
315
        _ => {
316
            return Err(FatalParseError::internal_error(
317
                format!(
318
                    "'{}' (kind: '{}') is not a valid constant",
319
                    raw_value,
320
                    inner.kind()
321
                ),
322
                Some(inner.range()),
323
            ));
324
        }
325
    };
326

            
327
    // Type check the constant against the expected context
328
6999
    if ctx.typechecking_context != TypecheckingContext::Unknown {
329
972
        let expected = match ctx.typechecking_context {
330
34
            TypecheckingContext::Boolean => "bool",
331
938
            TypecheckingContext::Arithmetic => "int",
332
            TypecheckingContext::Unknown => "",
333
        };
334

            
335
972
        let actual = match &lit {
336
34
            Literal::Bool(_) => "bool",
337
938
            Literal::Int(_) => "int",
338
            Literal::AbstractLiteral(_) => return Ok(None), // Abstract literals aren't type-checked here
339
        };
340

            
341
972
        if expected != actual {
342
68
            ctx.record_error(RecoverableParseError::new(
343
68
                format!("Type error:\n\tExpected: {}\n\tGot: {}", expected, actual),
344
68
                Some(node.range()),
345
            ));
346
68
            return Ok(None);
347
904
        }
348
6027
    }
349
6931
    Ok(Some(lit))
350
6999
}
351

            
352
6782
pub(crate) fn parse_int(ctx: &ParseContext, node: &Node) -> Result<i32, FatalParseError> {
353
6782
    let raw_value = &ctx.source_code[node.start_byte()..node.end_byte()];
354
6782
    raw_value.parse::<i32>().map_err(|_e| {
355
        FatalParseError::internal_error("Expected an integer here".to_string(), Some(node.range()))
356
    })
357
6782
}