1
use crate::diagnostics::diagnostics_api::SymbolKind;
2
use crate::diagnostics::source_map::{HoverInfo, span_with_hover};
3
use crate::errors::{FatalParseError, RecoverableParseError};
4
use crate::expression::{parse_binary_expression, parse_expression};
5
use crate::parser::ParseContext;
6
use crate::parser::abstract_literal::parse_abstract;
7
use crate::parser::comprehension::parse_comprehension;
8
use crate::util::{TypecheckingContext, named_children};
9
use crate::{field, named_child};
10
use conjure_cp_core::ast::{
11
    Atom, DeclarationPtr, Expression, GroundDomain, Literal, Metadata, Moo, Name,
12
};
13
use tree_sitter::Node;
14
use ustr::Ustr;
15

            
16
16964
pub fn parse_atom(
17
16964
    ctx: &mut ParseContext,
18
16964
    node: &Node,
19
16964
) -> Result<Option<Expression>, FatalParseError> {
20
16964
    match node.kind() {
21
16964
        "atom" | "sub_atom_expr" => parse_atom(ctx, &named_child!(node)),
22
8567
        "metavar" => {
23
644
            let ident = field!(node, "identifier");
24
644
            let name_str = &ctx.source_code[ident.start_byte()..ident.end_byte()];
25
644
            Ok(Some(Expression::Metavar(
26
644
                Metadata::new(),
27
644
                Ustr::from(name_str),
28
644
            )))
29
        }
30
7923
        "identifier" => {
31
1817
            let Some(var) = parse_variable(ctx, node)? else {
32
189
                return Ok(None);
33
            };
34
1628
            Ok(Some(Expression::Atomic(Metadata::new(), var)))
35
        }
36
6106
        "from_solution" => {
37
            if ctx.root.kind() != "dominance_relation" {
38
                return Err(FatalParseError::internal_error(
39
                    "fromSolution only allowed inside dominance relations".to_string(),
40
                    Some(node.range()),
41
                ));
42
            }
43

            
44
            let Some(inner) = parse_variable(ctx, &field!(node, "variable"))? else {
45
                return Ok(None);
46
            };
47

            
48
            Ok(Some(Expression::FromSolution(
49
                Metadata::new(),
50
                Moo::new(inner),
51
            )))
52
        }
53
6106
        "constant" => {
54
4224
            let Some(lit) = parse_constant(ctx, node)? else {
55
22
                return Ok(None);
56
            };
57
4202
            Ok(Some(Expression::Atomic(
58
4202
                Metadata::new(),
59
4202
                Atom::Literal(lit),
60
4202
            )))
61
        }
62
1882
        "matrix" | "record" | "tuple" | "set_literal" => {
63
1636
            let Some(abs) = parse_abstract(ctx, node)? else {
64
22
                return Ok(None);
65
            };
66
1614
            Ok(Some(Expression::AbstractLiteral(Metadata::new(), abs)))
67
        }
68
246
        "flatten" => parse_flatten(ctx, node),
69
246
        "table" | "negative_table" => parse_table(ctx, node),
70
158
        "index_or_slice" => parse_index_or_slice(ctx, node),
71
        // for now, assume is binary since powerset isn't implemented
72
        // TODO: add powerset support under "set_operation"
73
132
        "set_operation" => parse_binary_expression(ctx, node),
74
        "comprehension" => parse_comprehension(ctx, node),
75
        _ => Err(FatalParseError::internal_error(
76
            format!("Expected atom, got: {}", node.kind()),
77
            Some(node.range()),
78
        )),
79
    }
80
16964
}
81

            
82
fn parse_flatten(
83
    ctx: &mut ParseContext,
84
    node: &Node,
85
) -> Result<Option<Expression>, FatalParseError> {
86
    let expr_node = field!(node, "expression");
87
    let Some(expr) = parse_atom(ctx, &expr_node)? else {
88
        return Ok(None);
89
    };
90

            
91
    if node.child_by_field_name("depth").is_some() {
92
        let depth_node = field!(node, "depth");
93
        let depth = parse_int(ctx, &depth_node)?;
94
        let depth_expression =
95
            Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Int(depth)));
96
        Ok(Some(Expression::Flatten(
97
            Metadata::new(),
98
            Some(Moo::new(depth_expression)),
99
            Moo::new(expr),
100
        )))
101
    } else {
102
        Ok(Some(Expression::Flatten(
103
            Metadata::new(),
104
            None,
105
            Moo::new(expr),
106
        )))
107
    }
