Skip to main content

conjure_cp_essence_parser/parser/
expression.rs

1use crate::diagnostics::diagnostics_api::SymbolKind;
2use crate::errors::{FatalParseError, RecoverableParseError};
3use crate::parser::ParseContext;
4use crate::parser::atom::parse_atom;
5use crate::parser::comprehension::parse_quantifier_or_aggregate_expr;
6use crate::util::TypecheckingContext;
7use crate::{child, field, named_child};
8use conjure_cp_core::ast::{Expression, GroundDomain, Metadata, Moo};
9use conjure_cp_core::{domain_int, matrix_expr, range};
10use tree_sitter::Node;
11
12pub fn parse_expression(
13    ctx: &mut ParseContext,
14    node: Node,
15) -> Result<Option<Expression>, FatalParseError> {
16    match node.kind() {
17        "atom" => parse_atom(ctx, &node),
18        "bool_expr" => {
19            if ctx.typechecking_context == TypecheckingContext::Arithmetic {
20                ctx.record_error(RecoverableParseError::new(
21                    format!(
22                        "Type error: {}\n\tExepected: int\n\tGot: boolean expression",
23                        &ctx.source_code[node.start_byte()..node.end_byte()]
24                    ),
25                    Some(node.range()),
26                ));
27                return Ok(None);
28            }
29            parse_boolean_expression(ctx, &node)
30        }
31        "arithmetic_expr" => {
32            if ctx.typechecking_context == TypecheckingContext::Boolean {
33                ctx.record_error(RecoverableParseError::new(
34                    format!(
35                        "Type error: {}\n\tExepected: bool\n\tGot: arithmetic expression",
36                        &ctx.source_code[node.start_byte()..node.end_byte()]
37                    ),
38                    Some(node.range()),
39                ));
40                return Ok(None);
41            }
42            parse_arithmetic_expression(ctx, &node)
43        }
44        "comparison_expr" => {
45            if ctx.typechecking_context == TypecheckingContext::Arithmetic {
46                ctx.record_error(RecoverableParseError::new(
47                    format!(
48                        "Type error: {}\n\tExepected: int\n\tGot: comparison expression",
49                        &ctx.source_code[node.start_byte()..node.end_byte()]
50                    ),
51                    Some(node.range()),
52                ));
53                return Ok(None);
54            }
55            parse_comparison_expression(ctx, &node)
56        }
57        "all_diff_comparison" => {
58            if ctx.typechecking_context == TypecheckingContext::Arithmetic {
59                ctx.record_error(RecoverableParseError::new(
60                    format!("Type error: {}\n\tExepected: arithmetic expression\n\tFound: comparison expression", &ctx.source_code[node.start_byte()..node.end_byte()]),
61                    Some(node.range()),
62                ));
63                return Ok(None);
64            }
65            ctx.typechecking_context = TypecheckingContext::Matrix;
66            parse_all_diff_comparison(ctx, &node)
67        }
68        _ => {
69            ctx.record_error(RecoverableParseError::new(
70                format!("Unexpected expression type: '{}'", node.kind()),
71                Some(node.range()),
72            ));
73            Ok(None)
74        }
75    }
76}
77
78fn parse_arithmetic_expression(
79    ctx: &mut ParseContext,
80    node: &Node,
81) -> Result<Option<Expression>, FatalParseError> {
82    ctx.typechecking_context = TypecheckingContext::Arithmetic;
83    ctx.inner_typechecking_context = TypecheckingContext::Unknown;
84    let Some(inner) = named_child!(recover, ctx, node) else {
85        return Ok(None);
86    };
87    match inner.kind() {
88        "atom" => parse_atom(ctx, &inner),
89        "negative_expr" | "abs_value" | "sub_arith_expr" | "factorial_expr" => {
90            parse_unary_expression(ctx, &inner)
91        }
92        "toInt_expr" => {
93            // add special handling for toInt, as it is arithmetic but takes a non-arithmetic operand
94            ctx.typechecking_context = TypecheckingContext::Unknown;
95            parse_unary_expression(ctx, &inner)
96        }
97        "exponent" | "product_expr" | "sum_expr" => parse_binary_expression(ctx, &inner),
98        "list_combining_expr_arith" => {
99            // list-combining arithmetic operators accept either set or matrix operands
100            ctx.typechecking_context = TypecheckingContext::SetOrMatrix;
101
102            // set inner context to arithmetic to ensure elements of list are arithmetic expressions
103            ctx.inner_typechecking_context = TypecheckingContext::Arithmetic;
104            parse_list_combining_expression(ctx, &inner)
105        }
106        "aggregate_expr" => {
107            ctx.inner_typechecking_context = TypecheckingContext::Arithmetic;
108            parse_quantifier_or_aggregate_expr(ctx, &inner)
109        }
110        _ => {
111            ctx.record_error(RecoverableParseError::new(
112                format!("Expected arithmetic expression, found: {}", inner.kind()),
113                Some(inner.range()),
114            ));
115            Ok(None)
116        }
117    }
118}
119
120fn parse_comparison_expression(
121    ctx: &mut ParseContext,
122    node: &Node,
123) -> Result<Option<Expression>, FatalParseError> {
124    let Some(inner) = named_child!(recover, ctx, node) else {
125        return Ok(None);
126    };
127    match inner.kind() {
128        "arithmetic_comparison" => {
129            // Arithmetic comparisons require arithmetic operands
130            ctx.typechecking_context = TypecheckingContext::Arithmetic;
131            parse_binary_expression(ctx, &inner)
132        }
133        "lex_comparison" => {
134            // TODO: check that both operands are comparable collections.
135            ctx.typechecking_context = TypecheckingContext::Unknown;
136            parse_binary_expression(ctx, &inner)
137        }
138        "equality_comparison" => {
139            // Equality works on any type, typechecking of operands will be handled within parse_binary_expression
140            ctx.typechecking_context = TypecheckingContext::Unknown;
141            parse_binary_expression(ctx, &inner)
142        }
143        "set_comparison" => {
144            // Set comparisons require set operands (except 'in', which is hadled later)
145            ctx.typechecking_context = TypecheckingContext::Set;
146            parse_binary_expression(ctx, &inner)
147        }
148        "all_diff_comparison" => {
149            ctx.typechecking_context = TypecheckingContext::Matrix;
150            parse_all_diff_comparison(ctx, &inner)
151        }
152        _ => {
153            ctx.record_error(RecoverableParseError::new(
154                format!("Expected comparison expression, found '{}'", inner.kind()),
155                Some(inner.range()),
156            ));
157            Ok(None)
158        }
159    }
160}
161
162fn parse_boolean_expression(
163    ctx: &mut ParseContext,
164    node: &Node,
165) -> Result<Option<Expression>, FatalParseError> {
166    ctx.typechecking_context = TypecheckingContext::Boolean;
167    ctx.inner_typechecking_context = TypecheckingContext::Unknown;
168    let Some(inner) = named_child!(recover, ctx, node) else {
169        return Ok(None);
170    };
171    match inner.kind() {
172        "atom" => parse_atom(ctx, &inner),
173        "not_expr" | "sub_bool_expr" => parse_unary_expression(ctx, &inner),
174        "and_expr" | "or_expr" | "implication" | "iff_expr" => parse_binary_expression(ctx, &inner),
175        "list_combining_expr_bool" => {
176            // list-combining boolean operators accept either set or matrix operands
177            ctx.typechecking_context = TypecheckingContext::SetOrMatrix;
178
179            // set inner context to boolean to ensure elements of list are boolean expressions
180            ctx.inner_typechecking_context = TypecheckingContext::Boolean;
181            parse_list_combining_expression(ctx, &inner)
182        }
183        "quantifier_expr" => parse_quantifier_or_aggregate_expr(ctx, &inner),
184        _ => {
185            ctx.record_error(RecoverableParseError::new(
186                format!("Expected boolean expression, found '{}'", inner.kind()),
187                Some(inner.range()),
188            ));
189            Ok(None)
190        }
191    }
192}
193
194fn parse_list_combining_expression(
195    ctx: &mut ParseContext,
196    node: &Node,
197) -> Result<Option<Expression>, FatalParseError> {
198    let Some(operator_node) = field!(recover, ctx, node, "operator") else {
199        return Ok(None);
200    };
201    let operator_str = &ctx.source_code[operator_node.start_byte()..operator_node.end_byte()];
202
203    let Some(arg_node) = field!(recover, ctx, node, "arg") else {
204        return Ok(None);
205    };
206    // While parsing inner, the typechecking context is SetOrMatrix
207    // The inner context is either Boolean or Arithmetic so the elements of the set/matrix are typechecked correctly.
208    let Some(inner) = parse_atom(ctx, &arg_node)? else {
209        return Ok(None);
210    };
211
212    let expr = match operator_str {
213        "and" => Ok(Some(Expression::And(Metadata::new(), Moo::new(inner)))),
214        "or" => Ok(Some(Expression::Or(Metadata::new(), Moo::new(inner)))),
215        "sum" => Ok(Some(Expression::Sum(Metadata::new(), Moo::new(inner)))),
216        "product" => Ok(Some(Expression::Product(Metadata::new(), Moo::new(inner)))),
217        "min" => Ok(Some(Expression::Min(Metadata::new(), Moo::new(inner)))),
218        "max" => Ok(Some(Expression::Max(Metadata::new(), Moo::new(inner)))),
219        _ => {
220            ctx.record_error(RecoverableParseError::new(
221                format!("Invalid operator: '{operator_str}'"),
222                Some(operator_node.range()),
223            ));
224            Ok(None)
225        }
226    };
227
228    if expr.is_ok() {
229        ctx.add_span_and_doc_hover(
230            &operator_node,
231            operator_str,
232            SymbolKind::Function,
233            None,
234            None,
235        );
236    }
237
238    expr
239}
240
241fn parse_all_diff_comparison(
242    ctx: &mut ParseContext,
243    node: &Node,
244) -> Result<Option<Expression>, FatalParseError> {
245    let Some(arg_node) = field!(recover, ctx, node, "arg") else {
246        return Ok(None);
247    };
248    let Some(inner) = parse_expression(ctx, arg_node)? else {
249        return Ok(None);
250    };
251
252    let all_diff_keyword_node = child!(node, 0, "allDiff");
253    ctx.add_span_and_doc_hover(
254        &all_diff_keyword_node,
255        "allDiff",
256        SymbolKind::Function,
257        None,
258        None,
259    );
260    Ok(Some(Expression::AllDiff(Metadata::new(), Moo::new(inner))))
261}
262
263fn parse_unary_expression(
264    ctx: &mut ParseContext,
265    node: &Node,
266) -> Result<Option<Expression>, FatalParseError> {
267    let Some(expr_node) = field!(recover, ctx, node, "expression") else {
268        return Ok(None);
269    };
270    let Some(inner) = parse_expression(ctx, expr_node)? else {
271        return Ok(None);
272    };
273
274    match node.kind() {
275        "negative_expr" => Ok(Some(Expression::Neg(Metadata::new(), Moo::new(inner)))),
276        "abs_value" => Ok(Some(Expression::Abs(Metadata::new(), Moo::new(inner)))),
277        "not_expr" => Ok(Some(Expression::Not(Metadata::new(), Moo::new(inner)))),
278        "toInt_expr" => {
279            let to_int_keyword_node = child!(node, 0, "toInt");
280            ctx.add_span_and_doc_hover(
281                &to_int_keyword_node,
282                "toInt",
283                SymbolKind::Function,
284                None,
285                None,
286            );
287            Ok(Some(Expression::ToInt(Metadata::new(), Moo::new(inner))))
288        }
289        "factorial_expr" => {
290            // looking for the operator node (either '!' at the end or 'factorial' at the start) to add hover info
291            if let Some(op_node) = (0..node.child_count())
292                .filter_map(|i| node.child(i.try_into().unwrap()))
293                .find(|c| matches!(c.kind(), "!" | "factorial"))
294            {
295                ctx.add_span_and_doc_hover(
296                    &op_node,
297                    "post_factorial",
298                    SymbolKind::Function,
299                    None,
300                    None,
301                );
302            }
303
304            Ok(Some(Expression::Factorial(
305                Metadata::new(),
306                Moo::new(inner),
307            )))
308        }
309        "sub_bool_expr" | "sub_arith_expr" => Ok(Some(inner)),
310        _ => {
311            ctx.record_error(RecoverableParseError::new(
312                format!("Unrecognised unary operation: '{}'", node.kind()),
313                Some(node.range()),
314            ));
315            Ok(None)
316        }
317    }
318}
319
320pub fn parse_binary_expression(
321    ctx: &mut ParseContext,
322    node: &Node,
323) -> Result<Option<Expression>, FatalParseError> {
324    let Some(op_node) = field!(recover, ctx, node, "operator") else {
325        return Ok(None);
326    };
327    let op_str = &ctx.source_code[op_node.start_byte()..op_node.end_byte()];
328
329    let saved_ctx = ctx.typechecking_context;
330
331    // Special handling for 'in' operator, as the left operand doesn't have to be a set
332    if op_str == "in" {
333        ctx.typechecking_context = TypecheckingContext::Unknown
334    }
335
336    // parse left operand
337    let Some(left_node) = field!(recover, ctx, node, "left") else {
338        return Ok(None);
339    };
340    let Some(left) = parse_expression(ctx, left_node)? else {
341        return Ok(None);
342    };
343
344    // reset context, if needed
345    ctx.typechecking_context = saved_ctx;
346
347    // Equality/inequality: enforce right operand to match left operand type when inferable
348    if matches!(op_str, "=" | "!=") {
349        ctx.typechecking_context = inferred_context_from_expression(&left);
350    }
351
352    // parse right operand
353    let Some(right_node) = field!(recover, ctx, node, "right") else {
354        return Ok(None);
355    };
356    let Some(right) = parse_expression(ctx, right_node)? else {
357        return Ok(None);
358    };
359
360    // restore original contexts for parent expression parsing
361    ctx.typechecking_context = saved_ctx;
362
363    let mut doc_name = "";
364    let expr = match op_str {
365        // NB: We are deliberately setting the index domain to 1.., not 1..2.
366        // Semantically, this means "a list that can grow/shrink arbitrarily".
367        // This is expected by rules which will modify the terms of the sum expression
368        // (e.g. by partially evaluating them).
369        "+" => {
370            doc_name = "L_Plus";
371            Ok(Some(Expression::Sum(
372                Metadata::new(),
373                Moo::new(matrix_expr![left, right; domain_int!(1..)]),
374            )))
375        }
376        "-" => {
377            doc_name = "L_Minus";
378            Ok(Some(Expression::Minus(
379                Metadata::new(),
380                Moo::new(left),
381                Moo::new(right),
382            )))
383        }
384        "*" => {
385            doc_name = "L_Times";
386            Ok(Some(Expression::Product(
387                Metadata::new(),
388                Moo::new(matrix_expr![left, right; domain_int!(1..)]),
389            )))
390        }
391        "/\\" => {
392            doc_name = "and";
393            Ok(Some(Expression::And(
394                Metadata::new(),
395                Moo::new(matrix_expr![left, right; domain_int!(1..)]),
396            )))
397        }
398        "\\/" => {
399            // No documentation for or in Bits yet
400            doc_name = "or";
401            Ok(Some(Expression::Or(
402                Metadata::new(),
403                Moo::new(matrix_expr![left, right; domain_int!(1..)]),
404            )))
405        }
406        "**" => {
407            doc_name = "L_Pow";
408            Ok(Some(Expression::UnsafePow(
409                Metadata::new(),
410                Moo::new(left),
411                Moo::new(right),
412            )))
413        }
414        "/" => {
415            //TODO: add checks for if division is safe or not
416            doc_name = "L_Div";
417            Ok(Some(Expression::UnsafeDiv(
418                Metadata::new(),
419                Moo::new(left),
420                Moo::new(right),
421            )))
422        }
423        "%" => {
424            //TODO: add checks for if mod is safe or not
425            doc_name = "L_Mod";
426            Ok(Some(Expression::UnsafeMod(
427                Metadata::new(),
428                Moo::new(left),
429                Moo::new(right),
430            )))
431        }
432
433        "=" => {
434            doc_name = "L_Eq"; //no docs yet
435            Ok(Some(Expression::Eq(
436                Metadata::new(),
437                Moo::new(left),
438                Moo::new(right),
439            )))
440        }
441        "!=" => {
442            doc_name = "L_Neq"; //no docs yet
443            Ok(Some(Expression::Neq(
444                Metadata::new(),
445                Moo::new(left),
446                Moo::new(right),
447            )))
448        }
449        "<=" => {
450            doc_name = "L_Leq"; //no docs yet
451            Ok(Some(Expression::Leq(
452                Metadata::new(),
453                Moo::new(left),
454                Moo::new(right),
455            )))
456        }
457        ">=" => {
458            doc_name = "L_Geq"; //no docs yet
459            Ok(Some(Expression::Geq(
460                Metadata::new(),
461                Moo::new(left),
462                Moo::new(right),
463            )))
464        }
465        "<" => {
466            doc_name = "L_Lt"; //no docs yet
467            Ok(Some(Expression::Lt(
468                Metadata::new(),
469                Moo::new(left),
470                Moo::new(right),
471            )))
472        }
473        ">" => {
474            doc_name = "L_Gt"; //no docs yet
475            Ok(Some(Expression::Gt(
476                Metadata::new(),
477                Moo::new(left),
478                Moo::new(right),
479            )))
480        }
481
482        "->" => {
483            doc_name = "L_Imply"; //no docs yet
484            Ok(Some(Expression::Imply(
485                Metadata::new(),
486                Moo::new(left),
487                Moo::new(right),
488            )))
489        }
490        "<->" => {
491            doc_name = "L_Iff"; //no docs yet
492            Ok(Some(Expression::Iff(
493                Metadata::new(),
494                Moo::new(left),
495                Moo::new(right),
496            )))
497        }
498        "<lex" => {
499            doc_name = "L_LexLt"; //no docs yet
500            Ok(Some(Expression::LexLt(
501                Metadata::new(),
502                Moo::new(left),
503                Moo::new(right),
504            )))
505        }
506        ">lex" => {
507            doc_name = "L_LexGt"; //no docs yet
508            Ok(Some(Expression::LexGt(
509                Metadata::new(),
510                Moo::new(left),
511                Moo::new(right),
512            )))
513        }
514        "<=lex" => {
515            doc_name = "L_LexLeq"; //no docs yet
516            Ok(Some(Expression::LexLeq(
517                Metadata::new(),
518                Moo::new(left),
519                Moo::new(right),
520            )))
521        }
522        ">=lex" => {
523            doc_name = "L_LexGeq"; //no docs yet
524            Ok(Some(Expression::LexGeq(
525                Metadata::new(),
526                Moo::new(left),
527                Moo::new(right),
528            )))
529        }
530        "in" => {
531            doc_name = "L_in";
532            Ok(Some(Expression::In(
533                Metadata::new(),
534                Moo::new(left),
535                Moo::new(right),
536            )))
537        }
538        "subset" => {
539            doc_name = "L_subset";
540            Ok(Some(Expression::Subset(
541                Metadata::new(),
542                Moo::new(left),
543                Moo::new(right),
544            )))
545        }
546        "subsetEq" => {
547            doc_name = "L_subsetEq";
548            Ok(Some(Expression::SubsetEq(
549                Metadata::new(),
550                Moo::new(left),
551                Moo::new(right),
552            )))
553        }
554        "supset" => {
555            doc_name = "L_supset";
556            Ok(Some(Expression::Supset(
557                Metadata::new(),
558                Moo::new(left),
559                Moo::new(right),
560            )))
561        }
562        "supsetEq" => {
563            doc_name = "L_supsetEq";
564            Ok(Some(Expression::SupsetEq(
565                Metadata::new(),
566                Moo::new(left),
567                Moo::new(right),
568            )))
569        }
570        "union" => {
571            doc_name = "L_union";
572            Ok(Some(Expression::Union(
573                Metadata::new(),
574                Moo::new(left),
575                Moo::new(right),
576            )))
577        }
578        "intersect" => {
579            doc_name = "L_intersect";
580            Ok(Some(Expression::Intersect(
581                Metadata::new(),
582                Moo::new(left),
583                Moo::new(right),
584            )))
585        }
586        _ => {
587            ctx.record_error(RecoverableParseError::new(
588                format!("Invalid operator: '{op_str}'"),
589                Some(op_node.range()),
590            ));
591            Ok(None)
592        }
593    };
594
595    if expr.is_ok() {
596        ctx.add_span_and_doc_hover(&op_node, doc_name, SymbolKind::Function, None, None);
597    }
598
599    expr
600}
601
602fn inferred_context_from_expression(expr: &Expression) -> TypecheckingContext {
603    // TODO: typechecking for index/slice expressions
604    if matches!(
605        expr,
606        Expression::UnsafeIndex(_, _, _) | Expression::UnsafeSlice(_, _, _)
607    ) {
608        return TypecheckingContext::Unknown;
609    }
610
611    let Some(domain) = expr.domain_of() else {
612        return TypecheckingContext::Unknown;
613    };
614    let Some(ground) = domain.resolve() else {
615        return TypecheckingContext::Unknown;
616    };
617
618    match ground.as_ref() {
619        GroundDomain::Bool => TypecheckingContext::Boolean,
620        GroundDomain::Int(_) => TypecheckingContext::Arithmetic,
621        GroundDomain::Set(_, _) => TypecheckingContext::Set,
622        GroundDomain::MSet(_, _) => TypecheckingContext::MSet,
623        GroundDomain::Matrix(_, _) => TypecheckingContext::Matrix,
624        GroundDomain::Tuple(_) => TypecheckingContext::Tuple,
625        GroundDomain::Record(_) => TypecheckingContext::Record,
626        GroundDomain::Partition(_, _) => TypecheckingContext::Partition,
627        GroundDomain::Sequence(_, _) => TypecheckingContext::Sequence,
628        GroundDomain::Function(_, _, _)
629        | GroundDomain::Variant(_)
630        | GroundDomain::Relation(_, _)
631        | GroundDomain::Empty(_) => TypecheckingContext::Unknown,
632    }
633}