1use crate::diagnostics::diagnostics_api::SymbolKind;
2use crate::errors::{FatalParseError, RecoverableParseError};
3use crate::parser::ParseContext;
4use crate::parser::atom::parse_atom;
5use crate::parser::comprehension::parse_quantifier_or_aggregate_expr;
6use crate::util::TypecheckingContext;
7use crate::{child, field, named_child};
8use conjure_cp_core::ast::{Expression, GroundDomain, Metadata, Moo};
9use conjure_cp_core::{domain_int, matrix_expr, range};
10use tree_sitter::Node;
11
12pub fn parse_expression(
13 ctx: &mut ParseContext,
14 node: Node,
15) -> Result<Option<Expression>, FatalParseError> {
16 match node.kind() {
17 "atom" => parse_atom(ctx, &node),
18 "bool_expr" => {
19 if ctx.typechecking_context == TypecheckingContext::Arithmetic {
20 ctx.record_error(RecoverableParseError::new(
21 format!(
22 "Type error: {}\n\tExepected: int\n\tGot: boolean expression",
23 &ctx.source_code[node.start_byte()..node.end_byte()]
24 ),
25 Some(node.range()),
26 ));
27 return Ok(None);
28 }
29 parse_boolean_expression(ctx, &node)
30 }
31 "arithmetic_expr" => {
32 if ctx.typechecking_context == TypecheckingContext::Boolean {
33 ctx.record_error(RecoverableParseError::new(
34 format!(
35 "Type error: {}\n\tExepected: bool\n\tGot: arithmetic expression",
36 &ctx.source_code[node.start_byte()..node.end_byte()]
37 ),
38 Some(node.range()),
39 ));
40 return Ok(None);
41 }
42 parse_arithmetic_expression(ctx, &node)
43 }
44 "comparison_expr" => {
45 if ctx.typechecking_context == TypecheckingContext::Arithmetic {
46 ctx.record_error(RecoverableParseError::new(
47 format!(
48 "Type error: {}\n\tExepected: int\n\tGot: comparison expression",
49 &ctx.source_code[node.start_byte()..node.end_byte()]
50 ),
51 Some(node.range()),
52 ));
53 return Ok(None);
54 }
55 parse_comparison_expression(ctx, &node)
56 }
57 "all_diff_comparison" => {
58 if ctx.typechecking_context == TypecheckingContext::Arithmetic {
59 ctx.record_error(RecoverableParseError::new(
60 format!("Type error: {}\n\tExepected: arithmetic expression\n\tFound: comparison expression", &ctx.source_code[node.start_byte()..node.end_byte()]),
61 Some(node.range()),
62 ));
63 return Ok(None);
64 }
65 ctx.typechecking_context = TypecheckingContext::Matrix;
66 parse_all_diff_comparison(ctx, &node)
67 }
68 _ => {
69 ctx.record_error(RecoverableParseError::new(
70 format!("Unexpected expression type: '{}'", node.kind()),
71 Some(node.range()),
72 ));
73 Ok(None)
74 }
75 }
76}
77
78fn parse_arithmetic_expression(
79 ctx: &mut ParseContext,
80 node: &Node,
81) -> Result<Option<Expression>, FatalParseError> {
82 ctx.typechecking_context = TypecheckingContext::Arithmetic;
83 ctx.inner_typechecking_context = TypecheckingContext::Unknown;
84 let Some(inner) = named_child!(recover, ctx, node) else {
85 return Ok(None);
86 };
87 match inner.kind() {
88 "atom" => parse_atom(ctx, &inner),
89 "negative_expr" | "abs_value" | "sub_arith_expr" | "factorial_expr" => {
90 parse_unary_expression(ctx, &inner)
91 }
92 "toInt_expr" => {
93 ctx.typechecking_context = TypecheckingContext::Unknown;
95 parse_unary_expression(ctx, &inner)
96 }
97 "exponent" | "product_expr" | "sum_expr" => parse_binary_expression(ctx, &inner),
98 "list_combining_expr_arith" => {
99 ctx.typechecking_context = TypecheckingContext::SetOrMatrix;
101
102 ctx.inner_typechecking_context = TypecheckingContext::Arithmetic;
104 parse_list_combining_expression(ctx, &inner)
105 }
106 "aggregate_expr" => {
107 ctx.inner_typechecking_context = TypecheckingContext::Arithmetic;
108 parse_quantifier_or_aggregate_expr(ctx, &inner)
109 }
110 _ => {
111 ctx.record_error(RecoverableParseError::new(
112 format!("Expected arithmetic expression, found: {}", inner.kind()),
113 Some(inner.range()),
114 ));
115 Ok(None)
116 }
117 }
118}
119
120fn parse_comparison_expression(
121 ctx: &mut ParseContext,
122 node: &Node,
123) -> Result<Option<Expression>, FatalParseError> {
124 let Some(inner) = named_child!(recover, ctx, node) else {
125 return Ok(None);
126 };
127 match inner.kind() {
128 "arithmetic_comparison" => {
129 ctx.typechecking_context = TypecheckingContext::Arithmetic;
131 parse_binary_expression(ctx, &inner)
132 }
133 "lex_comparison" => {
134 ctx.typechecking_context = TypecheckingContext::Unknown;
136 parse_binary_expression(ctx, &inner)
137 }
138 "equality_comparison" => {
139 ctx.typechecking_context = TypecheckingContext::Unknown;
141 parse_binary_expression(ctx, &inner)
142 }
143 "set_comparison" => {
144 ctx.typechecking_context = TypecheckingContext::Set;
146 parse_binary_expression(ctx, &inner)
147 }
148 "all_diff_comparison" => {
149 ctx.typechecking_context = TypecheckingContext::Matrix;
150 parse_all_diff_comparison(ctx, &inner)
151 }
152 _ => {
153 ctx.record_error(RecoverableParseError::new(
154 format!("Expected comparison expression, found '{}'", inner.kind()),
155 Some(inner.range()),
156 ));
157 Ok(None)
158 }
159 }
160}
161
162fn parse_boolean_expression(
163 ctx: &mut ParseContext,
164 node: &Node,
165) -> Result<Option<Expression>, FatalParseError> {
166 ctx.typechecking_context = TypecheckingContext::Boolean;
167 ctx.inner_typechecking_context = TypecheckingContext::Unknown;
168 let Some(inner) = named_child!(recover, ctx, node) else {
169 return Ok(None);
170 };
171 match inner.kind() {
172 "atom" => parse_atom(ctx, &inner),
173 "not_expr" | "sub_bool_expr" => parse_unary_expression(ctx, &inner),
174 "and_expr" | "or_expr" | "implication" | "iff_expr" => parse_binary_expression(ctx, &inner),
175 "list_combining_expr_bool" => {
176 ctx.typechecking_context = TypecheckingContext::SetOrMatrix;
178
179 ctx.inner_typechecking_context = TypecheckingContext::Boolean;
181 parse_list_combining_expression(ctx, &inner)
182 }
183 "quantifier_expr" => parse_quantifier_or_aggregate_expr(ctx, &inner),
184 _ => {
185 ctx.record_error(RecoverableParseError::new(
186 format!("Expected boolean expression, found '{}'", inner.kind()),
187 Some(inner.range()),
188 ));
189 Ok(None)
190 }
191 }
192}
193
194fn parse_list_combining_expression(
195 ctx: &mut ParseContext,
196 node: &Node,
197) -> Result<Option<Expression>, FatalParseError> {
198 let Some(operator_node) = field!(recover, ctx, node, "operator") else {
199 return Ok(None);
200 };
201 let operator_str = &ctx.source_code[operator_node.start_byte()..operator_node.end_byte()];
202
203 let Some(arg_node) = field!(recover, ctx, node, "arg") else {
204 return Ok(None);
205 };
206 let Some(inner) = parse_atom(ctx, &arg_node)? else {
209 return Ok(None);
210 };
211
212 let expr = match operator_str {
213 "and" => Ok(Some(Expression::And(Metadata::new(), Moo::new(inner)))),
214 "or" => Ok(Some(Expression::Or(Metadata::new(), Moo::new(inner)))),
215 "sum" => Ok(Some(Expression::Sum(Metadata::new(), Moo::new(inner)))),
216 "product" => Ok(Some(Expression::Product(Metadata::new(), Moo::new(inner)))),
217 "min" => Ok(Some(Expression::Min(Metadata::new(), Moo::new(inner)))),
218 "max" => Ok(Some(Expression::Max(Metadata::new(), Moo::new(inner)))),
219 _ => {
220 ctx.record_error(RecoverableParseError::new(
221 format!("Invalid operator: '{operator_str}'"),
222 Some(operator_node.range()),
223 ));
224 Ok(None)
225 }
226 };
227
228 if expr.is_ok() {
229 ctx.add_span_and_doc_hover(
230 &operator_node,
231 operator_str,
232 SymbolKind::Function,
233 None,
234 None,
235 );
236 }
237
238 expr
239}
240
241fn parse_all_diff_comparison(
242 ctx: &mut ParseContext,
243 node: &Node,
244) -> Result<Option<Expression>, FatalParseError> {
245 let Some(arg_node) = field!(recover, ctx, node, "arg") else {
246 return Ok(None);
247 };
248 let Some(inner) = parse_expression(ctx, arg_node)? else {
249 return Ok(None);
250 };
251
252 let all_diff_keyword_node = child!(node, 0, "allDiff");
253 ctx.add_span_and_doc_hover(
254 &all_diff_keyword_node,
255 "allDiff",
256 SymbolKind::Function,
257 None,
258 None,
259 );
260 Ok(Some(Expression::AllDiff(Metadata::new(), Moo::new(inner))))
261}
262
263fn parse_unary_expression(
264 ctx: &mut ParseContext,
265 node: &Node,
266) -> Result<Option<Expression>, FatalParseError> {
267 let Some(expr_node) = field!(recover, ctx, node, "expression") else {
268 return Ok(None);
269 };
270 let Some(inner) = parse_expression(ctx, expr_node)? else {
271 return Ok(None);
272 };
273
274 match node.kind() {
275 "negative_expr" => Ok(Some(Expression::Neg(Metadata::new(), Moo::new(inner)))),
276 "abs_value" => Ok(Some(Expression::Abs(Metadata::new(), Moo::new(inner)))),
277 "not_expr" => Ok(Some(Expression::Not(Metadata::new(), Moo::new(inner)))),
278 "toInt_expr" => {
279 let to_int_keyword_node = child!(node, 0, "toInt");
280 ctx.add_span_and_doc_hover(
281 &to_int_keyword_node,
282 "toInt",
283 SymbolKind::Function,
284 None,
285 None,
286 );
287 Ok(Some(Expression::ToInt(Metadata::new(), Moo::new(inner))))
288 }
289 "factorial_expr" => {
290 if let Some(op_node) = (0..node.child_count())
292 .filter_map(|i| node.child(i.try_into().unwrap()))
293 .find(|c| matches!(c.kind(), "!" | "factorial"))
294 {
295 ctx.add_span_and_doc_hover(
296 &op_node,
297 "post_factorial",
298 SymbolKind::Function,
299 None,
300 None,
301 );
302 }
303
304 Ok(Some(Expression::Factorial(
305 Metadata::new(),
306 Moo::new(inner),
307 )))
308 }
309 "sub_bool_expr" | "sub_arith_expr" => Ok(Some(inner)),
310 _ => {
311 ctx.record_error(RecoverableParseError::new(
312 format!("Unrecognised unary operation: '{}'", node.kind()),
313 Some(node.range()),
314 ));
315 Ok(None)
316 }
317 }
318}
319
320pub fn parse_binary_expression(
321 ctx: &mut ParseContext,
322 node: &Node,
323) -> Result<Option<Expression>, FatalParseError> {
324 let Some(op_node) = field!(recover, ctx, node, "operator") else {
325 return Ok(None);
326 };
327 let op_str = &ctx.source_code[op_node.start_byte()..op_node.end_byte()];
328
329 let saved_ctx = ctx.typechecking_context;
330
331 if op_str == "in" {
333 ctx.typechecking_context = TypecheckingContext::Unknown
334 }
335
336 let Some(left_node) = field!(recover, ctx, node, "left") else {
338 return Ok(None);
339 };
340 let Some(left) = parse_expression(ctx, left_node)? else {
341 return Ok(None);
342 };
343
344 ctx.typechecking_context = saved_ctx;
346
347 if matches!(op_str, "=" | "!=") {
349 ctx.typechecking_context = inferred_context_from_expression(&left);
350 }
351
352 let Some(right_node) = field!(recover, ctx, node, "right") else {
354 return Ok(None);
355 };
356 let Some(right) = parse_expression(ctx, right_node)? else {
357 return Ok(None);
358 };
359
360 ctx.typechecking_context = saved_ctx;
362
363 let mut doc_name = "";
364 let expr = match op_str {
365 "+" => {
370 doc_name = "L_Plus";
371 Ok(Some(Expression::Sum(
372 Metadata::new(),
373 Moo::new(matrix_expr![left, right; domain_int!(1..)]),
374 )))
375 }
376 "-" => {
377 doc_name = "L_Minus";
378 Ok(Some(Expression::Minus(
379 Metadata::new(),
380 Moo::new(left),
381 Moo::new(right),
382 )))
383 }
384 "*" => {
385 doc_name = "L_Times";
386 Ok(Some(Expression::Product(
387 Metadata::new(),
388 Moo::new(matrix_expr![left, right; domain_int!(1..)]),
389 )))
390 }
391 "/\\" => {
392 doc_name = "and";
393 Ok(Some(Expression::And(
394 Metadata::new(),
395 Moo::new(matrix_expr![left, right; domain_int!(1..)]),
396 )))
397 }
398 "\\/" => {
399 doc_name = "or";
401 Ok(Some(Expression::Or(
402 Metadata::new(),
403 Moo::new(matrix_expr![left, right; domain_int!(1..)]),
404 )))
405 }
406 "**" => {
407 doc_name = "L_Pow";
408 Ok(Some(Expression::UnsafePow(
409 Metadata::new(),
410 Moo::new(left),
411 Moo::new(right),
412 )))
413 }
414 "/" => {
415 doc_name = "L_Div";
417 Ok(Some(Expression::UnsafeDiv(
418 Metadata::new(),
419 Moo::new(left),
420 Moo::new(right),
421 )))
422 }
423 "%" => {
424 doc_name = "L_Mod";
426 Ok(Some(Expression::UnsafeMod(
427 Metadata::new(),
428 Moo::new(left),
429 Moo::new(right),
430 )))
431 }
432
433 "=" => {
434 doc_name = "L_Eq"; Ok(Some(Expression::Eq(
436 Metadata::new(),
437 Moo::new(left),
438 Moo::new(right),
439 )))
440 }
441 "!=" => {
442 doc_name = "L_Neq"; Ok(Some(Expression::Neq(
444 Metadata::new(),
445 Moo::new(left),
446 Moo::new(right),
447 )))
448 }
449 "<=" => {
450 doc_name = "L_Leq"; Ok(Some(Expression::Leq(
452 Metadata::new(),
453 Moo::new(left),
454 Moo::new(right),
455 )))
456 }
457 ">=" => {
458 doc_name = "L_Geq"; Ok(Some(Expression::Geq(
460 Metadata::new(),
461 Moo::new(left),
462 Moo::new(right),
463 )))
464 }
465 "<" => {
466 doc_name = "L_Lt"; Ok(Some(Expression::Lt(
468 Metadata::new(),
469 Moo::new(left),
470 Moo::new(right),
471 )))
472 }
473 ">" => {
474 doc_name = "L_Gt"; Ok(Some(Expression::Gt(
476 Metadata::new(),
477 Moo::new(left),
478 Moo::new(right),
479 )))
480 }
481
482 "->" => {
483 doc_name = "L_Imply"; Ok(Some(Expression::Imply(
485 Metadata::new(),
486 Moo::new(left),
487 Moo::new(right),
488 )))
489 }
490 "<->" => {
491 doc_name = "L_Iff"; Ok(Some(Expression::Iff(
493 Metadata::new(),
494 Moo::new(left),
495 Moo::new(right),
496 )))
497 }
498 "<lex" => {
499 doc_name = "L_LexLt"; Ok(Some(Expression::LexLt(
501 Metadata::new(),
502 Moo::new(left),
503 Moo::new(right),
504 )))
505 }
506 ">lex" => {
507 doc_name = "L_LexGt"; Ok(Some(Expression::LexGt(
509 Metadata::new(),
510 Moo::new(left),
511 Moo::new(right),
512 )))
513 }
514 "<=lex" => {
515 doc_name = "L_LexLeq"; Ok(Some(Expression::LexLeq(
517 Metadata::new(),
518 Moo::new(left),
519 Moo::new(right),
520 )))
521 }
522 ">=lex" => {
523 doc_name = "L_LexGeq"; Ok(Some(Expression::LexGeq(
525 Metadata::new(),
526 Moo::new(left),
527 Moo::new(right),
528 )))
529 }
530 "in" => {
531 doc_name = "L_in";
532 Ok(Some(Expression::In(
533 Metadata::new(),
534 Moo::new(left),
535 Moo::new(right),
536 )))
537 }
538 "subset" => {
539 doc_name = "L_subset";
540 Ok(Some(Expression::Subset(
541 Metadata::new(),
542 Moo::new(left),
543 Moo::new(right),
544 )))
545 }
546 "subsetEq" => {
547 doc_name = "L_subsetEq";
548 Ok(Some(Expression::SubsetEq(
549 Metadata::new(),
550 Moo::new(left),
551 Moo::new(right),
552 )))
553 }
554 "supset" => {
555 doc_name = "L_supset";
556 Ok(Some(Expression::Supset(
557 Metadata::new(),
558 Moo::new(left),
559 Moo::new(right),
560 )))
561 }
562 "supsetEq" => {
563 doc_name = "L_supsetEq";
564 Ok(Some(Expression::SupsetEq(
565 Metadata::new(),
566 Moo::new(left),
567 Moo::new(right),
568 )))
569 }
570 "union" => {
571 doc_name = "L_union";
572 Ok(Some(Expression::Union(
573 Metadata::new(),
574 Moo::new(left),
575 Moo::new(right),
576 )))
577 }
578 "intersect" => {
579 doc_name = "L_intersect";
580 Ok(Some(Expression::Intersect(
581 Metadata::new(),
582 Moo::new(left),
583 Moo::new(right),
584 )))
585 }
586 _ => {
587 ctx.record_error(RecoverableParseError::new(
588 format!("Invalid operator: '{op_str}'"),
589 Some(op_node.range()),
590 ));
591 Ok(None)
592 }
593 };
594
595 if expr.is_ok() {
596 ctx.add_span_and_doc_hover(&op_node, doc_name, SymbolKind::Function, None, None);
597 }
598
599 expr
600}
601
602fn inferred_context_from_expression(expr: &Expression) -> TypecheckingContext {
603 if matches!(
605 expr,
606 Expression::UnsafeIndex(_, _, _) | Expression::UnsafeSlice(_, _, _)
607 ) {
608 return TypecheckingContext::Unknown;
609 }
610
611 let Some(domain) = expr.domain_of() else {
612 return TypecheckingContext::Unknown;
613 };
614 let Some(ground) = domain.resolve() else {
615 return TypecheckingContext::Unknown;
616 };
617
618 match ground.as_ref() {
619 GroundDomain::Bool => TypecheckingContext::Boolean,
620 GroundDomain::Int(_) => TypecheckingContext::Arithmetic,
621 GroundDomain::Set(_, _) => TypecheckingContext::Set,
622 GroundDomain::MSet(_, _) => TypecheckingContext::MSet,
623 GroundDomain::Matrix(_, _) => TypecheckingContext::Matrix,
624 GroundDomain::Tuple(_) => TypecheckingContext::Tuple,
625 GroundDomain::Record(_) => TypecheckingContext::Record,
626 GroundDomain::Partition(_, _) => TypecheckingContext::Partition,
627 GroundDomain::Sequence(_, _) => TypecheckingContext::Sequence,
628 GroundDomain::Function(_, _, _)
629 | GroundDomain::Variant(_)
630 | GroundDomain::Relation(_, _)
631 | GroundDomain::Empty(_) => TypecheckingContext::Unknown,
632 }
633}