1
use crate::errors::RecoverableParseError;
2
use crate::parser::traversal::WalkDFS;
3
use capitalize::Capitalize;
4
use std::collections::HashSet;
5
use tree_sitter::Node;
6

            
7
555
pub fn detect_syntactic_errors(
8
555
    source: &str,
9
555
    tree: &tree_sitter::Tree,
10
555
    errors: &mut Vec<RecoverableParseError>,
11
555
) {
12
555
    let mut malformed_lines_reported = HashSet::new();
13

            
14
555
    let root_node = tree.root_node();
15
14848
    let retract: &dyn Fn(&tree_sitter::Node) -> bool = &|node: &tree_sitter::Node| {
16
14848
        node.is_missing() || node.is_error() || node.start_position() == node.end_position()
17
14848
    };
18

            
19
14848
    for node in WalkDFS::with_retract(&root_node, &retract) {
20
14848
        if node.start_position() == node.end_position() {
21
144
            errors.push(classify_missing_token(node));
22
144
            continue;
23
14704
        }
24
14704
        if node.is_error() {
25
697
            let line = node.start_position().row;
26
            // If this line has already been reported as malformed, skip all error nodes on this line
27
697
            if malformed_lines_reported.contains(&line) {
28
176
                continue;
29
521
            }
30
521
            if is_malformed_line_error(&node, source) {
31
145
                malformed_lines_reported.insert(line);
32
145
                let start_byte = node.start_byte();
33
145
                let end_byte = node.end_byte();
34

            
35
145
                let last_char = source.lines().nth(line).map_or(0, |l| l.len());
36
145
                errors.push(RecoverableParseError::new(
37
145
                    generate_malformed_line_message(line, source),
38
145
                    Some(tree_sitter::Range {
39
145
                        start_byte,
40
145
                        end_byte,
41
145
                        start_point: tree_sitter::Point {
42
145
                            row: line,
43
145
                            column: 0,
44
145
                        },
45
145
                        end_point: tree_sitter::Point {
46
145
                            row: line,
47
145
                            column: last_char,
48
145
                        },
49
145
                    }),
50
                ));
51
145
                continue;
52
376
            } else {
53
376
                errors.push(classify_unexpected_token_error(node, source));
54
376
            }
55
376
            continue;
56
14007
        }
57
    }
58
555
}
59

            
60
/// Classifies a missing token node and generates a diagnostic with a context-aware message.
61
144
fn classify_missing_token(node: Node) -> RecoverableParseError {
62
144
    let start = node.start_position();
63
144
    let end = node.end_position();
64

            
65
144
    let message = if let Some(parent) = node.parent() {
66
144
        match parent.kind() {
67
144
            "letting_statement" => "Missing Expression or Domain".to_string(),
68
122
            _ => format!("Missing {}", user_friendly_token_name(node.kind(), false)),
69
        }
70
    } else {
71
        format!("Missing {}", user_friendly_token_name(node.kind(), false))
72
    };
73

            
74
144
    RecoverableParseError::new(
75
144
        message,
76
144
        Some(tree_sitter::Range {
77
144
            start_byte: node.start_byte(),
78
144
            end_byte: node.end_byte(),
79
144
            start_point: start,
80
144
            end_point: end,
81
144
        }),
82
    )
83
144
}
84

            
85
/// Classifies an unexpected token error node and generates a diagnostic.
86
376
fn classify_unexpected_token_error(node: Node, source_code: &str) -> RecoverableParseError {
87
376
    let message = if let Some(parent) = node.parent() {
88
376
        let start_byte = node.start_byte().min(source_code.len());
89
376
        let end_byte = node.end_byte().min(source_code.len());
90
376
        let src_token = &source_code[start_byte..end_byte];
91

            
92
376
        if parent.kind() == "program"
93
        // ERROR node is the direct child of the root node
94
        {
95
            // A case where the unexpected token is at the end of a valid statement
96
112
            format!("Unexpected {}", src_token)
97
            // }
98
        } else {
99
            // Unexpected token inside a construct
100
264
            format!(
101
                "Unexpected {} inside {}",
102
                src_token,
103
264
                user_friendly_token_name(parent.kind(), true)
104
            )
105
        }
106
    } else {
107
        // Should never happen since an ERROR node would always have a parent.
108
        "Unexpected token".to_string()
109
    };
110

            
111
376
    RecoverableParseError::new(
112
376
        message,
113
376
        Some(tree_sitter::Range {
114
376
            start_byte: node.start_byte(),
115
376
            end_byte: node.end_byte(),
116
376
            start_point: node.start_position(),
117
376
            end_point: node.end_position(),
118
376
        }),
119
    )
120
376
}
121

            
122
/// Determines if an error node represents a malformed line error.
123
522
fn is_malformed_line_error(node: &tree_sitter::Node, source: &str) -> bool {
124
522
    if node.start_position().column == 0 || error_node_out_of_range(node, source) {
125
146
        return true;
126
376
    }
127
376
    let parent = node.parent();
128
376
    let grandparent = parent.and_then(|n| n.parent());
129
376
    let root = grandparent.and_then(|n| n.parent());
130

            
131
376
    if let (Some(parent), Some(grandparent), Some(root)) = (parent, grandparent, root) {
132
253
        parent.kind() == "set_comparison"
133
            && grandparent.kind() == "comparison_expr"
134
            && root.kind() == "program"
135
    } else {
136
123
        false
137
    }
138
522
}
139

            
140
/// Coverts a token name into a more user-friendly format for error messages.
141
/// Removes underscores, replaces certain keywords with more natural language, and adds appropriate articles.
142
389
fn user_friendly_token_name(token: &str, article: bool) -> String {
143
389
    let capitalized = if token.contains("atom") {
144
22
        "Expression".to_string()
145
367
    } else if token == "COLON" {
146
23
        ":".to_string()
147
    } else {
148
344
        let friendly_name = token
149
344
            .replace("literal", "")
150
344
            .replace("int", "Integer")
151
344
            .replace("expr", "Expression")
152
344
            .replace('_', " ");
153
344
        friendly_name
154
344
            .split_whitespace()
155
621
            .map(|word| word.capitalize())
156
344
            .collect::<Vec<_>>()
157
344
            .join(" ")
158
    };
159

            
160
389
    if !article {
161
124
        return capitalized;
162
265
    }
163
265
    let first_char = capitalized.chars().next().unwrap();
164
265
    let article = match first_char.to_ascii_lowercase() {
165
89
        'a' | 'e' | 'i' | 'o' | 'u' => "an",
166
176
        _ => "a",
167
    };
168
265
    format!("{} {}", article, capitalized)
169
389
}
170

            
171
// Generates an informative error message for malformed lines
172
153
fn generate_malformed_line_message(line: usize, source: &str) -> String {
173
153
    let got = source.lines().nth(line).unwrap_or("").trim();
174
153
    let got = got.replace('"', "\\\"");
175
153
    let mut words = got.split_whitespace();
176
153
    let first = words.next().unwrap_or("").to_ascii_lowercase();
177
153
    let second = words.next().unwrap_or("").to_ascii_lowercase();
178

            
179
153
    let expected = match first.as_str() {
180
153
        "find" => "a find declaration statement",
181
73
        "letting" => "a letting declaration statement",
182
50
        "given" => "a given declaration statement",
183
49
        "where" => "an instantiation condition",
184
48
        "minimising" | "maximising" => "an objective statement",
185
47
        "such" => {
186
            // Check for invalid constraint statement
187
13
            if second == "that" {
188
12
                "a constraint statement"
189
            } else {
190
1
                "a valid top-level statement"
191
            }
192
        }
193

            
194
        _ => {
195
            // Default case for unrecognized starting tokens
196
34
            "a valid top-level statement"
197
        }
198
    };
199
153
    format!("Expected {}, but got '{}'", expected, got)
200
153
}
201

            
202
/// Returns true if the node's start or end column is out of range for its line in the source.
203
478
fn error_node_out_of_range(node: &tree_sitter::Node, source: &str) -> bool {
204
478
    let lines: Vec<&str> = source.lines().collect();
205
478
    let start = node.start_position();
206
478
    let end = node.end_position();
207

            
208
478
    let start_line_len = lines.get(start.row).map_or(0, |l| l.len());
209
478
    let end_line_len = lines.get(end.row).map_or(0, |l| l.len());
210

            
211
478
    (start.column > start_line_len) || (end.column > end_line_len)
212
478
}
213

            
214
#[cfg(test)]
215
mod test {
216

            
217
    use super::{detect_syntactic_errors, is_malformed_line_error, user_friendly_token_name};
218
    use crate::errors::RecoverableParseError;
219
    use crate::{parser::traversal::WalkDFS, util::get_tree};
220

            
221
    /// Helper function for tests to compare the actual error with the expected one.
222
1
    fn assert_essence_parse_error_eq(a: &RecoverableParseError, b: &RecoverableParseError) {
223
1
        assert_eq!(a.msg, b.msg, "error messages differ");
224
1
        assert_eq!(a.range, b.range, "error ranges differ");
225
1
    }
226

            
227
    #[test]
228
1
    fn malformed_line() {
229
1
        let source = " a,a,b: int(1..3)";
230
1
        let (tree, _) = get_tree(source).expect("Should parse");
231
1
        let root_node = tree.root_node();
232

            
233
1
        let error_node = WalkDFS::with_retract(&root_node, &|_node| false)
234
13
            .find(|node| node.is_error())
235
1
            .expect("Should find an error node");
236

            
237
1
        assert!(is_malformed_line_error(&error_node, source));
238
1
    }
239

            
240
    #[test]
241
1
    fn malformed_find_message() {
242
1
        let source = "find >=lex,b,c: int(1..3)";
243
1
        let message = super::generate_malformed_line_message(0, source);
244
1
        assert_eq!(
245
            message,
246
            "Expected a find declaration statement, but got 'find >=lex,b,c: int(1..3)'"
247
        );
248
1
    }
249

            
250
    #[test]
251
1
    fn malformed_top_level_message() {
252
1
        let source = "a >=lex,b,c: int(1..3)";
253
1
        let message = super::generate_malformed_line_message(0, source);
254
1
        assert_eq!(
255
            message,
256
            "Expected a valid top-level statement, but got 'a >=lex,b,c: int(1..3)'"
257
        );
258
1
    }
259

            
260
    #[test]
261
1
    fn malformed_objective_message() {
262
1
        let source = "maximising %x";
263
1
        let message = super::generate_malformed_line_message(0, source);
264
1
        assert_eq!(
265
            message,
266
            "Expected an objective statement, but got 'maximising %x'"
267
        );
268
1
    }
269

            
270
    #[test]
271
1
    fn malformed_letting_message() {
272
1
        let source = "letting % A be 3";
273
1
        let message = super::generate_malformed_line_message(0, source);
274
1
        assert_eq!(
275
            message,
276
            "Expected a letting declaration statement, but got 'letting % A be 3'"
277
        );
278
1
    }
279

            
280
    #[test]
281
1
    fn malformed_constraint_message() {
282
1
        let source = "such that % A > 3";
283
1
        let message = super::generate_malformed_line_message(0, source);
284
1
        assert_eq!(
285
            message,
286
            "Expected a constraint statement, but got 'such that % A > 3'"
287
        );
288
1
    }
289

            
290
    #[test]
291
1
    fn malformed_top_level_message_2() {
292
1
        let source = "such % A > 3";
293
1
        let message = super::generate_malformed_line_message(0, source);
294
1
        assert_eq!(
295
            message,
296
            "Expected a valid top-level statement, but got 'such % A > 3'"
297
        );
298
1
    }
299

            
300
    #[test]
301
1
    fn malformed_given_message() {
302
1
        let source = "given 1..3)";
303
1
        let message = super::generate_malformed_line_message(0, source);
304
1
        assert_eq!(
305
            message,
306
            "Expected a given declaration statement, but got 'given 1..3)'"
307
        );
308
1
    }
309

            
310
    #[test]
311
1
    fn malformed_where_message() {
312
1
        let source = "where x>6";
313
1
        let message = super::generate_malformed_line_message(0, source);
314
1
        assert_eq!(
315
            message,
316
            "Expected an instantiation condition, but got 'where x>6'"
317
        );
318
1
    }
319

            
320
    #[test]
321
1
    fn user_friendly_token_name_article() {
322
1
        assert_eq!(
323
1
            user_friendly_token_name("int_domain", false),
324
            "Integer Domain"
325
        );
326
1
        assert_eq!(
327
1
            user_friendly_token_name("int_domain", true),
328
            "an Integer Domain"
329
        );
330
        // assert_eq!(user_friendly_token_name("atom", true), "an Expression");
331
1
        assert_eq!(user_friendly_token_name("COLON", false), ":");
332
1
    }
333

            
334
    #[test]
335
1
    fn missing_domain() {
336
1
        let source = "find x:";
337
1
        let (tree, _) = get_tree(source).expect("Should parse");
338
1
        let mut errors = vec![];
339
1
        detect_syntactic_errors(source, &tree, &mut errors);
340
1
        assert_eq!(errors.len(), 1, "Expected exactly one diagnostic");
341

            
342
1
        let error = &errors[0];
343

            
344
1
        assert_essence_parse_error_eq(
345
1
            error,
346
1
            &RecoverableParseError::new(
347
1
                "Missing Domain".to_string(),
348
1
                Some(tree_sitter::Range {
349
1
                    start_byte: 7,
350
1
                    end_byte: 7,
351
1
                    start_point: tree_sitter::Point { row: 0, column: 7 },
352
1
                    end_point: tree_sitter::Point { row: 0, column: 7 },
353
1
                }),
354
1
            ),
355
        );
356
1
    }
357
}