1
#![allow(clippy::legacy_numeric_constants)]
2
use conjure_core::ast::declaration::Declaration;
3
use conjure_core::error::Error;
4
use std::fs;
5
use std::rc::Rc;
6
use std::sync::{Arc, RwLock};
7
use tree_sitter::{Node, Parser, Tree};
8
use tree_sitter_essence::LANGUAGE;
9

            
10
use conjure_core::ast::{Atom, Domain, Expression, Literal, Name, Range, SymbolTable};
11

            
12
use crate::utils::conjure::EssenceParseError;
13
use conjure_core::context::Context;
14
use conjure_core::metadata::Metadata;
15
use conjure_core::Model;
16
use std::collections::{BTreeMap, BTreeSet};
17

            
18
712
pub fn parse_essence_file_native(
19
712
    path: &str,
20
712
    filename: &str,
21
712
    extension: &str,
22
712
    context: Arc<RwLock<Context<'static>>>,
23
712
) -> Result<Model, EssenceParseError> {
24
712
    let (tree, source_code) = get_tree(path, filename, extension);
25
712

            
26
712
    let mut model = Model::new_empty(context);
27
712
    let root_node = tree.root_node();
28
2896
    for statement in named_children(&root_node) {
29
2896
        match statement.kind() {
30
2896
            "single_line_comment" => {}
31
2360
            "find_statement_list" => {
32
1152
                let var_hashmap = parse_find_statement(statement, &source_code);
33
3048
                for (name, decision_variable) in var_hashmap {
34
1896
                    model
35
1896
                        .symbols_mut()
36
1896
                        .insert(Rc::new(Declaration::new_var(name, decision_variable)));
37
1896
                }
38
            }
39
1208
            "constraint_list" => {
40
712
                let mut constraint_vec: Vec<Expression> = Vec::new();
41
1352
                for constraint in named_children(&statement) {
42
1352
                    if constraint.kind() != "single_line_comment" {
43
1080
                        constraint_vec.push(parse_constraint(constraint, &source_code));
44
1080
                    }
45
                }
46
712
                model.add_constraints(constraint_vec);
47
            }
48
496
            "e_prime_label" => {}
49
48
            "letting_statement_list" => {
50
48
                let letting_vars = parse_letting_statement(statement, &source_code);
51
48
                model.symbols_mut().extend(letting_vars);
52
48
            }
53
            _ => {
54
                let kind = statement.kind();
55
                return Err(EssenceParseError::ParseError(Error::Parse(
56
                    format!("Unrecognized top level statement kind: {kind}").to_owned(),
57
                )));
58
            }
59
        }
60
    }
61
712
    Ok(model)
62
712
}
63

            
64
712
fn get_tree(path: &str, filename: &str, extension: &str) -> (Tree, String) {
65
712
    let source_code = fs::read_to_string(format!("{path}/{filename}.{extension}"))
66
712
        .expect("Failed to read the source code file");
67
712
    let mut parser = Parser::new();
68
712
    parser.set_language(&LANGUAGE.into()).unwrap();
69
712
    (
70
712
        parser
71
712
            .parse(source_code.clone(), None)
72
712
            .expect("Failed to parse"),
73
712
        source_code,
74
712
    )
75
712
}
76

            
77
1152
fn parse_find_statement(find_statement_list: Node, source_code: &str) -> BTreeMap<Name, Domain> {
78
1152
    let mut vars = BTreeMap::new();
79

            
80
1176
    for find_statement in named_children(&find_statement_list) {
81
1176
        let mut temp_symbols = BTreeSet::new();
82
1176

            
83
1176
        let variable_list = find_statement
84
1176
            .named_child(0)
85
1176
            .expect("No variable list found");
86
1896
        for variable in named_children(&variable_list) {
87
1896
            let variable_name = &source_code[variable.start_byte()..variable.end_byte()];
88
1896
            temp_symbols.insert(variable_name);
89
1896
        }
90

            
91
1176
        let domain = find_statement.named_child(1).expect("No domain found");
92
1176
        let domain = parse_domain(domain, source_code);
93

            
94
3072
        for name in temp_symbols {
95
1896
            vars.insert(Name::UserName(String::from(name)), domain.clone());
96
1896
        }
97
    }
98
1152
    vars
99
1152
}
100

            
101
1192
fn parse_domain(domain: Node, source_code: &str) -> Domain {
102
1192
    let domain = domain.child(0).expect("No domain");
103
1192
    match domain.kind() {
104
1192
        "bool_domain" => Domain::BoolDomain,
105
1096
        "int_domain" => parse_int_domain(domain, source_code),
106
16
        "variable" => {
107
16
            let variable_name = &source_code[domain.start_byte()..domain.end_byte()];
108
16
            Domain::DomainReference(Name::UserName(String::from(variable_name)))
109
        }
110
        _ => panic!("Not bool or int domain"),
111
    }
112
1192
}
113

            
114
1080
fn parse_int_domain(int_domain: Node, source_code: &str) -> Domain {
115
1080
    if int_domain.child_count() == 1 {
116
        Domain::IntDomain(vec![Range::Bounded(std::i32::MIN, std::i32::MAX)])
117
    } else {
118
1080
        let mut ranges: Vec<Range<i32>> = Vec::new();
119
1080
        let range_list = int_domain
120
1080
            .named_child(0)
121
1080
            .expect("No range list found (expression ranges not supported yet");
122
1080
        for int_range in named_children(&range_list) {
123
1080
            match int_range.kind() {
124
1080
                "integer" => {
125
16
                    let integer_value = &source_code[int_range.start_byte()..int_range.end_byte()]
126
16
                        .parse::<i32>()
127
16
                        .unwrap();
128
16
                    ranges.push(Range::Single(*integer_value));
129
16
                }
130
1064
                "int_range" => {
131
                    let lower_bound: Option<i32>;
132
                    let upper_bound: Option<i32>;
133
1064
                    let range_component = int_range.child(0).expect("Error with integer range");
134
1064
                    match range_component.kind() {
135
1064
                        "expression" => {
136
1064
                            lower_bound = Some(
137
1064
                                source_code
138
1064
                                    [range_component.start_byte()..range_component.end_byte()]
139
1064
                                    .parse::<i32>()
140
1064
                                    .unwrap(),
141
1064
                            );
142

            
143
1064
                            if let Some(range_component) = range_component.next_named_sibling() {
144
1064
                                upper_bound = Some(
145
1064
                                    source_code
146
1064
                                        [range_component.start_byte()..range_component.end_byte()]
147
1064
                                        .parse::<i32>()
148
1064
                                        .unwrap(),
149
1064
                                );
150
1064
                            } else {
151
                                upper_bound = None;
152
                            }
153
                        }
154
                        ".." => {
155
                            lower_bound = None;
156
                            let range_component = range_component
157
                                .next_sibling()
158
                                .expect("Error with integer range");
159
                            upper_bound = Some(
160
                                source_code
161
                                    [range_component.start_byte()..range_component.end_byte()]
162
                                    .parse::<i32>()
163
                                    .unwrap(),
164
                            );
165
                        }
166
                        _ => panic!("unsupported int range type"),
167
                    }
168

            
169
1064
                    match (lower_bound, upper_bound) {
170
1064
                        (Some(lb), Some(ub)) => ranges.push(Range::Bounded(lb, ub)),
171
                        (Some(lb), None) => ranges.push(Range::Bounded(lb, std::i32::MAX)),
172
                        (None, Some(ub)) => ranges.push(Range::Bounded(std::i32::MIN, ub)),
173
                        _ => panic!("Unsupported int range type"),
174
                    }
175
                }
176
                _ => panic!("unsupported int range type"),
177
            }
178
        }
179
1080
        Domain::IntDomain(ranges)
180
    }
181
1080
}
182

            
183
48
fn parse_letting_statement(letting_statement_list: Node, source_code: &str) -> SymbolTable {
184
48
    let mut symbol_table = SymbolTable::new();
185

            
186
48
    for letting_statement in named_children(&letting_statement_list) {
187
48
        let mut temp_symbols = BTreeSet::new();
188
48

            
189
48
        let variable_list = letting_statement.child(0).expect("No variable list found");
190
48
        for variable in named_children(&variable_list) {
191
48
            let variable_name = &source_code[variable.start_byte()..variable.end_byte()];
192
48
            temp_symbols.insert(variable_name);
193
48
        }
194

            
195
48
        let expr_or_domain = letting_statement
196
48
            .named_child(1)
197
48
            .expect("No domain or expression found for letting statement");
198
48
        match expr_or_domain.kind() {
199
48
            "expression" => {
200
64
                for name in temp_symbols {
201
32
                    symbol_table.insert(Rc::new(Declaration::new_value_letting(
202
32
                        Name::UserName(String::from(name)),
203
32
                        parse_constraint(expr_or_domain, source_code),
204
32
                    )));
205
32
                }
206
            }
207
16
            "domain" => {
208
32
                for name in temp_symbols {
209
16
                    symbol_table.insert(Rc::new(Declaration::new_domain_letting(
210
16
                        Name::UserName(String::from(name)),
211
16
                        parse_domain(expr_or_domain, source_code),
212
16
                    )));
213
16
                }
214
            }
215
            _ => panic!("Unrecognized node in letting statement"),
216
        }
217
    }
218
48
    symbol_table
219
48
}
220

            
221
15768
fn parse_constraint(constraint: Node, source_code: &str) -> Expression {
222
15768
    match constraint.kind() {
223
15768
        "constraint" | "expression" => child_expr(constraint, source_code),
224
7416
        "not_expr" => Expression::Not(
225
152
            Metadata::new(),
226
152
            Box::new(child_expr(constraint, source_code)),
227
152
        ),
228
7264
        "abs_value" => Expression::Abs(
229
64
            Metadata::new(),
230
64
            Box::new(child_expr(constraint, source_code)),
231
64
        ),
232
7200
        "negative_expr" => Expression::Neg(
233
320
            Metadata::new(),
234
320
            Box::new(child_expr(constraint, source_code)),
235
320
        ),
236
6880
        "exponent" | "product_expr" | "sum_expr" | "comparison" | "and_expr" | "or_expr"
237
4432
        | "implication" => {
238
2752
            let expr1 = child_expr(constraint, source_code);
239
2752
            let op = constraint.child(1).unwrap_or_else(|| {
240
                panic!(
241
                    "Error: missing node in expression of kind {}",
242
                    constraint.kind()
243
                )
244
2752
            });
245
2752
            let op_type = &source_code[op.start_byte()..op.end_byte()];
246
2752
            let expr2_node = constraint.child(2).unwrap_or_else(|| {
247
                panic!(
248
                    "Error: missing node in expression of kind {}",
249
                    constraint.kind()
250
                )
251
2752
            });
252
2752
            let expr2 = parse_constraint(expr2_node, source_code);
253
2752

            
254
2752
            match op_type {
255
2752
                "**" => Expression::UnsafePow(Metadata::new(), Box::new(expr1), Box::new(expr2)),
256
2712
                "+" => Expression::Sum(Metadata::new(), vec![expr1, expr2]),
257
2320
                "-" => Expression::Minus(Metadata::new(), Box::new(expr1), Box::new(expr2)),
258
2288
                "*" => Expression::Product(Metadata::new(), vec![expr1, expr2]),
259
2120
                "/" => {
260
                    //TODO: add checks for if division is safe or not
261
296
                    Expression::UnsafeDiv(Metadata::new(), Box::new(expr1), Box::new(expr2))
262
                }
263
1824
                "%" => {
264
                    //TODO: add checks for if mod is safe or not
265
232
                    Expression::UnsafeMod(Metadata::new(), Box::new(expr1), Box::new(expr2))
266
                }
267
1592
                "=" => Expression::Eq(Metadata::new(), Box::new(expr1), Box::new(expr2)),
268
960
                "!=" => Expression::Neq(Metadata::new(), Box::new(expr1), Box::new(expr2)),
269
904
                "<=" => Expression::Leq(Metadata::new(), Box::new(expr1), Box::new(expr2)),
270
808
                ">=" => Expression::Geq(Metadata::new(), Box::new(expr1), Box::new(expr2)),
271
712
                "<" => Expression::Lt(Metadata::new(), Box::new(expr1), Box::new(expr2)),
272
544
                ">" => Expression::Gt(Metadata::new(), Box::new(expr1), Box::new(expr2)),
273
528
                "/\\" => Expression::And(Metadata::new(), vec![expr1, expr2]),
274
504
                "\\/" => Expression::Or(Metadata::new(), vec![expr1, expr2]),
275
304
                "->" => Expression::Imply(Metadata::new(), Box::new(expr1), Box::new(expr2)),
276
                _ => panic!("Error: unsupported operator"),
277
            }
278
        }
279
4128
        "quantifier_expr" => {
280
128
            let mut expr_list = Vec::new();
281
264
            for expr in named_children(&constraint) {
282
264
                expr_list.push(parse_constraint(expr, source_code));
283
264
            }
284

            
285
128
            let quantifier = constraint.child(0).unwrap_or_else(|| {
286
                panic!(
287
                    "Error: missing node in expression of kind {}",
288
                    constraint.kind()
289
                )
290
128
            });
291
128
            let quantifier_type = &source_code[quantifier.start_byte()..quantifier.end_byte()];
292
128

            
293
128
            match quantifier_type {
294
128
                "and" => Expression::And(Metadata::new(), expr_list),
295
128
                "or" => Expression::Or(Metadata::new(), expr_list),
296
120
                "min" => Expression::Min(Metadata::new(), expr_list),
297
64
                "max" => Expression::Max(Metadata::new(), expr_list),
298
40
                "sum" => Expression::Sum(Metadata::new(), expr_list),
299
16
                "allDiff" => Expression::AllDiff(Metadata::new(), expr_list),
300
                _ => panic!("Error: unsupported quantifier"),
301
            }
302
        }
303
4000
        "constant" => {
304
1584
            let child = constraint.child(0).unwrap_or_else(|| {
305
                panic!(
306
                    "Error: missing node in expression of kind {}",
307
                    constraint.kind()
308
                )
309
1584
            });
310
1584
            match child.kind() {
311
1584
                "integer" => {
312
1488
                    let constant_value = &source_code[child.start_byte()..child.end_byte()]
313
1488
                        .parse::<i32>()
314
1488
                        .unwrap();
315
1488
                    Expression::Atomic(
316
1488
                        Metadata::new(),
317
1488
                        Atom::Literal(Literal::Int(*constant_value)),
318
1488
                    )
319
                }
320
96
                "TRUE" => Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Bool(true))),
321
32
                "FALSE" => Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Bool(false))),
322
                _ => panic!("Error"),
323
            }
324
        }
325
2416
        "variable" => {
326
2416
            let variable_name =
327
2416
                String::from(&source_code[constraint.start_byte()..constraint.end_byte()]);
328
2416
            Expression::Atomic(
329
2416
                Metadata::new(),
330
2416
                Atom::Reference(Name::UserName(variable_name)),
331
2416
            )
332
        }
333
        _ => panic!("{} is not a recognized node kind", constraint.kind()),
334
    }
335
15768
}
336

            
337
5056
fn named_children<'a>(node: &'a Node<'a>) -> impl Iterator<Item = Node<'a>> + 'a {
338
8760
    (0..node.named_child_count()).filter_map(|i| node.named_child(i))
339
5056
}
340

            
341
11640
fn child_expr(node: Node, source_code: &str) -> Expression {
342
11640
    let child = node
343
11640
        .named_child(0)
344
11640
        .unwrap_or_else(|| panic!("Error: missing node in expression of kind {}", node.kind()));
345
11640
    parse_constraint(child, source_code)
346
11640
}