1
use crate::diagnostics::diagnostics_api::SymbolKind;
2
use crate::diagnostics::source_map::{HoverInfo, span_with_hover};
3
use crate::errors::{FatalParseError, RecoverableParseError};
4
use crate::expression::{parse_binary_expression, parse_expression};
5
use crate::parser::ParseContext;
6
use crate::parser::abstract_literal::parse_abstract;
7
use crate::parser::comprehension::parse_comprehension;
8
use crate::util::{TypecheckingContext, named_children};
9
use crate::{field, named_child};
10
use conjure_cp_core::ast::{
11
    Atom, DeclarationPtr, Expression, GroundDomain, Literal, Metadata, Moo, Name,
12
};
13
use tree_sitter::Node;
14
use ustr::Ustr;
15

            
16
9451
pub fn parse_atom(
17
9451
    ctx: &mut ParseContext,
18
9451
    node: &Node,
19
9451
) -> Result<Option<Expression>, FatalParseError> {
20
9451
    match node.kind() {
21
9451
        "atom" | "sub_atom_expr" => parse_atom(ctx, &named_child!(node)),
22
4788
        "metavar" => {
23
403
            let ident = field!(node, "identifier");
24
403
            let name_str = &ctx.source_code[ident.start_byte()..ident.end_byte()];
25
403
            Ok(Some(Expression::Metavar(
26
403
                Metadata::new(),
27
403
                Ustr::from(name_str),
28
403
            )))
29
        }
30
4385
        "identifier" => {
31
1101
            let Some(var) = parse_variable(ctx, node)? else {
32
199
                return Ok(None);
33
            };
34
902
            Ok(Some(Expression::Atomic(Metadata::new(), var)))
35
        }
36
3284
        "from_solution" => {
37
            if ctx.root.kind() != "dominance_relation" {
38
                return Err(FatalParseError::internal_error(
39
                    "fromSolution only allowed inside dominance relations".to_string(),
40
                    Some(node.range()),
41
                ));
42
            }
43

            
44
            let Some(inner) = parse_variable(ctx, &field!(node, "variable"))? else {
45
                return Ok(None);
46
            };
47

            
48
            Ok(Some(Expression::FromSolution(
49
                Metadata::new(),
50
                Moo::new(inner),
51
            )))
52
        }
53
3284
        "constant" => {
54
2277
            let Some(lit) = parse_constant(ctx, node)? else {
55
22
                return Ok(None);
56
            };
57
2255
            Ok(Some(Expression::Atomic(
58
2255
                Metadata::new(),
59
2255
                Atom::Literal(lit),
60
2255
            )))
61
        }
62
1007
        "matrix" | "record" | "tuple" | "set_literal" => {
63
840
            let Some(abs) = parse_abstract(ctx, node)? else {
64
22
                return Ok(None);
65
            };
66
818
            Ok(Some(Expression::AbstractLiteral(Metadata::new(), abs)))
67
        }
68
167
        "flatten" => parse_flatten(ctx, node),
69
167
        "table" | "negative_table" => parse_table(ctx, node),
70
123
        "index_or_slice" => parse_index_or_slice(ctx, node),
71
        // for now, assume is binary since powerset isn't implemented
72
        // TODO: add powerset support under "set_operation"
73
66
        "set_operation" => parse_binary_expression(ctx, node),
74
        "comprehension" => parse_comprehension(ctx, node),
75
        _ => Err(FatalParseError::internal_error(
76
            format!("Expected atom, got: {}", node.kind()),
77
            Some(node.range()),
78
        )),
79
    }
80
9451
}
81

            
82
fn parse_flatten(
83
    ctx: &mut ParseContext,
84
    node: &Node,
85
) -> Result<Option<Expression>, FatalParseError> {
86
    let expr_node = field!(node, "expression");
87
    let Some(expr) = parse_atom(ctx, &expr_node)? else {
88
        return Ok(None);
89
    };
90

            
91
    if node.child_by_field_name("depth").is_some() {
92
        let depth_node = field!(node, "depth");
93
        let depth = parse_int(ctx, &depth_node)?;
94
        let depth_expression =
95
            Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Int(depth)));
96
        Ok(Some(Expression::Flatten(
97
            Metadata::new(),
98
            Some(Moo::new(depth_expression)),
99
            Moo::new(expr),
100
        )))
101
    } else {
102
        Ok(Some(Expression::Flatten(
103
            Metadata::new(),
104
            None,
105
            Moo::new(expr),
106
        )))
107
    }
