1
use crate::diagnostics::diagnostics_api::SymbolKind;
2
use crate::diagnostics::source_map::{HoverInfo, span_with_hover};
3
use crate::errors::{FatalParseError, RecoverableParseError};
4
use crate::expression::{parse_binary_expression, parse_expression};
5
use crate::parser::ParseContext;
6
use crate::parser::abstract_literal::parse_abstract;
7
use crate::parser::comprehension::parse_comprehension;
8
use crate::util::{TypecheckingContext, named_children};
9
use crate::{field, named_child};
10
use conjure_cp_core::ast::{
11
    Atom, DeclarationPtr, Expression, GroundDomain, Literal, Metadata, Moo, Name,
12
};
13
use tree_sitter::Node;
14
use ustr::Ustr;
15

            
16
7966
pub fn parse_atom(
17
7966
    ctx: &mut ParseContext,
18
7966
    node: &Node,
19
7966
) -> Result<Option<Expression>, FatalParseError> {
20
7966
    match node.kind() {
21
7966
        "atom" | "sub_atom_expr" => parse_atom(ctx, &named_child!(node)),
22
3985
        "metavar" => {
23
403
            let ident = field!(node, "identifier");
24
403
            let name_str = &ctx.source_code[ident.start_byte()..ident.end_byte()];
25
403
            Ok(Some(Expression::Metavar(
26
403
                Metadata::new(),
27
403
                Ustr::from(name_str),
28
403
            )))
29
        }
30
3582
        "identifier" => {
31
903
            let Some(var) = parse_variable(ctx, node)? else {
32
155
                return Ok(None);
33
            };
34
748
            Ok(Some(Expression::Atomic(Metadata::new(), var)))
35
        }
36
2679
        "from_solution" => {
37
            if ctx.root.kind() != "dominance_relation" {
38
                return Err(FatalParseError::internal_error(
39
                    "fromSolution only allowed inside dominance relations".to_string(),
40
                    Some(node.range()),
41
                ));
42
            }
43

            
44
            let Some(inner) = parse_variable(ctx, &field!(node, "variable"))? else {
45
                return Ok(None);
46
            };
47

            
48
            Ok(Some(Expression::FromSolution(
49
                Metadata::new(),
50
                Moo::new(inner),
51
            )))
52
        }
53
2679
        "constant" => {
54
1958
            let Some(lit) = parse_constant(ctx, node)? else {
55
22
                return Ok(None);
56
            };
57
1936
            Ok(Some(Expression::Atomic(
58
1936
                Metadata::new(),
59
1936
                Atom::Literal(lit),
60
1936
            )))
61
        }
62
721
        "matrix" | "record" | "tuple" | "set_literal" => {
63
631
            let Some(abs) = parse_abstract(ctx, node)? else {
64
22
                return Ok(None);
65
            };
66
609
            Ok(Some(Expression::AbstractLiteral(Metadata::new(), abs)))
67
        }
68
90
        "flatten" => parse_flatten(ctx, node),
69
90
        "index_or_slice" => parse_index_or_slice(ctx, node),
70
        // for now, assume is binary since powerset isn't implemented
71
        // TODO: add powerset support under "set_operation"
72
66
        "set_operation" => parse_binary_expression(ctx, node),
73
        "comprehension" => parse_comprehension(ctx, node),
74
        _ => Err(FatalParseError::internal_error(
75
            format!("Expected atom, got: {}", node.kind()),
76
            Some(node.range()),
77
        )),
78
    }
79
7966
}
80

            
81
fn parse_flatten(
82
    ctx: &mut ParseContext,
83
    node: &Node,
84
) -> Result<Option<Expression>, FatalParseError> {
85
    let expr_node = field!(node, "expression");
86
    let Some(expr) = parse_atom(ctx, &expr_node)? else {
87
        return Ok(None);
88
    };
89

            
90
    if node.child_by_field_name("depth").is_some() {
91
        let depth_node = field!(node, "depth");
92
        let depth = parse_int(ctx, &depth_node)?;
93
        let depth_expression =
94
            Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Int(depth)));
95
        Ok(Some(Expression::Flatten(
96
            Metadata::new(),
97
            Some(Moo::new(depth_expression)),
98
            Moo::new(expr),
99
        )))
100
    } else {
101
        Ok(Some(Expression::Flatten(
102
            Metadata::new(),
103
            None,
104
            Moo::new(expr),
105
        )))
106
    }
