1
use std::collections::BTreeMap;
2
use std::sync::{Arc, RwLock};
3
use std::{fs, vec};
4

            
5
use conjure_cp_core::Model;
6
use conjure_cp_core::ast::DeclarationPtr;
7
use conjure_cp_core::ast::assertions::debug_assert_model_well_formed;
8
use conjure_cp_core::context::Context;
9
#[allow(unused)]
10
use uniplate::Uniplate;
11

            
12
use super::ParseContext;
13
use super::find::{parse_find_statement, parse_given_statement};
14
use super::letting::parse_letting_statement;
15
use super::util::{TypecheckingContext, get_tree};
16
use crate::diagnostics::source_map::SourceMap;
17
use crate::errors::{FatalParseError, ParseErrorCollection, RecoverableParseError};
18
use crate::expression::parse_expression;
19
use crate::parser::keyword_checks::keyword_as_identifier;
20
use crate::syntax_errors::detect_syntactic_errors;
21
use tree_sitter::Tree;
22

            
23
/// Parse an Essence file into a Model using the tree-sitter parser.
24
18312
pub fn parse_essence_file_native(
25
18312
    path: &str,
26
18312
    context: Arc<RwLock<Context<'static>>>,
27
18312
) -> Result<Model, Box<ParseErrorCollection>> {
28
18312
    let source_code = fs::read_to_string(path)
29
18312
        .unwrap_or_else(|_| panic!("Failed to read the source code file {path}"));
30

            
31
18312
    let mut errors = vec![];
32
18312
    let model = parse_essence_with_context(&source_code, context, &mut errors);
33

            
34
18312
    match model {
35
17632
        Ok(Some(m)) => {
36
17632
            debug_assert_model_well_formed(&m, "tree-sitter");
37
17632
            Ok(m)
38
        }
39
        Ok(None) => {
40
            // Recoverable errors were found, return them as a ParseErrorCollection
41
680
            Err(Box::new(ParseErrorCollection::multiple(
42
680
                errors,
43
680
                Some(source_code),
44
680
                Some(path.to_string()),
45
680
            )))
46
        }
47
        Err(fatal) => {
48
            // Fatal error - wrap in ParseErrorCollection::Fatal
49
            Err(Box::new(ParseErrorCollection::fatal(fatal)))
50
        }
51
    }
52
18312
}
53

            
54
18312
pub fn parse_essence_with_context(
55
18312
    src: &str,
56
18312
    context: Arc<RwLock<Context<'static>>>,
57
18312
    errors: &mut Vec<RecoverableParseError>,
58
18312
) -> Result<Option<Model>, FatalParseError> {
59
18312
    match parse_essence_with_context_and_map(src, context, errors, None)? {
60
17632
        (Some(model), _source_map) => Ok(Some(model)),
61
680
        (None, _source_map) => Ok(None),
62
    }
63
18312
}
64

            
65
/*
66
    this function is used by both the file-based parser and the LSP parser (which needs the source map)
67
    the LSP parser can also optionally pass in a pre-parsed tree to avoid parsing twice (which is how caching is implemented)
68
    if the tree is not passed in, we will parse it from scratch (this is what the file-based parser does)
69
    when cache is dirty, LSP has to call parse_essence_with_context_and_map with None for the tree,
70
    which will cause it to re-parse the source code and update the cache (Model = ast, SorceMap = map)
71
*/
72
18785
pub fn parse_essence_with_context_and_map(
73
18785
    src: &str,
74
18785
    context: Arc<RwLock<Context<'static>>>,
75
18785
    errors: &mut Vec<RecoverableParseError>,
76
18785
    tree: Option<&Tree>,
77
18785
) -> Result<(Option<Model>, SourceMap), FatalParseError> {
78
18785
    let (tree, source_code) = if let Some(tree) = tree {
79
468
        (tree.clone(), src.to_string())
80
    } else {
81
18317
        match get_tree(src) {
82
18317
            Some(tree) => tree,
83
            None => {
84
                return Err(FatalParseError::TreeSitterError(
85
                    "Failed to parse source code".to_string(),
86
                ));
87
            }
88
        }
89
    };
90

            
91
18785
    let has_syntax_errors = tree.root_node().has_error();
92
18785
    if has_syntax_errors {
93
732
        detect_syntactic_errors(src, &tree, errors);
94
18053
    }
95

            
96
    // don't detect semantic errors if there are syntactic errors, but still parse for source map.
97
18785
    let mut suppressed_semantic_errors = Vec::new();
98
18785
    let semantic_errors: &mut Vec<RecoverableParseError> = if has_syntax_errors {
99
732
        &mut suppressed_semantic_errors
100
    } else {
101
18053
        errors
102
    };
103

            
104
18785
    keyword_as_identifier(tree.root_node(), src, semantic_errors);
105

            
106
18785
    let mut model = Model::new(context);
107
18785
    let mut source_map = SourceMap::default();
108
18785
    let mut declaration_spans = BTreeMap::new();
109
18785
    let root_node = tree.root_node();
110

            
111
    // Create a ParseContext
112
18785
    let mut ctx = ParseContext::new(
113
18785
        &source_code,
114
18785
        &root_node,
115
18785
        Some(model.symbols_ptr_unchecked().clone()),
116
18785
        semantic_errors,
117
18785
        &mut source_map,
118
18785
        &mut declaration_spans,
119
    );
120

            
121
18785
    let mut cursor = root_node.walk();
122
164718
    for statement in root_node.children(&mut cursor) {
123
164718
        if !statement.is_named() || statement.is_error() || statement.kind() == "ERROR" {
124
34937
            continue;
125
129781
        }
126

            
127
129781
        match statement.kind() {
128
129781
            "single_line_comment" => {}
129
86865
            "language_declaration" => {}
130
80501
            "find_statement" => {
131
37928
                let var_hashmap = parse_find_statement(&mut ctx, statement)?;
132
44668
                for (name, domain) in var_hashmap {
133
44668
                    model
134
44668
                        .symbols_mut()
135
44668
                        .insert(DeclarationPtr::new_find(name, domain));
136
44668
                }
137
            }
138
42573
            "given_statement" => {
139
142
                let var_hashmap = parse_given_statement(&mut ctx, statement)?;
140
142
                for (name, domain) in var_hashmap {
141
129
                    model
142
129
                        .symbols_mut()
143
129
                        .insert(DeclarationPtr::new_given(name, domain));
144
129
                }
145
            }
146
42431
            "bool_expr" | "atom" | "comparison_expr" => {
147
34361
                ctx.typechecking_context = TypecheckingContext::Boolean;
148
34361
                let Some(expr) = parse_expression(&mut ctx, statement)? else {
149
429
                    continue;
150
                };
151
33932
                model.add_constraint(expr);
152
            }
153
8070
            "language_label" => {}
154
8070
            "letting_statement" => {
155
2373
                let Some(letting_vars) = parse_letting_statement(&mut ctx, statement)? else {
156
                    continue;
157
                };
158
2373
                model.symbols_mut().extend(letting_vars);
159
            }
160
5697
            "dominance_relation" => {
161
5697
                let Some(dominance) = parse_expression(&mut ctx, statement)? else {
162
                    continue;
163
                };
164
5697
                if model.dominance.is_some() {
165
                    ctx.record_error(RecoverableParseError::new(
166
                        "Duplicate dominance relation".to_string(),
167
                        None,
168
                    ));
169
                    continue;
170
5697
                }
171
5697
                model.dominance = Some(dominance);
172
            }
173
            _ => {
174
                ctx.record_error(RecoverableParseError::new(
175
                    format!("Unexpected top-level statement: {}", statement.kind()),
176
                    Some(statement.range()),
177
                ));
178
                continue;
179
            }
180
        }
181
    }
182

            
183
    // Check if there were any recoverable errors
184
18785
    if !errors.is_empty() {
185
1135
        return Ok((None, source_map));
186
17650
    }
187
    // otherwise return the model
188
17650
    Ok((Some(model), source_map))
189
18785
}
190

            
191
5
pub fn parse_essence(src: &str) -> Result<(Model, SourceMap), Box<ParseErrorCollection>> {
192
5
    let context = Arc::new(RwLock::new(Context::default()));
193
5
    let mut errors = vec![];
194
5
    match parse_essence_with_context_and_map(src, context, &mut errors, None) {
195
5
        Ok((Some(model), source_map)) => {
196
5
            debug_assert_model_well_formed(&model, "tree-sitter");
197
5
            Ok((model, source_map))
198
        }
199
        Ok((None, _source_map)) => {
200
            // Recoverable errors were found, return them as a ParseErrorCollection
201
            Err(Box::new(ParseErrorCollection::multiple(
202
                errors,
203
                Some(src.to_string()),
204
                None,
205
            )))
206
        }
207
        Err(fatal) => Err(Box::new(ParseErrorCollection::fatal(fatal))),
208
    }
209
5
}
210

            
211
mod test {
212
    #[allow(unused_imports)]
213
    use crate::parse_essence;
214
    #[allow(unused_imports)]
215
    use conjure_cp_core::ast::{Atom, Expression, Metadata, Moo, Name};
216
    #[allow(unused_imports)]
217
    use conjure_cp_core::{domain_int, matrix_expr, range};
218
    #[allow(unused_imports)]
219
    use std::ops::Deref;
220

            
221
    #[test]
222
1
    pub fn test_parse_xyz() {
223
1
        let src = "
224
1
        find x, y, z : int(1..4)
225
1
        such that x + y + z = 4
226
1
        such that x >= y
227
1
        ";
228

            
229
1
        let (model, _source_map) = parse_essence(src).unwrap();
230

            
231
1
        let st = model.symbols();
232
1
        let x = st.lookup(&Name::user("x")).unwrap();
233
1
        let y = st.lookup(&Name::user("y")).unwrap();
234
1
        let z = st.lookup(&Name::user("z")).unwrap();
235
1
        assert_eq!(x.domain(), Some(domain_int!(1..4)));
236
1
        assert_eq!(y.domain(), Some(domain_int!(1..4)));
237
1
        assert_eq!(z.domain(), Some(domain_int!(1..4)));
238

            
239
1
        let constraints = model.constraints();
240
1
        assert_eq!(constraints.len(), 2);
241

            
242
1
        let c1 = constraints[0].clone();
243
1
        let x_e = Expression::Atomic(Metadata::new(), Atom::new_ref(x));
244
1
        let y_e = Expression::Atomic(Metadata::new(), Atom::new_ref(y));
245
1
        let z_e = Expression::Atomic(Metadata::new(), Atom::new_ref(z));
246
1
        assert_eq!(
247
            c1,
248
1
            Expression::Eq(
249
1
                Metadata::new(),
250
1
                Moo::new(Expression::Sum(
251
1
                    Metadata::new(),
252
1
                    Moo::new(matrix_expr!(
253
1
                        Expression::Sum(
254
1
                            Metadata::new(),
255
1
                            Moo::new(matrix_expr!(x_e.clone(), y_e.clone()))
256
1
                        ),
257
1
                        z_e
258
1
                    ))
259
1
                )),
260
1
                Moo::new(Expression::Atomic(Metadata::new(), 4.into()))
261
1
            )
262
        );
263

            
264
1
        let c2 = constraints[1].clone();
265
1
        assert_eq!(
266
            c2,
267
1
            Expression::Geq(Metadata::new(), Moo::new(x_e), Moo::new(y_e))
268
        );
269
1
    }
270

            
271
    #[test]
272
1
    pub fn test_parse_letting_index() {
273
1
        let src = "
274
1
        letting a be [ [ 1,2,3 ; int(1,2,4) ], [ 1,3,2 ; int(1,2,4) ], [ 3,2,1 ; int(1,2,4) ] ; int(-2..0) ]
275
1
        find b: int(1..5)
276
1
        such that
277
1
        b < a[-2,2],
278
1
        allDiff(a[-2,..])
279
1
        ";
280

            
281
1
        let (model, _source_map) = parse_essence(src).unwrap();
282
1
        let st = model.symbols();
283
1
        let a_decl = st.lookup(&Name::user("a")).unwrap();
284
1
        let a = a_decl.as_value_letting().unwrap().deref().clone();
285
1
        assert_eq!(
286
            a,
287
1
            matrix_expr!(
288
1
                matrix_expr!(1.into(), 2.into(), 3.into() ; domain_int!(1, 2, 4)),
289
1
                matrix_expr!(1.into(), 3.into(), 2.into() ; domain_int!(1, 2, 4)),
290
1
                matrix_expr!(3.into(), 2.into(), 1.into() ; domain_int!(1, 2, 4));
291
1
                domain_int!(-2..0)
292
            )
293
        )
294
1
    }
295

            
296
    #[test]
297
1
    pub fn test_parse_pareto_in_dominance_relation() {
298
1
        let src = "
299
1
        find x : int(0..3)
300
1

            
301
1
        dominance relation
302
1
            pareto(minimising x)
303
1
        ";
304

            
305
1
        let (model, _source_map) = parse_essence(src).unwrap();
306
1
        let st = model.symbols();
307
1
        let x = st.lookup(&Name::user("x")).unwrap();
308
1
        let x_e = Expression::Atomic(Metadata::new(), Atom::new_ref(x.clone()));
309
1
        let x_prev = Expression::FromSolution(Metadata::new(), Moo::new(Atom::new_ref(x)));
310

            
311
1
        assert_eq!(
312
            model.dominance,
313
1
            Some(Expression::DominanceRelation(
314
1
                Metadata::new(),
315
1
                Moo::new(Expression::And(
316
1
                    Metadata::new(),
317
1
                    Moo::new(matrix_expr!(
318
1
                        Expression::Leq(
319
1
                            Metadata::new(),
320
1
                            Moo::new(x_e.clone()),
321
1
                            Moo::new(x_prev.clone())
322
1
                        ),
323
1
                        Expression::Lt(Metadata::new(), Moo::new(x_e), Moo::new(x_prev))
324
1
                    ))
325
1
                ))
326
1
            ))
327
        );
328
1
    }
329

            
330
    #[test]
331
1
    pub fn test_parse_pareto_with_mixed_directions() {
332
1
        let src = "
333
1
        find x : int(0..3)
334
1
        find y : int(0..3)
335
1

            
336
1
        dominance relation
337
1
            pareto(minimising x, maximising y)
338
1
        ";
339

            
340
1
        let (model, _source_map) = parse_essence(src).unwrap();
341
1
        let st = model.symbols();
342
1
        let x = st.lookup(&Name::user("x")).unwrap();
343
1
        let y = st.lookup(&Name::user("y")).unwrap();
344
1
        let x_e = Expression::Atomic(Metadata::new(), Atom::new_ref(x.clone()));
345
1
        let y_e = Expression::Atomic(Metadata::new(), Atom::new_ref(y.clone()));
346
1
        let x_prev = Expression::FromSolution(Metadata::new(), Moo::new(Atom::new_ref(x)));
347
1
        let y_prev = Expression::FromSolution(Metadata::new(), Moo::new(Atom::new_ref(y)));
348

            
349
1
        assert_eq!(
350
            model.dominance,
351
1
            Some(Expression::DominanceRelation(
352
1
                Metadata::new(),
353
1
                Moo::new(Expression::And(
354
1
                    Metadata::new(),
355
1
                    Moo::new(matrix_expr!(
356
1
                        Expression::Leq(
357
1
                            Metadata::new(),
358
1
                            Moo::new(x_e.clone()),
359
1
                            Moo::new(x_prev.clone())
360
1
                        ),
361
1
                        Expression::Geq(
362
1
                            Metadata::new(),
363
1
                            Moo::new(y_e.clone()),
364
1
                            Moo::new(y_prev.clone())
365
1
                        ),
366
1
                        Expression::Or(
367
1
                            Metadata::new(),
368
1
                            Moo::new(matrix_expr!(
369
1
                                Expression::Lt(Metadata::new(), Moo::new(x_e), Moo::new(x_prev)),
370
1
                                Expression::Gt(Metadata::new(), Moo::new(y_e), Moo::new(y_prev))
371
1
                            ))
372
1
                        )
373
1
                    ))
374
1
                ))
375
1
            ))
376
        );
377
1
    }
378

            
379
    #[test]
380
1
    pub fn test_parse_pareto_over_expression_component() {
381
1
        let src = "
382
1
        find x : int(0..3)
383
1

            
384
1
        dominance relation
385
1
            pareto(minimising x + 1)
386
1
        ";
387

            
388
1
        let (model, _source_map) = parse_essence(src).unwrap();
389
1
        let st = model.symbols();
390
1
        let x = st.lookup(&Name::user("x")).unwrap();
391
1
        let x_e = Expression::Atomic(Metadata::new(), Atom::new_ref(x.clone()));
392
1
        let x_prev = Expression::FromSolution(Metadata::new(), Moo::new(Atom::new_ref(x)));
393
1
        let one = Expression::Atomic(Metadata::new(), 1.into());
394
1
        let current = Expression::Sum(
395
1
            Metadata::new(),
396
1
            Moo::new(matrix_expr!(x_e.clone(), one.clone())),
397
1
        );
398
1
        let previous = Expression::Sum(Metadata::new(), Moo::new(matrix_expr!(x_prev, one)));
399

            
400
1
        assert_eq!(
401
            model.dominance,
402
1
            Some(Expression::DominanceRelation(
403
1
                Metadata::new(),
404
1
                Moo::new(Expression::And(
405
1
                    Metadata::new(),
406
1
                    Moo::new(matrix_expr!(
407
1
                        Expression::Leq(
408
1
                            Metadata::new(),
409
1
                            Moo::new(current.clone()),
410
1
                            Moo::new(previous.clone())
411
1
                        ),
412
1
                        Expression::Lt(Metadata::new(), Moo::new(current), Moo::new(previous))
413
1
                    ))
414
1
                ))
415
1
            ))
416
        );
417
1
    }
418
}