1
use crate::diagnostics::diagnostics_api::{Diagnostic, Position, Range, Severity};
2
use crate::parser::traversal::WalkDFS;
3
use crate::parser::util::get_tree;
4
use capitalize::Capitalize;
5
use std::collections::HashSet;
6
use tree_sitter::Node;
7
/// Helper function
8
pub fn print_diagnostics(diags: &[Diagnostic]) {
9
    for (i, diag) in diags.iter().enumerate() {
10
        println!(
11
            "Diagnostic {}:\n  Range: ({}:{}) - ({}:{})\n  Severity: {:?}\n  Message: {}\n  Source: {}\n",
12
            i + 1,
13
            diag.range.start.line,
14
            diag.range.start.character,
15
            diag.range.end.line,
16
            diag.range.end.character,
17
            diag.severity,
18
            diag.message,
19
            diag.source
20
        );
21
    }
22
}
23

            
24
/// Returns true if the node's start or end column is out of range for its line in the source.
25
254
fn error_node_out_of_range(node: &tree_sitter::Node, source: &str) -> bool {
26
254
    let lines: Vec<&str> = source.lines().collect();
27
254
    let start = node.start_position();
28
254
    let end = node.end_position();
29

            
30
254
    let start_line_len = lines.get(start.row).map_or(0, |l| l.len());
31
254
    let end_line_len = lines.get(end.row).map_or(0, |l| l.len());
32

            
33
254
    (start.column > start_line_len) || (end.column > end_line_len)
34
254
}
35

            
36
/// Detects syntactic issues in the essence source text and returns a vector of Diagnostics.
37
///
38
/// This function traverses the parse tree, looking for missing or error nodes, and generates
39
/// diagnostics for each. It uses a DFS and skips children of error/missing nodes
40
/// to avoid duplicate diagnostics. If the source cannot be parsed, a diagnostic is returned for that.
41
///
42
/// # Arguments
43
/// * `source` - The source code to analyze.
44
///
45
/// # Returns
46
/// * `Vec<Diagnostic>` - A vector of diagnostics describing syntactic issues found in the source.
47
320
pub fn detect_syntactic_errors(source: &str) -> Vec<Diagnostic> {
48
320
    let mut diagnostics = Vec::new();
49
320
    let mut malformed_lines_reported = HashSet::new();
50

            
51
320
    let (tree, _) = match get_tree(source) {
52
319
        Some(tree) => tree,
53
        None => {
54
1
            let last_line = source.lines().count().saturating_sub(1);
55
1
            let last_char = source.lines().last().map(|l| l.len()).unwrap_or(0);
56

            
57
1
            diagnostics.push(Diagnostic {
58
1
                range: Range {
59
1
                    start: Position {
60
1
                        line: 0,
61
1
                        character: 0,
62
1
                    },
63
1
                    end: Position {
64
1
                        line: last_line as u32,
65
1
                        character: last_char as u32,
66
1
                    },
67
1
                },
68
1
                severity: Severity::Error,
69
1
                message: "Failed to read the source code".to_string(),
70
1
                source: "Tree-Sitter-Parse-Error",
71
1
            });
72
1
            return diagnostics;
73
        }
74
    };
75

            
76
319
    let root_node = tree.root_node();
77
7216
    let retract: &dyn Fn(&tree_sitter::Node) -> bool = &|node: &tree_sitter::Node| {
78
7216
        node.is_missing() || node.is_error() || node.start_position() == node.end_position()
79
7216
    };
80

            
81
7216
    for node in WalkDFS::with_retract(&root_node, &retract) {
82
7216
        if node.start_position() == node.end_position() {
83
66
            diagnostics.push(classify_missing_token(node));
84
66
            continue;
85
7150
        }
86
7150
        if node.is_error() {
87
396
            let line = node.start_position().row;
88
            // If this line has already been reported as malformed, skip all error nodes on this line
89
396
            if malformed_lines_reported.contains(&line) {
90
121
                continue;
91
275
            }
92
275
            if is_malformed_line_error(&node, source) {
93
77
                malformed_lines_reported.insert(line);
94

            
95
77
                let last_char = source.lines().nth(line).map_or(0, |l| l.len());
96
77
                diagnostics.push(generate_a_syntax_err_diagnostic(
97
77
                    line as u32,
98
                    0,
99
77
                    line as u32,
100
77
                    last_char as u32,
101
77
                    &format!(
102
77
                        "Malformed line {}: '{}'",
103
77
                        line + 1,
104
77
                        source.lines().nth(line).unwrap_or("")
105
77
                    ),
106
                ));
107
77
                continue;
108
198
            } else {
109
198
                diagnostics.push(classify_unexpected_token_error(node, source));
110
198
            }
111
198
            continue;
112
6754
        }
113
    }
114

            
115
319
    diagnostics
116
320
}
117

            
118
/// Classifies a missing token node and generates a diagnostic with a context-aware message.
119
66
fn classify_missing_token(node: Node) -> Diagnostic {
120
66
    let start = node.start_position();
121
66
    let end = node.end_position();
122

            
123
66
    let message = if let Some(parent) = node.parent() {
124
66
        match parent.kind() {
125
66
            "letting_statement" => "Missing Expression or Domain".to_string(),
126
55
            _ => format!("Missing {}", user_friendly_token_name(node.kind(), false)),
127
        }
128
    } else {
129
        format!("Missing {}", user_friendly_token_name(node.kind(), false))
130
    };
131

            
132
66
    generate_a_syntax_err_diagnostic(
133
66
        start.row as u32,
134
66
        start.column as u32,
135
66
        end.row as u32,
136
66
        end.column as u32,
137
66
        &message,
138
    )
139
66
}
140

            
141
/// Classifies an unexpected token error node and generates a diagnostic.
142
198
fn classify_unexpected_token_error(node: Node, source_code: &str) -> Diagnostic {
143
198
    let message = if let Some(parent) = node.parent() {
144
198
        let start_byte = node.start_byte().min(source_code.len());
145
198
        let end_byte = node.end_byte().min(source_code.len());
146
198
        let src_token = &source_code[start_byte..end_byte];
147

            
148
198
        if parent.kind() == "program"
149
        // ERROR node is the direct child of the root node
150
        {
151
            // A case where the unexpected token is at the end of a valid statement
152
77
            format!("Unexpected {}", src_token)
153
            // }
154
        } else {
155
            // Unexpected token inside a construct
156
121
            format!(
157
                "Unexpected {} inside {}",
158
                src_token,
159
121
                user_friendly_token_name(parent.kind(), true)
160
            )
161
        }
162
    } else {
163
        // Should never happen since an ERROR node would always have a parent.
164
        "Unexpected token".to_string()
165
    };
166

            
167
198
    generate_a_syntax_err_diagnostic(
168
198
        node.start_position().row as u32,
169
198
        node.start_position().column as u32,
170
198
        node.end_position().row as u32,
171
198
        node.end_position().column as u32,
172
198
        &message,
173
    )
174
198
}
175

            
176
/// Determines if an error node represents a malformed line error.
177
276
fn is_malformed_line_error(node: &tree_sitter::Node, source: &str) -> bool {
178
276
    if node.start_position().column == 0 || error_node_out_of_range(node, source) {
179
78
        return true;
180
198
    }
181
198
    let parent = node.parent();
182
198
    let grandparent = parent.and_then(|n| n.parent());
183
198
    let root = grandparent.and_then(|n| n.parent());
184

            
185
198
    if let (Some(parent), Some(grandparent), Some(root)) = (parent, grandparent, root) {
186
110
        parent.kind() == "set_operation_bool"
187
            && grandparent.kind() == "bool_expr"
188
            && root.kind() == "program"
189
    } else {
190
88
        false
191
    }
192
276
}
193

            
194
/// Helper function for tests to compare the actual diagnostic with the expected one.
195
397
pub fn check_diagnostic(
196
397
    diag: &Diagnostic,
197
397
    line_start: u32,
198
397
    char_start: u32,
199
397
    line_end: u32,
200
397
    char_end: u32,
201
397
    msg: &str,
202
397
) {
203
    // Checking range
204
397
    assert_eq!(diag.range.start.line, line_start);
205
397
    assert_eq!(diag.range.start.character, char_start);
206
397
    assert_eq!(diag.range.end.line, line_end);
207
397
    assert_eq!(diag.range.end.character, char_end);
208

            
209
    // Check the message
210
397
    assert_eq!(diag.message, msg);
211
397
}
212

            
213
/// Coverts a token name into a more user-friendly format for error messages.
214
/// Removes underscores, replaces certain keywords with more natural language, and adds appropriate articles.
215
179
fn user_friendly_token_name(token: &str, article: bool) -> String {
216
179
    let capitalized = if token.contains("atom") {
217
11
        "Expression".to_string()
218
168
    } else if token == "COLON" {
219
12
        ":".to_string()
220
    } else {
221
156
        let friendly_name = token
222
156
            .replace("literal", "")
223
156
            .replace("int", "Integer")
224
156
            .replace("expr", "Expression")
225
156
            .replace('_', " ");
226
156
        friendly_name
227
156
            .split_whitespace()
228
279
            .map(|word| word.capitalize())
229
156
            .collect::<Vec<_>>()
230
156
            .join(" ")
231
    };
232

            
233
179
    if !article {
234
57
        return capitalized;
235
122
    }
236
122
    let first_char = capitalized.chars().next().unwrap();
237
122
    let article = match first_char.to_ascii_lowercase() {
238
34
        'a' | 'e' | 'i' | 'o' | 'u' => "an",
239
88
        _ => "a",
240
    };
241
122
    format!("{} {}", article, capitalized)
242
179
}
243

            
244
341
fn generate_a_syntax_err_diagnostic(
245
341
    line_start: u32,
246
341
    char_start: u32,
247
341
    line_end: u32,
248
341
    char_end: u32,
249
341
    msg: &str,
250
341
) -> Diagnostic {
251
341
    Diagnostic {
252
341
        range: Range {
253
341
            start: Position {
254
341
                line: line_start,
255
341
                character: char_start,
256
341
            },
257
341
            end: Position {
258
341
                line: line_end,
259
341
                character: char_end,
260
341
            },
261
341
        },
262
341
        severity: Severity::Error,
263
341
        message: msg.to_string(),
264
341
        source: "syntactic-error-detector",
265
341
    }
266
341
}
267

            
268
#[test]
269
1
fn error_at_start() {
270
1
    let source = "; find x: int(1..3)";
271
1
    let diagnostics = detect_syntactic_errors(source);
272
1
    assert!(!diagnostics.is_empty(), "Expected at least one diagnostic");
273
1
    let diag = &diagnostics[0];
274
1
    check_diagnostic(diag, 0, 0, 0, 19, "Failed to read the source code");
275
1
}
276

            
277
#[test]
278
1
fn user_friendly_token_name_article() {
279
1
    assert_eq!(
280
1
        user_friendly_token_name("int_domain", false),
281
        "Integer Domain"
282
    );
283
1
    assert_eq!(
284
1
        user_friendly_token_name("int_domain", true),
285
        "an Integer Domain"
286
    );
287
    // assert_eq!(user_friendly_token_name("atom", true), "an Expression");
288
1
    assert_eq!(user_friendly_token_name("COLON", false), ":");
289
1
}
290
#[test]
291
1
fn malformed_line() {
292
1
    let source = " a,a,b: int(1..3)";
293
1
    let (tree, _) = get_tree(source).expect("Should parse");
294
1
    let root_node = tree.root_node();
295

            
296
1
    let error_node = WalkDFS::with_retract(&root_node, &|_node| false)
297
13
        .find(|node| node.is_error())
298
1
        .expect("Should find an error node");
299

            
300
1
    assert!(is_malformed_line_error(&error_node, source));
301
1
}