conjure_cp_essence_parser/parser/
expression.rs

1use crate::errors::EssenceParseError;
2use crate::parser::atom::parse_atom;
3use crate::parser::comprehension::parse_quantifier_or_aggregate_expr;
4use crate::{field, named_child};
5use conjure_cp_core::ast::{Expression, Metadata, Moo, SymbolTable};
6use conjure_cp_core::{domain_int, matrix_expr, range};
7use std::cell::RefCell;
8use std::rc::Rc;
9use tree_sitter::Node;
10
11/// Parse an Essence expression into its Conjure AST representation.
12pub fn parse_expression(
13    node: Node,
14    source_code: &str,
15    root: &Node,
16    symbols_ptr: Option<Rc<RefCell<SymbolTable>>>,
17) -> Result<Expression, EssenceParseError> {
18    match node.kind() {
19        "atom" => parse_atom(&node, source_code, root, symbols_ptr),
20        "bool_expr" => parse_boolean_expression(&node, source_code, root, symbols_ptr),
21        "arithmetic_expr" => parse_arithmetic_expression(&node, source_code, root, symbols_ptr),
22        "comparison_expr" => parse_binary_expression(&node, source_code, root, symbols_ptr),
23        "dominance_relation" => parse_dominance_relation(&node, source_code, root, symbols_ptr),
24        "ERROR" => Err(EssenceParseError::syntax_error(
25            format!(
26                "'{}' is not a valid expression",
27                &source_code[node.start_byte()..node.end_byte()]
28            ),
29            Some(node.range()),
30        )),
31        _ => Err(EssenceParseError::syntax_error(
32            format!("Unknown expression kind: '{}'", node.kind()),
33            Some(node.range()),
34        )),
35    }
36}
37
38fn parse_dominance_relation(
39    node: &Node,
40    source_code: &str,
41    root: &Node,
42    symbols_ptr: Option<Rc<RefCell<SymbolTable>>>,
43) -> Result<Expression, EssenceParseError> {
44    if root.kind() == "dominance_relation" {
45        return Err(EssenceParseError::syntax_error(
46            "Nested dominance relations are not allowed".to_string(),
47            Some(node.range()),
48        ));
49    }
50
51    // NB: In all other cases, we keep the root the same;
52    // However, here we set the new root to `node` so downstream functions
53    // know we are inside a dominance relation
54    let inner = parse_expression(field!(node, "expression"), source_code, node, symbols_ptr)?;
55    Ok(Expression::DominanceRelation(
56        Metadata::new(),
57        Moo::new(inner),
58    ))
59}
60
61fn parse_arithmetic_expression(
62    node: &Node,
63    source_code: &str,
64    root: &Node,
65    symbols_ptr: Option<Rc<RefCell<SymbolTable>>>,
66) -> Result<Expression, EssenceParseError> {
67    let inner = named_child!(node);
68    match inner.kind() {
69        "atom" => parse_atom(&inner, source_code, root, symbols_ptr),
70        "negative_expr" | "abs_value" | "sub_arith_expr" | "toInt_expr" => {
71            parse_unary_expression(&inner, source_code, root, symbols_ptr)
72        }
73        "exponent" | "product_expr" | "sum_expr" => {
74            parse_binary_expression(&inner, source_code, root, symbols_ptr)
75        }
76        "list_combining_expr_arith" => {
77            parse_list_combining_expression(&inner, source_code, root, symbols_ptr)
78        }
79        "aggregate_expr" => {
80            parse_quantifier_or_aggregate_expr(&inner, source_code, root, symbols_ptr)
81        }
82        _ => Err(EssenceParseError::syntax_error(
83            format!("Expected arithmetic expression, found: {}", inner.kind()),
84            Some(inner.range()),
85        )),
86    }
87}
88
89fn parse_boolean_expression(
90    node: &Node,
91    source_code: &str,
92    root: &Node,
93    symbols_ptr: Option<Rc<RefCell<SymbolTable>>>,
94) -> Result<Expression, EssenceParseError> {
95    let inner = named_child!(node);
96    match inner.kind() {
97        "atom" => parse_atom(&inner, source_code, root, symbols_ptr),
98        "not_expr" | "sub_bool_expr" => {
99            parse_unary_expression(&inner, source_code, root, symbols_ptr)
100        }
101        "and_expr" | "or_expr" | "implication" | "iff_expr" | "set_operation_bool" => {
102            parse_binary_expression(&inner, source_code, root, symbols_ptr)
103        }
104        "list_combining_expr_bool" => {
105            parse_list_combining_expression(&inner, source_code, root, symbols_ptr)
106        }
107        "quantifier_expr" => {
108            parse_quantifier_or_aggregate_expr(&inner, source_code, root, symbols_ptr)
109        }
110        _ => Err(EssenceParseError::syntax_error(
111            format!("Expected boolean expression, found '{}'", inner.kind()),
112            Some(inner.range()),
113        )),
114    }
115}
116
117fn parse_list_combining_expression(
118    node: &Node,
119    source_code: &str,
120    root: &Node,
121    symbols_ptr: Option<Rc<RefCell<SymbolTable>>>,
122) -> Result<Expression, EssenceParseError> {
123    let operator_node = field!(node, "operator");
124    let operator_str = &source_code[operator_node.start_byte()..operator_node.end_byte()];
125
126    let inner = parse_atom(&field!(node, "arg"), source_code, root, symbols_ptr)?;
127
128    match operator_str {
129        "and" => Ok(Expression::And(Metadata::new(), Moo::new(inner))),
130        "or" => Ok(Expression::Or(Metadata::new(), Moo::new(inner))),
131        "sum" => Ok(Expression::Sum(Metadata::new(), Moo::new(inner))),
132        "product" => Ok(Expression::Product(Metadata::new(), Moo::new(inner))),
133        "min" => Ok(Expression::Min(Metadata::new(), Moo::new(inner))),
134        "max" => Ok(Expression::Max(Metadata::new(), Moo::new(inner))),
135        "allDiff" => Ok(Expression::AllDiff(Metadata::new(), Moo::new(inner))),
136        _ => Err(EssenceParseError::syntax_error(
137            format!("Invalid operator: '{operator_str}'"),
138            Some(operator_node.range()),
139        )),
140    }
141}
142
143fn parse_unary_expression(
144    node: &Node,
145    source_code: &str,
146    root: &Node,
147    symbols_ptr: Option<Rc<RefCell<SymbolTable>>>,
148) -> Result<Expression, EssenceParseError> {
149    let inner = parse_expression(field!(node, "expression"), source_code, root, symbols_ptr)?;
150    match node.kind() {
151        "negative_expr" => Ok(Expression::Neg(Metadata::new(), Moo::new(inner))),
152        "abs_value" => Ok(Expression::Abs(Metadata::new(), Moo::new(inner))),
153        "not_expr" => Ok(Expression::Not(Metadata::new(), Moo::new(inner))),
154        "toInt_expr" => Ok(Expression::ToInt(Metadata::new(), Moo::new(inner))),
155        "sub_bool_expr" | "sub_arith_expr" => Ok(inner),
156        _ => Err(EssenceParseError::syntax_error(
157            format!("Unrecognised unary operation: '{}'", node.kind()),
158            Some(node.range()),
159        )),
160    }
161}
162
163pub fn parse_binary_expression(
164    node: &Node,
165    source_code: &str,
166    root: &Node,
167    symbols_ptr: Option<Rc<RefCell<SymbolTable>>>,
168) -> Result<Expression, EssenceParseError> {
169    let parse_subexpr = |expr: Node| parse_expression(expr, source_code, root, symbols_ptr.clone());
170
171    let left = parse_subexpr(field!(node, "left"))?;
172    let right = parse_subexpr(field!(node, "right"))?;
173
174    let op_node = field!(node, "operator");
175    let op_str = &source_code[op_node.start_byte()..op_node.end_byte()];
176
177    match op_str {
178        // NB: We are deliberately setting the index domain to 1.., not 1..2.
179        // Semantically, this means "a list that can grow/shrink arbitrarily".
180        // This is expected by rules which will modify the terms of the sum expression
181        // (e.g. by partially evaluating them).
182        "+" => Ok(Expression::Sum(
183            Metadata::new(),
184            Moo::new(matrix_expr![left, right; domain_int!(1..)]),
185        )),
186        "-" => Ok(Expression::Minus(
187            Metadata::new(),
188            Moo::new(left),
189            Moo::new(right),
190        )),
191        "*" => Ok(Expression::Product(
192            Metadata::new(),
193            Moo::new(matrix_expr![left, right; domain_int!(1..)]),
194        )),
195        "/\\" => Ok(Expression::And(
196            Metadata::new(),
197            Moo::new(matrix_expr![left, right; domain_int!(1..)]),
198        )),
199        "\\/" => Ok(Expression::Or(
200            Metadata::new(),
201            Moo::new(matrix_expr![left, right; domain_int!(1..)]),
202        )),
203        "**" => Ok(Expression::UnsafePow(
204            Metadata::new(),
205            Moo::new(left),
206            Moo::new(right),
207        )),
208        "/" => {
209            //TODO: add checks for if division is safe or not
210            Ok(Expression::UnsafeDiv(
211                Metadata::new(),
212                Moo::new(left),
213                Moo::new(right),
214            ))
215        }
216        "%" => {
217            //TODO: add checks for if mod is safe or not
218            Ok(Expression::UnsafeMod(
219                Metadata::new(),
220                Moo::new(left),
221                Moo::new(right),
222            ))
223        }
224        "=" => Ok(Expression::Eq(
225            Metadata::new(),
226            Moo::new(left),
227            Moo::new(right),
228        )),
229        "!=" => Ok(Expression::Neq(
230            Metadata::new(),
231            Moo::new(left),
232            Moo::new(right),
233        )),
234        "<=" => Ok(Expression::Leq(
235            Metadata::new(),
236            Moo::new(left),
237            Moo::new(right),
238        )),
239        ">=" => Ok(Expression::Geq(
240            Metadata::new(),
241            Moo::new(left),
242            Moo::new(right),
243        )),
244        "<" => Ok(Expression::Lt(
245            Metadata::new(),
246            Moo::new(left),
247            Moo::new(right),
248        )),
249        ">" => Ok(Expression::Gt(
250            Metadata::new(),
251            Moo::new(left),
252            Moo::new(right),
253        )),
254        "->" => Ok(Expression::Imply(
255            Metadata::new(),
256            Moo::new(left),
257            Moo::new(right),
258        )),
259        "<->" => Ok(Expression::Iff(
260            Metadata::new(),
261            Moo::new(left),
262            Moo::new(right),
263        )),
264        "<lex" => Ok(Expression::LexLt(
265            Metadata::new(),
266            Moo::new(left),
267            Moo::new(right),
268        )),
269        ">lex" => Ok(Expression::LexGt(
270            Metadata::new(),
271            Moo::new(left),
272            Moo::new(right),
273        )),
274        "<=lex" => Ok(Expression::LexLeq(
275            Metadata::new(),
276            Moo::new(left),
277            Moo::new(right),
278        )),
279        ">=lex" => Ok(Expression::LexGeq(
280            Metadata::new(),
281            Moo::new(left),
282            Moo::new(right),
283        )),
284        "in" => Ok(Expression::In(
285            Metadata::new(),
286            Moo::new(left),
287            Moo::new(right),
288        )),
289        "subset" => Ok(Expression::Subset(
290            Metadata::new(),
291            Moo::new(left),
292            Moo::new(right),
293        )),
294        "subsetEq" => Ok(Expression::SubsetEq(
295            Metadata::new(),
296            Moo::new(left),
297            Moo::new(right),
298        )),
299        "supset" => Ok(Expression::Supset(
300            Metadata::new(),
301            Moo::new(left),
302            Moo::new(right),
303        )),
304        "supsetEq" => Ok(Expression::SupsetEq(
305            Metadata::new(),
306            Moo::new(left),
307            Moo::new(right),
308        )),
309        "union" => Ok(Expression::Union(
310            Metadata::new(),
311            Moo::new(left),
312            Moo::new(right),
313        )),
314        "intersect" => Ok(Expression::Intersect(
315            Metadata::new(),
316            Moo::new(left),
317            Moo::new(right),
318        )),
319        _ => Err(EssenceParseError::syntax_error(
320            format!("Invalid operator: '{op_str}'"),
321            Some(op_node.range()),
322        )),
323    }
324}