1
use crate::diagnostics::diagnostics_api::SymbolKind;
2
use crate::diagnostics::source_map::{HoverInfo, span_with_hover};
3
use crate::errors::FatalParseError;
4
use crate::parser::ParseContext;
5
use crate::parser::atom::parse_atom;
6
use crate::parser::comprehension::parse_quantifier_or_aggregate_expr;
7
use crate::util::TypecheckingContext;
8
use crate::{field, named_child};
9
use conjure_cp_core::ast::{Expression, Metadata, Moo};
10
use conjure_cp_core::{domain_int, matrix_expr, range};
11
use tree_sitter::Node;
12

            
13
46431
pub fn parse_expression(
14
46431
    ctx: &mut ParseContext,
15
46431
    node: Node,
16
46431
) -> Result<Option<Expression>, FatalParseError> {
17
46431
    match node.kind() {
18
46431
        "atom" => parse_atom(ctx, &node),
19
18247
        "bool_expr" => parse_boolean_expression(ctx, &node),
20
13532
        "arithmetic_expr" => parse_arithmetic_expression(ctx, &node),
21
7744
        "comparison_expr" => parse_comparison_expression(ctx, &node),
22
44
        "dominance_relation" => parse_dominance_relation(ctx, &node),
23
44
        "all_diff_comparison" => parse_all_diff_comparison(ctx, &node),
24
        _ => Err(FatalParseError::internal_error(
25
            format!("Unexpected expression type: '{}'", node.kind()),
26
            Some(node.range()),
27
        )),
28
    }
29
46431
}
30

            
31
fn parse_dominance_relation(
32
    ctx: &mut ParseContext,
33
    node: &Node,
34
) -> Result<Option<Expression>, FatalParseError> {
35
    if ctx.root.kind() == "dominance_relation" {
36
        return Err(FatalParseError::internal_error(
37
            "Nested dominance relations are not allowed".to_string(),
38
            Some(node.range()),
39
        ));
40
    }
41

            
42
    // NB: In all other cases, we keep the root the same;
43
    // However, here we create a new context with the new root so downstream functions
44
    // know we are inside a dominance relation
45
    let mut inner_ctx = ParseContext {
46
        source_code: ctx.source_code,
47
        root: node,
48
        symbols: ctx.symbols.clone(),
49
        errors: ctx.errors,
50
        source_map: &mut *ctx.source_map,
51
        typechecking_context: ctx.typechecking_context,
52
    };
53

            
54
    let Some(inner) = parse_expression(&mut inner_ctx, field!(node, "expression"))? else {
55
        return Ok(None);
56
    };
57

            
58
    Ok(Some(Expression::DominanceRelation(
59
        Metadata::new(),
60
        Moo::new(inner),
61
    )))
62
}
63

            
64
5788
fn parse_arithmetic_expression(
65
5788
    ctx: &mut ParseContext,
66
5788
    node: &Node,
67
5788
) -> Result<Option<Expression>, FatalParseError> {
68
5788
    ctx.typechecking_context = TypecheckingContext::Arithmetic;
69
5788
    let inner = named_child!(node);
70
5788
    match inner.kind() {
71
5788
        "atom" => parse_atom(ctx, &inner),
72
5788
        "negative_expr" | "abs_value" | "sub_arith_expr" => parse_unary_expression(ctx, &inner),
73
4245
        "toInt_expr" => {
74
            // add special handling for toInt, as it is arithmetic but takes a non-arithmetic operand
75
54
            ctx.typechecking_context = TypecheckingContext::Unknown;
76
54
            parse_unary_expression(ctx, &inner)
77
        }
78
4191
        "exponent" | "product_expr" | "sum_expr" => parse_binary_expression(ctx, &inner),
79
764
        "list_combining_expr_arith" => parse_list_combining_expression(ctx, &inner),
80
121
        "aggregate_expr" => parse_quantifier_or_aggregate_expr(ctx, &inner),
81
        _ => Err(FatalParseError::internal_error(
82
            format!("Expected arithmetic expression, found: {}", inner.kind()),
83
            Some(inner.range()),
84
        )),
85
    }
86
5788
}
87

            
88
7700
fn parse_comparison_expression(
89
7700
    ctx: &mut ParseContext,
90
7700
    node: &Node,
91
7700
) -> Result<Option<Expression>, FatalParseError> {
92
7700
    let inner = named_child!(node);
93
7700
    match inner.kind() {
94
7700
        "arithmetic_comparison" => {
95
            // Arithmetic comparisons require arithmetic operands
96
2496
            ctx.typechecking_context = TypecheckingContext::Arithmetic;
97
2496
            parse_binary_expression(ctx, &inner)
98
        }
99
5204
        "lex_comparison" => {
100
            // TODO: check that both operands are comparable collections.
101
220
            ctx.typechecking_context = TypecheckingContext::Unknown;
102
220
            parse_binary_expression(ctx, &inner)
103
        }
104
4984
        "equality_comparison" => {
105
            // Equality works on any type
106
            // TODO: add type checking to ensure both sides have the same type
107
4195
            ctx.typechecking_context = TypecheckingContext::Unknown;
108
4195
            parse_binary_expression(ctx, &inner)
109
        }
110
789
        "set_comparison" => {
111
            // Set comparisons require set operands (no specific type checking for now)
112
            // TODO: add typechecking for sets
113
282
            ctx.typechecking_context = TypecheckingContext::Unknown;
114
282
            parse_binary_expression(ctx, &inner)
115
        }
116
507
        "all_diff_comparison" => {
117
            // TODO: check that operand is a collection with compatible element type.
118
507
            ctx.typechecking_context = TypecheckingContext::Unknown;
119
507
            parse_all_diff_comparison(ctx, &inner)
120
        }
121
        _ => Err(FatalParseError::internal_error(
122
            format!("Expected comparison expression, found '{}'", inner.kind()),
123
            Some(inner.range()),
124
        )),
125
    }
126
7700
}
127

            
128
4715
fn parse_boolean_expression(
129
4715
    ctx: &mut ParseContext,
130
4715
    node: &Node,
131
4715
) -> Result<Option<Expression>, FatalParseError> {
132
4715
    ctx.typechecking_context = TypecheckingContext::Boolean;
133
4715
    let inner = named_child!(node);
134
4715
    match inner.kind() {
135
4715
        "atom" => parse_atom(ctx, &inner),
136
4715
        "not_expr" | "sub_bool_expr" => parse_unary_expression(ctx, &inner),
137
2491
        "and_expr" | "or_expr" | "implication" | "iff_expr" => parse_binary_expression(ctx, &inner),
138
995
        "list_combining_expr_bool" => parse_list_combining_expression(ctx, &inner),
139
770
        "quantifier_expr" => parse_quantifier_or_aggregate_expr(ctx, &inner),
140
        _ => Err(FatalParseError::internal_error(
141
            format!("Expected boolean expression, found '{}'", inner.kind()),
142
            Some(inner.range()),
143
        )),
144
    }
145
4715
}
146

            
147
868
fn parse_list_combining_expression(
148
868
    ctx: &mut ParseContext,
149
868
    node: &Node,
150
868
) -> Result<Option<Expression>, FatalParseError> {
151
868
    let operator_node = field!(node, "operator");
152
868
    let operator_str = &ctx.source_code[operator_node.start_byte()..operator_node.end_byte()];
153

            
154
868
    let Some(inner) = parse_atom(ctx, &field!(node, "arg"))? else {
155
22
        return Ok(None);
156
    };
157

            
158
846
    match operator_str {
159
846
        "and" => Ok(Some(Expression::And(Metadata::new(), Moo::new(inner)))),
160
665
        "or" => Ok(Some(Expression::Or(Metadata::new(), Moo::new(inner)))),
161
632
        "sum" => Ok(Some(Expression::Sum(Metadata::new(), Moo::new(inner)))),
162
539
        "product" => Ok(Some(Expression::Product(Metadata::new(), Moo::new(inner)))),
163
539
        "min" => Ok(Some(Expression::Min(Metadata::new(), Moo::new(inner)))),
164
253
        "max" => Ok(Some(Expression::Max(Metadata::new(), Moo::new(inner)))),
165
        _ => Err(FatalParseError::internal_error(
166
            format!("Invalid operator: '{operator_str}'"),
167
            Some(operator_node.range()),
168
        )),
169
    }
170
868
}
171

            
172
551
fn parse_all_diff_comparison(
173
551
    ctx: &mut ParseContext,
174
551
    node: &Node,
175
551
) -> Result<Option<Expression>, FatalParseError> {
176
551
    let Some(inner) = parse_expression(ctx, field!(node, "arg"))? else {
177
        return Ok(None);
178
    };
179

            
180
551
    Ok(Some(Expression::AllDiff(Metadata::new(), Moo::new(inner))))
181
551
}
182

            
183
3821
fn parse_unary_expression(
184
3821
    ctx: &mut ParseContext,
185
3821
    node: &Node,
186
3821
) -> Result<Option<Expression>, FatalParseError> {
187
3821
    let Some(inner) = parse_expression(ctx, field!(node, "expression"))? else {
188
        return Ok(None);
189
    };
190
3821
    match node.kind() {
191
3821
        "negative_expr" => Ok(Some(Expression::Neg(Metadata::new(), Moo::new(inner)))),
192
3426
        "abs_value" => Ok(Some(Expression::Abs(Metadata::new(), Moo::new(inner)))),
193
3283
        "not_expr" => Ok(Some(Expression::Not(Metadata::new(), Moo::new(inner)))),
194
2715
        "toInt_expr" => Ok(Some(Expression::ToInt(Metadata::new(), Moo::new(inner)))),
195
2661
        "sub_bool_expr" | "sub_arith_expr" => Ok(Some(inner)),
196
        _ => Err(FatalParseError::internal_error(
197
            format!("Unrecognised unary operation: '{}'", node.kind()),
198
            Some(node.range()),
199
        )),
200
    }
201
3821
}
202

            
203
12182
pub fn parse_binary_expression(
204
12182
    ctx: &mut ParseContext,
205
12182
    node: &Node,
206
12182
) -> Result<Option<Expression>, FatalParseError> {
207
24243
    let mut parse_subexpr = |expr: Node| parse_expression(ctx, expr);
208

            
209
12182
    let Some(left) = parse_subexpr(field!(node, "left"))? else {
210
121
        return Ok(None);
211
    };
212
12061
    let Some(right) = parse_subexpr(field!(node, "right"))? else {
213
44
        return Ok(None);
214
    };
215

            
216
12017
    let op_node = field!(node, "operator");
217
12017
    let op_str = &ctx.source_code[op_node.start_byte()..op_node.end_byte()];
218

            
219
12017
    let mut description = format!("Operator '{op_str}'");
220
12017
    let expr = match op_str {
221
        // NB: We are deliberately setting the index domain to 1.., not 1..2.
222
        // Semantically, this means "a list that can grow/shrink arbitrarily".
223
        // This is expected by rules which will modify the terms of the sum expression
224
        // (e.g. by partially evaluating them).
225
12017
        "+" => Ok(Some(Expression::Sum(
226
1299
            Metadata::new(),
227
1299
            Moo::new(matrix_expr![left, right; domain_int!(1..)]),
228
1299
        ))),
229
10718
        "-" => Ok(Some(Expression::Minus(
230
461
            Metadata::new(),
231
461
            Moo::new(left),
232
461
            Moo::new(right),
233
461
        ))),
234
10257
        "*" => Ok(Some(Expression::Product(
235
411
            Metadata::new(),
236
411
            Moo::new(matrix_expr![left, right; domain_int!(1..)]),
237
411
        ))),
238
9846
        "/\\" => Ok(Some(Expression::And(
239
405
            Metadata::new(),
240
405
            Moo::new(matrix_expr![left, right; domain_int!(1..)]),
241
405
        ))),
242
9441
        "\\/" => Ok(Some(Expression::Or(
243
478
            Metadata::new(),
244
478
            Moo::new(matrix_expr![left, right; domain_int!(1..)]),
245
478
        ))),
246
8963
        "**" => Ok(Some(Expression::UnsafePow(
247
341
            Metadata::new(),
248
341
            Moo::new(left),
249
341
            Moo::new(right),
250
341
        ))),
251
8622
        "/" => {
252
            //TODO: add checks for if division is safe or not
253
618
            Ok(Some(Expression::UnsafeDiv(
254
618
                Metadata::new(),
255
618
                Moo::new(left),
256
618
                Moo::new(right),
257
618
            )))
258
        }
259
8004
        "%" => {
260
            //TODO: add checks for if mod is safe or not
261
275
            Ok(Some(Expression::UnsafeMod(
262
275
                Metadata::new(),
263
275
                Moo::new(left),
264
275
                Moo::new(right),
265
275
            )))
266
        }
267
7729
        "=" => Ok(Some(Expression::Eq(
268
3415
            Metadata::new(),
269
3415
            Moo::new(left),
270
3415
            Moo::new(right),
271
3415
        ))),
272
4314
        "!=" => Ok(Some(Expression::Neq(
273
703
            Metadata::new(),
274
703
            Moo::new(left),
275
703
            Moo::new(right),
276
703
        ))),
277
3611
        "<=" => Ok(Some(Expression::Leq(
278
933
            Metadata::new(),
279
933
            Moo::new(left),
280
933
            Moo::new(right),
281
933
        ))),
282
2678
        ">=" => Ok(Some(Expression::Geq(
283
468
            Metadata::new(),
284
468
            Moo::new(left),
285
468
            Moo::new(right),
286
468
        ))),
287
2210
        "<" => Ok(Some(Expression::Lt(
288
644
            Metadata::new(),
289
644
            Moo::new(left),
290
644
            Moo::new(right),
291
644
        ))),
292
1566
        ">" => Ok(Some(Expression::Gt(
293
418
            Metadata::new(),
294
418
            Moo::new(left),
295
418
            Moo::new(right),
296
418
        ))),
297
1148
        "->" => Ok(Some(Expression::Imply(
298
520
            Metadata::new(),
299
520
            Moo::new(left),
300
520
            Moo::new(right),
301
520
        ))),
302
628
        "<->" => Ok(Some(Expression::Iff(
303
60
            Metadata::new(),
304
60
            Moo::new(left),
305
60
            Moo::new(right),
306
60
        ))),
307
568
        "<lex" => Ok(Some(Expression::LexLt(
308
66
            Metadata::new(),
309
66
            Moo::new(left),
310
66
            Moo::new(right),
311
66
        ))),
312
502
        ">lex" => Ok(Some(Expression::LexGt(
313
22
            Metadata::new(),
314
22
            Moo::new(left),
315
22
            Moo::new(right),
316
22
        ))),
317
480
        "<=lex" => Ok(Some(Expression::LexLeq(
318
99
            Metadata::new(),
319
99
            Moo::new(left),
320
99
            Moo::new(right),
321
99
        ))),
322
381
        ">=lex" => Ok(Some(Expression::LexGeq(
323
33
            Metadata::new(),
324
33
            Moo::new(left),
325
33
            Moo::new(right),
326
33
        ))),
327
348
        "in" => Ok(Some(Expression::In(
328
95
            Metadata::new(),
329
95
            Moo::new(left),
330
95
            Moo::new(right),
331
95
        ))),
332
253
        "subset" => Ok(Some(Expression::Subset(
333
55
            Metadata::new(),
334
55
            Moo::new(left),
335
55
            Moo::new(right),
336
55
        ))),
337
198
        "subsetEq" => Ok(Some(Expression::SubsetEq(
338
44
            Metadata::new(),
339
44
            Moo::new(left),
340
44
            Moo::new(right),
341
44
        ))),
342
154
        "supset" => Ok(Some(Expression::Supset(
343
44
            Metadata::new(),
344
44
            Moo::new(left),
345
44
            Moo::new(right),
346
44
        ))),
347
110
        "supsetEq" => Ok(Some(Expression::SupsetEq(
348
44
            Metadata::new(),
349
44
            Moo::new(left),
350
44
            Moo::new(right),
351
44
        ))),
352
66
        "union" => {
353
33
            description = "set union: combines the elements from both operands".to_string();
354
33
            Ok(Some(Expression::Union(
355
33
                Metadata::new(),
356
33
                Moo::new(left),
357
33
                Moo::new(right),
358
33
            )))
359
        }
360
33
        "intersect" => {
361
33
            description =
362
33
                "set intersection: keeps only elements common to both operands".to_string();
363
33
            Ok(Some(Expression::Intersect(
364
33
                Metadata::new(),
365
33
                Moo::new(left),
366
33
                Moo::new(right),
367
33
            )))
368
        }
369
        _ => Err(FatalParseError::internal_error(
370
            format!("Invalid operator: '{op_str}'"),
371
            Some(op_node.range()),
372
        )),
373
    };
374

            
375
12017
    if expr.is_ok() {
376
12017
        let hover = HoverInfo {
377
12017
            description,
378
12017
            kind: Some(SymbolKind::Function),
379
12017
            ty: None,
380
12017
            decl_span: None,
381
12017
        };
382
12017
        span_with_hover(&op_node, ctx.source_code, ctx.source_map, hover);
383
12017
    }
384

            
385
12017
    expr
386
12182
}