Skip to main content

conjure_cp_essence_parser/parser/
expression.rs

1use crate::errors::{FatalParseError, RecoverableParseError};
2use crate::parser::atom::parse_atom;
3use crate::parser::comprehension::parse_quantifier_or_aggregate_expr;
4use crate::{field, named_child};
5use conjure_cp_core::ast::{Expression, Metadata, Moo, SymbolTablePtr};
6use conjure_cp_core::{domain_int, matrix_expr, range};
7use tree_sitter::Node;
8
9/// Parse an Essence expression into its Conjure AST representation.
10pub fn parse_expression(
11    node: Node,
12    source_code: &str,
13    root: &Node,
14    symbols_ptr: Option<SymbolTablePtr>,
15    errors: &mut Vec<RecoverableParseError>,
16) -> Result<Expression, FatalParseError> {
17    match node.kind() {
18        "atom" => parse_atom(&node, source_code, root, symbols_ptr, errors),
19        "bool_expr" => parse_boolean_expression(&node, source_code, root, symbols_ptr, errors),
20        "arithmetic_expr" => {
21            parse_arithmetic_expression(&node, source_code, root, symbols_ptr, errors)
22        }
23        "comparison_expr" => parse_binary_expression(&node, source_code, root, symbols_ptr, errors),
24        "dominance_relation" => {
25            parse_dominance_relation(&node, source_code, root, symbols_ptr, errors)
26        }
27        "ERROR" => {
28            errors.push(RecoverableParseError::new(
29                format!(
30                    "'{}' is not a valid expression",
31                    &source_code[node.start_byte()..node.end_byte()]
32                ),
33                Some(node.range()),
34            ));
35            // Return a placeholder - actual error is in the errors vector
36            // TODO: figure out how to return when recoverable error is found
37            Ok(Expression::Atomic(
38                Metadata::new(),
39                conjure_cp_core::ast::Atom::Literal(conjure_cp_core::ast::Literal::Bool(false)),
40            ))
41        }
42        _ => {
43            errors.push(RecoverableParseError::new(
44                format!("Unknown expression kind: '{}'", node.kind()),
45                Some(node.range()),
46            ));
47            // Return a placeholder
48            Ok(Expression::Atomic(
49                Metadata::new(),
50                conjure_cp_core::ast::Atom::Literal(conjure_cp_core::ast::Literal::Bool(false)),
51            ))
52        }
53    }
54}
55
56fn parse_dominance_relation(
57    node: &Node,
58    source_code: &str,
59    root: &Node,
60    symbols_ptr: Option<SymbolTablePtr>,
61    errors: &mut Vec<RecoverableParseError>,
62) -> Result<Expression, FatalParseError> {
63    if root.kind() == "dominance_relation" {
64        return Err(FatalParseError::syntax_error(
65            "Nested dominance relations are not allowed".to_string(),
66            Some(node.range()),
67        ));
68    }
69
70    // NB: In all other cases, we keep the root the same;
71    // However, here we set the new root to `node` so downstream functions
72    // know we are inside a dominance relation
73    let inner = parse_expression(
74        field!(node, "expression"),
75        source_code,
76        node,
77        symbols_ptr,
78        errors,
79    )?;
80    Ok(Expression::DominanceRelation(
81        Metadata::new(),
82        Moo::new(inner),
83    ))
84}
85
86fn parse_arithmetic_expression(
87    node: &Node,
88    source_code: &str,
89    root: &Node,
90    symbols_ptr: Option<SymbolTablePtr>,
91    errors: &mut Vec<RecoverableParseError>,
92) -> Result<Expression, FatalParseError> {
93    let inner = named_child!(node);
94    match inner.kind() {
95        "atom" => parse_atom(&inner, source_code, root, symbols_ptr, errors),
96        "negative_expr" | "abs_value" | "sub_arith_expr" | "toInt_expr" => {
97            parse_unary_expression(&inner, source_code, root, symbols_ptr, errors)
98        }
99        "exponent" | "product_expr" | "sum_expr" => {
100            parse_binary_expression(&inner, source_code, root, symbols_ptr, errors)
101        }
102        "list_combining_expr_arith" => {
103            parse_list_combining_expression(&inner, source_code, root, symbols_ptr, errors)
104        }
105        "aggregate_expr" => {
106            parse_quantifier_or_aggregate_expr(&inner, source_code, root, symbols_ptr, errors)
107        }
108        _ => Err(FatalParseError::syntax_error(
109            format!("Expected arithmetic expression, found: {}", inner.kind()),
110            Some(inner.range()),
111        )),
112    }
113}
114
115fn parse_boolean_expression(
116    node: &Node,
117    source_code: &str,
118    root: &Node,
119    symbols_ptr: Option<SymbolTablePtr>,
120    errors: &mut Vec<RecoverableParseError>,
121) -> Result<Expression, FatalParseError> {
122    let inner = named_child!(node);
123    match inner.kind() {
124        "atom" => parse_atom(&inner, source_code, root, symbols_ptr, errors),
125        "not_expr" | "sub_bool_expr" => {
126            parse_unary_expression(&inner, source_code, root, symbols_ptr, errors)
127        }
128        "and_expr" | "or_expr" | "implication" | "iff_expr" | "set_operation_bool" => {
129            parse_binary_expression(&inner, source_code, root, symbols_ptr, errors)
130        }
131        "list_combining_expr_bool" => {
132            parse_list_combining_expression(&inner, source_code, root, symbols_ptr, errors)
133        }
134        "quantifier_expr" => {
135            parse_quantifier_or_aggregate_expr(&inner, source_code, root, symbols_ptr, errors)
136        }
137        _ => Err(FatalParseError::syntax_error(
138            format!("Expected boolean expression, found '{}'", inner.kind()),
139            Some(inner.range()),
140        )),
141    }
142}
143
144fn parse_list_combining_expression(
145    node: &Node,
146    source_code: &str,
147    root: &Node,
148    symbols_ptr: Option<SymbolTablePtr>,
149    errors: &mut Vec<RecoverableParseError>,
150) -> Result<Expression, FatalParseError> {
151    let operator_node = field!(node, "operator");
152    let operator_str = &source_code[operator_node.start_byte()..operator_node.end_byte()];
153
154    let inner = parse_atom(&field!(node, "arg"), source_code, root, symbols_ptr, errors)?;
155
156    match operator_str {
157        "and" => Ok(Expression::And(Metadata::new(), Moo::new(inner))),
158        "or" => Ok(Expression::Or(Metadata::new(), Moo::new(inner))),
159        "sum" => Ok(Expression::Sum(Metadata::new(), Moo::new(inner))),
160        "product" => Ok(Expression::Product(Metadata::new(), Moo::new(inner))),
161        "min" => Ok(Expression::Min(Metadata::new(), Moo::new(inner))),
162        "max" => Ok(Expression::Max(Metadata::new(), Moo::new(inner))),
163        "allDiff" => Ok(Expression::AllDiff(Metadata::new(), Moo::new(inner))),
164        _ => Err(FatalParseError::syntax_error(
165            format!("Invalid operator: '{operator_str}'"),
166            Some(operator_node.range()),
167        )),
168    }
169}
170
171fn parse_unary_expression(
172    node: &Node,
173    source_code: &str,
174    root: &Node,
175    symbols_ptr: Option<SymbolTablePtr>,
176    errors: &mut Vec<RecoverableParseError>,
177) -> Result<Expression, FatalParseError> {
178    let inner = parse_expression(
179        field!(node, "expression"),
180        source_code,
181        root,
182        symbols_ptr,
183        errors,
184    )?;
185    match node.kind() {
186        "negative_expr" => Ok(Expression::Neg(Metadata::new(), Moo::new(inner))),
187        "abs_value" => Ok(Expression::Abs(Metadata::new(), Moo::new(inner))),
188        "not_expr" => Ok(Expression::Not(Metadata::new(), Moo::new(inner))),
189        "toInt_expr" => Ok(Expression::ToInt(Metadata::new(), Moo::new(inner))),
190        "sub_bool_expr" | "sub_arith_expr" => Ok(inner),
191        _ => Err(FatalParseError::syntax_error(
192            format!("Unrecognised unary operation: '{}'", node.kind()),
193            Some(node.range()),
194        )),
195    }
196}
197
198pub fn parse_binary_expression(
199    node: &Node,
200    source_code: &str,
201    root: &Node,
202    symbols_ptr: Option<SymbolTablePtr>,
203    errors: &mut Vec<RecoverableParseError>,
204) -> Result<Expression, FatalParseError> {
205    let mut parse_subexpr =
206        |expr: Node| parse_expression(expr, source_code, root, symbols_ptr.clone(), errors);
207
208    let left = parse_subexpr(field!(node, "left"))?;
209    let right = parse_subexpr(field!(node, "right"))?;
210
211    let op_node = field!(node, "operator");
212    let op_str = &source_code[op_node.start_byte()..op_node.end_byte()];
213
214    match op_str {
215        // NB: We are deliberately setting the index domain to 1.., not 1..2.
216        // Semantically, this means "a list that can grow/shrink arbitrarily".
217        // This is expected by rules which will modify the terms of the sum expression
218        // (e.g. by partially evaluating them).
219        "+" => Ok(Expression::Sum(
220            Metadata::new(),
221            Moo::new(matrix_expr![left, right; domain_int!(1..)]),
222        )),
223        "-" => Ok(Expression::Minus(
224            Metadata::new(),
225            Moo::new(left),
226            Moo::new(right),
227        )),
228        "*" => Ok(Expression::Product(
229            Metadata::new(),
230            Moo::new(matrix_expr![left, right; domain_int!(1..)]),
231        )),
232        "/\\" => Ok(Expression::And(
233            Metadata::new(),
234            Moo::new(matrix_expr![left, right; domain_int!(1..)]),
235        )),
236        "\\/" => Ok(Expression::Or(
237            Metadata::new(),
238            Moo::new(matrix_expr![left, right; domain_int!(1..)]),
239        )),
240        "**" => Ok(Expression::UnsafePow(
241            Metadata::new(),
242            Moo::new(left),
243            Moo::new(right),
244        )),
245        "/" => {
246            //TODO: add checks for if division is safe or not
247            Ok(Expression::UnsafeDiv(
248                Metadata::new(),
249                Moo::new(left),
250                Moo::new(right),
251            ))
252        }
253        "%" => {
254            //TODO: add checks for if mod is safe or not
255            Ok(Expression::UnsafeMod(
256                Metadata::new(),
257                Moo::new(left),
258                Moo::new(right),
259            ))
260        }
261        "=" => Ok(Expression::Eq(
262            Metadata::new(),
263            Moo::new(left),
264            Moo::new(right),
265        )),
266        "!=" => Ok(Expression::Neq(
267            Metadata::new(),
268            Moo::new(left),
269            Moo::new(right),
270        )),
271        "<=" => Ok(Expression::Leq(
272            Metadata::new(),
273            Moo::new(left),
274            Moo::new(right),
275        )),
276        ">=" => Ok(Expression::Geq(
277            Metadata::new(),
278            Moo::new(left),
279            Moo::new(right),
280        )),
281        "<" => Ok(Expression::Lt(
282            Metadata::new(),
283            Moo::new(left),
284            Moo::new(right),
285        )),
286        ">" => Ok(Expression::Gt(
287            Metadata::new(),
288            Moo::new(left),
289            Moo::new(right),
290        )),
291        "->" => Ok(Expression::Imply(
292            Metadata::new(),
293            Moo::new(left),
294            Moo::new(right),
295        )),
296        "<->" => Ok(Expression::Iff(
297            Metadata::new(),
298            Moo::new(left),
299            Moo::new(right),
300        )),
301        "<lex" => Ok(Expression::LexLt(
302            Metadata::new(),
303            Moo::new(left),
304            Moo::new(right),
305        )),
306        ">lex" => Ok(Expression::LexGt(
307            Metadata::new(),
308            Moo::new(left),
309            Moo::new(right),
310        )),
311        "<=lex" => Ok(Expression::LexLeq(
312            Metadata::new(),
313            Moo::new(left),
314            Moo::new(right),
315        )),
316        ">=lex" => Ok(Expression::LexGeq(
317            Metadata::new(),
318            Moo::new(left),
319            Moo::new(right),
320        )),
321        "in" => Ok(Expression::In(
322            Metadata::new(),
323            Moo::new(left),
324            Moo::new(right),
325        )),
326        "subset" => Ok(Expression::Subset(
327            Metadata::new(),
328            Moo::new(left),
329            Moo::new(right),
330        )),
331        "subsetEq" => Ok(Expression::SubsetEq(
332            Metadata::new(),
333            Moo::new(left),
334            Moo::new(right),
335        )),
336        "supset" => Ok(Expression::Supset(
337            Metadata::new(),
338            Moo::new(left),
339            Moo::new(right),
340        )),
341        "supsetEq" => Ok(Expression::SupsetEq(
342            Metadata::new(),
343            Moo::new(left),
344            Moo::new(right),
345        )),
346        "union" => Ok(Expression::Union(
347            Metadata::new(),
348            Moo::new(left),
349            Moo::new(right),
350        )),
351        "intersect" => Ok(Expression::Intersect(
352            Metadata::new(),
353            Moo::new(left),
354            Moo::new(right),
355        )),
356        _ => Err(FatalParseError::syntax_error(
357            format!("Invalid operator: '{op_str}'"),
358            Some(op_node.range()),
359        )),
360    }
361}