Skip to main content

conjure_cp_essence_parser/parser/
syntax_errors.rs

1use crate::errors::RecoverableParseError;
2use crate::parser::traversal::WalkDFS;
3use capitalize::Capitalize;
4use std::collections::HashSet;
5use tree_sitter::Node;
6
7pub fn detect_syntactic_errors(
8    source: &str,
9    tree: &tree_sitter::Tree,
10    errors: &mut Vec<RecoverableParseError>,
11) {
12    let mut malformed_lines_reported = HashSet::new();
13
14    let root_node = tree.root_node();
15    let retract: &dyn Fn(&tree_sitter::Node) -> bool = &|node: &tree_sitter::Node| {
16        node.is_missing() || node.is_error() || node.start_position() == node.end_position()
17    };
18
19    for node in WalkDFS::with_retract(&root_node, &retract) {
20        if node.start_position() == node.end_position() {
21            errors.push(classify_missing_token(node));
22            continue;
23        }
24        if node.is_error() {
25            let line = node.start_position().row;
26            // If this line has already been reported as malformed, skip all error nodes on this line
27            if malformed_lines_reported.contains(&line) {
28                continue;
29            }
30            if is_malformed_line_error(&node, source) {
31                malformed_lines_reported.insert(line);
32                let start_byte = node.start_byte();
33                let end_byte = node.end_byte();
34
35                let last_char = source.lines().nth(line).map_or(0, |l| l.len());
36                errors.push(RecoverableParseError::new(
37                    generate_malformed_line_message(line, source),
38                    Some(tree_sitter::Range {
39                        start_byte,
40                        end_byte,
41                        start_point: tree_sitter::Point {
42                            row: line,
43                            column: 0,
44                        },
45                        end_point: tree_sitter::Point {
46                            row: line,
47                            column: last_char,
48                        },
49                    }),
50                ));
51                continue;
52            } else {
53                errors.push(classify_unexpected_token_error(node, source));
54            }
55            continue;
56        }
57    }
58}
59
60/// Classifies a missing token node and generates a diagnostic with a context-aware message.
61fn classify_missing_token(node: Node) -> RecoverableParseError {
62    let start = node.start_position();
63    let end = node.end_position();
64
65    let message = if let Some(parent) = node.parent() {
66        match parent.kind() {
67            "letting_statement" => "Missing Expression or Domain".to_string(),
68            _ => format!("Missing {}", user_friendly_token_name(node.kind(), false)),
69        }
70    } else {
71        format!("Missing {}", user_friendly_token_name(node.kind(), false))
72    };
73
74    RecoverableParseError::new(
75        message,
76        Some(tree_sitter::Range {
77            start_byte: node.start_byte(),
78            end_byte: node.end_byte(),
79            start_point: start,
80            end_point: end,
81        }),
82    )
83}
84
85/// Classifies an unexpected token error node and generates a diagnostic.
86fn classify_unexpected_token_error(node: Node, source_code: &str) -> RecoverableParseError {
87    let message = if let Some(parent) = node.parent() {
88        let start_byte = node.start_byte().min(source_code.len());
89        let end_byte = node.end_byte().min(source_code.len());
90        let src_token = &source_code[start_byte..end_byte];
91
92        if parent.kind() == "program"
93        // ERROR node is the direct child of the root node
94        {
95            // A case where the unexpected token is at the end of a valid statement
96            format!("Unexpected {}", src_token)
97            // }
98        } else {
99            // Unexpected token inside a construct
100            format!(
101                "Unexpected {} inside {}",
102                src_token,
103                user_friendly_token_name(parent.kind(), true)
104            )
105        }
106    } else {
107        // Should never happen since an ERROR node would always have a parent.
108        "Unexpected token".to_string()
109    };
110
111    RecoverableParseError::new(
112        message,
113        Some(tree_sitter::Range {
114            start_byte: node.start_byte(),
115            end_byte: node.end_byte(),
116            start_point: node.start_position(),
117            end_point: node.end_position(),
118        }),
119    )
120}
121
122/// Determines if an error node represents a malformed line error.
123fn is_malformed_line_error(node: &tree_sitter::Node, source: &str) -> bool {
124    if node.start_position().column == 0 || error_node_out_of_range(node, source) {
125        return true;
126    }
127    let parent = node.parent();
128    let grandparent = parent.and_then(|n| n.parent());
129    let root = grandparent.and_then(|n| n.parent());
130
131    if let (Some(parent), Some(grandparent), Some(root)) = (parent, grandparent, root) {
132        parent.kind() == "set_comparison"
133            && grandparent.kind() == "comparison_expr"
134            && root.kind() == "program"
135    } else {
136        false
137    }
138}
139
140/// Coverts a token name into a more user-friendly format for error messages.
141/// Removes underscores, replaces certain keywords with more natural language, and adds appropriate articles.
142fn user_friendly_token_name(token: &str, article: bool) -> String {
143    let capitalized = if token.contains("atom") {
144        "Expression".to_string()
145    } else if token == "COLON" {
146        ":".to_string()
147    } else {
148        let friendly_name = token
149            .replace("literal", "")
150            .replace("int", "Integer")
151            .replace("expr", "Expression")
152            .replace('_', " ");
153        friendly_name
154            .split_whitespace()
155            .map(|word| word.capitalize())
156            .collect::<Vec<_>>()
157            .join(" ")
158    };
159
160    if !article {
161        return capitalized;
162    }
163    let first_char = capitalized.chars().next().unwrap();
164    let article = match first_char.to_ascii_lowercase() {
165        'a' | 'e' | 'i' | 'o' | 'u' => "an",
166        _ => "a",
167    };
168    format!("{} {}", article, capitalized)
169}
170
171// Generates an informative error message for malformed lines
172fn generate_malformed_line_message(line: usize, source: &str) -> String {
173    let got = source.lines().nth(line).unwrap_or("").trim();
174    let got = got.replace('"', "\\\"");
175    let mut words = got.split_whitespace();
176    let first = words.next().unwrap_or("").to_ascii_lowercase();
177    let second = words.next().unwrap_or("").to_ascii_lowercase();
178
179    let expected = match first.as_str() {
180        "find" => "a find declaration statement",
181        "letting" => "a letting declaration statement",
182        "given" => "a given declaration statement",
183        "where" => "an instantiation condition",
184        "minimising" | "maximising" => "an objective statement",
185        "such" => {
186            // Check for invalid constraint statement
187            if second == "that" {
188                "a constraint statement"
189            } else {
190                "a valid top-level statement"
191            }
192        }
193
194        _ => {
195            // Default case for unrecognized starting tokens
196            "a valid top-level statement"
197        }
198    };
199    format!("Expected {}, but got '{}'", expected, got)
200}
201
202/// Returns true if the node's start or end column is out of range for its line in the source.
203fn error_node_out_of_range(node: &tree_sitter::Node, source: &str) -> bool {
204    let lines: Vec<&str> = source.lines().collect();
205    let start = node.start_position();
206    let end = node.end_position();
207
208    let start_line_len = lines.get(start.row).map_or(0, |l| l.len());
209    let end_line_len = lines.get(end.row).map_or(0, |l| l.len());
210
211    (start.column > start_line_len) || (end.column > end_line_len)
212}
213
214#[cfg(test)]
215mod test {
216
217    use super::{detect_syntactic_errors, is_malformed_line_error, user_friendly_token_name};
218    use crate::errors::RecoverableParseError;
219    use crate::{parser::traversal::WalkDFS, util::get_tree};
220
221    /// Helper function for tests to compare the actual error with the expected one.
222    fn assert_essence_parse_error_eq(a: &RecoverableParseError, b: &RecoverableParseError) {
223        assert_eq!(a.msg, b.msg, "error messages differ");
224        assert_eq!(a.range, b.range, "error ranges differ");
225    }
226
227    #[test]
228    fn malformed_line() {
229        let source = " a,a,b: int(1..3)";
230        let (tree, _) = get_tree(source).expect("Should parse");
231        let root_node = tree.root_node();
232
233        let error_node = WalkDFS::with_retract(&root_node, &|_node| false)
234            .find(|node| node.is_error())
235            .expect("Should find an error node");
236
237        assert!(is_malformed_line_error(&error_node, source));
238    }
239
240    #[test]
241    fn malformed_find_message() {
242        let source = "find >=lex,b,c: int(1..3)";
243        let message = super::generate_malformed_line_message(0, source);
244        assert_eq!(
245            message,
246            "Expected a find declaration statement, but got 'find >=lex,b,c: int(1..3)'"
247        );
248    }
249
250    #[test]
251    fn malformed_top_level_message() {
252        let source = "a >=lex,b,c: int(1..3)";
253        let message = super::generate_malformed_line_message(0, source);
254        assert_eq!(
255            message,
256            "Expected a valid top-level statement, but got 'a >=lex,b,c: int(1..3)'"
257        );
258    }
259
260    #[test]
261    fn malformed_objective_message() {
262        let source = "maximising %x";
263        let message = super::generate_malformed_line_message(0, source);
264        assert_eq!(
265            message,
266            "Expected an objective statement, but got 'maximising %x'"
267        );
268    }
269
270    #[test]
271    fn malformed_letting_message() {
272        let source = "letting % A be 3";
273        let message = super::generate_malformed_line_message(0, source);
274        assert_eq!(
275            message,
276            "Expected a letting declaration statement, but got 'letting % A be 3'"
277        );
278    }
279
280    #[test]
281    fn malformed_constraint_message() {
282        let source = "such that % A > 3";
283        let message = super::generate_malformed_line_message(0, source);
284        assert_eq!(
285            message,
286            "Expected a constraint statement, but got 'such that % A > 3'"
287        );
288    }
289
290    #[test]
291    fn malformed_top_level_message_2() {
292        let source = "such % A > 3";
293        let message = super::generate_malformed_line_message(0, source);
294        assert_eq!(
295            message,
296            "Expected a valid top-level statement, but got 'such % A > 3'"
297        );
298    }
299
300    #[test]
301    fn malformed_given_message() {
302        let source = "given 1..3)";
303        let message = super::generate_malformed_line_message(0, source);
304        assert_eq!(
305            message,
306            "Expected a given declaration statement, but got 'given 1..3)'"
307        );
308    }
309
310    #[test]
311    fn malformed_where_message() {
312        let source = "where x>6";
313        let message = super::generate_malformed_line_message(0, source);
314        assert_eq!(
315            message,
316            "Expected an instantiation condition, but got 'where x>6'"
317        );
318    }
319
320    #[test]
321    fn user_friendly_token_name_article() {
322        assert_eq!(
323            user_friendly_token_name("int_domain", false),
324            "Integer Domain"
325        );
326        assert_eq!(
327            user_friendly_token_name("int_domain", true),
328            "an Integer Domain"
329        );
330        // assert_eq!(user_friendly_token_name("atom", true), "an Expression");
331        assert_eq!(user_friendly_token_name("COLON", false), ":");
332    }
333
334    #[test]
335    fn missing_domain() {
336        let source = "find x:";
337        let (tree, _) = get_tree(source).expect("Should parse");
338        let mut errors = vec![];
339        detect_syntactic_errors(source, &tree, &mut errors);
340        assert_eq!(errors.len(), 1, "Expected exactly one diagnostic");
341
342        let error = &errors[0];
343
344        assert_essence_parse_error_eq(
345            error,
346            &RecoverableParseError::new(
347                "Missing Domain".to_string(),
348                Some(tree_sitter::Range {
349                    start_byte: 7,
350                    end_byte: 7,
351                    start_point: tree_sitter::Point { row: 0, column: 7 },
352                    end_point: tree_sitter::Point { row: 0, column: 7 },
353                }),
354            ),
355        );
356    }
357}