1
use crate::errors::{FatalParseError, RecoverableParseError};
2
use crate::parser::atom::parse_atom;
3
use crate::parser::comprehension::parse_quantifier_or_aggregate_expr;
4
use crate::{field, named_child};
5
use conjure_cp_core::ast::{Expression, Metadata, Moo, SymbolTablePtr};
6
use conjure_cp_core::{domain_int, matrix_expr, range};
7
use tree_sitter::Node;
8

            
9
/// Parse an Essence expression into its Conjure AST representation.
10
4399
pub fn parse_expression(
11
4399
    node: Node,
12
4399
    source_code: &str,
13
9254
    root: &Node,
14
9254
    symbols_ptr: Option<SymbolTablePtr>,
15
9254
    errors: &mut Vec<RecoverableParseError>,
16
9254
) -> Result<Expression, FatalParseError> {
17
9254
    match node.kind() {
18
9254
        "atom" => parse_atom(&node, source_code, root, symbols_ptr, errors),
19
2340
        "bool_expr" => parse_boolean_expression(&node, source_code, root, symbols_ptr, errors),
20
1090
        "arithmetic_expr" => {
21
500
            parse_arithmetic_expression(&node, source_code, root, symbols_ptr, errors)
22
        }
23
355
        "comparison_expr" => parse_binary_expression(&node, source_code, root, symbols_ptr, errors),
24
        "dominance_relation" => {
25
            parse_dominance_relation(&node, source_code, root, symbols_ptr, errors)
26
        }
27
        "ERROR" => {
28
4855
            errors.push(RecoverableParseError::new(
29
                format!(
30
                    "'{}' is not a valid expression",
31
                    &source_code[node.start_byte()..node.end_byte()]
32
                ),
33
                Some(node.range()),
34
            ));
35
            // Return a placeholder - actual error is in the errors vector
36
            // TODO: figure out how to return when recoverable error is found
37
            Ok(Expression::Atomic(
38
                Metadata::new(),
39
                conjure_cp_core::ast::Atom::Literal(conjure_cp_core::ast::Literal::Bool(false)),
40
            ))
41
        }
42
        _ => {
43
            errors.push(RecoverableParseError::new(
44
                format!("Unknown expression kind: '{}'", node.kind()),
45
                Some(node.range()),
46
            ));
47
            // Return a placeholder
48
            Ok(Expression::Atomic(
49
                Metadata::new(),
50
                conjure_cp_core::ast::Atom::Literal(conjure_cp_core::ast::Literal::Bool(false)),
51
            ))
52
        }
53
    }
54
4399
}
55

            
56
fn parse_dominance_relation(
57
    node: &Node,
58
    source_code: &str,
59
    root: &Node,
60
    symbols_ptr: Option<SymbolTablePtr>,
61
    errors: &mut Vec<RecoverableParseError>,
62
235
) -> Result<Expression, FatalParseError> {
63
235
    if root.kind() == "dominance_relation" {
64
235
        return Err(FatalParseError::syntax_error(
65
235
            "Nested dominance relations are not allowed".to_string(),
66
235
            Some(node.range()),
67
235
        ));
68
235
    }
69
235

            
70
    // NB: In all other cases, we keep the root the same;
71
    // However, here we set the new root to `node` so downstream functions
72
    // know we are inside a dominance relation
73
7
    let inner = parse_expression(
74
        field!(node, "expression"),
75
        source_code,
76
        node,
77
        symbols_ptr,
78
        errors,
79
    )?;
80
235
    Ok(Expression::DominanceRelation(
81
        Metadata::new(),
82
663
        Moo::new(inner),
83
663
    ))
84
663
}
85
663

            
86
797
fn parse_arithmetic_expression(
87
797
    node: &Node,
88
797
    source_code: &str,
89
797
    root: &Node,
90
539
    symbols_ptr: Option<SymbolTablePtr>,
91
509
    errors: &mut Vec<RecoverableParseError>,
92
134
) -> Result<Expression, FatalParseError> {
93
164
    let inner = named_child!(node);
94
156
    match inner.kind() {
95
134
        "atom" => parse_atom(&inner, source_code, root, symbols_ptr, errors),
96
134
        "negative_expr" | "abs_value" | "sub_arith_expr" | "toInt_expr" => {
97
46
            parse_unary_expression(&inner, source_code, root, symbols_ptr, errors)
98
        }
99
88
        "exponent" | "product_expr" | "sum_expr" => {
100
746
            parse_binary_expression(&inner, source_code, root, symbols_ptr, errors)
101
        }
102
20
        "list_combining_expr_arith" => {
103
20
            parse_list_combining_expression(&inner, source_code, root, symbols_ptr, errors)
104
15
        }
105
15
        "aggregate_expr" => {
106
15
            parse_quantifier_or_aggregate_expr(&inner, source_code, root, symbols_ptr, errors)
107
15
        }
108
        _ => Err(FatalParseError::syntax_error(
109
15
            format!("Expected arithmetic expression, found: {}", inner.kind()),
110
            Some(inner.range()),
111
        )),
112
    }
113
149
}
114
15

            
115
595
fn parse_boolean_expression(
116
595
    node: &Node,
117
588
    source_code: &str,
118
588
    root: &Node,
119
588
    symbols_ptr: Option<SymbolTablePtr>,
120
588
    errors: &mut Vec<RecoverableParseError>,
121
587
) -> Result<Expression, FatalParseError> {
122
587
    let inner = named_child!(node);
123
587
    match inner.kind() {
124
587
        "atom" => parse_atom(&inner, source_code, root, symbols_ptr, errors),
125
587
        "not_expr" | "sub_bool_expr" => {
126
259
            parse_unary_expression(&inner, source_code, root, symbols_ptr, errors)
127
        }
128
661
        "and_expr" | "or_expr" | "implication" | "iff_expr" | "set_operation_bool" => {
129
655
            parse_binary_expression(&inner, source_code, root, symbols_ptr, errors)
130
318
        }
131
324
        "list_combining_expr_bool" => {
132
324
            parse_list_combining_expression(&inner, source_code, root, symbols_ptr, errors)
133
        }
134
        "quantifier_expr" => {
135
318
            parse_quantifier_or_aggregate_expr(&inner, source_code, root, symbols_ptr, errors)
136
318
        }
137
304
        _ => Err(FatalParseError::syntax_error(
138
293
            format!("Expected boolean expression, found '{}'", inner.kind()),
139
237
            Some(inner.range()),
140
223
        )),
141
    }
142
587
}
143

            
144
11
fn parse_list_combining_expression(
145
11
    node: &Node,
146
329
    source_code: &str,
147
11
    root: &Node,
148
986
    symbols_ptr: Option<SymbolTablePtr>,
149
986
    errors: &mut Vec<RecoverableParseError>,
150
986
) -> Result<Expression, FatalParseError> {
151
986
    let operator_node = field!(node, "operator");
152
1961
    let operator_str = &source_code[operator_node.start_byte()..operator_node.end_byte()];
153

            
154
986
    let inner = parse_atom(&field!(node, "arg"), source_code, root, symbols_ptr, errors)?;
155

            
156
11
    match operator_str {
157
986
        "and" => Ok(Expression::And(Metadata::new(), Moo::new(inner))),
158
28
        "or" => Ok(Expression::Or(Metadata::new(), Moo::new(inner))),
159
6
        "sum" => Ok(Expression::Sum(Metadata::new(), Moo::new(inner))),
160
1
        "product" => Ok(Expression::Product(Metadata::new(), Moo::new(inner))),
161
954
        "min" => Ok(Expression::Min(Metadata::new(), Moo::new(inner))),
162
954
        "max" => Ok(Expression::Max(Metadata::new(), Moo::new(inner))),
163
1
        "allDiff" => Ok(Expression::AllDiff(Metadata::new(), Moo::new(inner))),
164
953
        _ => Err(FatalParseError::syntax_error(
165
953
            format!("Invalid operator: '{operator_str}'"),
166
            Some(operator_node.range()),
167
        )),
168
    }
169
11
}
170
953

            
171
397
fn parse_unary_expression(
172
397
    node: &Node,
173
397
    source_code: &str,
174
1136
    root: &Node,
175
326
    symbols_ptr: Option<SymbolTablePtr>,
176
326
    errors: &mut Vec<RecoverableParseError>,
177
326
) -> Result<Expression, FatalParseError> {
178
326
    let inner = parse_expression(
179
1100
        field!(node, "expression"),
180
311
        source_code,
181
311
        root,
182
311
        symbols_ptr,
183
1079
        errors,
184
28
    )?;
185
318
    match node.kind() {
186
318
        "negative_expr" => Ok(Expression::Neg(Metadata::new(), Moo::new(inner))),
187
1041
        "abs_value" => Ok(Expression::Abs(Metadata::new(), Moo::new(inner))),
188
276
        "not_expr" => Ok(Expression::Not(Metadata::new(), Moo::new(inner))),
189
236
        "toInt_expr" => Ok(Expression::ToInt(Metadata::new(), Moo::new(inner))),
190
226
        "sub_bool_expr" | "sub_arith_expr" => Ok(inner),
191
754
        _ => Err(FatalParseError::syntax_error(
192
            format!("Unrecognised unary operation: '{}'", node.kind()),
193
            Some(node.range()),
194
        )),
195
    }
196
1044
}
197

            
198
845
pub fn parse_binary_expression(
199
845
    node: &Node,
200
845
    source_code: &str,
201
845
    root: &Node,
202
845
    symbols_ptr: Option<SymbolTablePtr>,
203
841
    errors: &mut Vec<RecoverableParseError>,
204
1591
) -> Result<Expression, FatalParseError> {
205
841
    let mut parse_subexpr =
206
1682
        |expr: Node| parse_expression(expr, source_code, root, symbols_ptr.clone(), errors);
207

            
208
841
    let left = parse_subexpr(field!(node, "left"))?;
209
841
    let right = parse_subexpr(field!(node, "right"))?;
210

            
211
819
    let op_node = field!(node, "operator");
212
1569
    let op_str = &source_code[op_node.start_byte()..op_node.end_byte()];
213
260

            
214
1079
    match op_str {
215
        // NB: We are deliberately setting the index domain to 1.., not 1..2.
216
        // Semantically, this means "a list that can grow/shrink arbitrarily".
217
        // This is expected by rules which will modify the terms of the sum expression
218
        // (e.g. by partially evaluating them).
219
833
        "+" => Ok(Expression::Sum(
220
70
            Metadata::new(),
221
70
            Moo::new(matrix_expr![left, right; domain_int!(1..)]),
222
532
        )),
223
791
        "-" => Ok(Expression::Minus(
224
38
            Metadata::new(),
225
38
            Moo::new(left),
226
38
            Moo::new(right),
227
458
        )),
228
776
        "*" => Ok(Expression::Product(
229
38
            Metadata::new(),
230
38
            Moo::new(matrix_expr![left, right; domain_int!(1..)]),
231
38
        )),
232
1163
        "/\\" => Ok(Expression::And(
233
28
            Metadata::new(),
234
28
            Moo::new(matrix_expr![left, right; domain_int!(1..)]),
235
28
        )),
236
726
        "\\/" => Ok(Expression::Or(
237
422
            Metadata::new(),
238
16
            Moo::new(matrix_expr![left, right; domain_int!(1..)]),
239
16
        )),
240
724
        "**" => Ok(Expression::UnsafePow(
241
11
            Metadata::new(),
242
406
            Moo::new(left),
243
35
            Moo::new(right),
244
35
        )),
245
748
        "/" => {
246
            //TODO: add checks for if division is safe or not
247
373
            Ok(Expression::UnsafeDiv(
248
9
                Metadata::new(),
249
9
                Moo::new(left),
250
9
                Moo::new(right),
251
9
            ))
252
364
        }
253
711
        "%" => {
254
            //TODO: add checks for if mod is safe or not
255
            Ok(Expression::UnsafeMod(
256
                Metadata::new(),
257
364
                Moo::new(left),
258
                Moo::new(right),
259
            ))
260
        }
261
711
        "=" => Ok(Expression::Eq(
262
589
            Metadata::new(),
263
225
            Moo::new(left),
264
225
            Moo::new(right),
265
225
        )),
266
486
        "!=" => Ok(Expression::Neq(
267
396
            Metadata::new(),
268
32
            Moo::new(left),
269
32
            Moo::new(right),
270
32
        )),
271
454
        "<=" => Ok(Expression::Leq(
272
384
            Metadata::new(),
273
131
            Moo::new(left),
274
131
            Moo::new(right),
275
131
        )),
276
545
        ">=" => Ok(Expression::Geq(
277
270
            Metadata::new(),
278
72
            Moo::new(left),
279
72
            Moo::new(right),
280
72
        )),
281
472
        "<" => Ok(Expression::Lt(
282
204
            Metadata::new(),
283
50
            Moo::new(left),
284
50
            Moo::new(right),
285
50
        )),
286
455
        ">" => Ok(Expression::Gt(
287
187
            Metadata::new(),
288
77
            Moo::new(left),
289
77
            Moo::new(right),
290
77
        )),
291
422
        "->" => Ok(Expression::Imply(
292
135
            Metadata::new(),
293
69
            Moo::new(left),
294
69
            Moo::new(right),
295
69
        )),
296
397
        "<->" => Ok(Expression::Iff(
297
71
            Metadata::new(),
298
38
            Moo::new(left),
299
38
            Moo::new(right),
300
38
        )),
301
381
        "<lex" => Ok(Expression::LexLt(
302
33
            Metadata::new(),
303
33
            Moo::new(left),
304
            Moo::new(right),
305
33
        )),
306
381
        ">lex" => Ok(Expression::LexGt(
307
33
            Metadata::new(),
308
33
            Moo::new(left),
309
33
            Moo::new(right),
310
33
        )),
311
381
        "<=lex" => Ok(Expression::LexLeq(
312
33
            Metadata::new(),
313
            Moo::new(left),
314
            Moo::new(right),
315
        )),
316
348
        ">=lex" => Ok(Expression::LexGeq(
317
            Metadata::new(),
318
            Moo::new(left),
319
            Moo::new(right),
320
953
        )),
321
1301
        "in" => Ok(Expression::In(
322
1048
            Metadata::new(),
323
1048
            Moo::new(left),
324
1048
            Moo::new(right),
325
1048
        )),
326
1206
        "subset" => Ok(Expression::Subset(
327
1008
            Metadata::new(),
328
1008
            Moo::new(left),
329
55
            Moo::new(right),
330
1008
        )),
331
1173
        "subsetEq" => Ok(Expression::SubsetEq(
332
44
            Metadata::new(),
333
44
            Moo::new(left),
334
44
            Moo::new(right),
335
44
        )),
336
154
        "supset" => Ok(Expression::Supset(
337
44
            Metadata::new(),
338
44
            Moo::new(left),
339
44
            Moo::new(right),
340
44
        )),
341
110
        "supsetEq" => Ok(Expression::SupsetEq(
342
44
            Metadata::new(),
343
44
            Moo::new(left),
344
44
            Moo::new(right),
345
44
        )),
346
66
        "union" => Ok(Expression::Union(
347
33
            Metadata::new(),
348
33
            Moo::new(left),
349
33
            Moo::new(right),
350
33
        )),
351
33
        "intersect" => Ok(Expression::Intersect(
352
33
            Metadata::new(),
353
33
            Moo::new(left),
354
33
            Moo::new(right),
355
33
        )),
356
        _ => Err(FatalParseError::syntax_error(
357
            format!("Invalid operator: '{op_str}'"),
358
            Some(op_node.range()),
359
        )),
360
    }
361
841
}