1
use crate::errors::RecoverableParseError;
2
use crate::parser::traversal::WalkDFS;
3
use capitalize::Capitalize;
4
use std::collections::HashSet;
5
use tree_sitter::Node;
6

            
7
/// Returns the absolute byte offset of the start of `row` in `source`.
8
/// TODO: move to a separate module like utils, we don't want semantic tokens to depend on parser internals
9
70
pub fn line_start_byte(source: &[u8], row: usize) -> usize {
10
70
    let mut current_row = 0usize;
11
70
    let mut line_start = 0usize;
12
1041
    for (idx, b) in source.iter().enumerate() {
13
1041
        if current_row == row {
14
70
            break;
15
971
        }
16
971
        if *b == b'\n' {
17
30
            current_row += 1;
18
30
            line_start = idx + 1;
19
941
        }
20
    }
21
70
    line_start
22
70
}
23

            
24
40
fn point_range_at(source: &str, row: usize, column: usize) -> tree_sitter::Range {
25
40
    let line_start = line_start_byte(source.as_bytes(), row);
26
40
    let byte = line_start + column;
27
40
    tree_sitter::Range {
28
40
        start_byte: byte,
29
40
        end_byte: byte,
30
40
        start_point: tree_sitter::Point { row, column },
31
40
        end_point: tree_sitter::Point { row, column },
32
40
    }
33
40
}
34

            
35
44
fn is_int_keyword_suffix(prefix: &str) -> bool {
36
44
    let prefix = prefix.trim_end();
37
44
    if !prefix.ends_with("int") {
38
1
        return false;
39
43
    }
40
43
    let bytes = prefix.as_bytes();
41
43
    bytes.len() == 3 || {
42
43
        let b = bytes[bytes.len() - 4];
43
43
        !(b.is_ascii_alphanumeric() || b == b'_')
44
    }
45
44
}
46

            
47
526
fn int_domain_missing_rparen_line(line: &str, start_col: usize, end_col: usize) -> bool {
48
526
    line.as_bytes().get(start_col) == Some(&b'(')
49
56
        && line[end_col..].trim().is_empty()
50
42
        && !line[start_col..].contains(')')
51
41
        && is_int_keyword_suffix(&line[..start_col])
52
526
}
53

            
54
/// tree-sitter `ERROR` node spans can overlap bytes during recovery.
55
/// Need to clamp to the end of the non-comment prefix so diagnostics don't include comment
56
/// contents.
57
654
fn clamp_range_before_line_comment(range: &mut tree_sitter::Range, source: &str) {
58
654
    let Some(line) = source.lines().nth(range.start_point.row) else {
59
        return;
60
    };
61
654
    let Some(dollar_idx) = line.find('$') else {
62
627
        return;
63
    };
64

            
65
27
    let prefix = &line[..dollar_idx];
66
27
    let clamped_col = prefix.trim_end().len();
67

            
68
27
    if range.start_point.column > clamped_col {
69
        range.start_point.column = clamped_col;
70
27
    }
71
27
    if range.end_point.row == range.start_point.row && range.end_point.column > clamped_col {
72
1
        range.end_point.column = clamped_col;
73
27
    }
74
27
    if range.end_point.row > range.start_point.row {
75
        range.end_point.row = range.start_point.row;
76
        range.end_point.column = clamped_col;
77
27
    }
78

            
79
27
    let line_start = line_start_byte(source.as_bytes(), range.start_point.row);
80
27
    range.start_byte = line_start + range.start_point.column;
81
27
    range.end_byte = line_start + range.end_point.column;
82
654
}
83

            
84
720
pub fn detect_syntactic_errors(
85
720
    source: &str,
86
720
    tree: &tree_sitter::Tree,
87
720
    errors: &mut Vec<RecoverableParseError>,
88
720
) {
89
720
    let mut malformed_lines_reported = HashSet::new();
90

            
91
720
    let root_node = tree.root_node();
92
18010
    let retract: &dyn Fn(&tree_sitter::Node) -> bool = &|node: &tree_sitter::Node| {
93
18010
        node.is_missing() || node.is_error() || node.start_position() == node.end_position()
94
18010
    };
95

            
96
18010
    for node in WalkDFS::with_retract(&root_node, &retract) {
97
18010
        if node.start_position() == node.end_position() {
98
170
            errors.push(classify_missing_token(node, source));
99
170
            continue;
100
17840
        }
101
17840
        if node.is_error() {
102
706
            let line = node.start_position().row;
103
            // If this line has already been reported as malformed, skip all error nodes on this line
104
706
            if malformed_lines_reported.contains(&line) {
105
                continue;
106
706
            }
107
            // Ignore error nodes that start inside a single-line comment.
108
706
            if let Some(line_str) = source.lines().nth(line)
109
706
                && let Some(dollar_idx) = line_str.find('$')
110
52
                && node.start_position().column >= dollar_idx
111
            {
112
                continue;
113
706
            }
114

            
115
706
            if is_malformed_line_error(&node, source) {
116
184
                malformed_lines_reported.insert(line);
117
184
                let start_byte = node.start_byte();
118
184
                let end_byte = node.end_byte();
119

            
120
184
                let last_char = source.lines().nth(line).map_or(0, |l| l.len());
121
184
                errors.push(RecoverableParseError::new(
122
184
                    generate_malformed_line_message(line, source),
123
184
                    Some(tree_sitter::Range {
124
184
                        start_byte,
125
184
                        end_byte,
126
184
                        start_point: tree_sitter::Point {
127
184
                            row: line,
128
184
                            column: 0,
129
184
                        },
130
184
                        end_point: tree_sitter::Point {
131
184
                            row: line,
132
184
                            column: last_char,
133
184
                        },
134
184
                    }),
135
                ));
136
184
                continue;
137
            } else {
138
522
                if let Some(missing_rparen) = classify_int_domain_missing_rparen(&node, source) {
139
39
                    errors.push(missing_rparen);
140
39
                    continue;
141
483
                }
142
483
                errors.push(classify_unexpected_token_error(node, source));
143
            }
144
483
            continue;
145
17134
        }
146
    }
147
720
}
148

            
149
/// Tree-sitter recovery sometimes reduces `int_domain` to bare `int` and then wraps the following
150
/// `(` and range text in an `ERROR` node (especially at EOF).
151
/// This function detects this specific pattern and reports  "Missing )" error
152
522
fn classify_int_domain_missing_rparen(
153
522
    node: &tree_sitter::Node,
154
522
    source: &str,
155
522
) -> Option<RecoverableParseError> {
156
522
    let start = node.start_position();
157
522
    let end = node.end_position();
158
522
    let line = source.lines().nth(start.row)?;
159
522
    let comment_col = line.find('$').unwrap_or(line.len());
160
522
    let line = &line[..comment_col];
161
522
    let start_col = start.column.min(line.len());
162
522
    let end_col = end.column.min(line.len());
163
522
    if !int_domain_missing_rparen_line(line, start_col, end_col) {
164
483
        return None;
165
39
    }
166
39
    let insertion_col = line.trim_end().len();
167
39
    Some(RecoverableParseError::new(
168
39
        "Missing )".to_string(),
169
39
        Some(point_range_at(source, start.row, insertion_col)),
170
39
    ))
171
522
}
172

            
173
/// Classifies a missing token node and generates a diagnostic with a context-aware message.
174
170
fn classify_missing_token(node: Node, source: &str) -> RecoverableParseError {
175
170
    let mut range = tree_sitter::Range {
176
170
        start_byte: node.start_byte(),
177
170
        end_byte: node.end_byte(),
178
170
        start_point: node.start_position(),
179
170
        end_point: node.end_position(),
180
170
    };
181
170
    clamp_range_before_line_comment(&mut range, source);
182

            
183
170
    let message = if let Some(parent) = node.parent() {
184
170
        match parent.kind() {
185
170
            "letting_variable_declaration" => "Missing Expression or Domain".to_string(),
186
144
            _ => format!("Missing {}", user_friendly_token_name(node.kind(), false)),
187
        }
188
    } else {
189
        format!("Missing {}", user_friendly_token_name(node.kind(), false))
190
    };
191

            
192
170
    RecoverableParseError::new(message, Some(range))
193
170
}
194

            
195
/// Classifies an unexpected token error node and generates a diagnostic.
196
483
fn classify_unexpected_token_error(node: Node, source_code: &str) -> RecoverableParseError {
197
483
    let mut range = tree_sitter::Range {
198
483
        start_byte: node.start_byte().min(source_code.len()),
199
483
        end_byte: node.end_byte().min(source_code.len()),
200
483
        start_point: node.start_position(),
201
483
        end_point: node.end_position(),
202
483
    };
203
483
    clamp_range_before_line_comment(&mut range, source_code);
204

            
205
483
    let message = if let Some(parent) = node.parent() {
206
        // Extract the unexpected token text, handling out-of-range indices safely.
207
        // NOTE: tree-sitter byte offsets can land inside UTF-8 codepoints; decoding lossily avoids panics.
208
483
        let src_token: std::borrow::Cow<'_, str> = source_code
209
483
            .as_bytes()
210
483
            .get(range.start_byte..range.end_byte)
211
483
            .map(String::from_utf8_lossy)
212
483
            .unwrap_or_else(|| std::borrow::Cow::Borrowed("<unknown>"));
213
483
        let src_token = src_token.trim_end();
214

            
215
483
        if parent.kind() == "program" {
216
145
            format!("Unexpected {}", src_token)
217
        } else {
218
338
            format!(
219
                "Unexpected {} inside {}",
220
                src_token,
221
338
                user_friendly_token_name(parent.kind(), true)
222
            )
223
        }
224
    } else {
225
        "Unexpected token".to_string()
226
    };
227

            
228
483
    RecoverableParseError::new(message, Some(range))
229
483
}
230

            
231
/// Determines if an error node represents a malformed line error.
232
1599
pub fn is_malformed_line_error(node: &tree_sitter::Node, source: &str) -> bool {
233
1599
    let parent = node.parent();
234
1599
    let grandparent = parent.and_then(|n| n.parent());
235
1599
    let root = grandparent.and_then(|n| n.parent());
236

            
237
1599
    if let (Some(parent), Some(grandparent), Some(root)) = (parent, grandparent, root)
238
823
        && parent.kind() == "set_comparison"
239
        && grandparent.kind() == "comparison_expr"
240
        && root.kind() == "program"
241
    {
242
        return true;
243
1599
    }
244

            
245
    // check parent kinds to see if the error is a constraint continuation
246
1599
    let mut curr = node.parent();
247
3707
    while let Some(n) = curr {
248
2957
        let kind = n.kind();
249
849
        if matches!(
250
2957
            kind,
251
2957
            "find_statement"
252
2671
                | "given_statement"
253
2671
                | "letting_statement"
254
2671
                | "dominance_relation"
255
2671
                | "bool_expr"
256
2645
                | "comparison_expr"
257
2472
                | "arithmetic_expr"
258
2290
                | "atom"
259
        ) {
260
849
            return false;
261
2108
        }
262
2108
        curr = n.parent();
263
    }
264

            
265
    // check for the first non-whitespace character on the line before the error node
266
750
    let line = source.lines().nth(node.start_position().row).unwrap_or("");
267
750
    let first_non_witespace = line
268
750
        .as_bytes()
269
750
        .iter()
270
881
        .take_while(|b| b.is_ascii_whitespace())
271
750
        .count();
272

            
273
    // if the error node is before or at the first non-whitespace character, it's a malformed line error
274
    // if the first non-whitespace character is after the error node, it could be a constraint continuation
275
750
    if node.start_position().column <= first_non_witespace || error_node_out_of_range(node, source)
276
    {
277
395
        if first_non_witespace > 0 && is_constraint_continuation(source, node.start_position().row)
278
        {
279
26
            return false;
280
369
        }
281
369
        return true;
282
355
    }
283
355
    false
284
1599
}
285

            
286
/// Checks if a line is a continuation of a constraint (i.e., it ends with a comma or has "such that" at the start).
287
53
fn is_constraint_continuation(source: &str, row: usize) -> bool {
288
53
    let lines: Vec<&str> = source.lines().collect();
289
53
    if row == 0 {
290
27
        return false;
291
26
    }
292

            
293
26
    let mut r = row;
294
26
    while r > 0 {
295
26
        r -= 1;
296
26
        let line = lines.get(r).copied().unwrap_or("");
297
26
        let line = line.split('$').next().unwrap_or("").trim_end();
298
26
        if line.trim().is_empty() {
299
            continue;
300
26
        }
301
26
        let lower = line.trim_start().to_ascii_lowercase();
302
26
        return lower.starts_with("such that") || line.ends_with(',');
303
    }
304
    false
305
53
}
306

            
307
/// Coverts a token name into a more user-friendly format for error messages.
308
/// Removes underscores, replaces certain keywords with more natural language, and adds appropriate articles.
309
485
fn user_friendly_token_name(token: &str, article: bool) -> String {
310
485
    let capitalized = if token.contains("atom") {
311
26
        "Expression".to_string()
312
459
    } else if token == "COLON" {
313
27
        ":".to_string()
314
    } else {
315
432
        let friendly_name = token
316
432
            .replace("literal", "")
317
432
            .replace("int", "Integer")
318
432
            .replace("expr", "Expression")
319
432
            .replace('_', " ");
320
432
        friendly_name
321
432
            .split_whitespace()
322
772
            .map(|word| word.capitalize())
323
432
            .collect::<Vec<_>>()
324
432
            .join(" ")
325
    };
326

            
327
485
    if !article {
328
146
        return capitalized;
329
339
    }
330
339
    let first_char = capitalized.chars().next().unwrap();
331
339
    let article = match first_char.to_ascii_lowercase() {
332
118
        'a' | 'e' | 'i' | 'o' | 'u' => "an",
333
221
        _ => "a",
334
    };
335
339
    format!("{} {}", article, capitalized)
336
485
}
337

            
338
// Generates an informative error message for malformed lines
339
192
fn generate_malformed_line_message(line: usize, source: &str) -> String {
340
192
    let got = source.lines().nth(line).unwrap_or("").trim();
341
192
    let got = got.split('$').next().unwrap_or("").trim_end();
342
192
    let got = got.replace('"', "\\\"");
343
192
    let mut words = got.split_whitespace();
344
192
    let first = words.next().unwrap_or("").to_ascii_lowercase();
345
192
    let second = words.next().unwrap_or("").to_ascii_lowercase();
346

            
347
192
    let expected = match first.as_str() {
348
192
        "find" => "a find declaration statement",
349
98
        "letting" => "a letting declaration statement",
350
71
        "given" => "a given declaration statement",
351
70
        "where" => "an instantiation condition",
352
69
        "minimising" | "maximising" => "an objective statement",
353
        // Check for invalid constraint statement
354
68
        "such" if second == "that" => "a constraint statement",
355
1
        "such" => "a valid top-level statement",
356
        _ => {
357
            // Default case for unrecognized starting tokens
358
53
            "a valid top-level statement"
359
        }
360
    };
361
192
    format!("Expected {}, but got '{}'", expected, got)
362
192
}
363

            
364
/// Returns true if the node's start or end column is out of range for its line in the source.
365
355
fn error_node_out_of_range(node: &tree_sitter::Node, source: &str) -> bool {
366
355
    let lines: Vec<&str> = source.lines().collect();
367
355
    let start = node.start_position();
368
355
    let end = node.end_position();
369

            
370
355
    let start_line_len = lines.get(start.row).map_or(0, |l| l.len());
371
355
    let end_line_len = lines.get(end.row).map_or(0, |l| l.len());
372

            
373
355
    (start.column > start_line_len) || (end.column > end_line_len)
374
355
}
375

            
376
#[cfg(test)]
377
mod test {
378

            
379
    use super::{
380
        clamp_range_before_line_comment, detect_syntactic_errors, int_domain_missing_rparen_line,
381
        is_int_keyword_suffix, is_malformed_line_error, line_start_byte, point_range_at,
382
        user_friendly_token_name,
383
    };
384
    use crate::errors::RecoverableParseError;
385
    use crate::{parser::traversal::WalkDFS, util::get_tree};
386

            
387
    /// Helper function for tests to compare the actual error with the expected one.
388
1
    fn assert_essence_parse_error_eq(a: &RecoverableParseError, b: &RecoverableParseError) {
389
1
        assert_eq!(a.msg, b.msg, "error messages differ");
390
1
        assert_eq!(a.range, b.range, "error ranges differ");
391
1
    }
392

            
393
    #[test]
394
1
    fn malformed_line() {
395
1
        let source = " a,a,b: int(1..3)";
396
1
        let (tree, _) = get_tree(source).expect("Should parse");
397
1
        let root_node = tree.root_node();
398

            
399
1
        let error_node = WalkDFS::with_retract(&root_node, &|_node| false)
400
2
            .find(|node| node.is_error())
401
1
            .expect("Should find an error node");
402

            
403
1
        assert!(is_malformed_line_error(&error_node, source));
404
1
    }
405

            
406
    #[test]
407
1
    fn malformed_find_message() {
408
1
        let source = "find >=lex,b,c: int(1..3)";
409
1
        let message = super::generate_malformed_line_message(0, source);
410
1
        assert_eq!(
411
            message,
412
            "Expected a find declaration statement, but got 'find >=lex,b,c: int(1..3)'"
413
        );
414
1
    }
415

            
416
    #[test]
417
1
    fn malformed_top_level_message() {
418
1
        let source = "a >=lex,b,c: int(1..3)";
419
1
        let message = super::generate_malformed_line_message(0, source);
420
1
        assert_eq!(
421
            message,
422
            "Expected a valid top-level statement, but got 'a >=lex,b,c: int(1..3)'"
423
        );
424
1
    }
425

            
426
    #[test]
427
1
    fn malformed_objective_message() {
428
1
        let source = "maximising %x";
429
1
        let message = super::generate_malformed_line_message(0, source);
430
1
        assert_eq!(
431
            message,
432
            "Expected an objective statement, but got 'maximising %x'"
433
        );
434
1
    }
435

            
436
    #[test]
437
1
    fn malformed_letting_message() {
438
1
        let source = "letting % A be 3";
439
1
        let message = super::generate_malformed_line_message(0, source);
440
1
        assert_eq!(
441
            message,
442
            "Expected a letting declaration statement, but got 'letting % A be 3'"
443
        );
444
1
    }
445

            
446
    #[test]
447
1
    fn malformed_constraint_message() {
448
1
        let source = "such that % A > 3";
449
1
        let message = super::generate_malformed_line_message(0, source);
450
1
        assert_eq!(
451
            message,
452
            "Expected a constraint statement, but got 'such that % A > 3'"
453
        );
454
1
    }
455

            
456
    #[test]
457
1
    fn malformed_top_level_message_2() {
458
1
        let source = "such % A > 3";
459
1
        let message = super::generate_malformed_line_message(0, source);
460
1
        assert_eq!(
461
            message,
462
            "Expected a valid top-level statement, but got 'such % A > 3'"
463
        );
464
1
    }
465

            
466
    #[test]
467
1
    fn malformed_given_message() {
468
1
        let source = "given 1..3)";
469
1
        let message = super::generate_malformed_line_message(0, source);
470
1
        assert_eq!(
471
            message,
472
            "Expected a given declaration statement, but got 'given 1..3)'"
473
        );
474
1
    }
475

            
476
    #[test]
477
1
    fn malformed_where_message() {
478
1
        let source = "where x>6";
479
1
        let message = super::generate_malformed_line_message(0, source);
480
1
        assert_eq!(
481
            message,
482
            "Expected an instantiation condition, but got 'where x>6'"
483
        );
484
1
    }
485

            
486
    #[test]
487
1
    fn user_friendly_token_name_article() {
488
1
        assert_eq!(
489
1
            user_friendly_token_name("int_domain", false),
490
            "Integer Domain"
491
        );
492
1
        assert_eq!(
493
1
            user_friendly_token_name("int_domain", true),
494
            "an Integer Domain"
495
        );
496
        // assert_eq!(user_friendly_token_name("atom", true), "an Expression");
497
1
        assert_eq!(user_friendly_token_name("COLON", false), ":");
498
1
    }
499

            
500
    #[test]
501
1
    fn missing_domain() {
502
1
        let source = "find x:";
503
1
        let (tree, _) = get_tree(source).expect("Should parse");
504
1
        let mut errors = vec![];
505
1
        detect_syntactic_errors(source, &tree, &mut errors);
506
1
        assert_eq!(errors.len(), 1, "Expected exactly one diagnostic");
507

            
508
1
        let error = &errors[0];
509

            
510
1
        assert_essence_parse_error_eq(
511
1
            error,
512
1
            &RecoverableParseError::new(
513
1
                "Missing Domain".to_string(),
514
1
                Some(tree_sitter::Range {
515
1
                    start_byte: 7,
516
1
                    end_byte: 7,
517
1
                    start_point: tree_sitter::Point { row: 0, column: 7 },
518
1
                    end_point: tree_sitter::Point { row: 0, column: 7 },
519
1
                }),
520
1
            ),
521
        );
522
1
    }
523

            
524
    #[test]
525
1
    fn line_start_byte_returns_correct_offsets() {
526
1
        let source = "a\nbc\ndef";
527
1
        let bytes = source.as_bytes();
528
1
        assert_eq!(line_start_byte(bytes, 0), 0);
529
1
        assert_eq!(line_start_byte(bytes, 1), 2);
530
1
        assert_eq!(line_start_byte(bytes, 2), 5);
531
1
    }
532

            
533
    #[test]
534
1
    fn point_range_at_returns_correct_zero_length_range() {
535
1
        let source = "a\nbc\ndef";
536
1
        let range = point_range_at(source, 1, 1); // points to 'c'
537
1
        assert_eq!(range.start_point.row, 1);
538
1
        assert_eq!(range.start_point.column, 1);
539
1
        assert_eq!(range.end_point, range.start_point);
540
1
        assert_eq!(range.start_byte, 3);
541
1
        assert_eq!(range.end_byte, 3);
542
1
    }
543

            
544
    #[test]
545
1
    fn clamp_range_before_line_comment_clamps_end_to_before_dollar() {
546
1
        let source = "find x: int(1..3 $comment";
547
1
        let mut range = tree_sitter::Range {
548
1
            start_byte: 0,
549
1
            end_byte: source.len(),
550
1
            start_point: tree_sitter::Point { row: 0, column: 0 },
551
1
            end_point: tree_sitter::Point {
552
1
                row: 0,
553
1
                column: source.len(),
554
1
            },
555
1
        };
556

            
557
1
        clamp_range_before_line_comment(&mut range, source);
558

            
559
        // "find x: int(1..3" ends at byte/column 16; the `$comment` suffix must be excluded.
560
1
        assert_eq!(range.end_point.row, 0);
561
1
        assert_eq!(range.end_point.column, 16);
562
1
        assert_eq!(range.end_byte, 16);
563
1
    }
564

            
565
    #[test]
566
1
    fn int_keyword_suffix_checks_word_boundary() {
567
1
        assert!(is_int_keyword_suffix("find x: int"));
568
1
        assert!(!is_int_keyword_suffix("foo"));
569
1
        assert!(!is_int_keyword_suffix("mint"));
570
1
    }
571

            
572
    #[test]
573
1
    fn int_domain_missing_rparen_line_positive_and_negative_cases() {
574
1
        let ok = "find x: int(1..2";
575
1
        let start = ok.find('(').unwrap();
576
1
        assert!(int_domain_missing_rparen_line(ok, start, ok.len()));
577

            
578
1
        let has_rparen = "find x: int(1..2)";
579
1
        let start = has_rparen.find('(').unwrap();
580
1
        assert!(!int_domain_missing_rparen_line(
581
1
            has_rparen,
582
1
            start,
583
1
            has_rparen.len()
584
1
        ));
585

            
586
1
        let trailing = "find x: int(1..2 foo";
587
1
        let start = trailing.find('(').unwrap();
588
1
        let end = trailing.find(" foo").unwrap();
589
1
        assert!(!int_domain_missing_rparen_line(trailing, start, end));
590

            
591
1
        let print_like = "find x: print(1..2";
592
1
        let start = print_like.find('(').unwrap();
593
1
        assert!(!int_domain_missing_rparen_line(
594
1
            print_like,
595
1
            start,
596
1
            print_like.len()
597
1
        ));
598
1
    }
599
}