Skip to main content

conjure_cp_essence_parser/parser/
expression.rs

1use crate::diagnostics::diagnostics_api::SymbolKind;
2use crate::diagnostics::source_map::{HoverInfo, span_with_hover};
3use crate::errors::FatalParseError;
4use crate::parser::ParseContext;
5use crate::parser::atom::parse_atom;
6use crate::parser::comprehension::parse_quantifier_or_aggregate_expr;
7use crate::util::TypecheckingContext;
8use crate::{field, named_child};
9use conjure_cp_core::ast::{Expression, Metadata, Moo};
10use conjure_cp_core::{domain_int, matrix_expr, range};
11use tree_sitter::Node;
12
13pub fn parse_expression(
14    ctx: &mut ParseContext,
15    node: Node,
16) -> Result<Option<Expression>, FatalParseError> {
17    match node.kind() {
18        "atom" => parse_atom(ctx, &node),
19        "bool_expr" => parse_boolean_expression(ctx, &node),
20        "arithmetic_expr" => parse_arithmetic_expression(ctx, &node),
21        "comparison_expr" => parse_comparison_expression(ctx, &node),
22        "dominance_relation" => parse_dominance_relation(ctx, &node),
23        "all_diff_comparison" => parse_all_diff_comparison(ctx, &node),
24        _ => Err(FatalParseError::internal_error(
25            format!("Unexpected expression type: '{}'", node.kind()),
26            Some(node.range()),
27        )),
28    }
29}
30
31fn parse_dominance_relation(
32    ctx: &mut ParseContext,
33    node: &Node,
34) -> Result<Option<Expression>, FatalParseError> {
35    if ctx.root.kind() == "dominance_relation" {
36        return Err(FatalParseError::internal_error(
37            "Nested dominance relations are not allowed".to_string(),
38            Some(node.range()),
39        ));
40    }
41
42    // NB: In all other cases, we keep the root the same;
43    // However, here we create a new context with the new root so downstream functions
44    // know we are inside a dominance relation
45    let mut inner_ctx = ParseContext {
46        source_code: ctx.source_code,
47        root: node,
48        symbols: ctx.symbols.clone(),
49        errors: ctx.errors,
50        source_map: &mut *ctx.source_map,
51        typechecking_context: ctx.typechecking_context,
52    };
53
54    let Some(inner) = parse_expression(&mut inner_ctx, field!(node, "expression"))? else {
55        return Ok(None);
56    };
57
58    Ok(Some(Expression::DominanceRelation(
59        Metadata::new(),
60        Moo::new(inner),
61    )))
62}
63
64fn parse_arithmetic_expression(
65    ctx: &mut ParseContext,
66    node: &Node,
67) -> Result<Option<Expression>, FatalParseError> {
68    ctx.typechecking_context = TypecheckingContext::Arithmetic;
69    let inner = named_child!(node);
70    match inner.kind() {
71        "atom" => parse_atom(ctx, &inner),
72        "negative_expr" | "abs_value" | "sub_arith_expr" => parse_unary_expression(ctx, &inner),
73        "toInt_expr" => {
74            // add special handling for toInt, as it is arithmetic but takes a non-arithmetic operand
75            ctx.typechecking_context = TypecheckingContext::Unknown;
76            parse_unary_expression(ctx, &inner)
77        }
78        "exponent" | "product_expr" | "sum_expr" => parse_binary_expression(ctx, &inner),
79        "list_combining_expr_arith" => parse_list_combining_expression(ctx, &inner),
80        "aggregate_expr" => parse_quantifier_or_aggregate_expr(ctx, &inner),
81        _ => Err(FatalParseError::internal_error(
82            format!("Expected arithmetic expression, found: {}", inner.kind()),
83            Some(inner.range()),
84        )),
85    }
86}
87
88fn parse_comparison_expression(
89    ctx: &mut ParseContext,
90    node: &Node,
91) -> Result<Option<Expression>, FatalParseError> {
92    let inner = named_child!(node);
93    match inner.kind() {
94        "arithmetic_comparison" => {
95            // Arithmetic comparisons require arithmetic operands
96            ctx.typechecking_context = TypecheckingContext::Arithmetic;
97            parse_binary_expression(ctx, &inner)
98        }
99        "lex_comparison" => {
100            // TODO: check that both operands are comparable collections.
101            ctx.typechecking_context = TypecheckingContext::Unknown;
102            parse_binary_expression(ctx, &inner)
103        }
104        "equality_comparison" => {
105            // Equality works on any type
106            // TODO: add type checking to ensure both sides have the same type
107            ctx.typechecking_context = TypecheckingContext::Unknown;
108            parse_binary_expression(ctx, &inner)
109        }
110        "set_comparison" => {
111            // Set comparisons require set operands (no specific type checking for now)
112            // TODO: add typechecking for sets
113            ctx.typechecking_context = TypecheckingContext::Unknown;
114            parse_binary_expression(ctx, &inner)
115        }
116        "all_diff_comparison" => {
117            // TODO: check that operand is a collection with compatible element type.
118            ctx.typechecking_context = TypecheckingContext::Unknown;
119            parse_all_diff_comparison(ctx, &inner)
120        }
121        _ => Err(FatalParseError::internal_error(
122            format!("Expected comparison expression, found '{}'", inner.kind()),
123            Some(inner.range()),
124        )),
125    }
126}
127
128fn parse_boolean_expression(
129    ctx: &mut ParseContext,
130    node: &Node,
131) -> Result<Option<Expression>, FatalParseError> {
132    ctx.typechecking_context = TypecheckingContext::Boolean;
133    let inner = named_child!(node);
134    match inner.kind() {
135        "atom" => parse_atom(ctx, &inner),
136        "not_expr" | "sub_bool_expr" => parse_unary_expression(ctx, &inner),
137        "and_expr" | "or_expr" | "implication" | "iff_expr" => parse_binary_expression(ctx, &inner),
138        "list_combining_expr_bool" => parse_list_combining_expression(ctx, &inner),
139        "quantifier_expr" => parse_quantifier_or_aggregate_expr(ctx, &inner),
140        _ => Err(FatalParseError::internal_error(
141            format!("Expected boolean expression, found '{}'", inner.kind()),
142            Some(inner.range()),
143        )),
144    }
145}
146
147fn parse_list_combining_expression(
148    ctx: &mut ParseContext,
149    node: &Node,
150) -> Result<Option<Expression>, FatalParseError> {
151    let operator_node = field!(node, "operator");
152    let operator_str = &ctx.source_code[operator_node.start_byte()..operator_node.end_byte()];
153
154    let Some(inner) = parse_atom(ctx, &field!(node, "arg"))? else {
155        return Ok(None);
156    };
157
158    match operator_str {
159        "and" => Ok(Some(Expression::And(Metadata::new(), Moo::new(inner)))),
160        "or" => Ok(Some(Expression::Or(Metadata::new(), Moo::new(inner)))),
161        "sum" => Ok(Some(Expression::Sum(Metadata::new(), Moo::new(inner)))),
162        "product" => Ok(Some(Expression::Product(Metadata::new(), Moo::new(inner)))),
163        "min" => Ok(Some(Expression::Min(Metadata::new(), Moo::new(inner)))),
164        "max" => Ok(Some(Expression::Max(Metadata::new(), Moo::new(inner)))),
165        _ => Err(FatalParseError::internal_error(
166            format!("Invalid operator: '{operator_str}'"),
167            Some(operator_node.range()),
168        )),
169    }
170}
171
172fn parse_all_diff_comparison(
173    ctx: &mut ParseContext,
174    node: &Node,
175) -> Result<Option<Expression>, FatalParseError> {
176    let Some(inner) = parse_expression(ctx, field!(node, "arg"))? else {
177        return Ok(None);
178    };
179
180    Ok(Some(Expression::AllDiff(Metadata::new(), Moo::new(inner))))
181}
182
183fn parse_unary_expression(
184    ctx: &mut ParseContext,
185    node: &Node,
186) -> Result<Option<Expression>, FatalParseError> {
187    let Some(inner) = parse_expression(ctx, field!(node, "expression"))? else {
188        return Ok(None);
189    };
190    match node.kind() {
191        "negative_expr" => Ok(Some(Expression::Neg(Metadata::new(), Moo::new(inner)))),
192        "abs_value" => Ok(Some(Expression::Abs(Metadata::new(), Moo::new(inner)))),
193        "not_expr" => Ok(Some(Expression::Not(Metadata::new(), Moo::new(inner)))),
194        "toInt_expr" => Ok(Some(Expression::ToInt(Metadata::new(), Moo::new(inner)))),
195        "sub_bool_expr" | "sub_arith_expr" => Ok(Some(inner)),
196        _ => Err(FatalParseError::internal_error(
197            format!("Unrecognised unary operation: '{}'", node.kind()),
198            Some(node.range()),
199        )),
200    }
201}
202
203pub fn parse_binary_expression(
204    ctx: &mut ParseContext,
205    node: &Node,
206) -> Result<Option<Expression>, FatalParseError> {
207    let mut parse_subexpr = |expr: Node| parse_expression(ctx, expr);
208
209    let Some(left) = parse_subexpr(field!(node, "left"))? else {
210        return Ok(None);
211    };
212    let Some(right) = parse_subexpr(field!(node, "right"))? else {
213        return Ok(None);
214    };
215
216    let op_node = field!(node, "operator");
217    let op_str = &ctx.source_code[op_node.start_byte()..op_node.end_byte()];
218
219    let mut description = format!("Operator '{op_str}'");
220    let expr = match op_str {
221        // NB: We are deliberately setting the index domain to 1.., not 1..2.
222        // Semantically, this means "a list that can grow/shrink arbitrarily".
223        // This is expected by rules which will modify the terms of the sum expression
224        // (e.g. by partially evaluating them).
225        "+" => Ok(Some(Expression::Sum(
226            Metadata::new(),
227            Moo::new(matrix_expr![left, right; domain_int!(1..)]),
228        ))),
229        "-" => Ok(Some(Expression::Minus(
230            Metadata::new(),
231            Moo::new(left),
232            Moo::new(right),
233        ))),
234        "*" => Ok(Some(Expression::Product(
235            Metadata::new(),
236            Moo::new(matrix_expr![left, right; domain_int!(1..)]),
237        ))),
238        "/\\" => Ok(Some(Expression::And(
239            Metadata::new(),
240            Moo::new(matrix_expr![left, right; domain_int!(1..)]),
241        ))),
242        "\\/" => Ok(Some(Expression::Or(
243            Metadata::new(),
244            Moo::new(matrix_expr![left, right; domain_int!(1..)]),
245        ))),
246        "**" => Ok(Some(Expression::UnsafePow(
247            Metadata::new(),
248            Moo::new(left),
249            Moo::new(right),
250        ))),
251        "/" => {
252            //TODO: add checks for if division is safe or not
253            Ok(Some(Expression::UnsafeDiv(
254                Metadata::new(),
255                Moo::new(left),
256                Moo::new(right),
257            )))
258        }
259        "%" => {
260            //TODO: add checks for if mod is safe or not
261            Ok(Some(Expression::UnsafeMod(
262                Metadata::new(),
263                Moo::new(left),
264                Moo::new(right),
265            )))
266        }
267        "=" => Ok(Some(Expression::Eq(
268            Metadata::new(),
269            Moo::new(left),
270            Moo::new(right),
271        ))),
272        "!=" => Ok(Some(Expression::Neq(
273            Metadata::new(),
274            Moo::new(left),
275            Moo::new(right),
276        ))),
277        "<=" => Ok(Some(Expression::Leq(
278            Metadata::new(),
279            Moo::new(left),
280            Moo::new(right),
281        ))),
282        ">=" => Ok(Some(Expression::Geq(
283            Metadata::new(),
284            Moo::new(left),
285            Moo::new(right),
286        ))),
287        "<" => Ok(Some(Expression::Lt(
288            Metadata::new(),
289            Moo::new(left),
290            Moo::new(right),
291        ))),
292        ">" => Ok(Some(Expression::Gt(
293            Metadata::new(),
294            Moo::new(left),
295            Moo::new(right),
296        ))),
297        "->" => Ok(Some(Expression::Imply(
298            Metadata::new(),
299            Moo::new(left),
300            Moo::new(right),
301        ))),
302        "<->" => Ok(Some(Expression::Iff(
303            Metadata::new(),
304            Moo::new(left),
305            Moo::new(right),
306        ))),
307        "<lex" => Ok(Some(Expression::LexLt(
308            Metadata::new(),
309            Moo::new(left),
310            Moo::new(right),
311        ))),
312        ">lex" => Ok(Some(Expression::LexGt(
313            Metadata::new(),
314            Moo::new(left),
315            Moo::new(right),
316        ))),
317        "<=lex" => Ok(Some(Expression::LexLeq(
318            Metadata::new(),
319            Moo::new(left),
320            Moo::new(right),
321        ))),
322        ">=lex" => Ok(Some(Expression::LexGeq(
323            Metadata::new(),
324            Moo::new(left),
325            Moo::new(right),
326        ))),
327        "in" => Ok(Some(Expression::In(
328            Metadata::new(),
329            Moo::new(left),
330            Moo::new(right),
331        ))),
332        "subset" => Ok(Some(Expression::Subset(
333            Metadata::new(),
334            Moo::new(left),
335            Moo::new(right),
336        ))),
337        "subsetEq" => Ok(Some(Expression::SubsetEq(
338            Metadata::new(),
339            Moo::new(left),
340            Moo::new(right),
341        ))),
342        "supset" => Ok(Some(Expression::Supset(
343            Metadata::new(),
344            Moo::new(left),
345            Moo::new(right),
346        ))),
347        "supsetEq" => Ok(Some(Expression::SupsetEq(
348            Metadata::new(),
349            Moo::new(left),
350            Moo::new(right),
351        ))),
352        "union" => {
353            description = "set union: combines the elements from both operands".to_string();
354            Ok(Some(Expression::Union(
355                Metadata::new(),
356                Moo::new(left),
357                Moo::new(right),
358            )))
359        }
360        "intersect" => {
361            description =
362                "set intersection: keeps only elements common to both operands".to_string();
363            Ok(Some(Expression::Intersect(
364                Metadata::new(),
365                Moo::new(left),
366                Moo::new(right),
367            )))
368        }
369        _ => Err(FatalParseError::internal_error(
370            format!("Invalid operator: '{op_str}'"),
371            Some(op_node.range()),
372        )),
373    };
374
375    if expr.is_ok() {
376        let hover = HoverInfo {
377            description,
378            kind: Some(SymbolKind::Function),
379            ty: None,
380            decl_span: None,
381        };
382        span_with_hover(&op_node, ctx.source_code, ctx.source_map, hover);
383    }
384
385    expr
386}