Lines
60.37 %
Functions
43.33 %
use crate::errors::{FatalParseError, RecoverableParseError};
use crate::parser::atom::parse_atom;
use crate::parser::comprehension::parse_quantifier_or_aggregate_expr;
use crate::{field, named_child};
use conjure_cp_core::ast::{Expression, Metadata, Moo, SymbolTablePtr};
use conjure_cp_core::{domain_int, matrix_expr, range};
use tree_sitter::Node;
/// Parse an Essence expression into its Conjure AST representation.
pub fn parse_expression(
node: Node,
source_code: &str,
root: &Node,
symbols_ptr: Option<SymbolTablePtr>,
errors: &mut Vec<RecoverableParseError>,
) -> Result<Expression, FatalParseError> {
match node.kind() {
"atom" => parse_atom(&node, source_code, root, symbols_ptr, errors),
"bool_expr" => parse_boolean_expression(&node, source_code, root, symbols_ptr, errors),
"arithmetic_expr" => {
parse_arithmetic_expression(&node, source_code, root, symbols_ptr, errors)
}
"comparison_expr" => parse_binary_expression(&node, source_code, root, symbols_ptr, errors),
"dominance_relation" => {
parse_dominance_relation(&node, source_code, root, symbols_ptr, errors)
"ERROR" => {
errors.push(RecoverableParseError::new(
format!(
"'{}' is not a valid expression",
&source_code[node.start_byte()..node.end_byte()]
),
Some(node.range()),
));
// Return a placeholder - actual error is in the errors vector
// TODO: figure out how to return when recoverable error is found
Ok(Expression::Atomic(
Metadata::new(),
conjure_cp_core::ast::Atom::Literal(conjure_cp_core::ast::Literal::Bool(false)),
))
_ => {
format!("Unknown expression kind: '{}'", node.kind()),
// Return a placeholder
fn parse_dominance_relation(
node: &Node,
if root.kind() == "dominance_relation" {
return Err(FatalParseError::syntax_error(
"Nested dominance relations are not allowed".to_string(),
// NB: In all other cases, we keep the root the same;
// However, here we set the new root to `node` so downstream functions
// know we are inside a dominance relation
let inner = parse_expression(
field!(node, "expression"),
source_code,
node,
symbols_ptr,
errors,
)?;
Ok(Expression::DominanceRelation(
Moo::new(inner),
fn parse_arithmetic_expression(
let inner = named_child!(node);
match inner.kind() {
"atom" => parse_atom(&inner, source_code, root, symbols_ptr, errors),
"negative_expr" | "abs_value" | "sub_arith_expr" | "toInt_expr" => {
parse_unary_expression(&inner, source_code, root, symbols_ptr, errors)
"exponent" | "product_expr" | "sum_expr" => {
parse_binary_expression(&inner, source_code, root, symbols_ptr, errors)
"list_combining_expr_arith" => {
parse_list_combining_expression(&inner, source_code, root, symbols_ptr, errors)
"aggregate_expr" => {
parse_quantifier_or_aggregate_expr(&inner, source_code, root, symbols_ptr, errors)
_ => Err(FatalParseError::syntax_error(
format!("Expected arithmetic expression, found: {}", inner.kind()),
Some(inner.range()),
)),
fn parse_boolean_expression(
"not_expr" | "sub_bool_expr" => {
"and_expr" | "or_expr" | "implication" | "iff_expr" | "set_operation_bool" => {
"list_combining_expr_bool" => {
"quantifier_expr" => {
format!("Expected boolean expression, found '{}'", inner.kind()),
fn parse_list_combining_expression(
let operator_node = field!(node, "operator");
let operator_str = &source_code[operator_node.start_byte()..operator_node.end_byte()];
let inner = parse_atom(&field!(node, "arg"), source_code, root, symbols_ptr, errors)?;
match operator_str {
"and" => Ok(Expression::And(Metadata::new(), Moo::new(inner))),
"or" => Ok(Expression::Or(Metadata::new(), Moo::new(inner))),
"sum" => Ok(Expression::Sum(Metadata::new(), Moo::new(inner))),
"product" => Ok(Expression::Product(Metadata::new(), Moo::new(inner))),
"min" => Ok(Expression::Min(Metadata::new(), Moo::new(inner))),
"max" => Ok(Expression::Max(Metadata::new(), Moo::new(inner))),
"allDiff" => Ok(Expression::AllDiff(Metadata::new(), Moo::new(inner))),
format!("Invalid operator: '{operator_str}'"),
Some(operator_node.range()),
fn parse_unary_expression(
root,
"negative_expr" => Ok(Expression::Neg(Metadata::new(), Moo::new(inner))),
"abs_value" => Ok(Expression::Abs(Metadata::new(), Moo::new(inner))),
"not_expr" => Ok(Expression::Not(Metadata::new(), Moo::new(inner))),
"toInt_expr" => Ok(Expression::ToInt(Metadata::new(), Moo::new(inner))),
"sub_bool_expr" | "sub_arith_expr" => Ok(inner),
format!("Unrecognised unary operation: '{}'", node.kind()),
pub fn parse_binary_expression(
let mut parse_subexpr =
|expr: Node| parse_expression(expr, source_code, root, symbols_ptr.clone(), errors);
let left = parse_subexpr(field!(node, "left"))?;
let right = parse_subexpr(field!(node, "right"))?;
let op_node = field!(node, "operator");
let op_str = &source_code[op_node.start_byte()..op_node.end_byte()];
match op_str {
// NB: We are deliberately setting the index domain to 1.., not 1..2.
// Semantically, this means "a list that can grow/shrink arbitrarily".
// This is expected by rules which will modify the terms of the sum expression
// (e.g. by partially evaluating them).
"+" => Ok(Expression::Sum(
Moo::new(matrix_expr![left, right; domain_int!(1..)]),
"-" => Ok(Expression::Minus(
Moo::new(left),
Moo::new(right),
"*" => Ok(Expression::Product(
"/\\" => Ok(Expression::And(
"\\/" => Ok(Expression::Or(
"**" => Ok(Expression::UnsafePow(
"/" => {
//TODO: add checks for if division is safe or not
Ok(Expression::UnsafeDiv(
"%" => {
//TODO: add checks for if mod is safe or not
Ok(Expression::UnsafeMod(
"=" => Ok(Expression::Eq(
"!=" => Ok(Expression::Neq(
"<=" => Ok(Expression::Leq(
">=" => Ok(Expression::Geq(
"<" => Ok(Expression::Lt(
">" => Ok(Expression::Gt(
"->" => Ok(Expression::Imply(
"<->" => Ok(Expression::Iff(
"<lex" => Ok(Expression::LexLt(
">lex" => Ok(Expression::LexGt(
"<=lex" => Ok(Expression::LexLeq(
">=lex" => Ok(Expression::LexGeq(
"in" => Ok(Expression::In(
"subset" => Ok(Expression::Subset(
"subsetEq" => Ok(Expression::SubsetEq(
"supset" => Ok(Expression::Supset(
"supsetEq" => Ok(Expression::SupsetEq(
"union" => Ok(Expression::Union(
"intersect" => Ok(Expression::Intersect(
format!("Invalid operator: '{op_str}'"),
Some(op_node.range()),