107
}
108

            
109
24
fn parse_index_or_slice(
110
24
    ctx: &mut ParseContext,
111
24
    node: &Node,
112
24
) -> Result<Option<Expression>, FatalParseError> {
113
    // Save current context and temporarily set to Unknown for the collection
114
24
    let saved_context = ctx.typechecking_context;
115
24
    ctx.typechecking_context = TypecheckingContext::Unknown;
116
24
    let Some(collection) = parse_atom(ctx, &field!(node, "collection"))? else {
117
        return Ok(None);
118
    };
119
24
    ctx.typechecking_context = saved_context;
120
24
    let mut indices = Vec::new();
121
37
    for idx_node in named_children(&field!(node, "indices")) {
122
37
        indices.push(parse_index(ctx, &idx_node)?);
123
    }
124

            
125
26
    let has_null_idx = indices.iter().any(|idx| idx.is_none());
126
    // TODO: We could check whether the slice/index is safe here
127
24
    if has_null_idx {
128
        // It's a slice
129
23
        Ok(Some(Expression::UnsafeSlice(
130
23
            Metadata::new(),
131
23
            Moo::new(collection),
132
23
            indices,
133
23
        )))
134
    } else {
135
        // It's an index
136
2
        let idx_exprs: Vec<Expression> = indices.into_iter().map(|idx| idx.unwrap()).collect();
137
1
        Ok(Some(Expression::UnsafeIndex(
138
1
            Metadata::new(),
139
1
            Moo::new(collection),
140
1
            idx_exprs,
141
1
        )))
142
    }
143
24
}
144

            
145
37
fn parse_index(ctx: &mut ParseContext, node: &Node) -> Result<Option<Expression>, FatalParseError> {
146
37
    match node.kind() {
147
37
        "arithmetic_expr" | "atom" => {
148
36
            ctx.typechecking_context = TypecheckingContext::Arithmetic;
149
36
            let Some(expr) = parse_expression(ctx, *node)? else {
150
22
                return Ok(None);
151
            };
152
14
            Ok(Some(expr))
153
        }
154
1
        "null_index" => Ok(None),
155
        _ => Err(FatalParseError::internal_error(
156
            format!("Expected an index, got: '{}'", node.kind()),
157
            Some(node.range()),
158
        )),
159
    }
160
37
}
161

            
162
903
fn parse_variable(ctx: &mut ParseContext, node: &Node) -> Result<Option<Atom>, FatalParseError> {
163
903
    let raw_name = &ctx.source_code[node.start_byte()..node.end_byte()];
164
903
    let name = Name::user(raw_name.trim());
165
903
    if let Some(symbols) = &ctx.symbols {
166
903
        let lookup_result = {
167
903
            let symbols_read = symbols.read();
168
903
            symbols_read.lookup(&name)
169
        };
170

            
171
903
        if let Some(decl) = lookup_result {
172
858
            let hover = HoverInfo {
173
858
                description: format!("Variable: {name}"),
174
858
                kind: Some(SymbolKind::Decimal),
175
858
                ty: decl.domain().map(|d| d.to_string()),
176
858
                decl_span: None,
177
            };
178
858
            span_with_hover(node, ctx.source_code, ctx.source_map, hover);
179

            
180
            // Type check the variable against the expected context
181
858
            if let Some(error_msg) = typecheck_variable(&decl, ctx.typechecking_context) {
182
110
                ctx.record_error(RecoverableParseError::new(error_msg, Some(node.range())));
183
110
                return Ok(None);
184
748
            }
185

            
186
748
            Ok(Some(Atom::Reference(conjure_cp_core::ast::Reference::new(
187
748
                decl,
188
748
            ))))
189
        } else {
190
45
            ctx.record_error(RecoverableParseError::new(
191
45
                format!("The identifier '{}' is not defined", raw_name),
192
45
                Some(node.range()),
193
            ));
194
45
            Ok(None)
195
        }
196
    } else {
197
        Err(FatalParseError::internal_error(
198
            format!("Symbol table missing when parsing variable '{raw_name}'"),
199
            Some(node.range()),
200
        ))
201
    }
202
903
}
203

            
204
/// Type check a variable declaration against the expected expression context.
205
/// Returns an error message if the variable type doesn't match the context.
206
858
fn typecheck_variable(decl: &DeclarationPtr, context: TypecheckingContext) -> Option<String> {
207
    // Only type check when context is known
208
858
    if context == TypecheckingContext::Unknown {
209
542
        return None;
210
316
    }
211

            
212
    // Get the variable's domain and resolve it
213
316
    let domain = decl.domain()?;
214
316
    let ground_domain = domain.resolve()?;
215

            
216
    // Determine what type is expected
217
316
    let expected = match context {
218
55
        TypecheckingContext::Boolean => "bool",
219
261
        TypecheckingContext::Arithmetic => "int",
220
        TypecheckingContext::Unknown => return None, // shouldn't reach here
221
    };
222

            
223
    // Determine what type we actually have
224
316
    let actual = match ground_domain.as_ref() {
225
55
        GroundDomain::Bool => "bool",
226
250
        GroundDomain::Int(_) => "int",
227
11
        GroundDomain::Matrix(_, _) => "matrix",
228
        GroundDomain::Set(_, _) => "set",
229
        GroundDomain::MSet(_, _) => "mset",
230
        GroundDomain::Tuple(_) => "tuple",
231
        GroundDomain::Record(_) => "record",
232
        GroundDomain::Function(_, _, _) => "function",
233
        GroundDomain::Empty(_) => "empty",
234
    };
235

            
236
    // If types match, no error
237
316
    if expected == actual {
238
206
        return None;
239
110
    }
240

            
241
    // Otherwise, report the type mismatch
242
110
    Some(format!(
243
110
        "Type error:\n\tExpected: {}\n\tGot: {}",
244
110
        expected, actual
245
110
    ))
246
858
}
247

            
248
1958
fn parse_constant(ctx: &mut ParseContext, node: &Node) -> Result<Option<Literal>, FatalParseError> {
249
1958
    let inner = named_child!(node);
250
1958
    let raw_value = &ctx.source_code[inner.start_byte()..inner.end_byte()];
251
1958
    let lit = match inner.kind() {
252
1958
        "integer" => {
253
1882
            let value = parse_int(ctx, &inner)?;
254
1882
            Literal::Int(value)
255
        }
256
76
        "TRUE" => {
257
24
            let hover = HoverInfo {
258
24
                description: format!("Boolean constant: {raw_value}"),
259
24
                kind: None,
260
24
                ty: None,
261
24
                decl_span: None,
262
24
            };
263
24
            span_with_hover(&inner, ctx.source_code, ctx.source_map, hover);
264
24
            Literal::Bool(true)
265
        }
266
52
        "FALSE" => {
267
52
            let hover = HoverInfo {
268
52
                description: format!("Boolean constant: {raw_value}"),
269
52
                kind: None,
270
52
                ty: None,
271
52
                decl_span: None,
272
52
            };
273
52
            span_with_hover(&inner, ctx.source_code, ctx.source_map, hover);
274
52
            Literal::Bool(false)
275
        }
276
        _ => {
277
            return Err(FatalParseError::internal_error(
278
                format!(
279
                    "'{}' (kind: '{}') is not a valid constant",
280
                    raw_value,
281
                    inner.kind()
282
                ),
283
                Some(inner.range()),
284
            ));
285
        }
286
    };
287

            
288
    // Type check the constant against the expected context
289
1958
    if ctx.typechecking_context != TypecheckingContext::Unknown {
290
274
        let expected = match ctx.typechecking_context {
291
11
            TypecheckingContext::Boolean => "bool",
292
263
            TypecheckingContext::Arithmetic => "int",
293
            TypecheckingContext::Unknown => "",
294
        };
295

            
296
274
        let actual = match &lit {
297
11
            Literal::Bool(_) => "bool",
298
263
            Literal::Int(_) => "int",
299
            Literal::AbstractLiteral(_) => return Ok(None), // Abstract literals aren't type-checked here
300
        };
301

            
302
274
        if expected != actual {
303
22
            ctx.record_error(RecoverableParseError::new(
304
22
                format!("Type error:\n\tExpected: {}\n\tGot: {}", expected, actual),
305
22
                Some(node.range()),
306
            ));
307
22
            return Ok(None);
308
252
        }
309
1684
    }
310
1936
    Ok(Some(lit))
311
1958
}
312

            
313
1882
pub(crate) fn parse_int(ctx: &ParseContext, node: &Node) -> Result<i32, FatalParseError> {
314
1882
    let raw_value = &ctx.source_code[node.start_byte()..node.end_byte()];
315
1882
    raw_value.parse::<i32>().map_err(|_e| {
316
        FatalParseError::internal_error("Expected an integer here".to_string(), Some(node.range()))
317
    })
318
1882
}