1
use std::sync::{Arc, RwLock};
2
use std::{fs, vec};
3

            
4
use conjure_cp_core::Model;
5
use conjure_cp_core::ast::{DeclarationPtr, Expression, Metadata, Moo};
6
use conjure_cp_core::context::Context;
7
#[allow(unused)]
8
use uniplate::Uniplate;
9

            
10
use super::find::parse_find_statement;
11
use super::letting::parse_letting_statement;
12
use super::util::{get_tree, named_children};
13
use crate::errors::{FatalParseError, ParseErrorCollection, RecoverableParseError};
14
use crate::expression::parse_expression;
15
use crate::syntax_errors::detect_syntactic_errors;
16

            
17
/// Parse an Essence file into a Model using the tree-sitter parser.
18
344
pub fn parse_essence_file_native(
19
344
    path: &str,
20
344
    context: Arc<RwLock<Context<'static>>>,
21
344
) -> Result<Model, Box<ParseErrorCollection>> {
22
615
    let source_code = fs::read_to_string(path)
23
615
        .unwrap_or_else(|_| panic!("Failed to read the source code file {path}"));
24
271

            
25
615
    let mut errors = vec![];
26
615
    let model = parse_essence_with_context(&source_code, context, &mut errors);
27
271

            
28
344
    match model {
29
615
        Ok(m) => {
30
            // Check if there were any recoverable errors
31
344
            if !errors.is_empty() {
32
319
                return Err(Box::new(ParseErrorCollection::multiple(
33
317
                    errors,
34
48
                    Some(source_code),
35
48
                    Some(path.to_string()),
36
50
                )));
37
298
            }
38
            // Return model if no errors
39
298
            Ok(m)
40
2
        }
41
        Err(fatal) => {
42
            // Fatal error - wrap in ParseErrorCollection::Fatal
43
            Err(Box::new(ParseErrorCollection::fatal(fatal)))
44
        }
45
    }
46
344
}
47
271

            
48
731
pub fn parse_essence_with_context(
49
1387
    src: &str,
50
1387
    context: Arc<RwLock<Context<'static>>>,
51
1387
    errors: &mut Vec<RecoverableParseError>,
52
1387
) -> Result<Model, FatalParseError> {
53
1387
    let (tree, source_code) = match get_tree(src) {
54
1387
        Some(tree) => tree,
55
280
        None => {
56
376
            return Err(FatalParseError::TreeSitterError(
57
                "Failed to parse source code".to_string(),
58
656
            ));
59
        }
60
658
    };
61
658

            
62
1389
    if tree.root_node().has_error() {
63
1025
        detect_syntactic_errors(src, &tree, errors);
64
1025
        return Ok(Model::new(context));
65
1022
    }
66
658

            
67
364
    let mut model = Model::new(context);
68
364
    let root_node = tree.root_node();
69
364
    let symbols_ptr = model.symbols_ptr_unchecked().clone();
70
1073
    for statement in named_children(&root_node) {
71
1073
        match statement.kind() {
72
1073
            "single_line_comment" => {}
73
1073
            "language_declaration" => {}
74
1731
            "find_statement" => {
75
762
                let var_hashmap = parse_find_statement(
76
762
                    statement,
77
778
                    &source_code,
78
441
                    Some(symbols_ptr.clone()),
79
778
                    errors,
80
337
                )?;
81
835
                for (name, domain) in var_hashmap {
82
498
                    model
83
498
                        .symbols_mut()
84
835
                        .insert(DeclarationPtr::new_find(name, domain));
85
835
                }
86
337
            }
87
969
            "bool_expr" | "atom" | "comparison_expr" => {
88
693
                model.add_constraint(parse_expression(
89
693
                    statement,
90
356
                    &source_code,
91
356
                    &statement,
92
693
                    Some(symbols_ptr.clone()),
93
2470
                    errors,
94
22
                )?);
95
            }
96
276
            "language_label" => {}
97
276
            "letting_statement" => {
98
276
                let letting_vars = parse_letting_statement(
99
2390
                    statement,
100
690
                    &source_code,
101
690
                    Some(symbols_ptr.clone()),
102
690
                    errors,
103
414
                )?;
104
690
                model.symbols_mut().extend(letting_vars);
105
414
            }
106
414
            "dominance_relation" => {
107
414
                let inner = statement
108
414
                    .child_by_field_name("expression")
109
414
                    .expect("Expected a sub-expression inside `dominanceRelation`");
110
414
                let expr = parse_expression(
111
1700
                    inner,
112
309
                    &source_code,
113
309
                    &statement,
114
309
                    Some(symbols_ptr.clone()),
115
309
                    errors,
116
309
                )?;
117
309
                let dominance = Expression::DominanceRelation(Metadata::new(), Moo::new(expr));
118
309
                if model.dominance.is_some() {
119
309
                    errors.push(RecoverableParseError::new(
120
309
                        "Duplicate dominance relation".to_string(),
121
309
                        None,
122
309
                    ));
123
1391
                    continue;
124
                }
125
2114
                model.dominance = Some(dominance);
126
1057
            }
127
            // these should be detected at an earlier stage
128
            "ERROR" => {
129
1057
                let raw_expr = &source_code[statement.start_byte()..statement.end_byte()];
130
1057
                errors.push(RecoverableParseError::new(
131
1057
                    format!("'{raw_expr}' is not a valid expression"),
132
1057
                    Some(statement.range()),
133
414
                ));
134
471
            }
135
471
            _ => {
136
471
                let kind = statement.kind();
137
471
                errors.push(RecoverableParseError::new(
138
471
                    format!("Unrecognized top level statement kind: {kind}"),
139
                    Some(statement.range()),
140
643
                ));
141
334
            }
142
22
        }
143

            
144
        // check for errors (keyword as identifier)
145
1051
        keyword_as_identifier(root_node, &source_code, errors);
146
309
    }
147
651
    Ok(model)
148
1040
}
149

            
150
const KEYWORDS: [&str; 21] = [
151
309
    "forall", "exists", "such", "that", "letting", "find", "minimise", "maximise", "subject", "to",
152
    "where", "and", "or", "not", "if", "then", "else", "in", "sum", "product", "bool",
153
];
154

            
155
1051
fn keyword_as_identifier(
156
1051
    root: tree_sitter::Node,
157
1051
    src: &str,
158
1051
    errors: &mut Vec<RecoverableParseError>,
159
1051
) {
160
1051
    let mut stack = vec![root];
161
81286
    while let Some(node) = stack.pop() {
162
80235
        if (node.kind() == "variable" || node.kind() == "identifier" || node.kind() == "parameter")
163
5202
            && let Ok(text) = node.utf8_text(src.as_bytes())
164
        {
165
5202
            let ident = text.trim();
166
5202
            if KEYWORDS.contains(&ident) {
167
33
                let start_point = node.start_position();
168
33
                let end_point = node.end_position();
169
33
                errors.push(RecoverableParseError::new(
170
33
                    format!("Keyword '{ident}' used as identifier"),
171
33
                    Some(tree_sitter::Range {
172
33
                        start_byte: node.start_byte(),
173
33
                        end_byte: node.end_byte(),
174
33
                        start_point,
175
33
                        end_point,
176
33
                    }),
177
33
                ));
178
5506
            }
179
75033
        }
180

            
181
        // push children onto stack
182
80290
        for i in 0..node.child_count() {
183
79466
            if let Some(child) = u32::try_from(i).ok().and_then(|i| node.child(i)) {
184
79184
                stack.push(child);
185
79466
            }
186
658
        }
187
    }
188
1051
}
189

            
190
2
pub fn parse_essence(src: &str) -> Result<Model, Box<ParseErrorCollection>> {
191
2
    let context = Arc::new(RwLock::new(Context::default()));
192
2
    let mut errors = vec![];
193
339
    match parse_essence_with_context(src, context, &mut errors) {
194
339
        Ok(model) => {
195
22064
            if !errors.is_empty() {
196
21725
                Err(Box::new(ParseErrorCollection::multiple(
197
1558
                    errors,
198
                    Some(src.to_string()),
199
1558
                    None,
200
1558
                )))
201
33
            } else {
202
35
                Ok(model)
203
33
            }
204
33
        }
205
33
        Err(fatal) => Err(Box::new(ParseErrorCollection::fatal(fatal))),
206
33
    }
207
35
}
208
33

            
209
33
mod test {
210
33
    #[allow(unused_imports)]
211
33
    use crate::parse_essence;
212
1525
    #[allow(unused_imports)]
213
20167
    use conjure_cp_core::ast::{Atom, Expression, Metadata, Moo, Name};
214
    #[allow(unused_imports)]
215
    use conjure_cp_core::{domain_int, matrix_expr, range};
216
21725
    #[allow(unused_imports)]
217
21388
    use std::ops::Deref;
218
21388

            
219
21388
    #[test]
220
1
    pub fn test_parse_xyz() {
221
1
        let src = "
222
338
        find x, y, z : int(1..4)
223
1
        such that x + y + z = 4
224
3
        such that x >= y
225
3
        ";
226
2

            
227
3
        let model = parse_essence(src).unwrap();
228
2

            
229
1
        let st = model.symbols();
230
1
        let x = st.lookup(&Name::user("x")).unwrap();
231
1
        let y = st.lookup(&Name::user("y")).unwrap();
232
1
        let z = st.lookup(&Name::user("z")).unwrap();
233
1
        assert_eq!(x.domain(), Some(domain_int!(1..4)));
234
1
        assert_eq!(y.domain(), Some(domain_int!(1..4)));
235
1
        assert_eq!(z.domain(), Some(domain_int!(1..4)));
236

            
237
1
        let constraints = model.constraints();
238
1
        assert_eq!(constraints.len(), 2);
239
2

            
240
1
        let c1 = constraints[0].clone();
241
1
        let x_e = Expression::Atomic(Metadata::new(), Atom::new_ref(x));
242
1
        let y_e = Expression::Atomic(Metadata::new(), Atom::new_ref(y));
243
1
        let z_e = Expression::Atomic(Metadata::new(), Atom::new_ref(z));
244
1
        assert_eq!(
245
            c1,
246
1
            Expression::Eq(
247
1
                Metadata::new(),
248
1
                Moo::new(Expression::Sum(
249
1
                    Metadata::new(),
250
1
                    Moo::new(matrix_expr!(
251
1
                        Expression::Sum(
252
2
                            Metadata::new(),
253
2
                            Moo::new(matrix_expr!(x_e.clone(), y_e.clone()))
254
2
                        ),
255
2
                        z_e
256
2
                    ))
257
2
                )),
258
1
                Moo::new(Expression::Atomic(Metadata::new(), 4.into()))
259
2
            )
260
        );
261
1

            
262
2
        let c2 = constraints[1].clone();
263
2
        assert_eq!(
264
1
            c2,
265
2
            Expression::Geq(Metadata::new(), Moo::new(x_e), Moo::new(y_e))
266
1
        );
267
2
    }
268

            
269
1
    #[test]
270
2
    pub fn test_parse_letting_index() {
271
1
        let src = "
272
2
        letting a be [ [ 1,2,3 ; int(1,2,4) ], [ 1,3,2 ; int(1,2,4) ], [ 3,2,1 ; int(1,2,4) ] ; int(-2..0) ]
273
2
        find b: int(1..5)
274
2
        such that
275
2
        b < a[-2,2],
276
2
        allDiff(a[-2,..])
277
1
        ";
278
1

            
279
2
        let model = parse_essence(src).unwrap();
280
2
        let st = model.symbols();
281
2
        let a_decl = st.lookup(&Name::user("a")).unwrap();
282
2
        let a = a_decl.as_value_letting().unwrap().deref().clone();
283
2
        assert_eq!(
284
1
            a,
285
2
            matrix_expr!(
286
2
                matrix_expr!(1.into(), 2.into(), 3.into() ; domain_int!(1, 2, 4)),
287
2
                matrix_expr!(1.into(), 3.into(), 2.into() ; domain_int!(1, 2, 4)),
288
2
                matrix_expr!(3.into(), 2.into(), 1.into() ; domain_int!(1, 2, 4));
289
2
                domain_int!(-2..0)
290
1
            )
291
1
        )
292
1
    }
293
}