1use crate::errors::{FatalParseError, RecoverableParseError};
2use crate::parser::atom::parse_atom;
3use crate::parser::comprehension::parse_quantifier_or_aggregate_expr;
4use crate::{field, named_child};
5use conjure_cp_core::ast::{Expression, Metadata, Moo, SymbolTablePtr};
6use conjure_cp_core::{domain_int, matrix_expr, range};
7use tree_sitter::Node;
8
9pub fn parse_expression(
11 node: Node,
12 source_code: &str,
13 root: &Node,
14 symbols_ptr: Option<SymbolTablePtr>,
15 errors: &mut Vec<RecoverableParseError>,
16) -> Result<Expression, FatalParseError> {
17 match node.kind() {
18 "atom" => parse_atom(&node, source_code, root, symbols_ptr, errors),
19 "bool_expr" => parse_boolean_expression(&node, source_code, root, symbols_ptr, errors),
20 "arithmetic_expr" => {
21 parse_arithmetic_expression(&node, source_code, root, symbols_ptr, errors)
22 }
23 "comparison_expr" => parse_binary_expression(&node, source_code, root, symbols_ptr, errors),
24 "dominance_relation" => {
25 parse_dominance_relation(&node, source_code, root, symbols_ptr, errors)
26 }
27 "ERROR" => {
28 errors.push(RecoverableParseError::new(
29 format!(
30 "'{}' is not a valid expression",
31 &source_code[node.start_byte()..node.end_byte()]
32 ),
33 Some(node.range()),
34 ));
35 Ok(Expression::Atomic(
38 Metadata::new(),
39 conjure_cp_core::ast::Atom::Literal(conjure_cp_core::ast::Literal::Bool(false)),
40 ))
41 }
42 _ => {
43 errors.push(RecoverableParseError::new(
44 format!("Unknown expression kind: '{}'", node.kind()),
45 Some(node.range()),
46 ));
47 Ok(Expression::Atomic(
49 Metadata::new(),
50 conjure_cp_core::ast::Atom::Literal(conjure_cp_core::ast::Literal::Bool(false)),
51 ))
52 }
53 }
54}
55
56fn parse_dominance_relation(
57 node: &Node,
58 source_code: &str,
59 root: &Node,
60 symbols_ptr: Option<SymbolTablePtr>,
61 errors: &mut Vec<RecoverableParseError>,
62) -> Result<Expression, FatalParseError> {
63 if root.kind() == "dominance_relation" {
64 return Err(FatalParseError::syntax_error(
65 "Nested dominance relations are not allowed".to_string(),
66 Some(node.range()),
67 ));
68 }
69
70 let inner = parse_expression(
74 field!(node, "expression"),
75 source_code,
76 node,
77 symbols_ptr,
78 errors,
79 )?;
80 Ok(Expression::DominanceRelation(
81 Metadata::new(),
82 Moo::new(inner),
83 ))
84}
85
86fn parse_arithmetic_expression(
87 node: &Node,
88 source_code: &str,
89 root: &Node,
90 symbols_ptr: Option<SymbolTablePtr>,
91 errors: &mut Vec<RecoverableParseError>,
92) -> Result<Expression, FatalParseError> {
93 let inner = named_child!(node);
94 match inner.kind() {
95 "atom" => parse_atom(&inner, source_code, root, symbols_ptr, errors),
96 "negative_expr" | "abs_value" | "sub_arith_expr" | "toInt_expr" => {
97 parse_unary_expression(&inner, source_code, root, symbols_ptr, errors)
98 }
99 "exponent" | "product_expr" | "sum_expr" => {
100 parse_binary_expression(&inner, source_code, root, symbols_ptr, errors)
101 }
102 "list_combining_expr_arith" => {
103 parse_list_combining_expression(&inner, source_code, root, symbols_ptr, errors)
104 }
105 "aggregate_expr" => {
106 parse_quantifier_or_aggregate_expr(&inner, source_code, root, symbols_ptr, errors)
107 }
108 _ => Err(FatalParseError::syntax_error(
109 format!("Expected arithmetic expression, found: {}", inner.kind()),
110 Some(inner.range()),
111 )),
112 }
113}
114
115fn parse_boolean_expression(
116 node: &Node,
117 source_code: &str,
118 root: &Node,
119 symbols_ptr: Option<SymbolTablePtr>,
120 errors: &mut Vec<RecoverableParseError>,
121) -> Result<Expression, FatalParseError> {
122 let inner = named_child!(node);
123 match inner.kind() {
124 "atom" => parse_atom(&inner, source_code, root, symbols_ptr, errors),
125 "not_expr" | "sub_bool_expr" => {
126 parse_unary_expression(&inner, source_code, root, symbols_ptr, errors)
127 }
128 "and_expr" | "or_expr" | "implication" | "iff_expr" | "set_operation_bool" => {
129 parse_binary_expression(&inner, source_code, root, symbols_ptr, errors)
130 }
131 "list_combining_expr_bool" => {
132 parse_list_combining_expression(&inner, source_code, root, symbols_ptr, errors)
133 }
134 "quantifier_expr" => {
135 parse_quantifier_or_aggregate_expr(&inner, source_code, root, symbols_ptr, errors)
136 }
137 _ => Err(FatalParseError::syntax_error(
138 format!("Expected boolean expression, found '{}'", inner.kind()),
139 Some(inner.range()),
140 )),
141 }
142}
143
144fn parse_list_combining_expression(
145 node: &Node,
146 source_code: &str,
147 root: &Node,
148 symbols_ptr: Option<SymbolTablePtr>,
149 errors: &mut Vec<RecoverableParseError>,
150) -> Result<Expression, FatalParseError> {
151 let operator_node = field!(node, "operator");
152 let operator_str = &source_code[operator_node.start_byte()..operator_node.end_byte()];
153
154 let inner = parse_atom(&field!(node, "arg"), source_code, root, symbols_ptr, errors)?;
155
156 match operator_str {
157 "and" => Ok(Expression::And(Metadata::new(), Moo::new(inner))),
158 "or" => Ok(Expression::Or(Metadata::new(), Moo::new(inner))),
159 "sum" => Ok(Expression::Sum(Metadata::new(), Moo::new(inner))),
160 "product" => Ok(Expression::Product(Metadata::new(), Moo::new(inner))),
161 "min" => Ok(Expression::Min(Metadata::new(), Moo::new(inner))),
162 "max" => Ok(Expression::Max(Metadata::new(), Moo::new(inner))),
163 "allDiff" => Ok(Expression::AllDiff(Metadata::new(), Moo::new(inner))),
164 _ => Err(FatalParseError::syntax_error(
165 format!("Invalid operator: '{operator_str}'"),
166 Some(operator_node.range()),
167 )),
168 }
169}
170
171fn parse_unary_expression(
172 node: &Node,
173 source_code: &str,
174 root: &Node,
175 symbols_ptr: Option<SymbolTablePtr>,
176 errors: &mut Vec<RecoverableParseError>,
177) -> Result<Expression, FatalParseError> {
178 let inner = parse_expression(
179 field!(node, "expression"),
180 source_code,
181 root,
182 symbols_ptr,
183 errors,
184 )?;
185 match node.kind() {
186 "negative_expr" => Ok(Expression::Neg(Metadata::new(), Moo::new(inner))),
187 "abs_value" => Ok(Expression::Abs(Metadata::new(), Moo::new(inner))),
188 "not_expr" => Ok(Expression::Not(Metadata::new(), Moo::new(inner))),
189 "toInt_expr" => Ok(Expression::ToInt(Metadata::new(), Moo::new(inner))),
190 "sub_bool_expr" | "sub_arith_expr" => Ok(inner),
191 _ => Err(FatalParseError::syntax_error(
192 format!("Unrecognised unary operation: '{}'", node.kind()),
193 Some(node.range()),
194 )),
195 }
196}
197
198pub fn parse_binary_expression(
199 node: &Node,
200 source_code: &str,
201 root: &Node,
202 symbols_ptr: Option<SymbolTablePtr>,
203 errors: &mut Vec<RecoverableParseError>,
204) -> Result<Expression, FatalParseError> {
205 let mut parse_subexpr =
206 |expr: Node| parse_expression(expr, source_code, root, symbols_ptr.clone(), errors);
207
208 let left = parse_subexpr(field!(node, "left"))?;
209 let right = parse_subexpr(field!(node, "right"))?;
210
211 let op_node = field!(node, "operator");
212 let op_str = &source_code[op_node.start_byte()..op_node.end_byte()];
213
214 match op_str {
215 "+" => Ok(Expression::Sum(
220 Metadata::new(),
221 Moo::new(matrix_expr![left, right; domain_int!(1..)]),
222 )),
223 "-" => Ok(Expression::Minus(
224 Metadata::new(),
225 Moo::new(left),
226 Moo::new(right),
227 )),
228 "*" => Ok(Expression::Product(
229 Metadata::new(),
230 Moo::new(matrix_expr![left, right; domain_int!(1..)]),
231 )),
232 "/\\" => Ok(Expression::And(
233 Metadata::new(),
234 Moo::new(matrix_expr![left, right; domain_int!(1..)]),
235 )),
236 "\\/" => Ok(Expression::Or(
237 Metadata::new(),
238 Moo::new(matrix_expr![left, right; domain_int!(1..)]),
239 )),
240 "**" => Ok(Expression::UnsafePow(
241 Metadata::new(),
242 Moo::new(left),
243 Moo::new(right),
244 )),
245 "/" => {
246 Ok(Expression::UnsafeDiv(
248 Metadata::new(),
249 Moo::new(left),
250 Moo::new(right),
251 ))
252 }
253 "%" => {
254 Ok(Expression::UnsafeMod(
256 Metadata::new(),
257 Moo::new(left),
258 Moo::new(right),
259 ))
260 }
261 "=" => Ok(Expression::Eq(
262 Metadata::new(),
263 Moo::new(left),
264 Moo::new(right),
265 )),
266 "!=" => Ok(Expression::Neq(
267 Metadata::new(),
268 Moo::new(left),
269 Moo::new(right),
270 )),
271 "<=" => Ok(Expression::Leq(
272 Metadata::new(),
273 Moo::new(left),
274 Moo::new(right),
275 )),
276 ">=" => Ok(Expression::Geq(
277 Metadata::new(),
278 Moo::new(left),
279 Moo::new(right),
280 )),
281 "<" => Ok(Expression::Lt(
282 Metadata::new(),
283 Moo::new(left),
284 Moo::new(right),
285 )),
286 ">" => Ok(Expression::Gt(
287 Metadata::new(),
288 Moo::new(left),
289 Moo::new(right),
290 )),
291 "->" => Ok(Expression::Imply(
292 Metadata::new(),
293 Moo::new(left),
294 Moo::new(right),
295 )),
296 "<->" => Ok(Expression::Iff(
297 Metadata::new(),
298 Moo::new(left),
299 Moo::new(right),
300 )),
301 "<lex" => Ok(Expression::LexLt(
302 Metadata::new(),
303 Moo::new(left),
304 Moo::new(right),
305 )),
306 ">lex" => Ok(Expression::LexGt(
307 Metadata::new(),
308 Moo::new(left),
309 Moo::new(right),
310 )),
311 "<=lex" => Ok(Expression::LexLeq(
312 Metadata::new(),
313 Moo::new(left),
314 Moo::new(right),
315 )),
316 ">=lex" => Ok(Expression::LexGeq(
317 Metadata::new(),
318 Moo::new(left),
319 Moo::new(right),
320 )),
321 "in" => Ok(Expression::In(
322 Metadata::new(),
323 Moo::new(left),
324 Moo::new(right),
325 )),
326 "subset" => Ok(Expression::Subset(
327 Metadata::new(),
328 Moo::new(left),
329 Moo::new(right),
330 )),
331 "subsetEq" => Ok(Expression::SubsetEq(
332 Metadata::new(),
333 Moo::new(left),
334 Moo::new(right),
335 )),
336 "supset" => Ok(Expression::Supset(
337 Metadata::new(),
338 Moo::new(left),
339 Moo::new(right),
340 )),
341 "supsetEq" => Ok(Expression::SupsetEq(
342 Metadata::new(),
343 Moo::new(left),
344 Moo::new(right),
345 )),
346 "union" => Ok(Expression::Union(
347 Metadata::new(),
348 Moo::new(left),
349 Moo::new(right),
350 )),
351 "intersect" => Ok(Expression::Intersect(
352 Metadata::new(),
353 Moo::new(left),
354 Moo::new(right),
355 )),
356 _ => Err(FatalParseError::syntax_error(
357 format!("Invalid operator: '{op_str}'"),
358 Some(op_node.range()),
359 )),
360 }
361}