Skip to main content

conjure_cp_essence_parser/parser/
syntax_errors.rs

1use crate::errors::RecoverableParseError;
2use crate::parser::traversal::WalkDFS;
3use capitalize::Capitalize;
4use std::collections::HashSet;
5use tree_sitter::Node;
6
7/// Returns the absolute byte offset of the start of `row` in `source`.
8/// TODO: move to a separate module like utils, we don't want semantic tokens to depend on parser internals
9pub fn line_start_byte(source: &[u8], row: usize) -> usize {
10    let mut current_row = 0usize;
11    let mut line_start = 0usize;
12    for (idx, b) in source.iter().enumerate() {
13        if current_row == row {
14            break;
15        }
16        if *b == b'\n' {
17            current_row += 1;
18            line_start = idx + 1;
19        }
20    }
21    line_start
22}
23
24fn point_range_at(source: &str, row: usize, column: usize) -> tree_sitter::Range {
25    let line_start = line_start_byte(source.as_bytes(), row);
26    let byte = line_start + column;
27    tree_sitter::Range {
28        start_byte: byte,
29        end_byte: byte,
30        start_point: tree_sitter::Point { row, column },
31        end_point: tree_sitter::Point { row, column },
32    }
33}
34
35fn is_int_keyword_suffix(prefix: &str) -> bool {
36    let prefix = prefix.trim_end();
37    if !prefix.ends_with("int") {
38        return false;
39    }
40    let bytes = prefix.as_bytes();
41    bytes.len() == 3 || {
42        let b = bytes[bytes.len() - 4];
43        !(b.is_ascii_alphanumeric() || b == b'_')
44    }
45}
46
47fn int_domain_missing_rparen_line(line: &str, start_col: usize, end_col: usize) -> bool {
48    line.as_bytes().get(start_col) == Some(&b'(')
49        && line[end_col..].trim().is_empty()
50        && !line[start_col..].contains(')')
51        && is_int_keyword_suffix(&line[..start_col])
52}
53
54/// tree-sitter `ERROR` node spans can overlap bytes during recovery.
55/// Need to clamp to the end of the non-comment prefix so diagnostics don't include comment
56/// contents.
57fn clamp_range_before_line_comment(range: &mut tree_sitter::Range, source: &str) {
58    let Some(line) = source.lines().nth(range.start_point.row) else {
59        return;
60    };
61    let Some(dollar_idx) = line.find('$') else {
62        return;
63    };
64
65    let prefix = &line[..dollar_idx];
66    let clamped_col = prefix.trim_end().len();
67
68    if range.start_point.column > clamped_col {
69        range.start_point.column = clamped_col;
70    }
71    if range.end_point.row == range.start_point.row && range.end_point.column > clamped_col {
72        range.end_point.column = clamped_col;
73    }
74    if range.end_point.row > range.start_point.row {
75        range.end_point.row = range.start_point.row;
76        range.end_point.column = clamped_col;
77    }
78
79    let line_start = line_start_byte(source.as_bytes(), range.start_point.row);
80    range.start_byte = line_start + range.start_point.column;
81    range.end_byte = line_start + range.end_point.column;
82}
83
84pub fn detect_syntactic_errors(
85    source: &str,
86    tree: &tree_sitter::Tree,
87    errors: &mut Vec<RecoverableParseError>,
88) {
89    let mut malformed_lines_reported = HashSet::new();
90
91    let root_node = tree.root_node();
92    let retract: &dyn Fn(&tree_sitter::Node) -> bool = &|node: &tree_sitter::Node| {
93        node.is_missing() || node.is_error() || node.start_position() == node.end_position()
94    };
95
96    for node in WalkDFS::with_retract(&root_node, &retract) {
97        if node.start_position() == node.end_position() {
98            errors.push(classify_missing_token(node, source));
99            continue;
100        }
101        if node.is_error() {
102            let line = node.start_position().row;
103            // If this line has already been reported as malformed, skip all error nodes on this line
104            if malformed_lines_reported.contains(&line) {
105                continue;
106            }
107            // Ignore error nodes that start inside a single-line comment.
108            if let Some(line_str) = source.lines().nth(line)
109                && let Some(dollar_idx) = line_str.find('$')
110                && node.start_position().column >= dollar_idx
111            {
112                continue;
113            }
114
115            if is_malformed_line_error(&node, source) {
116                malformed_lines_reported.insert(line);
117                let start_byte = node.start_byte();
118                let end_byte = node.end_byte();
119
120                let last_char = source.lines().nth(line).map_or(0, |l| l.len());
121                errors.push(RecoverableParseError::new(
122                    generate_malformed_line_message(line, source),
123                    Some(tree_sitter::Range {
124                        start_byte,
125                        end_byte,
126                        start_point: tree_sitter::Point {
127                            row: line,
128                            column: 0,
129                        },
130                        end_point: tree_sitter::Point {
131                            row: line,
132                            column: last_char,
133                        },
134                    }),
135                ));
136                continue;
137            } else {
138                if let Some(missing_rparen) = classify_int_domain_missing_rparen(&node, source) {
139                    errors.push(missing_rparen);
140                    continue;
141                }
142                errors.push(classify_unexpected_token_error(node, source));
143            }
144            continue;
145        }
146    }
147}
148
149/// Tree-sitter recovery sometimes reduces `int_domain` to bare `int` and then wraps the following
150/// `(` and range text in an `ERROR` node (especially at EOF).
151/// This function detects this specific pattern and reports  "Missing )" error
152fn classify_int_domain_missing_rparen(
153    node: &tree_sitter::Node,
154    source: &str,
155) -> Option<RecoverableParseError> {
156    let start = node.start_position();
157    let end = node.end_position();
158    let line = source.lines().nth(start.row)?;
159    let comment_col = line.find('$').unwrap_or(line.len());
160    let line = &line[..comment_col];
161    let start_col = start.column.min(line.len());
162    let end_col = end.column.min(line.len());
163    if !int_domain_missing_rparen_line(line, start_col, end_col) {
164        return None;
165    }
166    let insertion_col = line.trim_end().len();
167    Some(RecoverableParseError::new(
168        "Missing )".to_string(),
169        Some(point_range_at(source, start.row, insertion_col)),
170    ))
171}
172
173/// Classifies a missing token node and generates a diagnostic with a context-aware message.
174fn classify_missing_token(node: Node, source: &str) -> RecoverableParseError {
175    let mut range = tree_sitter::Range {
176        start_byte: node.start_byte(),
177        end_byte: node.end_byte(),
178        start_point: node.start_position(),
179        end_point: node.end_position(),
180    };
181    clamp_range_before_line_comment(&mut range, source);
182
183    let message = if let Some(parent) = node.parent() {
184        match parent.kind() {
185            "letting_variable_declaration" => "Missing Expression or Domain".to_string(),
186            _ => format!("Missing {}", user_friendly_token_name(node.kind(), false)),
187        }
188    } else {
189        format!("Missing {}", user_friendly_token_name(node.kind(), false))
190    };
191
192    RecoverableParseError::new(message, Some(range))
193}
194
195/// Classifies an unexpected token error node and generates a diagnostic.
196fn classify_unexpected_token_error(node: Node, source_code: &str) -> RecoverableParseError {
197    let mut range = tree_sitter::Range {
198        start_byte: node.start_byte().min(source_code.len()),
199        end_byte: node.end_byte().min(source_code.len()),
200        start_point: node.start_position(),
201        end_point: node.end_position(),
202    };
203    clamp_range_before_line_comment(&mut range, source_code);
204
205    let message = if let Some(parent) = node.parent() {
206        // Extract the unexpected token text, handling out-of-range indices safely.
207        // NOTE: tree-sitter byte offsets can land inside UTF-8 codepoints; decoding lossily avoids panics.
208        let src_token: std::borrow::Cow<'_, str> = source_code
209            .as_bytes()
210            .get(range.start_byte..range.end_byte)
211            .map(String::from_utf8_lossy)
212            .unwrap_or_else(|| std::borrow::Cow::Borrowed("<unknown>"));
213        let src_token = src_token.trim_end();
214
215        if parent.kind() == "program" {
216            format!("Unexpected {}", src_token)
217        } else {
218            format!(
219                "Unexpected {} inside {}",
220                src_token,
221                user_friendly_token_name(parent.kind(), true)
222            )
223        }
224    } else {
225        "Unexpected token".to_string()
226    };
227
228    RecoverableParseError::new(message, Some(range))
229}
230
231/// Determines if an error node represents a malformed line error.
232pub fn is_malformed_line_error(node: &tree_sitter::Node, source: &str) -> bool {
233    let parent = node.parent();
234    let grandparent = parent.and_then(|n| n.parent());
235    let root = grandparent.and_then(|n| n.parent());
236
237    if let (Some(parent), Some(grandparent), Some(root)) = (parent, grandparent, root)
238        && parent.kind() == "set_comparison"
239        && grandparent.kind() == "comparison_expr"
240        && root.kind() == "program"
241    {
242        return true;
243    }
244
245    // check parent kinds to see if the error is a constraint continuation
246    let mut curr = node.parent();
247    while let Some(n) = curr {
248        let kind = n.kind();
249        if matches!(
250            kind,
251            "find_statement"
252                | "given_statement"
253                | "letting_statement"
254                | "dominance_relation"
255                | "bool_expr"
256                | "comparison_expr"
257                | "arithmetic_expr"
258                | "atom"
259        ) {
260            return false;
261        }
262        curr = n.parent();
263    }
264
265    // check for the first non-whitespace character on the line before the error node
266    let line = source.lines().nth(node.start_position().row).unwrap_or("");
267    let first_non_witespace = line
268        .as_bytes()
269        .iter()
270        .take_while(|b| b.is_ascii_whitespace())
271        .count();
272
273    // if the error node is before or at the first non-whitespace character, it's a malformed line error
274    // if the first non-whitespace character is after the error node, it could be a constraint continuation
275    if node.start_position().column <= first_non_witespace || error_node_out_of_range(node, source)
276    {
277        if first_non_witespace > 0 && is_constraint_continuation(source, node.start_position().row)
278        {
279            return false;
280        }
281        return true;
282    }
283    false
284}
285
286/// Checks if a line is a continuation of a constraint (i.e., it ends with a comma or has "such that" at the start).
287fn is_constraint_continuation(source: &str, row: usize) -> bool {
288    let lines: Vec<&str> = source.lines().collect();
289    if row == 0 {
290        return false;
291    }
292
293    let mut r = row;
294    while r > 0 {
295        r -= 1;
296        let line = lines.get(r).copied().unwrap_or("");
297        let line = line.split('$').next().unwrap_or("").trim_end();
298        if line.trim().is_empty() {
299            continue;
300        }
301        let lower = line.trim_start().to_ascii_lowercase();
302        return lower.starts_with("such that") || line.ends_with(',');
303    }
304    false
305}
306
307/// Coverts a token name into a more user-friendly format for error messages.
308/// Removes underscores, replaces certain keywords with more natural language, and adds appropriate articles.
309fn user_friendly_token_name(token: &str, article: bool) -> String {
310    let capitalized = if token.contains("atom") {
311        "Expression".to_string()
312    } else if token == "COLON" {
313        ":".to_string()
314    } else {
315        let friendly_name = token
316            .replace("literal", "")
317            .replace("int", "Integer")
318            .replace("expr", "Expression")
319            .replace('_', " ");
320        friendly_name
321            .split_whitespace()
322            .map(|word| word.capitalize())
323            .collect::<Vec<_>>()
324            .join(" ")
325    };
326
327    if !article {
328        return capitalized;
329    }
330    let first_char = capitalized.chars().next().unwrap();
331    let article = match first_char.to_ascii_lowercase() {
332        'a' | 'e' | 'i' | 'o' | 'u' => "an",
333        _ => "a",
334    };
335    format!("{} {}", article, capitalized)
336}
337
338// Generates an informative error message for malformed lines
339fn generate_malformed_line_message(line: usize, source: &str) -> String {
340    let got = source.lines().nth(line).unwrap_or("").trim();
341    let got = got.split('$').next().unwrap_or("").trim_end();
342    let got = got.replace('"', "\\\"");
343    let mut words = got.split_whitespace();
344    let first = words.next().unwrap_or("").to_ascii_lowercase();
345    let second = words.next().unwrap_or("").to_ascii_lowercase();
346
347    let expected = match first.as_str() {
348        "find" => "a find declaration statement",
349        "letting" => "a letting declaration statement",
350        "given" => "a given declaration statement",
351        "where" => "an instantiation condition",
352        "minimising" | "maximising" => "an objective statement",
353        // Check for invalid constraint statement
354        "such" if second == "that" => "a constraint statement",
355        "such" => "a valid top-level statement",
356        _ => {
357            // Default case for unrecognized starting tokens
358            "a valid top-level statement"
359        }
360    };
361    format!("Expected {}, but got '{}'", expected, got)
362}
363
364/// Returns true if the node's start or end column is out of range for its line in the source.
365fn error_node_out_of_range(node: &tree_sitter::Node, source: &str) -> bool {
366    let lines: Vec<&str> = source.lines().collect();
367    let start = node.start_position();
368    let end = node.end_position();
369
370    let start_line_len = lines.get(start.row).map_or(0, |l| l.len());
371    let end_line_len = lines.get(end.row).map_or(0, |l| l.len());
372
373    (start.column > start_line_len) || (end.column > end_line_len)
374}
375
376#[cfg(test)]
377mod test {
378
379    use super::{
380        clamp_range_before_line_comment, detect_syntactic_errors, int_domain_missing_rparen_line,
381        is_int_keyword_suffix, is_malformed_line_error, line_start_byte, point_range_at,
382        user_friendly_token_name,
383    };
384    use crate::errors::RecoverableParseError;
385    use crate::{parser::traversal::WalkDFS, util::get_tree};
386
387    /// Helper function for tests to compare the actual error with the expected one.
388    fn assert_essence_parse_error_eq(a: &RecoverableParseError, b: &RecoverableParseError) {
389        assert_eq!(a.msg, b.msg, "error messages differ");
390        assert_eq!(a.range, b.range, "error ranges differ");
391    }
392
393    #[test]
394    fn malformed_line() {
395        let source = " a,a,b: int(1..3)";
396        let (tree, _) = get_tree(source).expect("Should parse");
397        let root_node = tree.root_node();
398
399        let error_node = WalkDFS::with_retract(&root_node, &|_node| false)
400            .find(|node| node.is_error())
401            .expect("Should find an error node");
402
403        assert!(is_malformed_line_error(&error_node, source));
404    }
405
406    #[test]
407    fn malformed_find_message() {
408        let source = "find >=lex,b,c: int(1..3)";
409        let message = super::generate_malformed_line_message(0, source);
410        assert_eq!(
411            message,
412            "Expected a find declaration statement, but got 'find >=lex,b,c: int(1..3)'"
413        );
414    }
415
416    #[test]
417    fn malformed_top_level_message() {
418        let source = "a >=lex,b,c: int(1..3)";
419        let message = super::generate_malformed_line_message(0, source);
420        assert_eq!(
421            message,
422            "Expected a valid top-level statement, but got 'a >=lex,b,c: int(1..3)'"
423        );
424    }
425
426    #[test]
427    fn malformed_objective_message() {
428        let source = "maximising %x";
429        let message = super::generate_malformed_line_message(0, source);
430        assert_eq!(
431            message,
432            "Expected an objective statement, but got 'maximising %x'"
433        );
434    }
435
436    #[test]
437    fn malformed_letting_message() {
438        let source = "letting % A be 3";
439        let message = super::generate_malformed_line_message(0, source);
440        assert_eq!(
441            message,
442            "Expected a letting declaration statement, but got 'letting % A be 3'"
443        );
444    }
445
446    #[test]
447    fn malformed_constraint_message() {
448        let source = "such that % A > 3";
449        let message = super::generate_malformed_line_message(0, source);
450        assert_eq!(
451            message,
452            "Expected a constraint statement, but got 'such that % A > 3'"
453        );
454    }
455
456    #[test]
457    fn malformed_top_level_message_2() {
458        let source = "such % A > 3";
459        let message = super::generate_malformed_line_message(0, source);
460        assert_eq!(
461            message,
462            "Expected a valid top-level statement, but got 'such % A > 3'"
463        );
464    }
465
466    #[test]
467    fn malformed_given_message() {
468        let source = "given 1..3)";
469        let message = super::generate_malformed_line_message(0, source);
470        assert_eq!(
471            message,
472            "Expected a given declaration statement, but got 'given 1..3)'"
473        );
474    }
475
476    #[test]
477    fn malformed_where_message() {
478        let source = "where x>6";
479        let message = super::generate_malformed_line_message(0, source);
480        assert_eq!(
481            message,
482            "Expected an instantiation condition, but got 'where x>6'"
483        );
484    }
485
486    #[test]
487    fn user_friendly_token_name_article() {
488        assert_eq!(
489            user_friendly_token_name("int_domain", false),
490            "Integer Domain"
491        );
492        assert_eq!(
493            user_friendly_token_name("int_domain", true),
494            "an Integer Domain"
495        );
496        // assert_eq!(user_friendly_token_name("atom", true), "an Expression");
497        assert_eq!(user_friendly_token_name("COLON", false), ":");
498    }
499
500    #[test]
501    fn missing_domain() {
502        let source = "find x:";
503        let (tree, _) = get_tree(source).expect("Should parse");
504        let mut errors = vec![];
505        detect_syntactic_errors(source, &tree, &mut errors);
506        assert_eq!(errors.len(), 1, "Expected exactly one diagnostic");
507
508        let error = &errors[0];
509
510        assert_essence_parse_error_eq(
511            error,
512            &RecoverableParseError::new(
513                "Missing Domain".to_string(),
514                Some(tree_sitter::Range {
515                    start_byte: 7,
516                    end_byte: 7,
517                    start_point: tree_sitter::Point { row: 0, column: 7 },
518                    end_point: tree_sitter::Point { row: 0, column: 7 },
519                }),
520            ),
521        );
522    }
523
524    #[test]
525    fn line_start_byte_returns_correct_offsets() {
526        let source = "a\nbc\ndef";
527        let bytes = source.as_bytes();
528        assert_eq!(line_start_byte(bytes, 0), 0);
529        assert_eq!(line_start_byte(bytes, 1), 2);
530        assert_eq!(line_start_byte(bytes, 2), 5);
531    }
532
533    #[test]
534    fn point_range_at_returns_correct_zero_length_range() {
535        let source = "a\nbc\ndef";
536        let range = point_range_at(source, 1, 1); // points to 'c'
537        assert_eq!(range.start_point.row, 1);
538        assert_eq!(range.start_point.column, 1);
539        assert_eq!(range.end_point, range.start_point);
540        assert_eq!(range.start_byte, 3);
541        assert_eq!(range.end_byte, 3);
542    }
543
544    #[test]
545    fn clamp_range_before_line_comment_clamps_end_to_before_dollar() {
546        let source = "find x: int(1..3 $comment";
547        let mut range = tree_sitter::Range {
548            start_byte: 0,
549            end_byte: source.len(),
550            start_point: tree_sitter::Point { row: 0, column: 0 },
551            end_point: tree_sitter::Point {
552                row: 0,
553                column: source.len(),
554            },
555        };
556
557        clamp_range_before_line_comment(&mut range, source);
558
559        // "find x: int(1..3" ends at byte/column 16; the `$comment` suffix must be excluded.
560        assert_eq!(range.end_point.row, 0);
561        assert_eq!(range.end_point.column, 16);
562        assert_eq!(range.end_byte, 16);
563    }
564
565    #[test]
566    fn int_keyword_suffix_checks_word_boundary() {
567        assert!(is_int_keyword_suffix("find x: int"));
568        assert!(!is_int_keyword_suffix("foo"));
569        assert!(!is_int_keyword_suffix("mint"));
570    }
571
572    #[test]
573    fn int_domain_missing_rparen_line_positive_and_negative_cases() {
574        let ok = "find x: int(1..2";
575        let start = ok.find('(').unwrap();
576        assert!(int_domain_missing_rparen_line(ok, start, ok.len()));
577
578        let has_rparen = "find x: int(1..2)";
579        let start = has_rparen.find('(').unwrap();
580        assert!(!int_domain_missing_rparen_line(
581            has_rparen,
582            start,
583            has_rparen.len()
584        ));
585
586        let trailing = "find x: int(1..2 foo";
587        let start = trailing.find('(').unwrap();
588        let end = trailing.find(" foo").unwrap();
589        assert!(!int_domain_missing_rparen_line(trailing, start, end));
590
591        let print_like = "find x: print(1..2";
592        let start = print_like.find('(').unwrap();
593        assert!(!int_domain_missing_rparen_line(
594            print_like,
595            start,
596            print_like.len()
597        ));
598    }
599}