108
}
109

            
110
88
fn parse_table(ctx: &mut ParseContext, node: &Node) -> Result<Option<Expression>, FatalParseError> {
111
    // the variables and rows can contain arbitrary expressions, so we temporarily set the context to Unknown to avoid typechecking errors
112
88
    let saved_context = ctx.typechecking_context;
113
88
    ctx.typechecking_context = TypecheckingContext::Unknown;
114

            
115
88
    let variables_node = field!(node, "variables");
116
88
    let Some(variables) = parse_atom(ctx, &variables_node)? else {
117
        return Ok(None);
118
    };
119

            
120
88
    let rows_node = field!(node, "rows");
121
88
    let Some(rows) = parse_atom(ctx, &rows_node)? else {
122
        return Ok(None);
123
    };
124

            
125
88
    ctx.typechecking_context = saved_context;
126

            
127
88
    match node.kind() {
128
88
        "table" => Ok(Some(Expression::Table(
129
66
            Metadata::new(),
130
66
            Moo::new(variables),
131
66
            Moo::new(rows),
132
66
        ))),
133
22
        "negative_table" => Ok(Some(Expression::NegativeTable(
134
22
            Metadata::new(),
135
22
            Moo::new(variables),
136
22
            Moo::new(rows),
137
22
        ))),
138
        _ => Err(FatalParseError::internal_error(
139
            format!(
140
                "Expected 'table' or 'negative_table', got: '{}'",
141
                node.kind()
142
            ),
143
            Some(node.range()),
144
        )),
145
    }
146
88
}
147

            
148
26
fn parse_index_or_slice(
149
26
    ctx: &mut ParseContext,
150
26
    node: &Node,
151
26
) -> Result<Option<Expression>, FatalParseError> {
152
    // Save current context and temporarily set to Unknown for the collection
153
26
    let saved_context = ctx.typechecking_context;
154
26
    ctx.typechecking_context = TypecheckingContext::Unknown;
155
26
    let Some(collection) = parse_atom(ctx, &field!(node, "collection"))? else {
156
        return Ok(None);
157
    };
158
26
    ctx.typechecking_context = saved_context;
159
26
    let mut indices = Vec::new();
160
41
    for idx_node in named_children(&field!(node, "indices")) {
161
41
        indices.push(parse_index(ctx, &idx_node)?);
162
    }
163

            
164
30
    let has_null_idx = indices.iter().any(|idx| idx.is_none());
165
    // TODO: We could check whether the slice/index is safe here
166
26
    if has_null_idx {
167
        // It's a slice
168
24
        Ok(Some(Expression::UnsafeSlice(
169
24
            Metadata::new(),
170
24
            Moo::new(collection),
171
24
            indices,
172
24
        )))
173
    } else {
174
        // It's an index
175
4
        let idx_exprs: Vec<Expression> = indices.into_iter().map(|idx| idx.unwrap()).collect();
176
2
        Ok(Some(Expression::UnsafeIndex(
177
2
            Metadata::new(),
178
2
            Moo::new(collection),
179
2
            idx_exprs,
180
2
        )))
181
    }
182
26
}
183

            
184
41
fn parse_index(ctx: &mut ParseContext, node: &Node) -> Result<Option<Expression>, FatalParseError> {
185
41
    match node.kind() {
186
41
        "arithmetic_expr" | "atom" => {
187
39
            ctx.typechecking_context = TypecheckingContext::Arithmetic;
188
39
            let Some(expr) = parse_expression(ctx, *node)? else {
189
22
                return Ok(None);
190
            };
191
17
            Ok(Some(expr))
192
        }
193
2
        "null_index" => Ok(None),
194
        _ => Err(FatalParseError::internal_error(
195
            format!("Expected an index, got: '{}'", node.kind()),
196
            Some(node.range()),
197
        )),
198
    }
199
41
}
200

            
201
1817
fn parse_variable(ctx: &mut ParseContext, node: &Node) -> Result<Option<Atom>, FatalParseError> {
202
1817
    let raw_name = &ctx.source_code[node.start_byte()..node.end_byte()];
203
1817
    let name = Name::user(raw_name.trim());
204
1817
    if let Some(symbols) = &ctx.symbols {
205
1817
        let lookup_result = {
206
1817
            let symbols_read = symbols.read();
207
1817
            symbols_read.lookup(&name)
208
        };
209

            
210
1817
        if let Some(decl) = lookup_result {
211
1738
            let hover = HoverInfo {
212
1738
                description: format!("Variable: {name}"),
213
1738
                kind: Some(SymbolKind::Decimal),
214
1738
                ty: decl.domain().map(|d| d.to_string()),
215
1738
                decl_span: None,
216
            };
217
1738
            span_with_hover(node, ctx.source_code, ctx.source_map, hover);
218

            
219
            // Type check the variable against the expected context
220
1738
            if let Some(error_msg) = typecheck_variable(&decl, ctx.typechecking_context) {
221
110
                ctx.record_error(RecoverableParseError::new(error_msg, Some(node.range())));
222
110
                return Ok(None);
223
1628
            }
224

            
225
1628
            Ok(Some(Atom::Reference(conjure_cp_core::ast::Reference::new(
226
1628
                decl,
227
1628
            ))))
228
        } else {
229
79
            ctx.record_error(RecoverableParseError::new(
230
79
                format!("The identifier '{}' is not defined", raw_name),
231
79
                Some(node.range()),
232
            ));
233
79
            Ok(None)
234
        }
235
    } else {
236
        Err(FatalParseError::internal_error(
237
            format!("Symbol table missing when parsing variable '{raw_name}'"),
238
            Some(node.range()),
239
        ))
240
    }
241
1817
}
242

            
243
/// Type check a variable declaration against the expected expression context.
244
/// Returns an error message if the variable type doesn't match the context.
245
1738
fn typecheck_variable(decl: &DeclarationPtr, context: TypecheckingContext) -> Option<String> {
246
    // Only type check when context is known
247
1738
    if context == TypecheckingContext::Unknown {
248
1227
        return None;
249
511
    }
250

            
251
    // Get the variable's domain and resolve it
252
511
    let domain = decl.domain()?;
253
511
    let ground_domain = domain.resolve()?;
254

            
255
    // Determine what type is expected
256
511
    let expected = match context {
257
55
        TypecheckingContext::Boolean => "bool",
258
456
        TypecheckingContext::Arithmetic => "int",
259
        TypecheckingContext::Unknown => return None, // shouldn't reach here
260
    };
261

            
262
    // Determine what type we actually have
263
511
    let actual = match ground_domain.as_ref() {
264
55
        GroundDomain::Bool => "bool",
265
445
        GroundDomain::Int(_) => "int",
266
11
        GroundDomain::Matrix(_, _) => "matrix",
267
        GroundDomain::Set(_, _) => "set",
268
        GroundDomain::MSet(_, _) => "mset",
269
        GroundDomain::Tuple(_) => "tuple",
270
        GroundDomain::Record(_) => "record",
271
        GroundDomain::Function(_, _, _) => "function",
272
        GroundDomain::Empty(_) => "empty",
273
    };
274

            
275
    // If types match, no error
276
511
    if expected == actual {
277
401
        return None;
278
110
    }
279

            
280
    // Otherwise, report the type mismatch
281
110
    Some(format!(
282
110
        "Type error:\n\tExpected: {}\n\tGot: {}",
283
110
        expected, actual
284
110
    ))
285
1738
}
286

            
287
4224
fn parse_constant(ctx: &mut ParseContext, node: &Node) -> Result<Option<Literal>, FatalParseError> {
288
4224
    let inner = named_child!(node);
289
4224
    let raw_value = &ctx.source_code[inner.start_byte()..inner.end_byte()];
290
4224
    let lit = match inner.kind() {
291
4224
        "integer" => {
292
4113
            let value = parse_int(ctx, &inner)?;
293
4113
            Literal::Int(value)
294
        }
295
111
        "TRUE" => {
296
43
            let hover = HoverInfo {
297
43
                description: format!("Boolean constant: {raw_value}"),
298
43
                kind: None,
299
43
                ty: None,
300
43
                decl_span: None,
301
43
            };
302
43
            span_with_hover(&inner, ctx.source_code, ctx.source_map, hover);
303
43
            Literal::Bool(true)
304
        }
305
68
        "FALSE" => {
306
68
            let hover = HoverInfo {
307
68
                description: format!("Boolean constant: {raw_value}"),
308
68
                kind: None,
309
68
                ty: None,
310
68
                decl_span: None,
311
68
            };
312
68
            span_with_hover(&inner, ctx.source_code, ctx.source_map, hover);
313
68
            Literal::Bool(false)
314
        }
315
        _ => {
316
            return Err(FatalParseError::internal_error(
317
                format!(
318
                    "'{}' (kind: '{}') is not a valid constant",
319
                    raw_value,
320
                    inner.kind()
321
                ),
322
                Some(inner.range()),
323
            ));
324
        }
325
    };
326

            
327
    // Type check the constant against the expected context
328
4224
    if ctx.typechecking_context != TypecheckingContext::Unknown {
329
439
        let expected = match ctx.typechecking_context {
330
11
            TypecheckingContext::Boolean => "bool",
331
428
            TypecheckingContext::Arithmetic => "int",
332
            TypecheckingContext::Unknown => "",
333
        };
334

            
335
439
        let actual = match &lit {
336
11
            Literal::Bool(_) => "bool",
337
428
            Literal::Int(_) => "int",
338
            Literal::AbstractLiteral(_) => return Ok(None), // Abstract literals aren't type-checked here
339
        };
340

            
341
439
        if expected != actual {
342
22
            ctx.record_error(RecoverableParseError::new(
343
22
                format!("Type error:\n\tExpected: {}\n\tGot: {}", expected, actual),
344
22
                Some(node.range()),
345
            ));
346
22
            return Ok(None);
347
417
        }
348
3785
    }
349
4202
    Ok(Some(lit))
350
4224
}
351

            
352
4113
pub(crate) fn parse_int(ctx: &ParseContext, node: &Node) -> Result<i32, FatalParseError> {
353
4113
    let raw_value = &ctx.source_code[node.start_byte()..node.end_byte()];
354
4113
    raw_value.parse::<i32>().map_err(|_e| {
355
        FatalParseError::internal_error("Expected an integer here".to_string(), Some(node.range()))
356
    })
357
4113
}