1use crate::diagnostics::diagnostics_api::SymbolKind;
2use crate::diagnostics::source_map::{HoverInfo, span_with_hover};
3use crate::errors::FatalParseError;
4use crate::parser::ParseContext;
5use crate::parser::atom::parse_atom;
6use crate::parser::comprehension::parse_quantifier_or_aggregate_expr;
7use crate::util::TypecheckingContext;
8use crate::{field, named_child};
9use conjure_cp_core::ast::{Expression, Metadata, Moo};
10use conjure_cp_core::{domain_int, matrix_expr, range};
11use tree_sitter::Node;
12
13pub fn parse_expression(
14 ctx: &mut ParseContext,
15 node: Node,
16) -> Result<Option<Expression>, FatalParseError> {
17 match node.kind() {
18 "atom" => parse_atom(ctx, &node),
19 "bool_expr" => parse_boolean_expression(ctx, &node),
20 "arithmetic_expr" => parse_arithmetic_expression(ctx, &node),
21 "comparison_expr" => parse_comparison_expression(ctx, &node),
22 "dominance_relation" => parse_dominance_relation(ctx, &node),
23 "all_diff_comparison" => parse_all_diff_comparison(ctx, &node),
24 _ => Err(FatalParseError::internal_error(
25 format!("Unexpected expression type: '{}'", node.kind()),
26 Some(node.range()),
27 )),
28 }
29}
30
31fn parse_dominance_relation(
32 ctx: &mut ParseContext,
33 node: &Node,
34) -> Result<Option<Expression>, FatalParseError> {
35 if ctx.root.kind() == "dominance_relation" {
36 return Err(FatalParseError::internal_error(
37 "Nested dominance relations are not allowed".to_string(),
38 Some(node.range()),
39 ));
40 }
41
42 let mut inner_ctx = ParseContext {
46 source_code: ctx.source_code,
47 root: node,
48 symbols: ctx.symbols.clone(),
49 errors: ctx.errors,
50 source_map: &mut *ctx.source_map,
51 typechecking_context: ctx.typechecking_context,
52 };
53
54 let Some(inner) = parse_expression(&mut inner_ctx, field!(node, "expression"))? else {
55 return Ok(None);
56 };
57
58 Ok(Some(Expression::DominanceRelation(
59 Metadata::new(),
60 Moo::new(inner),
61 )))
62}
63
64fn parse_arithmetic_expression(
65 ctx: &mut ParseContext,
66 node: &Node,
67) -> Result<Option<Expression>, FatalParseError> {
68 ctx.typechecking_context = TypecheckingContext::Arithmetic;
69 let inner = named_child!(node);
70 match inner.kind() {
71 "atom" => parse_atom(ctx, &inner),
72 "negative_expr" | "abs_value" | "sub_arith_expr" => parse_unary_expression(ctx, &inner),
73 "toInt_expr" => {
74 ctx.typechecking_context = TypecheckingContext::Unknown;
76 parse_unary_expression(ctx, &inner)
77 }
78 "exponent" | "product_expr" | "sum_expr" => parse_binary_expression(ctx, &inner),
79 "list_combining_expr_arith" => parse_list_combining_expression(ctx, &inner),
80 "aggregate_expr" => parse_quantifier_or_aggregate_expr(ctx, &inner),
81 _ => Err(FatalParseError::internal_error(
82 format!("Expected arithmetic expression, found: {}", inner.kind()),
83 Some(inner.range()),
84 )),
85 }
86}
87
88fn parse_comparison_expression(
89 ctx: &mut ParseContext,
90 node: &Node,
91) -> Result<Option<Expression>, FatalParseError> {
92 let inner = named_child!(node);
93 match inner.kind() {
94 "arithmetic_comparison" => {
95 ctx.typechecking_context = TypecheckingContext::Arithmetic;
97 parse_binary_expression(ctx, &inner)
98 }
99 "lex_comparison" => {
100 ctx.typechecking_context = TypecheckingContext::Unknown;
102 parse_binary_expression(ctx, &inner)
103 }
104 "equality_comparison" => {
105 ctx.typechecking_context = TypecheckingContext::Unknown;
108 parse_binary_expression(ctx, &inner)
109 }
110 "set_comparison" => {
111 ctx.typechecking_context = TypecheckingContext::Unknown;
114 parse_binary_expression(ctx, &inner)
115 }
116 "all_diff_comparison" => {
117 ctx.typechecking_context = TypecheckingContext::Unknown;
119 parse_all_diff_comparison(ctx, &inner)
120 }
121 _ => Err(FatalParseError::internal_error(
122 format!("Expected comparison expression, found '{}'", inner.kind()),
123 Some(inner.range()),
124 )),
125 }
126}
127
128fn parse_boolean_expression(
129 ctx: &mut ParseContext,
130 node: &Node,
131) -> Result<Option<Expression>, FatalParseError> {
132 ctx.typechecking_context = TypecheckingContext::Boolean;
133 let inner = named_child!(node);
134 match inner.kind() {
135 "atom" => parse_atom(ctx, &inner),
136 "not_expr" | "sub_bool_expr" => parse_unary_expression(ctx, &inner),
137 "and_expr" | "or_expr" | "implication" | "iff_expr" => parse_binary_expression(ctx, &inner),
138 "list_combining_expr_bool" => parse_list_combining_expression(ctx, &inner),
139 "quantifier_expr" => parse_quantifier_or_aggregate_expr(ctx, &inner),
140 _ => Err(FatalParseError::internal_error(
141 format!("Expected boolean expression, found '{}'", inner.kind()),
142 Some(inner.range()),
143 )),
144 }
145}
146
147fn parse_list_combining_expression(
148 ctx: &mut ParseContext,
149 node: &Node,
150) -> Result<Option<Expression>, FatalParseError> {
151 let operator_node = field!(node, "operator");
152 let operator_str = &ctx.source_code[operator_node.start_byte()..operator_node.end_byte()];
153
154 let Some(inner) = parse_atom(ctx, &field!(node, "arg"))? else {
155 return Ok(None);
156 };
157
158 match operator_str {
159 "and" => Ok(Some(Expression::And(Metadata::new(), Moo::new(inner)))),
160 "or" => Ok(Some(Expression::Or(Metadata::new(), Moo::new(inner)))),
161 "sum" => Ok(Some(Expression::Sum(Metadata::new(), Moo::new(inner)))),
162 "product" => Ok(Some(Expression::Product(Metadata::new(), Moo::new(inner)))),
163 "min" => Ok(Some(Expression::Min(Metadata::new(), Moo::new(inner)))),
164 "max" => Ok(Some(Expression::Max(Metadata::new(), Moo::new(inner)))),
165 _ => Err(FatalParseError::internal_error(
166 format!("Invalid operator: '{operator_str}'"),
167 Some(operator_node.range()),
168 )),
169 }
170}
171
172fn parse_all_diff_comparison(
173 ctx: &mut ParseContext,
174 node: &Node,
175) -> Result<Option<Expression>, FatalParseError> {
176 let Some(inner) = parse_expression(ctx, field!(node, "arg"))? else {
177 return Ok(None);
178 };
179
180 Ok(Some(Expression::AllDiff(Metadata::new(), Moo::new(inner))))
181}
182
183fn parse_unary_expression(
184 ctx: &mut ParseContext,
185 node: &Node,
186) -> Result<Option<Expression>, FatalParseError> {
187 let Some(inner) = parse_expression(ctx, field!(node, "expression"))? else {
188 return Ok(None);
189 };
190 match node.kind() {
191 "negative_expr" => Ok(Some(Expression::Neg(Metadata::new(), Moo::new(inner)))),
192 "abs_value" => Ok(Some(Expression::Abs(Metadata::new(), Moo::new(inner)))),
193 "not_expr" => Ok(Some(Expression::Not(Metadata::new(), Moo::new(inner)))),
194 "toInt_expr" => Ok(Some(Expression::ToInt(Metadata::new(), Moo::new(inner)))),
195 "sub_bool_expr" | "sub_arith_expr" => Ok(Some(inner)),
196 _ => Err(FatalParseError::internal_error(
197 format!("Unrecognised unary operation: '{}'", node.kind()),
198 Some(node.range()),
199 )),
200 }
201}
202
203pub fn parse_binary_expression(
204 ctx: &mut ParseContext,
205 node: &Node,
206) -> Result<Option<Expression>, FatalParseError> {
207 let mut parse_subexpr = |expr: Node| parse_expression(ctx, expr);
208
209 let Some(left) = parse_subexpr(field!(node, "left"))? else {
210 return Ok(None);
211 };
212 let Some(right) = parse_subexpr(field!(node, "right"))? else {
213 return Ok(None);
214 };
215
216 let op_node = field!(node, "operator");
217 let op_str = &ctx.source_code[op_node.start_byte()..op_node.end_byte()];
218
219 let mut description = format!("Operator '{op_str}'");
220 let expr = match op_str {
221 "+" => Ok(Some(Expression::Sum(
226 Metadata::new(),
227 Moo::new(matrix_expr![left, right; domain_int!(1..)]),
228 ))),
229 "-" => Ok(Some(Expression::Minus(
230 Metadata::new(),
231 Moo::new(left),
232 Moo::new(right),
233 ))),
234 "*" => Ok(Some(Expression::Product(
235 Metadata::new(),
236 Moo::new(matrix_expr![left, right; domain_int!(1..)]),
237 ))),
238 "/\\" => Ok(Some(Expression::And(
239 Metadata::new(),
240 Moo::new(matrix_expr![left, right; domain_int!(1..)]),
241 ))),
242 "\\/" => Ok(Some(Expression::Or(
243 Metadata::new(),
244 Moo::new(matrix_expr![left, right; domain_int!(1..)]),
245 ))),
246 "**" => Ok(Some(Expression::UnsafePow(
247 Metadata::new(),
248 Moo::new(left),
249 Moo::new(right),
250 ))),
251 "/" => {
252 Ok(Some(Expression::UnsafeDiv(
254 Metadata::new(),
255 Moo::new(left),
256 Moo::new(right),
257 )))
258 }
259 "%" => {
260 Ok(Some(Expression::UnsafeMod(
262 Metadata::new(),
263 Moo::new(left),
264 Moo::new(right),
265 )))
266 }
267 "=" => Ok(Some(Expression::Eq(
268 Metadata::new(),
269 Moo::new(left),
270 Moo::new(right),
271 ))),
272 "!=" => Ok(Some(Expression::Neq(
273 Metadata::new(),
274 Moo::new(left),
275 Moo::new(right),
276 ))),
277 "<=" => Ok(Some(Expression::Leq(
278 Metadata::new(),
279 Moo::new(left),
280 Moo::new(right),
281 ))),
282 ">=" => Ok(Some(Expression::Geq(
283 Metadata::new(),
284 Moo::new(left),
285 Moo::new(right),
286 ))),
287 "<" => Ok(Some(Expression::Lt(
288 Metadata::new(),
289 Moo::new(left),
290 Moo::new(right),
291 ))),
292 ">" => Ok(Some(Expression::Gt(
293 Metadata::new(),
294 Moo::new(left),
295 Moo::new(right),
296 ))),
297 "->" => Ok(Some(Expression::Imply(
298 Metadata::new(),
299 Moo::new(left),
300 Moo::new(right),
301 ))),
302 "<->" => Ok(Some(Expression::Iff(
303 Metadata::new(),
304 Moo::new(left),
305 Moo::new(right),
306 ))),
307 "<lex" => Ok(Some(Expression::LexLt(
308 Metadata::new(),
309 Moo::new(left),
310 Moo::new(right),
311 ))),
312 ">lex" => Ok(Some(Expression::LexGt(
313 Metadata::new(),
314 Moo::new(left),
315 Moo::new(right),
316 ))),
317 "<=lex" => Ok(Some(Expression::LexLeq(
318 Metadata::new(),
319 Moo::new(left),
320 Moo::new(right),
321 ))),
322 ">=lex" => Ok(Some(Expression::LexGeq(
323 Metadata::new(),
324 Moo::new(left),
325 Moo::new(right),
326 ))),
327 "in" => Ok(Some(Expression::In(
328 Metadata::new(),
329 Moo::new(left),
330 Moo::new(right),
331 ))),
332 "subset" => Ok(Some(Expression::Subset(
333 Metadata::new(),
334 Moo::new(left),
335 Moo::new(right),
336 ))),
337 "subsetEq" => Ok(Some(Expression::SubsetEq(
338 Metadata::new(),
339 Moo::new(left),
340 Moo::new(right),
341 ))),
342 "supset" => Ok(Some(Expression::Supset(
343 Metadata::new(),
344 Moo::new(left),
345 Moo::new(right),
346 ))),
347 "supsetEq" => Ok(Some(Expression::SupsetEq(
348 Metadata::new(),
349 Moo::new(left),
350 Moo::new(right),
351 ))),
352 "union" => {
353 description = "set union: combines the elements from both operands".to_string();
354 Ok(Some(Expression::Union(
355 Metadata::new(),
356 Moo::new(left),
357 Moo::new(right),
358 )))
359 }
360 "intersect" => {
361 description =
362 "set intersection: keeps only elements common to both operands".to_string();
363 Ok(Some(Expression::Intersect(
364 Metadata::new(),
365 Moo::new(left),
366 Moo::new(right),
367 )))
368 }
369 _ => Err(FatalParseError::internal_error(
370 format!("Invalid operator: '{op_str}'"),
371 Some(op_node.range()),
372 )),
373 };
374
375 if expr.is_ok() {
376 let hover = HoverInfo {
377 description,
378 kind: Some(SymbolKind::Function),
379 ty: None,
380 decl_span: None,
381 };
382 span_with_hover(&op_node, ctx.source_code, ctx.source_map, hover);
383 }
384
385 expr
386}