Skip to main content

conjure_cp_essence_parser/parser/
parse_model.rs

1use std::collections::BTreeMap;
2use std::sync::{Arc, RwLock};
3use std::{fs, vec};
4
5use conjure_cp_core::Model;
6use conjure_cp_core::ast::DeclarationPtr;
7use conjure_cp_core::ast::assertions::debug_assert_model_well_formed;
8use conjure_cp_core::context::Context;
9#[allow(unused)]
10use uniplate::Uniplate;
11
12use super::ParseContext;
13use super::dominance::parse_dominance_relation;
14use super::find::{parse_find_statement, parse_given_statement};
15use super::letting::parse_letting_statement;
16use super::util::{TypecheckingContext, get_tree};
17use crate::diagnostics::source_map::SourceMap;
18use crate::errors::{FatalParseError, ParseErrorCollection, RecoverableParseError};
19use crate::expression::parse_expression;
20use crate::parser::keyword_checks::keyword_as_identifier;
21use crate::syntax_errors::detect_syntactic_errors;
22use tree_sitter::Tree;
23
24/// Parse an Essence file into a Model using the tree-sitter parser.
25pub fn parse_essence_file_native(
26    path: &str,
27    context: Arc<RwLock<Context<'static>>>,
28) -> Result<Model, Box<ParseErrorCollection>> {
29    let source_code = fs::read_to_string(path)
30        .unwrap_or_else(|_| panic!("Failed to read the source code file {path}"));
31
32    let mut errors = vec![];
33    let model = parse_essence_with_context(&source_code, context, &mut errors);
34
35    match model {
36        Ok(Some(m)) => {
37            debug_assert_model_well_formed(&m, "tree-sitter");
38            Ok(m)
39        }
40        Ok(None) => {
41            // Recoverable errors were found, return them as a ParseErrorCollection
42            Err(Box::new(ParseErrorCollection::multiple(
43                errors,
44                Some(source_code),
45                Some(path.to_string()),
46            )))
47        }
48        Err(fatal) => {
49            // Fatal error - wrap in ParseErrorCollection::Fatal
50            Err(Box::new(ParseErrorCollection::fatal(fatal)))
51        }
52    }
53}
54
55pub fn parse_essence_with_context(
56    src: &str,
57    context: Arc<RwLock<Context<'static>>>,
58    errors: &mut Vec<RecoverableParseError>,
59) -> Result<Option<Model>, FatalParseError> {
60    match parse_essence_with_context_and_map(src, context, errors, None)? {
61        (Some(model), _source_map) => Ok(Some(model)),
62        (None, _source_map) => Ok(None),
63    }
64}
65
66/*
67    this function is used by both the file-based parser and the LSP parser (which needs the source map)
68    the LSP parser can also optionally pass in a pre-parsed tree to avoid parsing twice (which is how caching is implemented)
69    if the tree is not passed in, we will parse it from scratch (this is what the file-based parser does)
70    when cache is dirty, LSP has to call parse_essence_with_context_and_map with None for the tree,
71    which will cause it to re-parse the source code and update the cache (Model = ast, SorceMap = map)
72*/
73pub fn parse_essence_with_context_and_map(
74    src: &str,
75    context: Arc<RwLock<Context<'static>>>,
76    errors: &mut Vec<RecoverableParseError>,
77    tree: Option<&Tree>,
78) -> Result<(Option<Model>, SourceMap), FatalParseError> {
79    let (tree, source_code) = if let Some(tree) = tree {
80        (tree.clone(), src.to_string())
81    } else {
82        match get_tree(src) {
83            Some(tree) => tree,
84            None => {
85                return Err(FatalParseError::TreeSitterError(
86                    "Failed to parse source code".to_string(),
87                ));
88            }
89        }
90    };
91
92    let has_syntax_errors = tree.root_node().has_error();
93    if has_syntax_errors {
94        detect_syntactic_errors(src, &tree, errors);
95    }
96
97    // don't detect semantic errors if there are syntactic errors, but still parse for source map.
98    let mut suppressed_semantic_errors = Vec::new();
99    let semantic_errors: &mut Vec<RecoverableParseError> = if has_syntax_errors {
100        &mut suppressed_semantic_errors
101    } else {
102        errors
103    };
104
105    keyword_as_identifier(tree.root_node(), src, semantic_errors);
106
107    let mut model = Model::new(context);
108    let mut source_map = SourceMap::default();
109    let mut declaration_spans = BTreeMap::new();
110    let root_node = tree.root_node();
111
112    // Create a ParseContext
113    let mut ctx = ParseContext::new(
114        &source_code,
115        &root_node,
116        Some(model.symbols_ptr_unchecked().clone()),
117        semantic_errors,
118        &mut source_map,
119        &mut declaration_spans,
120    );
121
122    let mut cursor = root_node.walk();
123    for statement in root_node.children(&mut cursor) {
124        if !statement.is_named() || statement.is_error() || statement.kind() == "ERROR" {
125            continue;
126        }
127
128        ctx.typechecking_context = TypecheckingContext::Unknown;
129        ctx.inner_typechecking_context = TypecheckingContext::Unknown;
130
131        match statement.kind() {
132            "single_line_comment" => {}
133            "language_declaration" => {}
134            "find_statement" => {
135                let var_hashmap = parse_find_statement(&mut ctx, statement)?;
136                for (name, domain) in var_hashmap {
137                    model
138                        .symbols_mut()
139                        .insert(DeclarationPtr::new_find(name, domain));
140                }
141            }
142            "given_statement" => {
143                let var_hashmap = parse_given_statement(&mut ctx, statement)?;
144                for (name, domain) in var_hashmap {
145                    model
146                        .symbols_mut()
147                        .insert(DeclarationPtr::new_given(name, domain));
148                }
149            }
150            "bool_expr" | "atom" | "comparison_expr" => {
151                ctx.typechecking_context = TypecheckingContext::Boolean;
152                let Some(expr) = parse_expression(&mut ctx, statement)? else {
153                    continue;
154                };
155                model.add_constraint(expr);
156            }
157            "language_label" => {}
158            "letting_statement" => {
159                let Some(letting_vars) = parse_letting_statement(&mut ctx, statement)? else {
160                    continue;
161                };
162                model.symbols_mut().extend(letting_vars);
163            }
164            "dominance_relation" => {
165                let Some(dominance) = parse_dominance_relation(&mut ctx, &statement)? else {
166                    continue;
167                };
168                if model.dominance.is_some() {
169                    ctx.record_error(RecoverableParseError::new(
170                        "Duplicate dominance relation".to_string(),
171                        None,
172                    ));
173                    continue;
174                }
175                model.dominance = Some(dominance);
176            }
177            _ => {
178                ctx.record_error(RecoverableParseError::new(
179                    format!("Unexpected top-level statement: {}", statement.kind()),
180                    Some(statement.range()),
181                ));
182                continue;
183            }
184        }
185    }
186
187    // Check if there were any recoverable errors
188    if !errors.is_empty() {
189        return Ok((None, source_map));
190    }
191    // otherwise return the model
192    Ok((Some(model), source_map))
193}
194
195pub fn parse_essence(src: &str) -> Result<(Model, SourceMap), Box<ParseErrorCollection>> {
196    let context = Arc::new(RwLock::new(Context::default()));
197    let mut errors = vec![];
198    match parse_essence_with_context_and_map(src, context, &mut errors, None) {
199        Ok((Some(model), source_map)) => {
200            debug_assert_model_well_formed(&model, "tree-sitter");
201            Ok((model, source_map))
202        }
203        Ok((None, _source_map)) => {
204            // Recoverable errors were found, return them as a ParseErrorCollection
205            Err(Box::new(ParseErrorCollection::multiple(
206                errors,
207                Some(src.to_string()),
208                None,
209            )))
210        }
211        Err(fatal) => Err(Box::new(ParseErrorCollection::fatal(fatal))),
212    }
213}
214
215mod test {
216    #[allow(unused_imports)]
217    use crate::parse_essence;
218    #[allow(unused_imports)]
219    use conjure_cp_core::ast::{Atom, Expression, Metadata, Moo, Name};
220    #[allow(unused_imports)]
221    use conjure_cp_core::{domain_int, matrix_expr, range};
222    #[allow(unused_imports)]
223    use std::ops::Deref;
224
225    #[test]
226    pub fn test_parse_xyz() {
227        let src = "
228        find x, y, z : int(1..4)
229        such that x + y + z = 4
230        such that x >= y
231        ";
232
233        let (model, _source_map) = parse_essence(src).unwrap();
234
235        let st = model.symbols();
236        let x = st.lookup(&Name::user("x")).unwrap();
237        let y = st.lookup(&Name::user("y")).unwrap();
238        let z = st.lookup(&Name::user("z")).unwrap();
239        assert_eq!(x.domain(), Some(domain_int!(1..4)));
240        assert_eq!(y.domain(), Some(domain_int!(1..4)));
241        assert_eq!(z.domain(), Some(domain_int!(1..4)));
242
243        let constraints = model.constraints();
244        assert_eq!(constraints.len(), 2);
245
246        let c1 = constraints[0].clone();
247        let x_e = Expression::Atomic(Metadata::new(), Atom::new_ref(x));
248        let y_e = Expression::Atomic(Metadata::new(), Atom::new_ref(y));
249        let z_e = Expression::Atomic(Metadata::new(), Atom::new_ref(z));
250        assert_eq!(
251            c1,
252            Expression::Eq(
253                Metadata::new(),
254                Moo::new(Expression::Sum(
255                    Metadata::new(),
256                    Moo::new(matrix_expr!(
257                        Expression::Sum(
258                            Metadata::new(),
259                            Moo::new(matrix_expr!(x_e.clone(), y_e.clone()))
260                        ),
261                        z_e
262                    ))
263                )),
264                Moo::new(Expression::Atomic(Metadata::new(), 4.into()))
265            )
266        );
267
268        let c2 = constraints[1].clone();
269        assert_eq!(
270            c2,
271            Expression::Geq(Metadata::new(), Moo::new(x_e), Moo::new(y_e))
272        );
273    }
274
275    #[test]
276    pub fn test_parse_letting_index() {
277        let src = "
278        letting a be [ [ 1,2,3 ; int(1,2,4) ], [ 1,3,2 ; int(1,2,4) ], [ 3,2,1 ; int(1,2,4) ] ; int(-2..0) ]
279        find b: int(1..5)
280        such that
281        b < a[-2,2],
282        allDiff(a[-2,..])
283        ";
284
285        let (model, _source_map) = parse_essence(src).unwrap();
286        let st = model.symbols();
287        let a_decl = st.lookup(&Name::user("a")).unwrap();
288        let a = a_decl.as_value_letting().unwrap().deref().clone();
289        assert_eq!(
290            a,
291            matrix_expr!(
292                matrix_expr!(1.into(), 2.into(), 3.into() ; domain_int!(1, 2, 4)),
293                matrix_expr!(1.into(), 3.into(), 2.into() ; domain_int!(1, 2, 4)),
294                matrix_expr!(3.into(), 2.into(), 1.into() ; domain_int!(1, 2, 4));
295                domain_int!(-2..0)
296            )
297        )
298    }
299
300    #[test]
301    pub fn test_parse_pareto_in_dominance_relation() {
302        let src = "
303        find x : int(0..3)
304
305        dominance relation
306            pareto(minimising x)
307        ";
308
309        let (model, _source_map) = parse_essence(src).unwrap();
310        let st = model.symbols();
311        let x = st.lookup(&Name::user("x")).unwrap();
312        let x_e = Expression::Atomic(Metadata::new(), Atom::new_ref(x.clone()));
313        let x_prev = Expression::FromSolution(Metadata::new(), Moo::new(Atom::new_ref(x)));
314
315        assert_eq!(
316            model.dominance,
317            Some(Expression::DominanceRelation(
318                Metadata::new(),
319                Moo::new(Expression::And(
320                    Metadata::new(),
321                    Moo::new(matrix_expr!(
322                        Expression::Leq(
323                            Metadata::new(),
324                            Moo::new(x_e.clone()),
325                            Moo::new(x_prev.clone())
326                        ),
327                        Expression::Lt(Metadata::new(), Moo::new(x_e), Moo::new(x_prev))
328                    ))
329                ))
330            ))
331        );
332    }
333
334    #[test]
335    pub fn test_parse_pareto_with_mixed_directions() {
336        let src = "
337        find x : int(0..3)
338        find y : int(0..3)
339
340        dominance relation
341            pareto(minimising x, maximising y)
342        ";
343
344        let (model, _source_map) = parse_essence(src).unwrap();
345        let st = model.symbols();
346        let x = st.lookup(&Name::user("x")).unwrap();
347        let y = st.lookup(&Name::user("y")).unwrap();
348        let x_e = Expression::Atomic(Metadata::new(), Atom::new_ref(x.clone()));
349        let y_e = Expression::Atomic(Metadata::new(), Atom::new_ref(y.clone()));
350        let x_prev = Expression::FromSolution(Metadata::new(), Moo::new(Atom::new_ref(x)));
351        let y_prev = Expression::FromSolution(Metadata::new(), Moo::new(Atom::new_ref(y)));
352
353        assert_eq!(
354            model.dominance,
355            Some(Expression::DominanceRelation(
356                Metadata::new(),
357                Moo::new(Expression::And(
358                    Metadata::new(),
359                    Moo::new(matrix_expr!(
360                        Expression::Leq(
361                            Metadata::new(),
362                            Moo::new(x_e.clone()),
363                            Moo::new(x_prev.clone())
364                        ),
365                        Expression::Geq(
366                            Metadata::new(),
367                            Moo::new(y_e.clone()),
368                            Moo::new(y_prev.clone())
369                        ),
370                        Expression::Or(
371                            Metadata::new(),
372                            Moo::new(matrix_expr!(
373                                Expression::Lt(Metadata::new(), Moo::new(x_e), Moo::new(x_prev)),
374                                Expression::Gt(Metadata::new(), Moo::new(y_e), Moo::new(y_prev))
375                            ))
376                        )
377                    ))
378                ))
379            ))
380        );
381    }
382
383    #[test]
384    pub fn test_parse_pareto_over_expression_component() {
385        let src = "
386        find x : int(0..3)
387
388        dominance relation
389            pareto(minimising x + 1)
390        ";
391
392        let (model, _source_map) = parse_essence(src).unwrap();
393        let st = model.symbols();
394        let x = st.lookup(&Name::user("x")).unwrap();
395        let x_e = Expression::Atomic(Metadata::new(), Atom::new_ref(x.clone()));
396        let x_prev = Expression::FromSolution(Metadata::new(), Moo::new(Atom::new_ref(x)));
397        let one = Expression::Atomic(Metadata::new(), 1.into());
398        let current = Expression::Sum(
399            Metadata::new(),
400            Moo::new(matrix_expr!(x_e.clone(), one.clone())),
401        );
402        let previous = Expression::Sum(Metadata::new(), Moo::new(matrix_expr!(x_prev, one)));
403
404        assert_eq!(
405            model.dominance,
406            Some(Expression::DominanceRelation(
407                Metadata::new(),
408                Moo::new(Expression::And(
409                    Metadata::new(),
410                    Moo::new(matrix_expr!(
411                        Expression::Leq(
412                            Metadata::new(),
413                            Moo::new(current.clone()),
414                            Moo::new(previous.clone())
415                        ),
416                        Expression::Lt(Metadata::new(), Moo::new(current), Moo::new(previous))
417                    ))
418                ))
419            ))
420        );
421    }
422}