1
use crate::errors::EssenceParseError;
2
use crate::parser::atom::parse_atom;
3
use crate::parser::comprehension::parse_quantifier_or_aggregate_expr;
4
use crate::{field, named_child};
5
use conjure_cp_core::ast::{Expression, Metadata, Moo, SymbolTable};
6
use conjure_cp_core::{domain_int, matrix_expr, range};
7
use std::cell::RefCell;
8
use std::rc::Rc;
9
use tree_sitter::Node;
10

            
11
/// Parse an Essence expression into its Conjure AST representation.
12
613
pub fn parse_expression(
13
613
    node: Node,
14
613
    source_code: &str,
15
613
    root: &Node,
16
613
    symbols_ptr: Option<Rc<RefCell<SymbolTable>>>,
17
613
) -> Result<Expression, EssenceParseError> {
18
613
    match node.kind() {
19
613
        "atom" => parse_atom(&node, source_code, root, symbols_ptr),
20
588
        "bool_expr" => parse_boolean_expression(&node, source_code, root, symbols_ptr),
21
453
        "arithmetic_expr" => parse_arithmetic_expression(&node, source_code, root, symbols_ptr),
22
86
        "comparison_expr" => parse_binary_expression(&node, source_code, root, symbols_ptr),
23
        "dominance_relation" => parse_dominance_relation(&node, source_code, root, symbols_ptr),
24
        "ERROR" => Err(EssenceParseError::syntax_error(
25
            format!(
26
                "'{}' is not a valid expression",
27
                &source_code[node.start_byte()..node.end_byte()]
28
            ),
29
            Some(node.range()),
30
        )),
31
        _ => Err(EssenceParseError::syntax_error(
32
            format!("Unknown expression kind: '{}'", node.kind()),
33
            Some(node.range()),
34
        )),
35
    }
36
613
}
37

            
38
fn parse_dominance_relation(
39
    node: &Node,
40
    source_code: &str,
41
    root: &Node,
42
    symbols_ptr: Option<Rc<RefCell<SymbolTable>>>,
43
) -> Result<Expression, EssenceParseError> {
44
    if root.kind() == "dominance_relation" {
45
        return Err(EssenceParseError::syntax_error(
46
            "Nested dominance relations are not allowed".to_string(),
47
            Some(node.range()),
48
        ));
49
    }
50

            
51
    // NB: In all other cases, we keep the root the same;
52
    // However, here we set the new root to `node` so downstream functions
53
    // know we are inside a dominance relation
54
    let inner = parse_expression(field!(node, "expression"), source_code, node, symbols_ptr)?;
55
    Ok(Expression::DominanceRelation(
56
        Metadata::new(),
57
        Moo::new(inner),
58
    ))
59
}
60

            
61
367
fn parse_arithmetic_expression(
62
367
    node: &Node,
63
367
    source_code: &str,
64
367
    root: &Node,
65
367
    symbols_ptr: Option<Rc<RefCell<SymbolTable>>>,
66
367
) -> Result<Expression, EssenceParseError> {
67
367
    let inner = named_child!(node);
68
367
    match inner.kind() {
69
367
        "atom" => parse_atom(&inner, source_code, root, symbols_ptr),
70
76
        "negative_expr" | "abs_value" | "sub_arith_expr" | "toInt_expr" => {
71
35
            parse_unary_expression(&inner, source_code, root, symbols_ptr)
72
        }
73
41
        "exponent" | "product_expr" | "sum_expr" => {
74
41
            parse_binary_expression(&inner, source_code, root, symbols_ptr)
75
        }
76
        "list_combining_expr_arith" => {
77
            parse_list_combining_expression(&inner, source_code, root, symbols_ptr)
78
        }
79
        "aggregate_expr" => {
80
            parse_quantifier_or_aggregate_expr(&inner, source_code, root, symbols_ptr)
81
        }
82
        _ => Err(EssenceParseError::syntax_error(
83
            format!("Expected arithmetic expression, found: {}", inner.kind()),
84
            Some(inner.range()),
85
        )),
86
    }
87
367
}
88

            
89
135
fn parse_boolean_expression(
90
135
    node: &Node,
91
135
    source_code: &str,
92
135
    root: &Node,
93
135
    symbols_ptr: Option<Rc<RefCell<SymbolTable>>>,
94
135
) -> Result<Expression, EssenceParseError> {
95
135
    let inner = named_child!(node);
96
135
    match inner.kind() {
97
135
        "atom" => parse_atom(&inner, source_code, root, symbols_ptr),
98
85
        "not_expr" | "sub_bool_expr" => {
99
40
            parse_unary_expression(&inner, source_code, root, symbols_ptr)
100
        }
101
45
        "and_expr" | "or_expr" | "implication" | "iff_expr" | "set_operation_bool" => {
102
45
            parse_binary_expression(&inner, source_code, root, symbols_ptr)
103
        }
104
        "list_combining_expr_bool" => {
105
            parse_list_combining_expression(&inner, source_code, root, symbols_ptr)
106
        }
107
        "quantifier_expr" => {
108
            parse_quantifier_or_aggregate_expr(&inner, source_code, root, symbols_ptr)
109
        }
110
        _ => Err(EssenceParseError::syntax_error(
111
            format!("Expected boolean expression, found '{}'", inner.kind()),
112
            Some(inner.range()),
113
        )),
114
    }
115
135
}
116

            
117
fn parse_list_combining_expression(
118
    node: &Node,
119
    source_code: &str,
120
    root: &Node,
121
    symbols_ptr: Option<Rc<RefCell<SymbolTable>>>,
122
) -> Result<Expression, EssenceParseError> {
123
    let operator_node = field!(node, "operator");
124
    let operator_str = &source_code[operator_node.start_byte()..operator_node.end_byte()];
125

            
126
    let inner = parse_atom(&field!(node, "arg"), source_code, root, symbols_ptr)?;
127

            
128
    match operator_str {
129
        "and" => Ok(Expression::And(Metadata::new(), Moo::new(inner))),
130
        "or" => Ok(Expression::Or(Metadata::new(), Moo::new(inner))),
131
        "sum" => Ok(Expression::Sum(Metadata::new(), Moo::new(inner))),
132
        "product" => Ok(Expression::Product(Metadata::new(), Moo::new(inner))),
133
        "min" => Ok(Expression::Min(Metadata::new(), Moo::new(inner))),
134
        "max" => Ok(Expression::Max(Metadata::new(), Moo::new(inner))),
135
        "allDiff" => Ok(Expression::AllDiff(Metadata::new(), Moo::new(inner))),
136
        _ => Err(EssenceParseError::syntax_error(
137
            format!("Invalid operator: '{operator_str}'"),
138
            Some(operator_node.range()),
139
        )),
140
    }
141
}
142

            
143
75
fn parse_unary_expression(
144
75
    node: &Node,
145
75
    source_code: &str,
146
75
    root: &Node,
147
75
    symbols_ptr: Option<Rc<RefCell<SymbolTable>>>,
148
75
) -> Result<Expression, EssenceParseError> {
149
75
    let inner = parse_expression(field!(node, "expression"), source_code, root, symbols_ptr)?;
150
75
    match node.kind() {
151
75
        "negative_expr" => Ok(Expression::Neg(Metadata::new(), Moo::new(inner))),
152
65
        "abs_value" => Ok(Expression::Abs(Metadata::new(), Moo::new(inner))),
153
65
        "not_expr" => Ok(Expression::Not(Metadata::new(), Moo::new(inner))),
154
40
        "toInt_expr" => Ok(Expression::ToInt(Metadata::new(), Moo::new(inner))),
155
40
        "sub_bool_expr" | "sub_arith_expr" => Ok(inner),
156
        _ => Err(EssenceParseError::syntax_error(
157
            format!("Unrecognised unary operation: '{}'", node.kind()),
158
            Some(node.range()),
159
        )),
160
    }
161
75
}
162

            
163
172
pub fn parse_binary_expression(
164
172
    node: &Node,
165
172
    source_code: &str,
166
172
    root: &Node,
167
172
    symbols_ptr: Option<Rc<RefCell<SymbolTable>>>,
168
172
) -> Result<Expression, EssenceParseError> {
169
344
    let parse_subexpr = |expr: Node| parse_expression(expr, source_code, root, symbols_ptr.clone());
170

            
171
172
    let left = parse_subexpr(field!(node, "left"))?;
172
172
    let right = parse_subexpr(field!(node, "right"))?;
173

            
174
172
    let op_node = field!(node, "operator");
175
172
    let op_str = &source_code[op_node.start_byte()..op_node.end_byte()];
176

            
177
172
    match op_str {
178
        // NB: We are deliberately setting the index domain to 1.., not 1..2.
179
        // Semantically, this means "a list that can grow/shrink arbitrarily".
180
        // This is expected by rules which will modify the terms of the sum expression
181
        // (e.g. by partially evaluating them).
182
172
        "+" => Ok(Expression::Sum(
183
20
            Metadata::new(),
184
20
            Moo::new(matrix_expr![left, right; domain_int!(1..)]),
185
20
        )),
186
152
        "-" => Ok(Expression::Minus(
187
10
            Metadata::new(),
188
10
            Moo::new(left),
189
10
            Moo::new(right),
190
10
        )),
191
142
        "*" => Ok(Expression::Product(
192
10
            Metadata::new(),
193
10
            Moo::new(matrix_expr![left, right; domain_int!(1..)]),
194
10
        )),
195
132
        "/\\" => Ok(Expression::And(
196
20
            Metadata::new(),
197
20
            Moo::new(matrix_expr![left, right; domain_int!(1..)]),
198
20
        )),
199
112
        "\\/" => Ok(Expression::Or(
200
5
            Metadata::new(),
201
5
            Moo::new(matrix_expr![left, right; domain_int!(1..)]),
202
5
        )),
203
107
        "**" => Ok(Expression::UnsafePow(
204
            Metadata::new(),
205
            Moo::new(left),
206
            Moo::new(right),
207
        )),
208
107
        "/" => {
209
            //TODO: add checks for if division is safe or not
210
1
            Ok(Expression::UnsafeDiv(
211
1
                Metadata::new(),
212
1
                Moo::new(left),
213
1
                Moo::new(right),
214
1
            ))
215
        }
216
106
        "%" => {
217
            //TODO: add checks for if mod is safe or not
218
            Ok(Expression::UnsafeMod(
219
                Metadata::new(),
220
                Moo::new(left),
221
                Moo::new(right),
222
            ))
223
        }
224
106
        "=" => Ok(Expression::Eq(
225
31
            Metadata::new(),
226
31
            Moo::new(left),
227
31
            Moo::new(right),
228
31
        )),
229
75
        "!=" => Ok(Expression::Neq(
230
10
            Metadata::new(),
231
10
            Moo::new(left),
232
10
            Moo::new(right),
233
10
        )),
234
65
        "<=" => Ok(Expression::Leq(
235
20
            Metadata::new(),
236
20
            Moo::new(left),
237
20
            Moo::new(right),
238
20
        )),
239
45
        ">=" => Ok(Expression::Geq(
240
20
            Metadata::new(),
241
20
            Moo::new(left),
242
20
            Moo::new(right),
243
20
        )),
244
25
        "<" => Ok(Expression::Lt(
245
5
            Metadata::new(),
246
5
            Moo::new(left),
247
5
            Moo::new(right),
248
5
        )),
249
20
        ">" => Ok(Expression::Gt(
250
            Metadata::new(),
251
            Moo::new(left),
252
            Moo::new(right),
253
        )),
254
20
        "->" => Ok(Expression::Imply(
255
20
            Metadata::new(),
256
20
            Moo::new(left),
257
20
            Moo::new(right),
258
20
        )),
259
        "<->" => Ok(Expression::Iff(
260
            Metadata::new(),
261
            Moo::new(left),
262
            Moo::new(right),
263
        )),
264
        "<lex" => Ok(Expression::LexLt(
265
            Metadata::new(),
266
            Moo::new(left),
267
            Moo::new(right),
268
        )),
269
        ">lex" => Ok(Expression::LexGt(
270
            Metadata::new(),
271
            Moo::new(left),
272
            Moo::new(right),
273
        )),
274
        "<=lex" => Ok(Expression::LexLeq(
275
            Metadata::new(),
276
            Moo::new(left),
277
            Moo::new(right),
278
        )),
279
        ">=lex" => Ok(Expression::LexGeq(
280
            Metadata::new(),
281
            Moo::new(left),
282
            Moo::new(right),
283
        )),
284
        "in" => Ok(Expression::In(
285
            Metadata::new(),
286
            Moo::new(left),
287
            Moo::new(right),
288
        )),
289
        "subset" => Ok(Expression::Subset(
290
            Metadata::new(),
291
            Moo::new(left),
292
            Moo::new(right),
293
        )),
294
        "subsetEq" => Ok(Expression::SubsetEq(
295
            Metadata::new(),
296
            Moo::new(left),
297
            Moo::new(right),
298
        )),
299
        "supset" => Ok(Expression::Supset(
300
            Metadata::new(),
301
            Moo::new(left),
302
            Moo::new(right),
303
        )),
304
        "supsetEq" => Ok(Expression::SupsetEq(
305
            Metadata::new(),
306
            Moo::new(left),
307
            Moo::new(right),
308
        )),
309
        "union" => Ok(Expression::Union(
310
            Metadata::new(),
311
            Moo::new(left),
312
            Moo::new(right),
313
        )),
314
        "intersect" => Ok(Expression::Intersect(
315
            Metadata::new(),
316
            Moo::new(left),
317
            Moo::new(right),
318
        )),
319
        _ => Err(EssenceParseError::syntax_error(
320
            format!("Invalid operator: '{op_str}'"),
321
            Some(op_node.range()),
322
        )),
323
    }
324
172
}