1use crate::errors::EssenceParseError;
2use crate::parser::atom::parse_atom;
3use crate::parser::comprehension::parse_quantifier_or_aggregate_expr;
4use crate::{field, named_child};
5use conjure_cp_core::ast::{Expression, Metadata, Moo, SymbolTable};
6use conjure_cp_core::{domain_int, matrix_expr, range};
7use std::cell::RefCell;
8use std::rc::Rc;
9use tree_sitter::Node;
10
11pub fn parse_expression(
13 node: Node,
14 source_code: &str,
15 root: &Node,
16 symbols_ptr: Option<Rc<RefCell<SymbolTable>>>,
17) -> Result<Expression, EssenceParseError> {
18 match node.kind() {
19 "atom" => parse_atom(&node, source_code, root, symbols_ptr),
20 "bool_expr" => parse_boolean_expression(&node, source_code, root, symbols_ptr),
21 "arithmetic_expr" => parse_arithmetic_expression(&node, source_code, root, symbols_ptr),
22 "comparison_expr" => parse_binary_expression(&node, source_code, root, symbols_ptr),
23 "dominance_relation" => parse_dominance_relation(&node, source_code, root, symbols_ptr),
24 "ERROR" => Err(EssenceParseError::syntax_error(
25 format!(
26 "'{}' is not a valid expression",
27 &source_code[node.start_byte()..node.end_byte()]
28 ),
29 Some(node.range()),
30 )),
31 _ => Err(EssenceParseError::syntax_error(
32 format!("Unknown expression kind: '{}'", node.kind()),
33 Some(node.range()),
34 )),
35 }
36}
37
38fn parse_dominance_relation(
39 node: &Node,
40 source_code: &str,
41 root: &Node,
42 symbols_ptr: Option<Rc<RefCell<SymbolTable>>>,
43) -> Result<Expression, EssenceParseError> {
44 if root.kind() == "dominance_relation" {
45 return Err(EssenceParseError::syntax_error(
46 "Nested dominance relations are not allowed".to_string(),
47 Some(node.range()),
48 ));
49 }
50
51 let inner = parse_expression(field!(node, "expression"), source_code, node, symbols_ptr)?;
55 Ok(Expression::DominanceRelation(
56 Metadata::new(),
57 Moo::new(inner),
58 ))
59}
60
61fn parse_arithmetic_expression(
62 node: &Node,
63 source_code: &str,
64 root: &Node,
65 symbols_ptr: Option<Rc<RefCell<SymbolTable>>>,
66) -> Result<Expression, EssenceParseError> {
67 let inner = named_child!(node);
68 match inner.kind() {
69 "atom" => parse_atom(&inner, source_code, root, symbols_ptr),
70 "negative_expr" | "abs_value" | "sub_arith_expr" | "toInt_expr" => {
71 parse_unary_expression(&inner, source_code, root, symbols_ptr)
72 }
73 "exponent" | "product_expr" | "sum_expr" => {
74 parse_binary_expression(&inner, source_code, root, symbols_ptr)
75 }
76 "list_combining_expr_arith" => {
77 parse_list_combining_expression(&inner, source_code, root, symbols_ptr)
78 }
79 "aggregate_expr" => {
80 parse_quantifier_or_aggregate_expr(&inner, source_code, root, symbols_ptr)
81 }
82 _ => Err(EssenceParseError::syntax_error(
83 format!("Expected arithmetic expression, found: {}", inner.kind()),
84 Some(inner.range()),
85 )),
86 }
87}
88
89fn parse_boolean_expression(
90 node: &Node,
91 source_code: &str,
92 root: &Node,
93 symbols_ptr: Option<Rc<RefCell<SymbolTable>>>,
94) -> Result<Expression, EssenceParseError> {
95 let inner = named_child!(node);
96 match inner.kind() {
97 "atom" => parse_atom(&inner, source_code, root, symbols_ptr),
98 "not_expr" | "sub_bool_expr" => {
99 parse_unary_expression(&inner, source_code, root, symbols_ptr)
100 }
101 "and_expr" | "or_expr" | "implication" | "iff_expr" | "set_operation_bool" => {
102 parse_binary_expression(&inner, source_code, root, symbols_ptr)
103 }
104 "list_combining_expr_bool" => {
105 parse_list_combining_expression(&inner, source_code, root, symbols_ptr)
106 }
107 "quantifier_expr" => {
108 parse_quantifier_or_aggregate_expr(&inner, source_code, root, symbols_ptr)
109 }
110 _ => Err(EssenceParseError::syntax_error(
111 format!("Expected boolean expression, found '{}'", inner.kind()),
112 Some(inner.range()),
113 )),
114 }
115}
116
117fn parse_list_combining_expression(
118 node: &Node,
119 source_code: &str,
120 root: &Node,
121 symbols_ptr: Option<Rc<RefCell<SymbolTable>>>,
122) -> Result<Expression, EssenceParseError> {
123 let operator_node = field!(node, "operator");
124 let operator_str = &source_code[operator_node.start_byte()..operator_node.end_byte()];
125
126 let inner = parse_atom(&field!(node, "arg"), source_code, root, symbols_ptr)?;
127
128 match operator_str {
129 "and" => Ok(Expression::And(Metadata::new(), Moo::new(inner))),
130 "or" => Ok(Expression::Or(Metadata::new(), Moo::new(inner))),
131 "sum" => Ok(Expression::Sum(Metadata::new(), Moo::new(inner))),
132 "product" => Ok(Expression::Product(Metadata::new(), Moo::new(inner))),
133 "min" => Ok(Expression::Min(Metadata::new(), Moo::new(inner))),
134 "max" => Ok(Expression::Max(Metadata::new(), Moo::new(inner))),
135 "allDiff" => Ok(Expression::AllDiff(Metadata::new(), Moo::new(inner))),
136 _ => Err(EssenceParseError::syntax_error(
137 format!("Invalid operator: '{operator_str}'"),
138 Some(operator_node.range()),
139 )),
140 }
141}
142
143fn parse_unary_expression(
144 node: &Node,
145 source_code: &str,
146 root: &Node,
147 symbols_ptr: Option<Rc<RefCell<SymbolTable>>>,
148) -> Result<Expression, EssenceParseError> {
149 let inner = parse_expression(field!(node, "expression"), source_code, root, symbols_ptr)?;
150 match node.kind() {
151 "negative_expr" => Ok(Expression::Neg(Metadata::new(), Moo::new(inner))),
152 "abs_value" => Ok(Expression::Abs(Metadata::new(), Moo::new(inner))),
153 "not_expr" => Ok(Expression::Not(Metadata::new(), Moo::new(inner))),
154 "toInt_expr" => Ok(Expression::ToInt(Metadata::new(), Moo::new(inner))),
155 "sub_bool_expr" | "sub_arith_expr" => Ok(inner),
156 _ => Err(EssenceParseError::syntax_error(
157 format!("Unrecognised unary operation: '{}'", node.kind()),
158 Some(node.range()),
159 )),
160 }
161}
162
163pub fn parse_binary_expression(
164 node: &Node,
165 source_code: &str,
166 root: &Node,
167 symbols_ptr: Option<Rc<RefCell<SymbolTable>>>,
168) -> Result<Expression, EssenceParseError> {
169 let parse_subexpr = |expr: Node| parse_expression(expr, source_code, root, symbols_ptr.clone());
170
171 let left = parse_subexpr(field!(node, "left"))?;
172 let right = parse_subexpr(field!(node, "right"))?;
173
174 let op_node = field!(node, "operator");
175 let op_str = &source_code[op_node.start_byte()..op_node.end_byte()];
176
177 match op_str {
178 "+" => Ok(Expression::Sum(
183 Metadata::new(),
184 Moo::new(matrix_expr![left, right; domain_int!(1..)]),
185 )),
186 "-" => Ok(Expression::Minus(
187 Metadata::new(),
188 Moo::new(left),
189 Moo::new(right),
190 )),
191 "*" => Ok(Expression::Product(
192 Metadata::new(),
193 Moo::new(matrix_expr![left, right; domain_int!(1..)]),
194 )),
195 "/\\" => Ok(Expression::And(
196 Metadata::new(),
197 Moo::new(matrix_expr![left, right; domain_int!(1..)]),
198 )),
199 "\\/" => Ok(Expression::Or(
200 Metadata::new(),
201 Moo::new(matrix_expr![left, right; domain_int!(1..)]),
202 )),
203 "**" => Ok(Expression::UnsafePow(
204 Metadata::new(),
205 Moo::new(left),
206 Moo::new(right),
207 )),
208 "/" => {
209 Ok(Expression::UnsafeDiv(
211 Metadata::new(),
212 Moo::new(left),
213 Moo::new(right),
214 ))
215 }
216 "%" => {
217 Ok(Expression::UnsafeMod(
219 Metadata::new(),
220 Moo::new(left),
221 Moo::new(right),
222 ))
223 }
224 "=" => Ok(Expression::Eq(
225 Metadata::new(),
226 Moo::new(left),
227 Moo::new(right),
228 )),
229 "!=" => Ok(Expression::Neq(
230 Metadata::new(),
231 Moo::new(left),
232 Moo::new(right),
233 )),
234 "<=" => Ok(Expression::Leq(
235 Metadata::new(),
236 Moo::new(left),
237 Moo::new(right),
238 )),
239 ">=" => Ok(Expression::Geq(
240 Metadata::new(),
241 Moo::new(left),
242 Moo::new(right),
243 )),
244 "<" => Ok(Expression::Lt(
245 Metadata::new(),
246 Moo::new(left),
247 Moo::new(right),
248 )),
249 ">" => Ok(Expression::Gt(
250 Metadata::new(),
251 Moo::new(left),
252 Moo::new(right),
253 )),
254 "->" => Ok(Expression::Imply(
255 Metadata::new(),
256 Moo::new(left),
257 Moo::new(right),
258 )),
259 "<->" => Ok(Expression::Iff(
260 Metadata::new(),
261 Moo::new(left),
262 Moo::new(right),
263 )),
264 "<lex" => Ok(Expression::LexLt(
265 Metadata::new(),
266 Moo::new(left),
267 Moo::new(right),
268 )),
269 ">lex" => Ok(Expression::LexGt(
270 Metadata::new(),
271 Moo::new(left),
272 Moo::new(right),
273 )),
274 "<=lex" => Ok(Expression::LexLeq(
275 Metadata::new(),
276 Moo::new(left),
277 Moo::new(right),
278 )),
279 ">=lex" => Ok(Expression::LexGeq(
280 Metadata::new(),
281 Moo::new(left),
282 Moo::new(right),
283 )),
284 "in" => Ok(Expression::In(
285 Metadata::new(),
286 Moo::new(left),
287 Moo::new(right),
288 )),
289 "subset" => Ok(Expression::Subset(
290 Metadata::new(),
291 Moo::new(left),
292 Moo::new(right),
293 )),
294 "subsetEq" => Ok(Expression::SubsetEq(
295 Metadata::new(),
296 Moo::new(left),
297 Moo::new(right),
298 )),
299 "supset" => Ok(Expression::Supset(
300 Metadata::new(),
301 Moo::new(left),
302 Moo::new(right),
303 )),
304 "supsetEq" => Ok(Expression::SupsetEq(
305 Metadata::new(),
306 Moo::new(left),
307 Moo::new(right),
308 )),
309 "union" => Ok(Expression::Union(
310 Metadata::new(),
311 Moo::new(left),
312 Moo::new(right),
313 )),
314 "intersect" => Ok(Expression::Intersect(
315 Metadata::new(),
316 Moo::new(left),
317 Moo::new(right),
318 )),
319 _ => Err(EssenceParseError::syntax_error(
320 format!("Invalid operator: '{op_str}'"),
321 Some(op_node.range()),
322 )),
323 }
324}