108
}
109

            
110
44
fn parse_table(ctx: &mut ParseContext, node: &Node) -> Result<Option<Expression>, FatalParseError> {
111
    // the variables and rows can contain arbitrary expressions, so we temporarily set the context to Unknown to avoid typechecking errors
112
44
    let saved_context = ctx.typechecking_context;
113
44
    ctx.typechecking_context = TypecheckingContext::Unknown;
114

            
115
44
    let variables_node = field!(node, "variables");
116
44
    let Some(variables) = parse_atom(ctx, &variables_node)? else {
117
        return Ok(None);
118
    };
119

            
120
44
    let rows_node = field!(node, "rows");
121
44
    let Some(rows) = parse_atom(ctx, &rows_node)? else {
122
        return Ok(None);
123
    };
124

            
125
44
    ctx.typechecking_context = saved_context;
126

            
127
44
    match node.kind() {
128
44
        "table" => Ok(Some(Expression::Table(
129
33
            Metadata::new(),
130
33
            Moo::new(variables),
131
33
            Moo::new(rows),
132
33
        ))),
133
11
        "negative_table" => Ok(Some(Expression::NegativeTable(
134
11
            Metadata::new(),
135
11
            Moo::new(variables),
136
11
            Moo::new(rows),
137
11
        ))),
138
        _ => Err(FatalParseError::internal_error(
139
            format!(
140
                "Expected 'table' or 'negative_table', got: '{}'",
141
                node.kind()
142
            ),
143
            Some(node.range()),
144
        )),
145
    }
146
44
}
147

            
148
57
fn parse_index_or_slice(
149
57
    ctx: &mut ParseContext,
150
57
    node: &Node,
151
57
) -> Result<Option<Expression>, FatalParseError> {
152
    // Save current context and temporarily set to Unknown for the collection
153
57
    let saved_context = ctx.typechecking_context;
154
57
    ctx.typechecking_context = TypecheckingContext::Unknown;
155
57
    let Some(collection) = parse_atom(ctx, &field!(node, "collection"))? else {
156
11
        return Ok(None);
157
    };
158
46
    ctx.typechecking_context = saved_context;
159
46
    let mut indices = Vec::new();
160
59
    for idx_node in named_children(&field!(node, "indices")) {
161
59
        indices.push(parse_index(ctx, &idx_node)?);
162
    }
163

            
164
48
    let has_null_idx = indices.iter().any(|idx| idx.is_none());
165
    // TODO: We could check whether the slice/index is safe here
166
46
    if has_null_idx {
167
        // It's a slice
168
45
        Ok(Some(Expression::UnsafeSlice(
169
45
            Metadata::new(),
170
45
            Moo::new(collection),
171
45
            indices,
172
45
        )))
173
    } else {
174
        // It's an index
175
2
        let idx_exprs: Vec<Expression> = indices.into_iter().map(|idx| idx.unwrap()).collect();
176
1
        Ok(Some(Expression::UnsafeIndex(
177
1
            Metadata::new(),
178
1
            Moo::new(collection),
179
1
            idx_exprs,
180
1
        )))
181
    }
182
57
}
183

            
184
59
fn parse_index(ctx: &mut ParseContext, node: &Node) -> Result<Option<Expression>, FatalParseError> {
185
59
    match node.kind() {
186
59
        "arithmetic_expr" | "atom" => {
187
58
            ctx.typechecking_context = TypecheckingContext::Arithmetic;
188
58
            let Some(expr) = parse_expression(ctx, *node)? else {
189
44
                return Ok(None);
190
            };
191
14
            Ok(Some(expr))
192
        }
193
1
        "null_index" => Ok(None),
194
        _ => Err(FatalParseError::internal_error(
195
            format!("Expected an index, got: '{}'", node.kind()),
196
            Some(node.range()),
197
        )),
198
    }
199
59
}
200

            
201
1101
fn parse_variable(ctx: &mut ParseContext, node: &Node) -> Result<Option<Atom>, FatalParseError> {
202
1101
    let raw_name = &ctx.source_code[node.start_byte()..node.end_byte()];
203
1101
    let name = Name::user(raw_name.trim());
204
1101
    if let Some(symbols) = &ctx.symbols {
205
1101
        let lookup_result = {
206
1101
            let symbols_read = symbols.read();
207
1101
            symbols_read.lookup(&name)
208
        };
209

            
210
1101
        if let Some(decl) = lookup_result {
211
1034
            let hover = HoverInfo {
212
1034
                description: format!("Variable: {name}"),
213
1034
                kind: Some(SymbolKind::Decimal),
214
1034
                ty: decl.domain().map(|d| d.to_string()),
215
1034
                decl_span: None,
216
            };
217
1034
            span_with_hover(node, ctx.source_code, ctx.source_map, hover);
218

            
219
            // Type check the variable against the expected context
220
1034
            if let Some(error_msg) = typecheck_variable(&decl, ctx.typechecking_context, raw_name) {
221
132
                ctx.record_error(RecoverableParseError::new(error_msg, Some(node.range())));
222
132
                return Ok(None);
223
902
            }
224

            
225
902
            Ok(Some(Atom::Reference(conjure_cp_core::ast::Reference::new(
226
902
                decl,
227
902
            ))))
228
        } else {
229
67
            ctx.record_error(RecoverableParseError::new(
230
67
                format!("The identifier '{}' is not defined", raw_name),
231
67
                Some(node.range()),
232
            ));
233
67
            Ok(None)
234
        }
235
    } else {
236
        Err(FatalParseError::internal_error(
237
            format!("Symbol table missing when parsing variable '{raw_name}'"),
238
            Some(node.range()),
239
        ))
240
    }
241
1101
}
242

            
243
/// Type check a variable declaration against the expected expression context.
244
/// Returns an error message if the variable type doesn't match the context.
245
1034
fn typecheck_variable(
246
1034
    decl: &DeclarationPtr,
247
1034
    context: TypecheckingContext,
248
1034
    raw_name: &str,
249
1034
) -> Option<String> {
250
    // Only type check when context is known
251
1034
    if context == TypecheckingContext::Unknown {
252
674
        return None;
253
360
    }
254

            
255
    // Get the variable's domain and resolve it
256
360
    let domain = decl.domain()?;
257
360
    let ground_domain = domain.resolve()?;
258

            
259
    // Determine what type is expected
260
360
    let expected = match context {
261
55
        TypecheckingContext::Boolean => "bool",
262
305
        TypecheckingContext::Arithmetic => "int",
263
        TypecheckingContext::Unknown => return None, // shouldn't reach here
264
    };
265

            
266
    // Determine what type we actually have
267
360
    let actual = match ground_domain.as_ref() {
268
55
        GroundDomain::Bool => "bool",
269
272
        GroundDomain::Int(_) => "int",
270
11
        GroundDomain::Matrix(_, _) => "matrix",
271
        GroundDomain::Set(_, _) => "set",
272
        GroundDomain::MSet(_, _) => "mset",
273
22
        GroundDomain::Tuple(_) => "tuple",
274
        GroundDomain::Record(_) => "record",
275
        GroundDomain::Function(_, _, _) => "function",
276
        GroundDomain::Empty(_) => "empty",
277
    };
278

            
279
    // If types match, no error
280
360
    if expected == actual {
281
228
        return None;
282
132
    }
283

            
284
    // Otherwise, report the type mismatch
285
132
    Some(format!(
286
132
        "Type error: {}\n\tExpected: {}\n\tGot: {}",
287
132
        raw_name, expected, actual
288
132
    ))
289
1034
}
290

            
291
2277
fn parse_constant(ctx: &mut ParseContext, node: &Node) -> Result<Option<Literal>, FatalParseError> {
292
2277
    let inner = named_child!(node);
293
2277
    let raw_value = &ctx.source_code[inner.start_byte()..inner.end_byte()];
294
2277
    let lit = match inner.kind() {
295
2277
        "integer" => {
296
2201
            let value = parse_int(ctx, &inner)?;
297
2201
            Literal::Int(value)
298
        }
299
76
        "TRUE" => {
300
24
            let hover = HoverInfo {
301
24
                description: format!("Boolean constant: {raw_value}"),
302
24
                kind: None,
303
24
                ty: None,
304
24
                decl_span: None,
305
24
            };
306
24
            span_with_hover(&inner, ctx.source_code, ctx.source_map, hover);
307
24
            Literal::Bool(true)
308
        }
309
52
        "FALSE" => {
310
52
            let hover = HoverInfo {
311
52
                description: format!("Boolean constant: {raw_value}"),
312
52
                kind: None,
313
52
                ty: None,
314
52
                decl_span: None,
315
52
            };
316
52
            span_with_hover(&inner, ctx.source_code, ctx.source_map, hover);
317
52
            Literal::Bool(false)
318
        }
319
        _ => {
320
            return Err(FatalParseError::internal_error(
321
                format!(
322
                    "'{}' (kind: '{}') is not a valid constant",
323
                    raw_value,
324
                    inner.kind()
325
                ),
326
                Some(inner.range()),
327
            ));
328
        }
329
    };
330

            
331
    // Type check the constant against the expected context
332
2277
    if ctx.typechecking_context != TypecheckingContext::Unknown {
333
318
        let expected = match ctx.typechecking_context {
334
11
            TypecheckingContext::Boolean => "bool",
335
307
            TypecheckingContext::Arithmetic => "int",
336
            TypecheckingContext::Unknown => "",
337
        };
338

            
339
318
        let actual = match &lit {
340
11
            Literal::Bool(_) => "bool",
341
307
            Literal::Int(_) => "int",
342
            Literal::AbstractLiteral(_) => return Ok(None), // Abstract literals aren't type-checked here
343
        };
344

            
345
318
        if expected != actual {
346
22
            ctx.record_error(RecoverableParseError::new(
347
22
                format!(
348
                    "Type error: {}\n\tExpected: {}\n\tGot: {}",
349
                    raw_value, expected, actual
350
                ),
351
22
                Some(node.range()),
352
            ));
353
22
            return Ok(None);
354
296
        }
355
1959
    }
356
2255
    Ok(Some(lit))
357
2277
}
358

            
359
2201
pub(crate) fn parse_int(ctx: &ParseContext, node: &Node) -> Result<i32, FatalParseError> {
360
2201
    let raw_value = &ctx.source_code[node.start_byte()..node.end_byte()];
361
2201
    raw_value.parse::<i32>().map_err(|_e| {
362
        FatalParseError::internal_error("Expected an integer here".to_string(), Some(node.range()))
363
    })
364
2201
}