1
use crate::errors::RecoverableParseError;
2
use crate::parser::traversal::WalkDFS;
3
use capitalize::Capitalize;
4
use std::collections::HashSet;
5
use tree_sitter::Node;
6

            
7
1459
pub fn detect_syntactic_errors(
8
1459
    source: &str,
9
1459
    tree: &tree_sitter::Tree,
10
1459
    errors: &mut Vec<RecoverableParseError>,
11
1459
) {
12
1459
    let mut malformed_lines_reported = HashSet::new();
13

            
14
1459
    let root_node = tree.root_node();
15
32290
    let retract: &dyn Fn(&tree_sitter::Node) -> bool = &|node: &tree_sitter::Node| {
16
32290
        node.is_missing() || node.is_error() || node.start_position() == node.end_position()
17
32290
    };
18

            
19
32290
    for node in WalkDFS::with_retract(&root_node, &retract) {
20
32290
        if node.start_position() == node.end_position() {
21
411
            errors.push(classify_missing_token(node));
22
411
            continue;
23
31879
        }
24
31879
        if node.is_error() {
25
1624
            let line = node.start_position().row;
26
            // If this line has already been reported as malformed, skip all error nodes on this line
27
1624
            if malformed_lines_reported.contains(&line) {
28
508
                continue;
29
1116
            }
30
1116
            if is_malformed_line_error(&node, source) {
31
376
                malformed_lines_reported.insert(line);
32
376
                let start_byte = node.start_byte();
33
376
                let end_byte = node.end_byte();
34

            
35
376
                let last_char = source.lines().nth(line).map_or(0, |l| l.len());
36
376
                errors.push(RecoverableParseError::new(
37
376
                    format!(
38
266
                        "Malformed line {}: '{}'",
39
376
                        line + 1,
40
376
                        source.lines().nth(line).unwrap_or("")
41
266
                    ),
42
376
                    Some(tree_sitter::Range {
43
376
                        start_byte,
44
376
                        end_byte,
45
376
                        start_point: tree_sitter::Point {
46
376
                            row: line,
47
376
                            column: 0,
48
376
                        },
49
376
                        end_point: tree_sitter::Point {
50
110
                            row: line,
51
376
                            column: last_char,
52
596
                        },
53
596
                    }),
54
486
                ));
55
596
                continue;
56
20912
            } else {
57
254
                errors.push(classify_unexpected_token_error(node, source));
58
1228
            }
59
254
            continue;
60
9597
        }
61
266
    }
62
751
}
63
266

            
64
/// Classifies a missing token node and generates a diagnostic with a context-aware message.
65
411
fn classify_missing_token(node: Node) -> RecoverableParseError {
66
411
    let start = node.start_position();
67
411
    let end = node.end_position();
68
222

            
69
145
    let message = if let Some(parent) = node.parent() {
70
145
        match parent.kind() {
71
145
            "letting_statement" => "Missing Expression or Domain".to_string(),
72
121
            _ => format!("Missing {}", user_friendly_token_name(node.kind(), false)),
73
        }
74
266
    } else {
75
266
        format!("Missing {}", user_friendly_token_name(node.kind(), false))
76
266
    };
77
266

            
78
411
    RecoverableParseError::new(
79
411
        message,
80
411
        Some(tree_sitter::Range {
81
411
            start_byte: node.start_byte(),
82
145
            end_byte: node.end_byte(),
83
411
            start_point: start,
84
145
            end_point: end,
85
145
        }),
86
486
    )
87
631
}
88
486

            
89
/// Classifies an unexpected token error node and generates a diagnostic.
90
740
fn classify_unexpected_token_error(node: Node, source_code: &str) -> RecoverableParseError {
91
254
    let message = if let Some(parent) = node.parent() {
92
740
        let start_byte = node.start_byte().min(source_code.len());
93
254
        let end_byte = node.end_byte().min(source_code.len());
94
254
        let src_token = &source_code[start_byte..end_byte];
95

            
96
454
        if parent.kind() == "program"
97
        // ERROR node is the direct child of the root node
98
        {
99
            // A case where the unexpected token is at the end of a valid statement
100
396
            format!("Unexpected {}", src_token)
101
            // }
102
        } else {
103
            // Unexpected token inside a construct
104
144
            format!(
105
                "Unexpected {} inside {}",
106
                src_token,
107
144
                user_friendly_token_name(parent.kind(), true)
108
            )
109
        }
110
    } else {
111
        // Should never happen since an ERROR node would always have a parent.
112
486
        "Unexpected token".to_string()
113
486
    };
114
486

            
115
740
    RecoverableParseError::new(
116
740
        message,
117
740
        Some(tree_sitter::Range {
118
740
            start_byte: node.start_byte(),
119
254
            end_byte: node.end_byte(),
120
740
            start_point: node.start_position(),
121
254
            end_point: node.end_position(),
122
254
        }),
123
754
    )
124
1008
}
125
268

            
126
/// Determines if an error node represents a malformed line error.
127
851
fn is_malformed_line_error(node: &tree_sitter::Node, source: &str) -> bool {
128
851
    if node.start_position().column == 0 || error_node_out_of_range(node, source) {
129
597
        return true;
130
254
    }
131
740
    let parent = node.parent();
132
518
    let grandparent = parent.and_then(|n| n.parent());
133
254
    let root = grandparent.and_then(|n| n.parent());
134

            
135
254
    if let (Some(parent), Some(grandparent), Some(root)) = (parent, grandparent, root) {
136
354
        parent.kind() == "set_comparison"
137
            && grandparent.kind() == "comparison_expr"
138
754
            && root.kind() == "program"
139
    } else {
140
122
        false
141
    }
142
879
}
143
514

            
144
/// Coverts a token name into a more user-friendly format for error messages.
145
/// Removes underscores, replaces certain keywords with more natural language, and adds appropriate articles.
146
314
fn user_friendly_token_name(token: &str, article: bool) -> String {
147
268
    let capitalized = if token.contains("atom") {
148
458
        "Expression".to_string()
149
702
    } else if token == "COLON" {
150
471
        ":".to_string()
151
446
    } else {
152
677
        let friendly_name = token
153
677
            .replace("literal", "")
154
677
            .replace("int", "Integer")
155
989
            .replace("expr", "Expression")
156
677
            .replace('_', " ");
157
677
        friendly_name
158
231
            .split_whitespace()
159
401
            .map(|word| word.capitalize())
160
745
            .collect::<Vec<_>>()
161
457
            .join(" ")
162
288
    };
163
288

            
164
556
    if !article {
165
213
        return capitalized;
166
343
    }
167
145
    let first_char = capitalized.chars().next().unwrap();
168
433
    let article = match first_char.to_ascii_lowercase() {
169
563
        'a' | 'e' | 'i' | 'o' | 'u' => "an",
170
96
        _ => "a",
171
    };
172
426
    format!("{} {}", article, capitalized)
173
549
}
174
281

            
175
/// Returns true if the node's start or end column is out of range for its line in the source.
176
610
fn error_node_out_of_range(node: &tree_sitter::Node, source: &str) -> bool {
177
610
    let lines: Vec<&str> = source.lines().collect();
178
329
    let start = node.start_position();
179
610
    let end = node.end_position();
180
281

            
181
474
    let start_line_len = lines.get(start.row).map_or(0, |l| l.len());
182
428
    let end_line_len = lines.get(end.row).map_or(0, |l| l.len());
183
97

            
184
424
    (start.column > start_line_len) || (end.column > end_line_len)
185
423
}
186

            
187
26
#[cfg(test)]
188
24
mod test {
189

            
190
2
    use super::{detect_syntactic_errors, is_malformed_line_error, user_friendly_token_name};
191
    use crate::errors::RecoverableParseError;
192
    use crate::{parser::traversal::WalkDFS, util::get_tree};
193

            
194
    /// Helper function for tests to compare the actual error with the expected one.
195
1
    fn assert_essence_parse_error_eq(a: &RecoverableParseError, b: &RecoverableParseError) {
196
69
        assert_eq!(a.msg, b.msg, "error messages differ");
197
1
        assert_eq!(a.range, b.range, "error ranges differ");
198
1
    }
199
281

            
200
281
    #[test]
201
1
    fn malformed_line() {
202
1
        let source = " a,a,b: int(1..3)";
203
689
        let (tree, _) = get_tree(source).expect("Should parse");
204
689
        let root_node = tree.root_node();
205
688

            
206
689
        let error_node = WalkDFS::with_retract(&root_node, &|_node| false)
207
13
            .find(|node| node.is_error())
208
689
            .expect("Should find an error node");
209
688

            
210
1
        assert!(is_malformed_line_error(&error_node, source));
211
689
    }
212
688

            
213
    #[test]
214
1
    fn user_friendly_token_name_article() {
215
1
        assert_eq!(
216
1
            user_friendly_token_name("int_domain", false),
217
            "Integer Domain"
218
        );
219
1
        assert_eq!(
220
1
            user_friendly_token_name("int_domain", true),
221
            "an Integer Domain"
222
2
        );
223
        // assert_eq!(user_friendly_token_name("atom", true), "an Expression");
224
3
        assert_eq!(user_friendly_token_name("COLON", false), ":");
225
3
    }
226

            
227
    #[test]
228
3
    fn missing_domain() {
229
3
        let source = "find x:";
230
3
        let (tree, _) = get_tree(source).expect("Should parse");
231
3
        let mut errors = vec![];
232
1
        detect_syntactic_errors(source, &tree, &mut errors);
233
3
        assert_eq!(errors.len(), 1, "Expected exactly one diagnostic");
234
26

            
235
3
        let error = &errors[0];
236

            
237
3
        assert_essence_parse_error_eq(
238
3
            error,
239
1
            &RecoverableParseError::new(
240
1
                "Missing Domain".to_string(),
241
3
                Some(tree_sitter::Range {
242
3
                    start_byte: 7,
243
3
                    end_byte: 7,
244
3
                    start_point: tree_sitter::Point { row: 0, column: 7 },
245
1
                    end_point: tree_sitter::Point { row: 0, column: 7 },
246
1
                }),
247
1
            ),
248
2
        );
249
1
    }
250
}