conjure_oxide/utils/
essence_parser.rs

1#![allow(clippy::legacy_numeric_constants)]
2use conjure_core::ast::Declaration;
3use conjure_core::error::Error;
4use std::fs;
5use std::rc::Rc;
6use std::sync::{Arc, RwLock};
7use tree_sitter::{Node, Parser, Tree};
8use tree_sitter_essence::LANGUAGE;
9
10use conjure_core::ast::{Atom, Domain, Expression, Literal, Name, Range, SymbolTable};
11
12use crate::utils::conjure::EssenceParseError;
13use conjure_core::context::Context;
14use conjure_core::metadata::Metadata;
15use conjure_core::{into_matrix_expr, matrix_expr, Model};
16use std::collections::{BTreeMap, BTreeSet};
17
18pub fn parse_essence_file_native(
19    path: &str,
20    filename: &str,
21    extension: &str,
22    context: Arc<RwLock<Context<'static>>>,
23) -> Result<Model, EssenceParseError> {
24    let (tree, source_code) = get_tree(path, filename, extension);
25
26    let mut model = Model::new(context);
27    let root_node = tree.root_node();
28    for statement in named_children(&root_node) {
29        match statement.kind() {
30            "single_line_comment" => {}
31            "find_statement_list" => {
32                let var_hashmap = parse_find_statement(statement, &source_code);
33                for (name, decision_variable) in var_hashmap {
34                    model
35                        .as_submodel_mut()
36                        .symbols_mut()
37                        .insert(Rc::new(Declaration::new_var(name, decision_variable)));
38                }
39            }
40            "constraint_list" => {
41                let mut constraint_vec: Vec<Expression> = Vec::new();
42                for constraint in named_children(&statement) {
43                    if constraint.kind() != "single_line_comment" {
44                        constraint_vec.push(parse_constraint(constraint, &source_code, &statement));
45                    }
46                }
47                model.as_submodel_mut().add_constraints(constraint_vec);
48            }
49            "e_prime_label" => {}
50            "letting_statement_list" => {
51                let letting_vars = parse_letting_statement(statement, &source_code);
52                model.as_submodel_mut().symbols_mut().extend(letting_vars);
53            }
54            "dominance_relation" => {
55                let inner = statement
56                    .child(1)
57                    .expect("Expected a sub-expression inside `dominanceRelation`");
58                let expr = parse_constraint(inner, &source_code, &statement);
59                let dominance = Expression::DominanceRelation(Metadata::new(), Box::new(expr));
60                if model.dominance.is_some() {
61                    return Err(EssenceParseError::ParseError(Error::Parse(
62                        "Duplicate dominance relation".to_owned(),
63                    )));
64                }
65                model.dominance = Some(dominance);
66            }
67            _ => {
68                let kind = statement.kind();
69                return Err(EssenceParseError::ParseError(Error::Parse(
70                    format!("Unrecognized top level statement kind: {kind}").to_owned(),
71                )));
72            }
73        }
74    }
75    Ok(model)
76}
77
78fn get_tree(path: &str, filename: &str, extension: &str) -> (Tree, String) {
79    let pth = format!("{path}/{filename}.{extension}");
80    let source_code = fs::read_to_string(&pth)
81        .unwrap_or_else(|_| panic!("Failed to read the source code file {}", pth));
82    let mut parser = Parser::new();
83    parser.set_language(&LANGUAGE.into()).unwrap();
84    (
85        parser
86            .parse(source_code.clone(), None)
87            .expect("Failed to parse"),
88        source_code,
89    )
90}
91
92fn parse_find_statement(find_statement_list: Node, source_code: &str) -> BTreeMap<Name, Domain> {
93    let mut vars = BTreeMap::new();
94
95    for find_statement in named_children(&find_statement_list) {
96        let mut temp_symbols = BTreeSet::new();
97
98        let variable_list = find_statement
99            .named_child(0)
100            .expect("No variable list found");
101        for variable in named_children(&variable_list) {
102            let variable_name = &source_code[variable.start_byte()..variable.end_byte()];
103            temp_symbols.insert(variable_name);
104        }
105
106        let domain = find_statement.named_child(1).expect("No domain found");
107        let domain = parse_domain(domain, source_code);
108
109        for name in temp_symbols {
110            vars.insert(Name::UserName(String::from(name)), domain.clone());
111        }
112    }
113    vars
114}
115
116fn parse_domain(domain: Node, source_code: &str) -> Domain {
117    let domain = domain.child(0).expect("No domain");
118    match domain.kind() {
119        "bool_domain" => Domain::BoolDomain,
120        "int_domain" => parse_int_domain(domain, source_code),
121        "variable" => {
122            let variable_name = &source_code[domain.start_byte()..domain.end_byte()];
123            Domain::DomainReference(Name::UserName(String::from(variable_name)))
124        }
125        _ => panic!("Not bool or int domain"),
126    }
127}
128
129fn parse_int_domain(int_domain: Node, source_code: &str) -> Domain {
130    if int_domain.child_count() == 1 {
131        Domain::IntDomain(vec![Range::Bounded(std::i32::MIN, std::i32::MAX)])
132    } else {
133        let mut ranges: Vec<Range<i32>> = Vec::new();
134        let range_list = int_domain
135            .named_child(0)
136            .expect("No range list found (expression ranges not supported yet");
137        for int_range in named_children(&range_list) {
138            match int_range.kind() {
139                "integer" => {
140                    let integer_value = &source_code[int_range.start_byte()..int_range.end_byte()]
141                        .parse::<i32>()
142                        .unwrap();
143                    ranges.push(Range::Single(*integer_value));
144                }
145                "int_range" => {
146                    let lower_bound: Option<i32>;
147                    let upper_bound: Option<i32>;
148                    let range_component = int_range.child(0).expect("Error with integer range");
149                    match range_component.kind() {
150                        "expression" => {
151                            lower_bound = Some(
152                                source_code
153                                    [range_component.start_byte()..range_component.end_byte()]
154                                    .parse::<i32>()
155                                    .unwrap(),
156                            );
157
158                            if let Some(range_component) = range_component.next_named_sibling() {
159                                upper_bound = Some(
160                                    source_code
161                                        [range_component.start_byte()..range_component.end_byte()]
162                                        .parse::<i32>()
163                                        .unwrap(),
164                                );
165                            } else {
166                                upper_bound = None;
167                            }
168                        }
169                        ".." => {
170                            lower_bound = None;
171                            let range_component = range_component
172                                .next_sibling()
173                                .expect("Error with integer range");
174                            upper_bound = Some(
175                                source_code
176                                    [range_component.start_byte()..range_component.end_byte()]
177                                    .parse::<i32>()
178                                    .unwrap(),
179                            );
180                        }
181                        _ => panic!("unsupported int range type"),
182                    }
183
184                    match (lower_bound, upper_bound) {
185                        (Some(lb), Some(ub)) => ranges.push(Range::Bounded(lb, ub)),
186                        (Some(lb), None) => ranges.push(Range::Bounded(lb, std::i32::MAX)),
187                        (None, Some(ub)) => ranges.push(Range::Bounded(std::i32::MIN, ub)),
188                        _ => panic!("Unsupported int range type"),
189                    }
190                }
191                _ => panic!("unsupported int range type"),
192            }
193        }
194        Domain::IntDomain(ranges)
195    }
196}
197
198fn parse_letting_statement(letting_statement_list: Node, source_code: &str) -> SymbolTable {
199    let mut symbol_table = SymbolTable::new();
200
201    for letting_statement in named_children(&letting_statement_list) {
202        let mut temp_symbols = BTreeSet::new();
203
204        let variable_list = letting_statement.child(0).expect("No variable list found");
205        for variable in named_children(&variable_list) {
206            let variable_name = &source_code[variable.start_byte()..variable.end_byte()];
207            temp_symbols.insert(variable_name);
208        }
209
210        let expr_or_domain = letting_statement
211            .named_child(1)
212            .expect("No domain or expression found for letting statement");
213        match expr_or_domain.kind() {
214            "expression" => {
215                for name in temp_symbols {
216                    symbol_table.insert(Rc::new(Declaration::new_value_letting(
217                        Name::UserName(String::from(name)),
218                        parse_constraint(expr_or_domain, source_code, &letting_statement_list),
219                    )));
220                }
221            }
222            "domain" => {
223                for name in temp_symbols {
224                    symbol_table.insert(Rc::new(Declaration::new_domain_letting(
225                        Name::UserName(String::from(name)),
226                        parse_domain(expr_or_domain, source_code),
227                    )));
228                }
229            }
230            _ => panic!("Unrecognized node in letting statement"),
231        }
232    }
233    symbol_table
234}
235
236fn parse_constraint(constraint: Node, source_code: &str, root: &Node) -> Expression {
237    match constraint.kind() {
238        "constraint" | "expression" => child_expr(constraint, source_code, root),
239        "not_expr" => Expression::Not(
240            Metadata::new(),
241            Box::new(child_expr(constraint, source_code, root)),
242        ),
243        "abs_value" => Expression::Abs(
244            Metadata::new(),
245            Box::new(child_expr(constraint, source_code, root)),
246        ),
247        "negative_expr" => Expression::Neg(
248            Metadata::new(),
249            Box::new(child_expr(constraint, source_code, root)),
250        ),
251        "exponent" | "product_expr" | "sum_expr" | "comparison" | "and_expr" | "or_expr"
252        | "implication" => {
253            let expr1 = child_expr(constraint, source_code, root);
254            let op = constraint.child(1).unwrap_or_else(|| {
255                panic!(
256                    "Error: missing node in expression of kind {}",
257                    constraint.kind()
258                )
259            });
260            let op_type = &source_code[op.start_byte()..op.end_byte()];
261            let expr2_node = constraint.child(2).unwrap_or_else(|| {
262                panic!(
263                    "Error: missing node in expression of kind {}",
264                    constraint.kind()
265                )
266            });
267            let expr2 = parse_constraint(expr2_node, source_code, root);
268
269            match op_type {
270                "**" => Expression::UnsafePow(Metadata::new(), Box::new(expr1), Box::new(expr2)),
271                "+" => Expression::Sum(Metadata::new(), vec![expr1, expr2]),
272                "-" => Expression::Minus(Metadata::new(), Box::new(expr1), Box::new(expr2)),
273                "*" => Expression::Product(Metadata::new(), vec![expr1, expr2]),
274                "/" => {
275                    //TODO: add checks for if division is safe or not
276                    Expression::UnsafeDiv(Metadata::new(), Box::new(expr1), Box::new(expr2))
277                }
278                "%" => {
279                    //TODO: add checks for if mod is safe or not
280                    Expression::UnsafeMod(Metadata::new(), Box::new(expr1), Box::new(expr2))
281                }
282                "=" => Expression::Eq(Metadata::new(), Box::new(expr1), Box::new(expr2)),
283                "!=" => Expression::Neq(Metadata::new(), Box::new(expr1), Box::new(expr2)),
284                "<=" => Expression::Leq(Metadata::new(), Box::new(expr1), Box::new(expr2)),
285                ">=" => Expression::Geq(Metadata::new(), Box::new(expr1), Box::new(expr2)),
286                "<" => Expression::Lt(Metadata::new(), Box::new(expr1), Box::new(expr2)),
287                ">" => Expression::Gt(Metadata::new(), Box::new(expr1), Box::new(expr2)),
288                "/\\" => Expression::And(Metadata::new(), Box::new(matrix_expr![expr1, expr2])),
289                "\\/" => Expression::Or(Metadata::new(), Box::new(matrix_expr![expr1, expr2])),
290                "->" => Expression::Imply(Metadata::new(), Box::new(expr1), Box::new(expr2)),
291                _ => panic!("Error: unsupported operator"),
292            }
293        }
294        "quantifier_expr" => {
295            let mut expr_list = Vec::new();
296            for expr in named_children(&constraint) {
297                expr_list.push(parse_constraint(expr, source_code, root));
298            }
299
300            let quantifier = constraint.child(0).unwrap_or_else(|| {
301                panic!(
302                    "Error: missing node in expression of kind {}",
303                    constraint.kind()
304                )
305            });
306            let quantifier_type = &source_code[quantifier.start_byte()..quantifier.end_byte()];
307
308            match quantifier_type {
309                "and" => Expression::And(Metadata::new(), Box::new(into_matrix_expr![expr_list])),
310                "or" => Expression::Or(Metadata::new(), Box::new(into_matrix_expr![expr_list])),
311                "min" => Expression::Min(Metadata::new(), Box::new(into_matrix_expr![expr_list])),
312                "max" => Expression::Max(Metadata::new(), Box::new(into_matrix_expr![expr_list])),
313                "sum" => Expression::Sum(Metadata::new(), expr_list),
314                "allDiff" => {
315                    Expression::AllDiff(Metadata::new(), Box::new(into_matrix_expr![expr_list]))
316                }
317                _ => panic!("Error: unsupported quantifier"),
318            }
319        }
320        "constant" => {
321            let child = constraint.child(0).unwrap_or_else(|| {
322                panic!(
323                    "Error: missing node in expression of kind {}",
324                    constraint.kind()
325                )
326            });
327            match child.kind() {
328                "integer" => {
329                    let constant_value = &source_code[child.start_byte()..child.end_byte()]
330                        .parse::<i32>()
331                        .unwrap();
332                    Expression::Atomic(
333                        Metadata::new(),
334                        Atom::Literal(Literal::Int(*constant_value)),
335                    )
336                }
337                "TRUE" => Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Bool(true))),
338                "FALSE" => Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Bool(false))),
339                _ => panic!("Error"),
340            }
341        }
342        "variable" => {
343            let variable_name =
344                String::from(&source_code[constraint.start_byte()..constraint.end_byte()]);
345            Expression::Atomic(
346                Metadata::new(),
347                Atom::Reference(Name::UserName(variable_name)),
348            )
349        }
350        "from_solution" => match root.kind() {
351            "dominance_relation" => {
352                let inner = child_expr(constraint, source_code, root);
353                match inner {
354                    Expression::Atomic(_, _) => {
355                        Expression::FromSolution(Metadata::new(), Box::new(inner))
356                    }
357                    _ => panic!("Expression inside a `fromSolution()` must be a variable name"),
358                }
359            }
360            _ => panic!("`fromSolution()` is only allowed inside dominance relation definitions"),
361        },
362        _ => panic!("{} is not a recognized node kind", constraint.kind()),
363    }
364}
365
366fn named_children<'a>(node: &'a Node<'a>) -> impl Iterator<Item = Node<'a>> + 'a {
367    (0..node.named_child_count()).filter_map(|i| node.named_child(i))
368}
369
370fn child_expr(node: Node, source_code: &str, root: &Node) -> Expression {
371    let child = node
372        .named_child(0)
373        .unwrap_or_else(|| panic!("Error: missing node in expression of kind {}", node.kind()));
374    parse_constraint(child, source_code, root)
375}