1
use crate::errors::EssenceParseError;
2
use crate::parser::atom::parse_atom;
3
use crate::parser::comprehension::parse_quantifier_or_aggregate_expr;
4
use crate::{field, named_child};
5
use conjure_cp_core::ast::{Expression, Metadata, Moo, SymbolTablePtr};
6
use conjure_cp_core::{domain_int, matrix_expr, range};
7
use tree_sitter::Node;
8

            
9
/// Parse an Essence expression into its Conjure AST representation.
10
4429
pub fn parse_expression(
11
4429
    node: Node,
12
4429
    source_code: &str,
13
4429
    root: &Node,
14
4429
    symbols_ptr: Option<SymbolTablePtr>,
15
4429
) -> Result<Expression, EssenceParseError> {
16
4429
    match node.kind() {
17
4429
        "atom" => parse_atom(&node, source_code, root, symbols_ptr),
18
1089
        "bool_expr" => parse_boolean_expression(&node, source_code, root, symbols_ptr),
19
491
        "arithmetic_expr" => parse_arithmetic_expression(&node, source_code, root, symbols_ptr),
20
357
        "comparison_expr" => parse_binary_expression(&node, source_code, root, symbols_ptr),
21
        "dominance_relation" => parse_dominance_relation(&node, source_code, root, symbols_ptr),
22
        "ERROR" => Err(EssenceParseError::syntax_error(
23
            format!(
24
                "'{}' is not a valid expression",
25
                &source_code[node.start_byte()..node.end_byte()]
26
            ),
27
            Some(node.range()),
28
        )),
29
        _ => Err(EssenceParseError::syntax_error(
30
            format!("Unknown expression kind: '{}'", node.kind()),
31
            Some(node.range()),
32
        )),
33
    }
34
4429
}
35

            
36
fn parse_dominance_relation(
37
    node: &Node,
38
    source_code: &str,
39
    root: &Node,
40
    symbols_ptr: Option<SymbolTablePtr>,
41
) -> Result<Expression, EssenceParseError> {
42
    if root.kind() == "dominance_relation" {
43
        return Err(EssenceParseError::syntax_error(
44
            "Nested dominance relations are not allowed".to_string(),
45
            Some(node.range()),
46
        ));
47
    }
48

            
49
    // NB: In all other cases, we keep the root the same;
50
    // However, here we set the new root to `node` so downstream functions
51
    // know we are inside a dominance relation
52
    let inner = parse_expression(field!(node, "expression"), source_code, node, symbols_ptr)?;
53
    Ok(Expression::DominanceRelation(
54
        Metadata::new(),
55
        Moo::new(inner),
56
    ))
57
}
58

            
59
134
fn parse_arithmetic_expression(
60
134
    node: &Node,
61
134
    source_code: &str,
62
134
    root: &Node,
63
134
    symbols_ptr: Option<SymbolTablePtr>,
64
134
) -> Result<Expression, EssenceParseError> {
65
134
    let inner = named_child!(node);
66
134
    match inner.kind() {
67
134
        "atom" => parse_atom(&inner, source_code, root, symbols_ptr),
68
134
        "negative_expr" | "abs_value" | "sub_arith_expr" | "toInt_expr" => {
69
46
            parse_unary_expression(&inner, source_code, root, symbols_ptr)
70
        }
71
88
        "exponent" | "product_expr" | "sum_expr" => {
72
83
            parse_binary_expression(&inner, source_code, root, symbols_ptr)
73
        }
74
5
        "list_combining_expr_arith" => {
75
5
            parse_list_combining_expression(&inner, source_code, root, symbols_ptr)
76
        }
77
        "aggregate_expr" => {
78
            parse_quantifier_or_aggregate_expr(&inner, source_code, root, symbols_ptr)
79
        }
80
        _ => Err(EssenceParseError::syntax_error(
81
            format!("Expected arithmetic expression, found: {}", inner.kind()),
82
            Some(inner.range()),
83
        )),
84
    }
85
134
}
86

            
87
598
fn parse_boolean_expression(
88
598
    node: &Node,
89
598
    source_code: &str,
90
598
    root: &Node,
91
598
    symbols_ptr: Option<SymbolTablePtr>,
92
598
) -> Result<Expression, EssenceParseError> {
93
598
    let inner = named_child!(node);
94
598
    match inner.kind() {
95
598
        "atom" => parse_atom(&inner, source_code, root, symbols_ptr),
96
598
        "not_expr" | "sub_bool_expr" => {
97
244
            parse_unary_expression(&inner, source_code, root, symbols_ptr)
98
        }
99
354
        "and_expr" | "or_expr" | "implication" | "iff_expr" | "set_operation_bool" => {
100
348
            parse_binary_expression(&inner, source_code, root, symbols_ptr)
101
        }
102
6
        "list_combining_expr_bool" => {
103
6
            parse_list_combining_expression(&inner, source_code, root, symbols_ptr)
104
        }
105
        "quantifier_expr" => {
106
            parse_quantifier_or_aggregate_expr(&inner, source_code, root, symbols_ptr)
107
        }
108
        _ => Err(EssenceParseError::syntax_error(
109
            format!("Expected boolean expression, found '{}'", inner.kind()),
110
            Some(inner.range()),
111
        )),
112
    }
113
598
}
114

            
115
11
fn parse_list_combining_expression(
116
11
    node: &Node,
117
11
    source_code: &str,
118
11
    root: &Node,
119
11
    symbols_ptr: Option<SymbolTablePtr>,
120
11
) -> Result<Expression, EssenceParseError> {
121
11
    let operator_node = field!(node, "operator");
122
11
    let operator_str = &source_code[operator_node.start_byte()..operator_node.end_byte()];
123

            
124
11
    let inner = parse_atom(&field!(node, "arg"), source_code, root, symbols_ptr)?;
125

            
126
11
    match operator_str {
127
11
        "and" => Ok(Expression::And(Metadata::new(), Moo::new(inner))),
128
6
        "or" => Ok(Expression::Or(Metadata::new(), Moo::new(inner))),
129
6
        "sum" => Ok(Expression::Sum(Metadata::new(), Moo::new(inner))),
130
1
        "product" => Ok(Expression::Product(Metadata::new(), Moo::new(inner))),
131
1
        "min" => Ok(Expression::Min(Metadata::new(), Moo::new(inner))),
132
1
        "max" => Ok(Expression::Max(Metadata::new(), Moo::new(inner))),
133
1
        "allDiff" => Ok(Expression::AllDiff(Metadata::new(), Moo::new(inner))),
134
        _ => Err(EssenceParseError::syntax_error(
135
            format!("Invalid operator: '{operator_str}'"),
136
            Some(operator_node.range()),
137
        )),
138
    }
139
11
}
140

            
141
290
fn parse_unary_expression(
142
290
    node: &Node,
143
290
    source_code: &str,
144
290
    root: &Node,
145
290
    symbols_ptr: Option<SymbolTablePtr>,
146
290
) -> Result<Expression, EssenceParseError> {
147
290
    let inner = parse_expression(field!(node, "expression"), source_code, root, symbols_ptr)?;
148
290
    match node.kind() {
149
290
        "negative_expr" => Ok(Expression::Neg(Metadata::new(), Moo::new(inner))),
150
280
        "abs_value" => Ok(Expression::Abs(Metadata::new(), Moo::new(inner))),
151
269
        "not_expr" => Ok(Expression::Not(Metadata::new(), Moo::new(inner))),
152
229
        "toInt_expr" => Ok(Expression::ToInt(Metadata::new(), Moo::new(inner))),
153
219
        "sub_bool_expr" | "sub_arith_expr" => Ok(inner),
154
        _ => Err(EssenceParseError::syntax_error(
155
            format!("Unrecognised unary operation: '{}'", node.kind()),
156
            Some(node.range()),
157
        )),
158
    }
159
290
}
160

            
161
854
pub fn parse_binary_expression(
162
854
    node: &Node,
163
854
    source_code: &str,
164
854
    root: &Node,
165
854
    symbols_ptr: Option<SymbolTablePtr>,
166
854
) -> Result<Expression, EssenceParseError> {
167
1697
    let parse_subexpr = |expr: Node| parse_expression(expr, source_code, root, symbols_ptr.clone());
168

            
169
854
    let left = parse_subexpr(field!(node, "left"))?;
170
843
    let right = parse_subexpr(field!(node, "right"))?;
171

            
172
821
    let op_node = field!(node, "operator");
173
821
    let op_str = &source_code[op_node.start_byte()..op_node.end_byte()];
174

            
175
821
    match op_str {
176
        // NB: We are deliberately setting the index domain to 1.., not 1..2.
177
        // Semantically, this means "a list that can grow/shrink arbitrarily".
178
        // This is expected by rules which will modify the terms of the sum expression
179
        // (e.g. by partially evaluating them).
180
821
        "+" => Ok(Expression::Sum(
181
56
            Metadata::new(),
182
56
            Moo::new(matrix_expr![left, right; domain_int!(1..)]),
183
56
        )),
184
765
        "-" => Ok(Expression::Minus(
185
10
            Metadata::new(),
186
10
            Moo::new(left),
187
10
            Moo::new(right),
188
10
        )),
189
755
        "*" => Ok(Expression::Product(
190
15
            Metadata::new(),
191
15
            Moo::new(matrix_expr![left, right; domain_int!(1..)]),
192
15
        )),
193
740
        "/\\" => Ok(Expression::And(
194
20
            Metadata::new(),
195
20
            Moo::new(matrix_expr![left, right; domain_int!(1..)]),
196
20
        )),
197
720
        "\\/" => Ok(Expression::Or(
198
5
            Metadata::new(),
199
5
            Moo::new(matrix_expr![left, right; domain_int!(1..)]),
200
5
        )),
201
715
        "**" => Ok(Expression::UnsafePow(
202
            Metadata::new(),
203
            Moo::new(left),
204
            Moo::new(right),
205
        )),
206
715
        "/" => {
207
            //TODO: add checks for if division is safe or not
208
2
            Ok(Expression::UnsafeDiv(
209
2
                Metadata::new(),
210
2
                Moo::new(left),
211
2
                Moo::new(right),
212
2
            ))
213
        }
214
713
        "%" => {
215
            //TODO: add checks for if mod is safe or not
216
            Ok(Expression::UnsafeMod(
217
                Metadata::new(),
218
                Moo::new(left),
219
                Moo::new(right),
220
            ))
221
        }
222
713
        "=" => Ok(Expression::Eq(
223
225
            Metadata::new(),
224
225
            Moo::new(left),
225
225
            Moo::new(right),
226
225
        )),
227
488
        "!=" => Ok(Expression::Neq(
228
32
            Metadata::new(),
229
32
            Moo::new(left),
230
32
            Moo::new(right),
231
32
        )),
232
456
        "<=" => Ok(Expression::Leq(
233
20
            Metadata::new(),
234
20
            Moo::new(left),
235
20
            Moo::new(right),
236
20
        )),
237
436
        ">=" => Ok(Expression::Geq(
238
17
            Metadata::new(),
239
17
            Moo::new(left),
240
17
            Moo::new(right),
241
17
        )),
242
419
        "<" => Ok(Expression::Lt(
243
8
            Metadata::new(),
244
8
            Moo::new(left),
245
8
            Moo::new(right),
246
8
        )),
247
411
        ">" => Ok(Expression::Gt(
248
33
            Metadata::new(),
249
33
            Moo::new(left),
250
33
            Moo::new(right),
251
33
        )),
252
378
        "->" => Ok(Expression::Imply(
253
25
            Metadata::new(),
254
25
            Moo::new(left),
255
25
            Moo::new(right),
256
25
        )),
257
353
        "<->" => Ok(Expression::Iff(
258
5
            Metadata::new(),
259
5
            Moo::new(left),
260
5
            Moo::new(right),
261
5
        )),
262
348
        "<lex" => Ok(Expression::LexLt(
263
            Metadata::new(),
264
            Moo::new(left),
265
            Moo::new(right),
266
        )),
267
348
        ">lex" => Ok(Expression::LexGt(
268
            Metadata::new(),
269
            Moo::new(left),
270
            Moo::new(right),
271
        )),
272
348
        "<=lex" => Ok(Expression::LexLeq(
273
            Metadata::new(),
274
            Moo::new(left),
275
            Moo::new(right),
276
        )),
277
348
        ">=lex" => Ok(Expression::LexGeq(
278
            Metadata::new(),
279
            Moo::new(left),
280
            Moo::new(right),
281
        )),
282
348
        "in" => Ok(Expression::In(
283
95
            Metadata::new(),
284
95
            Moo::new(left),
285
95
            Moo::new(right),
286
95
        )),
287
253
        "subset" => Ok(Expression::Subset(
288
55
            Metadata::new(),
289
55
            Moo::new(left),
290
55
            Moo::new(right),
291
55
        )),
292
198
        "subsetEq" => Ok(Expression::SubsetEq(
293
44
            Metadata::new(),
294
44
            Moo::new(left),
295
44
            Moo::new(right),
296
44
        )),
297
154
        "supset" => Ok(Expression::Supset(
298
44
            Metadata::new(),
299
44
            Moo::new(left),
300
44
            Moo::new(right),
301
44
        )),
302
110
        "supsetEq" => Ok(Expression::SupsetEq(
303
44
            Metadata::new(),
304
44
            Moo::new(left),
305
44
            Moo::new(right),
306
44
        )),
307
66
        "union" => Ok(Expression::Union(
308
33
            Metadata::new(),
309
33
            Moo::new(left),
310
33
            Moo::new(right),
311
33
        )),
312
33
        "intersect" => Ok(Expression::Intersect(
313
33
            Metadata::new(),
314
33
            Moo::new(left),
315
33
            Moo::new(right),
316
33
        )),
317
        _ => Err(EssenceParseError::syntax_error(
318
            format!("Invalid operator: '{op_str}'"),
319
            Some(op_node.range()),
320
        )),
321
    }
322
854
}