Skip to main content

conjure_cp_essence_parser/parser/
parse_model.rs

1use std::sync::{Arc, RwLock};
2use std::{fs, vec};
3
4use conjure_cp_core::Model;
5use conjure_cp_core::ast::{DeclarationPtr, Expression, Metadata, Moo};
6use conjure_cp_core::context::Context;
7#[allow(unused)]
8use uniplate::Uniplate;
9
10use super::find::parse_find_statement;
11use super::letting::parse_letting_statement;
12use super::util::{get_tree, named_children};
13use crate::errors::{FatalParseError, ParseErrorCollection, RecoverableParseError};
14use crate::expression::parse_expression;
15
16/// Parse an Essence file into a Model using the tree-sitter parser.
17pub fn parse_essence_file_native(
18    path: &str,
19    context: Arc<RwLock<Context<'static>>>,
20) -> Result<Model, Box<ParseErrorCollection>> {
21    let source_code = fs::read_to_string(path)
22        .unwrap_or_else(|_| panic!("Failed to read the source code file {path}"));
23
24    let mut errors = vec![];
25    let model = parse_essence_with_context(&source_code, context, &mut errors);
26
27    match model {
28        Ok(m) => {
29            // Check if there were any recoverable errors
30            if !errors.is_empty() {
31                return Err(Box::new(ParseErrorCollection::multiple(
32                    errors,
33                    Some(source_code),
34                    Some(path.to_string()),
35                )));
36            }
37            // Return model if no errors
38            Ok(m)
39        }
40        Err(fatal) => {
41            // Fatal error - wrap in ParseErrorCollection::Fatal
42            Err(Box::new(ParseErrorCollection::fatal(fatal)))
43        }
44    }
45}
46
47pub fn parse_essence_with_context(
48    src: &str,
49    context: Arc<RwLock<Context<'static>>>,
50    errors: &mut Vec<RecoverableParseError>,
51) -> Result<Model, FatalParseError> {
52    let (tree, source_code) = match get_tree(src) {
53        Some(tree) => tree,
54        None => {
55            return Err(FatalParseError::TreeSitterError(
56                "Failed to parse source code".to_string(),
57            ));
58        }
59    };
60
61    let mut model = Model::new(context);
62    let root_node = tree.root_node();
63    let symbols_ptr = model.as_submodel().symbols_ptr_unchecked().clone();
64    for statement in named_children(&root_node) {
65        match statement.kind() {
66            "single_line_comment" => {}
67            "language_declaration" => {}
68            "find_statement" => {
69                let var_hashmap = parse_find_statement(
70                    statement,
71                    &source_code,
72                    Some(symbols_ptr.clone()),
73                    errors,
74                )?;
75                for (name, domain) in var_hashmap {
76                    model
77                        .as_submodel_mut()
78                        .symbols_mut()
79                        .insert(DeclarationPtr::new_find(name, domain));
80                }
81            }
82            "bool_expr" | "atom" | "comparison_expr" => {
83                model.as_submodel_mut().add_constraint(parse_expression(
84                    statement,
85                    &source_code,
86                    &statement,
87                    Some(symbols_ptr.clone()),
88                    errors,
89                )?);
90            }
91            "language_label" => {}
92            "letting_statement" => {
93                let letting_vars = parse_letting_statement(
94                    statement,
95                    &source_code,
96                    Some(symbols_ptr.clone()),
97                    errors,
98                )?;
99                model.as_submodel_mut().symbols_mut().extend(letting_vars);
100            }
101            "dominance_relation" => {
102                let inner = statement
103                    .child_by_field_name("expression")
104                    .expect("Expected a sub-expression inside `dominanceRelation`");
105                let expr = parse_expression(
106                    inner,
107                    &source_code,
108                    &statement,
109                    Some(symbols_ptr.clone()),
110                    errors,
111                )?;
112                let dominance = Expression::DominanceRelation(Metadata::new(), Moo::new(expr));
113                if model.dominance.is_some() {
114                    errors.push(RecoverableParseError::new(
115                        "Duplicate dominance relation".to_string(),
116                        None,
117                    ));
118                    continue;
119                }
120                model.dominance = Some(dominance);
121            }
122            "ERROR" => {
123                let raw_expr = &source_code[statement.start_byte()..statement.end_byte()];
124                errors.push(RecoverableParseError::new(
125                    format!("'{raw_expr}' is not a valid expression"),
126                    Some(statement.range()),
127                ));
128            }
129            _ => {
130                let kind = statement.kind();
131                errors.push(RecoverableParseError::new(
132                    format!("Unrecognized top level statement kind: {kind}"),
133                    Some(statement.range()),
134                ));
135            }
136        }
137
138        // check for errors (keyword as identifier)
139        keyword_as_identifier(root_node, &source_code, errors);
140    }
141    Ok(model)
142}
143
144const KEYWORDS: [&str; 21] = [
145    "forall", "exists", "such", "that", "letting", "find", "minimise", "maximise", "subject", "to",
146    "where", "and", "or", "not", "if", "then", "else", "in", "sum", "product", "bool",
147];
148
149fn keyword_as_identifier(
150    root: tree_sitter::Node,
151    src: &str,
152    errors: &mut Vec<RecoverableParseError>,
153) {
154    let mut stack = vec![root];
155    while let Some(node) = stack.pop() {
156        if (node.kind() == "variable" || node.kind() == "identifier" || node.kind() == "parameter")
157            && let Ok(text) = node.utf8_text(src.as_bytes())
158        {
159            let ident = text.trim();
160            if KEYWORDS.contains(&ident) {
161                let start_point = node.start_position();
162                let end_point = node.end_position();
163                errors.push(RecoverableParseError::new(
164                    format!("Keyword '{ident}' used as identifier"),
165                    Some(tree_sitter::Range {
166                        start_byte: node.start_byte(),
167                        end_byte: node.end_byte(),
168                        start_point,
169                        end_point,
170                    }),
171                ));
172            }
173        }
174
175        // push children onto stack
176        for i in 0..node.child_count() {
177            if let Some(child) = u32::try_from(i).ok().and_then(|i| node.child(i)) {
178                stack.push(child);
179            }
180        }
181    }
182}
183
184pub fn parse_essence(src: &str) -> Result<Model, Box<ParseErrorCollection>> {
185    let context = Arc::new(RwLock::new(Context::default()));
186    let mut errors = vec![];
187    match parse_essence_with_context(src, context, &mut errors) {
188        Ok(model) => {
189            if !errors.is_empty() {
190                Err(Box::new(ParseErrorCollection::multiple(
191                    errors,
192                    Some(src.to_string()),
193                    None,
194                )))
195            } else {
196                Ok(model)
197            }
198        }
199        Err(fatal) => Err(Box::new(ParseErrorCollection::fatal(fatal))),
200    }
201}
202
203mod test {
204    #[allow(unused_imports)]
205    use crate::parse_essence;
206    #[allow(unused_imports)]
207    use conjure_cp_core::ast::{Atom, Expression, Metadata, Moo, Name};
208    #[allow(unused_imports)]
209    use conjure_cp_core::{domain_int, matrix_expr, range};
210    #[allow(unused_imports)]
211    use std::ops::Deref;
212
213    #[test]
214    pub fn test_parse_xyz() {
215        let src = "
216        find x, y, z : int(1..4)
217        such that x + y + z = 4
218        such that x >= y
219        ";
220
221        let model = parse_essence(src).unwrap();
222
223        let st = model.as_submodel().symbols();
224        let x = st.lookup(&Name::user("x")).unwrap();
225        let y = st.lookup(&Name::user("y")).unwrap();
226        let z = st.lookup(&Name::user("z")).unwrap();
227        assert_eq!(x.domain(), Some(domain_int!(1..4)));
228        assert_eq!(y.domain(), Some(domain_int!(1..4)));
229        assert_eq!(z.domain(), Some(domain_int!(1..4)));
230
231        let constraints = model.as_submodel().constraints();
232        assert_eq!(constraints.len(), 2);
233
234        let c1 = constraints[0].clone();
235        let x_e = Expression::Atomic(Metadata::new(), Atom::new_ref(x));
236        let y_e = Expression::Atomic(Metadata::new(), Atom::new_ref(y));
237        let z_e = Expression::Atomic(Metadata::new(), Atom::new_ref(z));
238        assert_eq!(
239            c1,
240            Expression::Eq(
241                Metadata::new(),
242                Moo::new(Expression::Sum(
243                    Metadata::new(),
244                    Moo::new(matrix_expr!(
245                        Expression::Sum(
246                            Metadata::new(),
247                            Moo::new(matrix_expr!(x_e.clone(), y_e.clone()))
248                        ),
249                        z_e
250                    ))
251                )),
252                Moo::new(Expression::Atomic(Metadata::new(), 4.into()))
253            )
254        );
255
256        let c2 = constraints[1].clone();
257        assert_eq!(
258            c2,
259            Expression::Geq(Metadata::new(), Moo::new(x_e), Moo::new(y_e))
260        );
261    }
262
263    #[test]
264    pub fn test_parse_letting_index() {
265        let src = "
266        letting a be [ [ 1,2,3 ; int(1,2,4) ], [ 1,3,2 ; int(1,2,4) ], [ 3,2,1 ; int(1,2,4) ] ; int(-2..0) ]
267        find b: int(1..5)
268        such that
269        b < a[-2,2],
270        allDiff(a[-2,..])
271        ";
272
273        let model = parse_essence(src).unwrap();
274        let st = model.as_submodel().symbols();
275        let a_decl = st.lookup(&Name::user("a")).unwrap();
276        let a = a_decl.as_value_letting().unwrap().deref().clone();
277        assert_eq!(
278            a,
279            matrix_expr!(
280                matrix_expr!(1.into(), 2.into(), 3.into() ; domain_int!(1, 2, 4)),
281                matrix_expr!(1.into(), 3.into(), 2.into() ; domain_int!(1, 2, 4)),
282                matrix_expr!(3.into(), 2.into(), 1.into() ; domain_int!(1, 2, 4));
283                domain_int!(-2..0)
284            )
285        )
286    }
287}