1
use crate::diagnostics::diagnostics_api::SymbolKind;
2
use crate::diagnostics::source_map::{HoverInfo, span_with_hover};
3
use crate::errors::FatalParseError;
4
use crate::parser::ParseContext;
5
use crate::parser::atom::parse_atom;
6
use crate::parser::comprehension::parse_quantifier_or_aggregate_expr;
7
use crate::util::TypecheckingContext;
8
use crate::{field, named_child};
9
use conjure_cp_core::ast::{Expression, Metadata, Moo};
10
use conjure_cp_core::{domain_int, matrix_expr, range};
11
use tree_sitter::Node;
12

            
13
21211
pub fn parse_expression(
14
21211
    ctx: &mut ParseContext,
15
21211
    node: Node,
16
21211
) -> Result<Option<Expression>, FatalParseError> {
17
21211
    match node.kind() {
18
21211
        "atom" => parse_atom(ctx, &node),
19
6288
        "bool_expr" => parse_boolean_expression(ctx, &node),
20
3297
        "arithmetic_expr" => parse_arithmetic_expression(ctx, &node),
21
2575
        "comparison_expr" => parse_comparison_expression(ctx, &node),
22
60
        "dominance_relation" => parse_dominance_relation(ctx, &node),
23
        _ => Err(FatalParseError::internal_error(
24
            format!("Unexpected expression type: '{}'", node.kind()),
25
            Some(node.range()),
26
        )),
27
    }
28
21211
}
29

            
30
60
fn parse_dominance_relation(
31
60
    ctx: &mut ParseContext,
32
60
    node: &Node,
33
60
) -> Result<Option<Expression>, FatalParseError> {
34
60
    if ctx.root.kind() == "dominance_relation" {
35
        return Err(FatalParseError::internal_error(
36
            "Nested dominance relations are not allowed".to_string(),
37
            Some(node.range()),
38
        ));
39
60
    }
40

            
41
    // NB: In all other cases, we keep the root the same;
42
    // However, here we create a new context with the new root so downstream functions
43
    // know we are inside a dominance relation
44
60
    let mut inner_ctx = ParseContext {
45
60
        source_code: ctx.source_code,
46
60
        root: node,
47
60
        symbols: ctx.symbols.clone(),
48
60
        errors: ctx.errors,
49
60
        source_map: &mut *ctx.source_map,
50
60
        typechecking_context: ctx.typechecking_context,
51
60
    };
52

            
53
60
    let Some(inner) = parse_expression(&mut inner_ctx, field!(node, "expression"))? else {
54
        return Ok(None);
55
    };
56

            
57
60
    Ok(Some(Expression::DominanceRelation(
58
60
        Metadata::new(),
59
60
        Moo::new(inner),
60
60
    )))
61
60
}
62

            
63
722
fn parse_arithmetic_expression(
64
722
    ctx: &mut ParseContext,
65
722
    node: &Node,
66
722
) -> Result<Option<Expression>, FatalParseError> {
67
722
    ctx.typechecking_context = TypecheckingContext::Arithmetic;
68
722
    let inner = named_child!(node);
69
722
    match inner.kind() {
70
722
        "atom" => parse_atom(ctx, &inner),
71
722
        "negative_expr" | "abs_value" | "sub_arith_expr" | "toInt_expr" => {
72
111
            parse_unary_expression(ctx, &inner)
73
        }
74
611
        "exponent" | "product_expr" | "sum_expr" => parse_binary_expression(ctx, &inner),
75
79
        "list_combining_expr_arith" => parse_list_combining_expression(ctx, &inner),
76
        "aggregate_expr" => parse_quantifier_or_aggregate_expr(ctx, &inner),
77
        _ => Err(FatalParseError::internal_error(
78
            format!("Expected arithmetic expression, found: {}", inner.kind()),
79
            Some(inner.range()),
80
        )),
81
    }
82
722
}
83

            
84
2515
fn parse_comparison_expression(
85
2515
    ctx: &mut ParseContext,
86
2515
    node: &Node,
87
2515
) -> Result<Option<Expression>, FatalParseError> {
88
2515
    let inner = named_child!(node);
89
2515
    match inner.kind() {
90
2515
        "arithmetic_comparison" => {
91
            // Arithmetic comparisons require arithmetic operands
92
463
            ctx.typechecking_context = TypecheckingContext::Arithmetic;
93
463
            parse_binary_expression(ctx, &inner)
94
        }
95
2052
        "equality_comparison" => {
96
            // Equality works on any type
97
            // TODO: add type checking to ensure both sides have the same type
98
1216
            ctx.typechecking_context = TypecheckingContext::Unknown;
99
1216
            parse_binary_expression(ctx, &inner)
100
        }
101
836
        "set_comparison" => {
102
            // Set comparisons require set operands (no specific type checking for now)
103
            // TODO: add typechecking for sets
104
836
            ctx.typechecking_context = TypecheckingContext::Unknown;
105
836
            parse_binary_expression(ctx, &inner)
106
        }
107
        _ => Err(FatalParseError::internal_error(
108
            format!("Expected comparison expression, found '{}'", inner.kind()),
109
            Some(inner.range()),
110
        )),
111
    }
112
2515
}
113

            
114
2991
fn parse_boolean_expression(
115
2991
    ctx: &mut ParseContext,
116
2991
    node: &Node,
117
2991
) -> Result<Option<Expression>, FatalParseError> {
118
2991
    ctx.typechecking_context = TypecheckingContext::Boolean;
119
2991
    let inner = named_child!(node);
120
2991
    match inner.kind() {
121
2991
        "atom" => parse_atom(ctx, &inner),
122
2991
        "not_expr" | "sub_bool_expr" => parse_unary_expression(ctx, &inner),
123
1287
        "and_expr" | "or_expr" | "implication" | "iff_expr" => parse_binary_expression(ctx, &inner),
124
116
        "list_combining_expr_bool" => parse_list_combining_expression(ctx, &inner),
125
68
        "quantifier_expr" => parse_quantifier_or_aggregate_expr(ctx, &inner),
126
        _ => Err(FatalParseError::internal_error(
127
            format!("Expected boolean expression, found '{}'", inner.kind()),
128
            Some(inner.range()),
129
        )),
130
    }
131
2991
}
132

            
133
127
fn parse_list_combining_expression(
134
127
    ctx: &mut ParseContext,
135
127
    node: &Node,
136
127
) -> Result<Option<Expression>, FatalParseError> {
137
127
    let operator_node = field!(node, "operator");
138
127
    let operator_str = &ctx.source_code[operator_node.start_byte()..operator_node.end_byte()];
139

            
140
127
    let Some(inner) = parse_atom(ctx, &field!(node, "arg"))? else {
141
68
        return Ok(None);
142
    };
143

            
144
59
    match operator_str {
145
59
        "and" => Ok(Some(Expression::And(Metadata::new(), Moo::new(inner)))),
146
48
        "or" => Ok(Some(Expression::Or(Metadata::new(), Moo::new(inner)))),
147
48
        "sum" => Ok(Some(Expression::Sum(Metadata::new(), Moo::new(inner)))),
148
3
        "product" => Ok(Some(Expression::Product(Metadata::new(), Moo::new(inner)))),
149
3
        "min" => Ok(Some(Expression::Min(Metadata::new(), Moo::new(inner)))),
150
3
        "max" => Ok(Some(Expression::Max(Metadata::new(), Moo::new(inner)))),
151
3
        "allDiff" => Ok(Some(Expression::AllDiff(Metadata::new(), Moo::new(inner)))),
152
        _ => Err(FatalParseError::internal_error(
153
            format!("Invalid operator: '{operator_str}'"),
154
            Some(operator_node.range()),
155
        )),
156
    }
157
127
}
158

            
159
1815
fn parse_unary_expression(
160
1815
    ctx: &mut ParseContext,
161
1815
    node: &Node,
162
1815
) -> Result<Option<Expression>, FatalParseError> {
163
1815
    let Some(inner) = parse_expression(ctx, field!(node, "expression"))? else {
164
        return Ok(None);
165
    };
166
1815
    match node.kind() {
167
1815
        "negative_expr" => Ok(Some(Expression::Neg(Metadata::new(), Moo::new(inner)))),
168
1793
        "abs_value" => Ok(Some(Expression::Abs(Metadata::new(), Moo::new(inner)))),
169
1759
        "not_expr" => Ok(Some(Expression::Not(Metadata::new(), Moo::new(inner)))),
170
1419
        "toInt_expr" => Ok(Some(Expression::ToInt(Metadata::new(), Moo::new(inner)))),
171
1397
        "sub_bool_expr" | "sub_arith_expr" => Ok(Some(inner)),
172
        _ => Err(FatalParseError::internal_error(
173
            format!("Unrecognised unary operation: '{}'", node.kind()),
174
            Some(node.range()),
175
        )),
176
    }
177
1815
}
178

            
179
4422
pub fn parse_binary_expression(
180
4422
    ctx: &mut ParseContext,
181
4422
    node: &Node,
182
4422
) -> Result<Option<Expression>, FatalParseError> {
183
8482
    let mut parse_subexpr = |expr: Node| parse_expression(ctx, expr);
184

            
185
4422
    let Some(left) = parse_subexpr(field!(node, "left"))? else {
186
362
        return Ok(None);
187
    };
188
4060
    let Some(right) = parse_subexpr(field!(node, "right"))? else {
189
136
        return Ok(None);
190
    };
191

            
192
3924
    let op_node = field!(node, "operator");
193
3924
    let op_str = &ctx.source_code[op_node.start_byte()..op_node.end_byte()];
194

            
195
3924
    let mut description = format!("Operator '{op_str}'");
196
3924
    let expr = match op_str {
197
        // NB: We are deliberately setting the index domain to 1.., not 1..2.
198
        // Semantically, this means "a list that can grow/shrink arbitrarily".
199
        // This is expected by rules which will modify the terms of the sum expression
200
        // (e.g. by partially evaluating them).
201
3924
        "+" => Ok(Some(Expression::Sum(
202
335
            Metadata::new(),
203
335
            Moo::new(matrix_expr![left, right; domain_int!(1..)]),
204
335
        ))),
205
3589
        "-" => Ok(Some(Expression::Minus(
206
90
            Metadata::new(),
207
90
            Moo::new(left),
208
90
            Moo::new(right),
209
90
        ))),
210
3499
        "*" => Ok(Some(Expression::Product(
211
33
            Metadata::new(),
212
33
            Moo::new(matrix_expr![left, right; domain_int!(1..)]),
213
33
        ))),
214
3466
        "/\\" => Ok(Some(Expression::And(
215
512
            Metadata::new(),
216
512
            Moo::new(matrix_expr![left, right; domain_int!(1..)]),
217
512
        ))),
218
2954
        "\\/" => Ok(Some(Expression::Or(
219
299
            Metadata::new(),
220
299
            Moo::new(matrix_expr![left, right; domain_int!(1..)]),
221
299
        ))),
222
2655
        "**" => Ok(Some(Expression::UnsafePow(
223
            Metadata::new(),
224
            Moo::new(left),
225
            Moo::new(right),
226
        ))),
227
2655
        "/" => {
228
            //TODO: add checks for if division is safe or not
229
6
            Ok(Some(Expression::UnsafeDiv(
230
6
                Metadata::new(),
231
6
                Moo::new(left),
232
6
                Moo::new(right),
233
6
            )))
234
        }
235
2649
        "%" => {
236
            //TODO: add checks for if mod is safe or not
237
            Ok(Some(Expression::UnsafeMod(
238
                Metadata::new(),
239
                Moo::new(left),
240
                Moo::new(right),
241
            )))
242
        }
243
2649
        "=" => Ok(Some(Expression::Eq(
244
900
            Metadata::new(),
245
900
            Moo::new(left),
246
900
            Moo::new(right),
247
900
        ))),
248
1749
        "!=" => Ok(Some(Expression::Neq(
249
90
            Metadata::new(),
250
90
            Moo::new(left),
251
90
            Moo::new(right),
252
90
        ))),
253
1659
        "<=" => Ok(Some(Expression::Leq(
254
68
            Metadata::new(),
255
68
            Moo::new(left),
256
68
            Moo::new(right),
257
68
        ))),
258
1591
        ">=" => Ok(Some(Expression::Geq(
259
51
            Metadata::new(),
260
51
            Moo::new(left),
261
51
            Moo::new(right),
262
51
        ))),
263
1540
        "<" => Ok(Some(Expression::Lt(
264
38
            Metadata::new(),
265
38
            Moo::new(left),
266
38
            Moo::new(right),
267
38
        ))),
268
1502
        ">" => Ok(Some(Expression::Gt(
269
204
            Metadata::new(),
270
204
            Moo::new(left),
271
204
            Moo::new(right),
272
204
        ))),
273
1298
        "->" => Ok(Some(Expression::Imply(
274
247
            Metadata::new(),
275
247
            Moo::new(left),
276
247
            Moo::new(right),
277
247
        ))),
278
1051
        "<->" => Ok(Some(Expression::Iff(
279
11
            Metadata::new(),
280
11
            Moo::new(left),
281
11
            Moo::new(right),
282
11
        ))),
283
1040
        "<lex" => Ok(Some(Expression::LexLt(
284
            Metadata::new(),
285
            Moo::new(left),
286
            Moo::new(right),
287
        ))),
288
1040
        ">lex" => Ok(Some(Expression::LexGt(
289
            Metadata::new(),
290
            Moo::new(left),
291
            Moo::new(right),
292
        ))),
293
1040
        "<=lex" => Ok(Some(Expression::LexLeq(
294
            Metadata::new(),
295
            Moo::new(left),
296
            Moo::new(right),
297
        ))),
298
1040
        ">=lex" => Ok(Some(Expression::LexGeq(
299
            Metadata::new(),
300
            Moo::new(left),
301
            Moo::new(right),
302
        ))),
303
1040
        "in" => Ok(Some(Expression::In(
304
258
            Metadata::new(),
305
258
            Moo::new(left),
306
258
            Moo::new(right),
307
258
        ))),
308
782
        "subset" => Ok(Some(Expression::Subset(
309
170
            Metadata::new(),
310
170
            Moo::new(left),
311
170
            Moo::new(right),
312
170
        ))),
313
612
        "subsetEq" => Ok(Some(Expression::SubsetEq(
314
136
            Metadata::new(),
315
136
            Moo::new(left),
316
136
            Moo::new(right),
317
136
        ))),
318
476
        "supset" => Ok(Some(Expression::Supset(
319
136
            Metadata::new(),
320
136
            Moo::new(left),
321
136
            Moo::new(right),
322
136
        ))),
323
340
        "supsetEq" => Ok(Some(Expression::SupsetEq(
324
136
            Metadata::new(),
325
136
            Moo::new(left),
326
136
            Moo::new(right),
327
136
        ))),
328
204
        "union" => {
329
102
            description = "set union: combines the elements from both operands".to_string();
330
102
            Ok(Some(Expression::Union(
331
102
                Metadata::new(),
332
102
                Moo::new(left),
333
102
                Moo::new(right),
334
102
            )))
335
        }
336
102
        "intersect" => {
337
102
            description =
338
102
                "set intersection: keeps only elements common to both operands".to_string();
339
102
            Ok(Some(Expression::Intersect(
340
102
                Metadata::new(),
341
102
                Moo::new(left),
342
102
                Moo::new(right),
343
102
            )))
344
        }
345
        _ => Err(FatalParseError::internal_error(
346
            format!("Invalid operator: '{op_str}'"),
347
            Some(op_node.range()),
348
        )),
349
    };
350

            
351
3924
    if expr.is_ok() {
352
3924
        let hover = HoverInfo {
353
3924
            description,
354
3924
            kind: Some(SymbolKind::Function),
355
3924
            ty: None,
356
3924
            decl_span: None,
357
3924
        };
358
3924
        span_with_hover(&op_node, ctx.source_code, ctx.source_map, hover);
359
3924
    }
360

            
361
3924
    expr
362
4422
}