1
use crate::errors::RecoverableParseError;
2
use crate::parser::traversal::WalkDFS;
3
use capitalize::Capitalize;
4
use std::collections::HashSet;
5
use tree_sitter::Node;
6

            
7
/// Returns the absolute byte offset of the start of `row` in `source`.
8
70
fn line_start_byte(source: &[u8], row: usize) -> usize {
9
70
    let mut current_row = 0usize;
10
70
    let mut line_start = 0usize;
11
1041
    for (idx, b) in source.iter().enumerate() {
12
1041
        if current_row == row {
13
70
            break;
14
971
        }
15
971
        if *b == b'\n' {
16
30
            current_row += 1;
17
30
            line_start = idx + 1;
18
941
        }
19
    }
20
70
    line_start
21
70
}
22

            
23
40
fn point_range_at(source: &str, row: usize, column: usize) -> tree_sitter::Range {
24
40
    let line_start = line_start_byte(source.as_bytes(), row);
25
40
    let byte = line_start + column;
26
40
    tree_sitter::Range {
27
40
        start_byte: byte,
28
40
        end_byte: byte,
29
40
        start_point: tree_sitter::Point { row, column },
30
40
        end_point: tree_sitter::Point { row, column },
31
40
    }
32
40
}
33

            
34
44
fn is_int_keyword_suffix(prefix: &str) -> bool {
35
44
    let prefix = prefix.trim_end();
36
44
    if !prefix.ends_with("int") {
37
1
        return false;
38
43
    }
39
43
    let bytes = prefix.as_bytes();
40
43
    bytes.len() == 3 || {
41
43
        let b = bytes[bytes.len() - 4];
42
43
        !(b.is_ascii_alphanumeric() || b == b'_')
43
    }
44
44
}
45

            
46
539
fn int_domain_missing_rparen_line(line: &str, start_col: usize, end_col: usize) -> bool {
47
539
    line.as_bytes().get(start_col) == Some(&b'(')
48
56
        && line[end_col..].trim().is_empty()
49
42
        && !line[start_col..].contains(')')
50
41
        && is_int_keyword_suffix(&line[..start_col])
51
539
}
52

            
53
/// tree-sitter `ERROR` node spans can overlap bytes during recovery.
54
/// Need to clamp to the end of the non-comment prefix so diagnostics don't include comment
55
/// contents.
56
667
fn clamp_range_before_line_comment(range: &mut tree_sitter::Range, source: &str) {
57
667
    let Some(line) = source.lines().nth(range.start_point.row) else {
58
        return;
59
    };
60
667
    let Some(dollar_idx) = line.find('$') else {
61
640
        return;
62
    };
63

            
64
27
    let prefix = &line[..dollar_idx];
65
27
    let clamped_col = prefix.trim_end().len();
66

            
67
27
    if range.start_point.column > clamped_col {
68
        range.start_point.column = clamped_col;
69
27
    }
70
27
    if range.end_point.row == range.start_point.row && range.end_point.column > clamped_col {
71
1
        range.end_point.column = clamped_col;
72
27
    }
73
27
    if range.end_point.row > range.start_point.row {
74
        range.end_point.row = range.start_point.row;
75
        range.end_point.column = clamped_col;
76
27
    }
77

            
78
27
    let line_start = line_start_byte(source.as_bytes(), range.start_point.row);
79
27
    range.start_byte = line_start + range.start_point.column;
80
27
    range.end_byte = line_start + range.end_point.column;
81
667
}
82

            
83
733
pub fn detect_syntactic_errors(
84
733
    source: &str,
85
733
    tree: &tree_sitter::Tree,
86
733
    errors: &mut Vec<RecoverableParseError>,
87
733
) {
88
733
    let mut malformed_lines_reported = HashSet::new();
89

            
90
733
    let root_node = tree.root_node();
91
18335
    let retract: &dyn Fn(&tree_sitter::Node) -> bool = &|node: &tree_sitter::Node| {
92
18335
        node.is_missing() || node.is_error() || node.start_position() == node.end_position()
93
18335
    };
94

            
95
18335
    for node in WalkDFS::with_retract(&root_node, &retract) {
96
18335
        if node.start_position() == node.end_position() {
97
170
            errors.push(classify_missing_token(node, source));
98
170
            continue;
99
18165
        }
100
18165
        if node.is_error() {
101
719
            let line = node.start_position().row;
102
            // If this line has already been reported as malformed, skip all error nodes on this line
103
719
            if malformed_lines_reported.contains(&line) {
104
                continue;
105
719
            }
106
            // Ignore error nodes that start inside a single-line comment.
107
719
            if let Some(line_str) = source.lines().nth(line)
108
719
                && let Some(dollar_idx) = line_str.find('$')
109
52
                && node.start_position().column >= dollar_idx
110
            {
111
                continue;
112
719
            }
113

            
114
719
            if is_malformed_line_error(&node, source) {
115
184
                malformed_lines_reported.insert(line);
116
184
                let start_byte = node.start_byte();
117
184
                let end_byte = node.end_byte();
118

            
119
184
                let last_char = source.lines().nth(line).map_or(0, |l| l.len());
120
184
                errors.push(RecoverableParseError::new(
121
184
                    generate_malformed_line_message(line, source),
122
184
                    Some(tree_sitter::Range {
123
184
                        start_byte,
124
184
                        end_byte,
125
184
                        start_point: tree_sitter::Point {
126
184
                            row: line,
127
184
                            column: 0,
128
184
                        },
129
184
                        end_point: tree_sitter::Point {
130
184
                            row: line,
131
184
                            column: last_char,
132
184
                        },
133
184
                    }),
134
                ));
135
184
                continue;
136
            } else {
137
535
                if let Some(missing_rparen) = classify_int_domain_missing_rparen(&node, source) {
138
39
                    errors.push(missing_rparen);
139
39
                    continue;
140
496
                }
141
496
                errors.push(classify_unexpected_token_error(node, source));
142
            }
143
496
            continue;
144
17446
        }
145
    }
146
733
}
147

            
148
/// Tree-sitter recovery sometimes reduces `int_domain` to bare `int` and then wraps the following
149
/// `(` and range text in an `ERROR` node (especially at EOF).
150
/// This function detects this specific pattern and reports  "Missing )" error
151
535
fn classify_int_domain_missing_rparen(
152
535
    node: &tree_sitter::Node,
153
535
    source: &str,
154
535
) -> Option<RecoverableParseError> {
155
535
    let start = node.start_position();
156
535
    let end = node.end_position();
157
535
    let line = source.lines().nth(start.row)?;
158
535
    let comment_col = line.find('$').unwrap_or(line.len());
159
535
    let line = &line[..comment_col];
160
535
    let start_col = start.column.min(line.len());
161
535
    let end_col = end.column.min(line.len());
162
535
    if !int_domain_missing_rparen_line(line, start_col, end_col) {
163
496
        return None;
164
39
    }
165
39
    let insertion_col = line.trim_end().len();
166
39
    Some(RecoverableParseError::new(
167
39
        "Missing )".to_string(),
168
39
        Some(point_range_at(source, start.row, insertion_col)),
169
39
    ))
170
535
}
171

            
172
/// Classifies a missing token node and generates a diagnostic with a context-aware message.
173
170
fn classify_missing_token(node: Node, source: &str) -> RecoverableParseError {
174
170
    let mut range = tree_sitter::Range {
175
170
        start_byte: node.start_byte(),
176
170
        end_byte: node.end_byte(),
177
170
        start_point: node.start_position(),
178
170
        end_point: node.end_position(),
179
170
    };
180
170
    clamp_range_before_line_comment(&mut range, source);
181

            
182
170
    let message = if let Some(parent) = node.parent() {
183
170
        match parent.kind() {
184
170
            "letting_variable_declaration" => "Missing Expression or Domain".to_string(),
185
144
            _ => format!("Missing {}", user_friendly_token_name(node.kind(), false)),
186
        }
187
    } else {
188
        format!("Missing {}", user_friendly_token_name(node.kind(), false))
189
    };
190

            
191
170
    RecoverableParseError::new(message, Some(range))
192
170
}
193

            
194
/// Classifies an unexpected token error node and generates a diagnostic.
195
496
fn classify_unexpected_token_error(node: Node, source_code: &str) -> RecoverableParseError {
196
496
    let mut range = tree_sitter::Range {
197
496
        start_byte: node.start_byte().min(source_code.len()),
198
496
        end_byte: node.end_byte().min(source_code.len()),
199
496
        start_point: node.start_position(),
200
496
        end_point: node.end_position(),
201
496
    };
202
496
    clamp_range_before_line_comment(&mut range, source_code);
203

            
204
496
    let message = if let Some(parent) = node.parent() {
205
        // Extract the unexpected token text, handling out-of-range indices safely.
206
        // NOTE: tree-sitter byte offsets can land inside UTF-8 codepoints; decoding lossily avoids panics.
207
496
        let src_token: std::borrow::Cow<'_, str> = source_code
208
496
            .as_bytes()
209
496
            .get(range.start_byte..range.end_byte)
210
496
            .map(String::from_utf8_lossy)
211
496
            .unwrap_or_else(|| std::borrow::Cow::Borrowed("<unknown>"));
212
496
        let src_token = src_token.trim_end();
213

            
214
496
        if parent.kind() == "program" {
215
145
            format!("Unexpected {}", src_token)
216
        } else {
217
351
            format!(
218
                "Unexpected {} inside {}",
219
                src_token,
220
351
                user_friendly_token_name(parent.kind(), true)
221
            )
222
        }
223
    } else {
224
        "Unexpected token".to_string()
225
    };
226

            
227
496
    RecoverableParseError::new(message, Some(range))
228
496
}
229

            
230
/// Determines if an error node represents a malformed line error.
231
1625
pub fn is_malformed_line_error(node: &tree_sitter::Node, source: &str) -> bool {
232
1625
    let parent = node.parent();
233
1625
    let grandparent = parent.and_then(|n| n.parent());
234
1625
    let root = grandparent.and_then(|n| n.parent());
235

            
236
1625
    if let (Some(parent), Some(grandparent), Some(root)) = (parent, grandparent, root)
237
849
        && parent.kind() == "set_comparison"
238
        && grandparent.kind() == "comparison_expr"
239
        && root.kind() == "program"
240
    {
241
        return true;
242
1625
    }
243

            
244
    // check parent kinds to see if the error is a constraint continuation
245
1625
    let mut curr = node.parent();
246
3759
    while let Some(n) = curr {
247
3009
        let kind = n.kind();
248
875
        if matches!(
249
3009
            kind,
250
3009
            "find_statement"
251
2723
                | "given_statement"
252
2723
                | "letting_statement"
253
2723
                | "dominance_relation"
254
2723
                | "bool_expr"
255
2671
                | "comparison_expr"
256
2498
                | "arithmetic_expr"
257
2316
                | "atom"
258
        ) {
259
875
            return false;
260
2134
        }
261
2134
        curr = n.parent();
262
    }
263

            
264
    // check for the first non-whitespace character on the line before the error node
265
750
    let line = source.lines().nth(node.start_position().row).unwrap_or("");
266
750
    let first_non_witespace = line
267
750
        .as_bytes()
268
750
        .iter()
269
881
        .take_while(|b| b.is_ascii_whitespace())
270
750
        .count();
271

            
272
    // if the error node is before or at the first non-whitespace character, it's a malformed line error
273
    // if the first non-whitespace character is after the error node, it could be a constraint continuation
274
750
    if node.start_position().column <= first_non_witespace || error_node_out_of_range(node, source)
275
    {
276
395
        if first_non_witespace > 0 && is_constraint_continuation(source, node.start_position().row)
277
        {
278
26
            return false;
279
369
        }
280
369
        return true;
281
355
    }
282
355
    false
283
1625
}
284

            
285
/// Checks if a line is a continuation of a constraint (i.e., it ends with a comma or has "such that" at the start).
286
53
fn is_constraint_continuation(source: &str, row: usize) -> bool {
287
53
    let lines: Vec<&str> = source.lines().collect();
288
53
    if row == 0 {
289
27
        return false;
290
26
    }
291

            
292
26
    let mut r = row;
293
26
    while r > 0 {
294
26
        r -= 1;
295
26
        let line = lines.get(r).copied().unwrap_or("");
296
26
        let line = line.split('$').next().unwrap_or("").trim_end();
297
26
        if line.trim().is_empty() {
298
            continue;
299
26
        }
300
26
        let lower = line.trim_start().to_ascii_lowercase();
301
26
        return lower.starts_with("such that") || line.ends_with(',');
302
    }
303
    false
304
53
}
305

            
306
/// Coverts a token name into a more user-friendly format for error messages.
307
/// Removes underscores, replaces certain keywords with more natural language, and adds appropriate articles.
308
498
fn user_friendly_token_name(token: &str, article: bool) -> String {
309
498
    let capitalized = if token.contains("atom") {
310
26
        "Expression".to_string()
311
472
    } else if token == "COLON" {
312
27
        ":".to_string()
313
    } else {
314
445
        let friendly_name = token
315
445
            .replace("literal", "")
316
445
            .replace("int", "Integer")
317
445
            .replace("expr", "Expression")
318
445
            .replace('_', " ");
319
445
        friendly_name
320
445
            .split_whitespace()
321
824
            .map(|word| word.capitalize())
322
445
            .collect::<Vec<_>>()
323
445
            .join(" ")
324
    };
325

            
326
498
    if !article {
327
146
        return capitalized;
328
352
    }
329
352
    let first_char = capitalized.chars().next().unwrap();
330
352
    let article = match first_char.to_ascii_lowercase() {
331
118
        'a' | 'e' | 'i' | 'o' | 'u' => "an",
332
234
        _ => "a",
333
    };
334
352
    format!("{} {}", article, capitalized)
335
498
}
336

            
337
// Generates an informative error message for malformed lines
338
192
fn generate_malformed_line_message(line: usize, source: &str) -> String {
339
192
    let got = source.lines().nth(line).unwrap_or("").trim();
340
192
    let got = got.split('$').next().unwrap_or("").trim_end();
341
192
    let got = got.replace('"', "\\\"");
342
192
    let mut words = got.split_whitespace();
343
192
    let first = words.next().unwrap_or("").to_ascii_lowercase();
344
192
    let second = words.next().unwrap_or("").to_ascii_lowercase();
345

            
346
192
    let expected = match first.as_str() {
347
192
        "find" => "a find declaration statement",
348
98
        "letting" => "a letting declaration statement",
349
71
        "given" => "a given declaration statement",
350
70
        "where" => "an instantiation condition",
351
69
        "minimising" | "maximising" => "an objective statement",
352
68
        "such" => {
353
            // Check for invalid constraint statement
354
15
            if second == "that" {
355
14
                "a constraint statement"
356
            } else {
357
1
                "a valid top-level statement"
358
            }
359
        }
360

            
361
        _ => {
362
            // Default case for unrecognized starting tokens
363
53
            "a valid top-level statement"
364
        }
365
    };
366
192
    format!("Expected {}, but got '{}'", expected, got)
367
192
}
368

            
369
/// Returns true if the node's start or end column is out of range for its line in the source.
370
355
fn error_node_out_of_range(node: &tree_sitter::Node, source: &str) -> bool {
371
355
    let lines: Vec<&str> = source.lines().collect();
372
355
    let start = node.start_position();
373
355
    let end = node.end_position();
374

            
375
355
    let start_line_len = lines.get(start.row).map_or(0, |l| l.len());
376
355
    let end_line_len = lines.get(end.row).map_or(0, |l| l.len());
377

            
378
355
    (start.column > start_line_len) || (end.column > end_line_len)
379
355
}
380

            
381
#[cfg(test)]
382
mod test {
383

            
384
    use super::{
385
        clamp_range_before_line_comment, detect_syntactic_errors, int_domain_missing_rparen_line,
386
        is_int_keyword_suffix, is_malformed_line_error, line_start_byte, point_range_at,
387
        user_friendly_token_name,
388
    };
389
    use crate::errors::RecoverableParseError;
390
    use crate::{parser::traversal::WalkDFS, util::get_tree};
391

            
392
    /// Helper function for tests to compare the actual error with the expected one.
393
1
    fn assert_essence_parse_error_eq(a: &RecoverableParseError, b: &RecoverableParseError) {
394
1
        assert_eq!(a.msg, b.msg, "error messages differ");
395
1
        assert_eq!(a.range, b.range, "error ranges differ");
396
1
    }
397

            
398
    #[test]
399
1
    fn malformed_line() {
400
1
        let source = " a,a,b: int(1..3)";
401
1
        let (tree, _) = get_tree(source).expect("Should parse");
402
1
        let root_node = tree.root_node();
403

            
404
1
        let error_node = WalkDFS::with_retract(&root_node, &|_node| false)
405
2
            .find(|node| node.is_error())
406
1
            .expect("Should find an error node");
407

            
408
1
        assert!(is_malformed_line_error(&error_node, source));
409
1
    }
410

            
411
    #[test]
412
1
    fn malformed_find_message() {
413
1
        let source = "find >=lex,b,c: int(1..3)";
414
1
        let message = super::generate_malformed_line_message(0, source);
415
1
        assert_eq!(
416
            message,
417
            "Expected a find declaration statement, but got 'find >=lex,b,c: int(1..3)'"
418
        );
419
1
    }
420

            
421
    #[test]
422
1
    fn malformed_top_level_message() {
423
1
        let source = "a >=lex,b,c: int(1..3)";
424
1
        let message = super::generate_malformed_line_message(0, source);
425
1
        assert_eq!(
426
            message,
427
            "Expected a valid top-level statement, but got 'a >=lex,b,c: int(1..3)'"
428
        );
429
1
    }
430

            
431
    #[test]
432
1
    fn malformed_objective_message() {
433
1
        let source = "maximising %x";
434
1
        let message = super::generate_malformed_line_message(0, source);
435
1
        assert_eq!(
436
            message,
437
            "Expected an objective statement, but got 'maximising %x'"
438
        );
439
1
    }
440

            
441
    #[test]
442
1
    fn malformed_letting_message() {
443
1
        let source = "letting % A be 3";
444
1
        let message = super::generate_malformed_line_message(0, source);
445
1
        assert_eq!(
446
            message,
447
            "Expected a letting declaration statement, but got 'letting % A be 3'"
448
        );
449
1
    }
450

            
451
    #[test]
452
1
    fn malformed_constraint_message() {
453
1
        let source = "such that % A > 3";
454
1
        let message = super::generate_malformed_line_message(0, source);
455
1
        assert_eq!(
456
            message,
457
            "Expected a constraint statement, but got 'such that % A > 3'"
458
        );
459
1
    }
460

            
461
    #[test]
462
1
    fn malformed_top_level_message_2() {
463
1
        let source = "such % A > 3";
464
1
        let message = super::generate_malformed_line_message(0, source);
465
1
        assert_eq!(
466
            message,
467
            "Expected a valid top-level statement, but got 'such % A > 3'"
468
        );
469
1
    }
470

            
471
    #[test]
472
1
    fn malformed_given_message() {
473
1
        let source = "given 1..3)";
474
1
        let message = super::generate_malformed_line_message(0, source);
475
1
        assert_eq!(
476
            message,
477
            "Expected a given declaration statement, but got 'given 1..3)'"
478
        );
479
1
    }
480

            
481
    #[test]
482
1
    fn malformed_where_message() {
483
1
        let source = "where x>6";
484
1
        let message = super::generate_malformed_line_message(0, source);
485
1
        assert_eq!(
486
            message,
487
            "Expected an instantiation condition, but got 'where x>6'"
488
        );
489
1
    }
490

            
491
    #[test]
492
1
    fn user_friendly_token_name_article() {
493
1
        assert_eq!(
494
1
            user_friendly_token_name("int_domain", false),
495
            "Integer Domain"
496
        );
497
1
        assert_eq!(
498
1
            user_friendly_token_name("int_domain", true),
499
            "an Integer Domain"
500
        );
501
        // assert_eq!(user_friendly_token_name("atom", true), "an Expression");
502
1
        assert_eq!(user_friendly_token_name("COLON", false), ":");
503
1
    }
504

            
505
    #[test]
506
1
    fn missing_domain() {
507
1
        let source = "find x:";
508
1
        let (tree, _) = get_tree(source).expect("Should parse");
509
1
        let mut errors = vec![];
510
1
        detect_syntactic_errors(source, &tree, &mut errors);
511
1
        assert_eq!(errors.len(), 1, "Expected exactly one diagnostic");
512

            
513
1
        let error = &errors[0];
514

            
515
1
        assert_essence_parse_error_eq(
516
1
            error,
517
1
            &RecoverableParseError::new(
518
1
                "Missing Domain".to_string(),
519
1
                Some(tree_sitter::Range {
520
1
                    start_byte: 7,
521
1
                    end_byte: 7,
522
1
                    start_point: tree_sitter::Point { row: 0, column: 7 },
523
1
                    end_point: tree_sitter::Point { row: 0, column: 7 },
524
1
                }),
525
1
            ),
526
        );
527
1
    }
528

            
529
    #[test]
530
1
    fn line_start_byte_returns_correct_offsets() {
531
1
        let source = "a\nbc\ndef";
532
1
        let bytes = source.as_bytes();
533
1
        assert_eq!(line_start_byte(bytes, 0), 0);
534
1
        assert_eq!(line_start_byte(bytes, 1), 2);
535
1
        assert_eq!(line_start_byte(bytes, 2), 5);
536
1
    }
537

            
538
    #[test]
539
1
    fn point_range_at_returns_correct_zero_length_range() {
540
1
        let source = "a\nbc\ndef";
541
1
        let range = point_range_at(source, 1, 1); // points to 'c'
542
1
        assert_eq!(range.start_point.row, 1);
543
1
        assert_eq!(range.start_point.column, 1);
544
1
        assert_eq!(range.end_point, range.start_point);
545
1
        assert_eq!(range.start_byte, 3);
546
1
        assert_eq!(range.end_byte, 3);
547
1
    }
548

            
549
    #[test]
550
1
    fn clamp_range_before_line_comment_clamps_end_to_before_dollar() {
551
1
        let source = "find x: int(1..3 $comment";
552
1
        let mut range = tree_sitter::Range {
553
1
            start_byte: 0,
554
1
            end_byte: source.len(),
555
1
            start_point: tree_sitter::Point { row: 0, column: 0 },
556
1
            end_point: tree_sitter::Point {
557
1
                row: 0,
558
1
                column: source.len(),
559
1
            },
560
1
        };
561

            
562
1
        clamp_range_before_line_comment(&mut range, source);
563

            
564
        // "find x: int(1..3" ends at byte/column 16; the `$comment` suffix must be excluded.
565
1
        assert_eq!(range.end_point.row, 0);
566
1
        assert_eq!(range.end_point.column, 16);
567
1
        assert_eq!(range.end_byte, 16);
568
1
    }
569

            
570
    #[test]
571
1
    fn int_keyword_suffix_checks_word_boundary() {
572
1
        assert!(is_int_keyword_suffix("find x: int"));
573
1
        assert!(!is_int_keyword_suffix("foo"));
574
1
        assert!(!is_int_keyword_suffix("mint"));
575
1
    }
576

            
577
    #[test]
578
1
    fn int_domain_missing_rparen_line_positive_and_negative_cases() {
579
1
        let ok = "find x: int(1..2";
580
1
        let start = ok.find('(').unwrap();
581
1
        assert!(int_domain_missing_rparen_line(ok, start, ok.len()));
582

            
583
1
        let has_rparen = "find x: int(1..2)";
584
1
        let start = has_rparen.find('(').unwrap();
585
1
        assert!(!int_domain_missing_rparen_line(
586
1
            has_rparen,
587
1
            start,
588
1
            has_rparen.len()
589
1
        ));
590

            
591
1
        let trailing = "find x: int(1..2 foo";
592
1
        let start = trailing.find('(').unwrap();
593
1
        let end = trailing.find(" foo").unwrap();
594
1
        assert!(!int_domain_missing_rparen_line(trailing, start, end));
595

            
596
1
        let print_like = "find x: print(1..2";
597
1
        let start = print_like.find('(').unwrap();
598
1
        assert!(!int_domain_missing_rparen_line(
599
1
            print_like,
600
1
            start,
601
1
            print_like.len()
602
1
        ));
603
1
    }
604
}