Skip to main content

conjure_cp_essence_parser/parser/
syntax_errors.rs

1use crate::errors::RecoverableParseError;
2use crate::parser::traversal::WalkDFS;
3use capitalize::Capitalize;
4use std::collections::HashSet;
5use tree_sitter::Node;
6
7/// Returns the absolute byte offset of the start of `row` in `source`.
8fn line_start_byte(source: &[u8], row: usize) -> usize {
9    let mut current_row = 0usize;
10    let mut line_start = 0usize;
11    for (idx, b) in source.iter().enumerate() {
12        if current_row == row {
13            break;
14        }
15        if *b == b'\n' {
16            current_row += 1;
17            line_start = idx + 1;
18        }
19    }
20    line_start
21}
22
23fn point_range_at(source: &str, row: usize, column: usize) -> tree_sitter::Range {
24    let line_start = line_start_byte(source.as_bytes(), row);
25    let byte = line_start + column;
26    tree_sitter::Range {
27        start_byte: byte,
28        end_byte: byte,
29        start_point: tree_sitter::Point { row, column },
30        end_point: tree_sitter::Point { row, column },
31    }
32}
33
34fn is_int_keyword_suffix(prefix: &str) -> bool {
35    let prefix = prefix.trim_end();
36    if !prefix.ends_with("int") {
37        return false;
38    }
39    let bytes = prefix.as_bytes();
40    bytes.len() == 3 || {
41        let b = bytes[bytes.len() - 4];
42        !(b.is_ascii_alphanumeric() || b == b'_')
43    }
44}
45
46fn int_domain_missing_rparen_line(line: &str, start_col: usize, end_col: usize) -> bool {
47    line.as_bytes().get(start_col) == Some(&b'(')
48        && line[end_col..].trim().is_empty()
49        && !line[start_col..].contains(')')
50        && is_int_keyword_suffix(&line[..start_col])
51}
52
53/// tree-sitter `ERROR` node spans can overlap bytes during recovery.
54/// Need to clamp to the end of the non-comment prefix so diagnostics don't include comment
55/// contents.
56fn clamp_range_before_line_comment(range: &mut tree_sitter::Range, source: &str) {
57    let Some(line) = source.lines().nth(range.start_point.row) else {
58        return;
59    };
60    let Some(dollar_idx) = line.find('$') else {
61        return;
62    };
63
64    let prefix = &line[..dollar_idx];
65    let clamped_col = prefix.trim_end().len();
66
67    if range.start_point.column > clamped_col {
68        range.start_point.column = clamped_col;
69    }
70    if range.end_point.row == range.start_point.row && range.end_point.column > clamped_col {
71        range.end_point.column = clamped_col;
72    }
73    if range.end_point.row > range.start_point.row {
74        range.end_point.row = range.start_point.row;
75        range.end_point.column = clamped_col;
76    }
77
78    let line_start = line_start_byte(source.as_bytes(), range.start_point.row);
79    range.start_byte = line_start + range.start_point.column;
80    range.end_byte = line_start + range.end_point.column;
81}
82
83pub fn detect_syntactic_errors(
84    source: &str,
85    tree: &tree_sitter::Tree,
86    errors: &mut Vec<RecoverableParseError>,
87) {
88    let mut malformed_lines_reported = HashSet::new();
89
90    let root_node = tree.root_node();
91    let retract: &dyn Fn(&tree_sitter::Node) -> bool = &|node: &tree_sitter::Node| {
92        node.is_missing() || node.is_error() || node.start_position() == node.end_position()
93    };
94
95    for node in WalkDFS::with_retract(&root_node, &retract) {
96        if node.start_position() == node.end_position() {
97            errors.push(classify_missing_token(node, source));
98            continue;
99        }
100        if node.is_error() {
101            let line = node.start_position().row;
102            // If this line has already been reported as malformed, skip all error nodes on this line
103            if malformed_lines_reported.contains(&line) {
104                continue;
105            }
106            // Ignore error nodes that start inside a single-line comment.
107            if let Some(line_str) = source.lines().nth(line)
108                && let Some(dollar_idx) = line_str.find('$')
109                && node.start_position().column >= dollar_idx
110            {
111                continue;
112            }
113
114            if is_malformed_line_error(&node, source) {
115                malformed_lines_reported.insert(line);
116                let start_byte = node.start_byte();
117                let end_byte = node.end_byte();
118
119                let last_char = source.lines().nth(line).map_or(0, |l| l.len());
120                errors.push(RecoverableParseError::new(
121                    generate_malformed_line_message(line, source),
122                    Some(tree_sitter::Range {
123                        start_byte,
124                        end_byte,
125                        start_point: tree_sitter::Point {
126                            row: line,
127                            column: 0,
128                        },
129                        end_point: tree_sitter::Point {
130                            row: line,
131                            column: last_char,
132                        },
133                    }),
134                ));
135                continue;
136            } else {
137                if let Some(missing_rparen) = classify_int_domain_missing_rparen(&node, source) {
138                    errors.push(missing_rparen);
139                    continue;
140                }
141                errors.push(classify_unexpected_token_error(node, source));
142            }
143            continue;
144        }
145    }
146}
147
148/// Tree-sitter recovery sometimes reduces `int_domain` to bare `int` and then wraps the following
149/// `(` and range text in an `ERROR` node (especially at EOF).
150/// This function detects this specific pattern and reports  "Missing )" error
151fn classify_int_domain_missing_rparen(
152    node: &tree_sitter::Node,
153    source: &str,
154) -> Option<RecoverableParseError> {
155    let start = node.start_position();
156    let end = node.end_position();
157    let line = source.lines().nth(start.row)?;
158    let comment_col = line.find('$').unwrap_or(line.len());
159    let line = &line[..comment_col];
160    let start_col = start.column.min(line.len());
161    let end_col = end.column.min(line.len());
162    if !int_domain_missing_rparen_line(line, start_col, end_col) {
163        return None;
164    }
165    let insertion_col = line.trim_end().len();
166    Some(RecoverableParseError::new(
167        "Missing )".to_string(),
168        Some(point_range_at(source, start.row, insertion_col)),
169    ))
170}
171
172/// Classifies a missing token node and generates a diagnostic with a context-aware message.
173fn classify_missing_token(node: Node, source: &str) -> RecoverableParseError {
174    let mut range = tree_sitter::Range {
175        start_byte: node.start_byte(),
176        end_byte: node.end_byte(),
177        start_point: node.start_position(),
178        end_point: node.end_position(),
179    };
180    clamp_range_before_line_comment(&mut range, source);
181
182    let message = if let Some(parent) = node.parent() {
183        match parent.kind() {
184            "letting_variable_declaration" => "Missing Expression or Domain".to_string(),
185            _ => format!("Missing {}", user_friendly_token_name(node.kind(), false)),
186        }
187    } else {
188        format!("Missing {}", user_friendly_token_name(node.kind(), false))
189    };
190
191    RecoverableParseError::new(message, Some(range))
192}
193
194/// Classifies an unexpected token error node and generates a diagnostic.
195fn classify_unexpected_token_error(node: Node, source_code: &str) -> RecoverableParseError {
196    let mut range = tree_sitter::Range {
197        start_byte: node.start_byte().min(source_code.len()),
198        end_byte: node.end_byte().min(source_code.len()),
199        start_point: node.start_position(),
200        end_point: node.end_position(),
201    };
202    clamp_range_before_line_comment(&mut range, source_code);
203
204    let message = if let Some(parent) = node.parent() {
205        // Extract the unexpected token text, handling out-of-range indices safely.
206        // NOTE: tree-sitter byte offsets can land inside UTF-8 codepoints; decoding lossily avoids panics.
207        let src_token: std::borrow::Cow<'_, str> = source_code
208            .as_bytes()
209            .get(range.start_byte..range.end_byte)
210            .map(String::from_utf8_lossy)
211            .unwrap_or_else(|| std::borrow::Cow::Borrowed("<unknown>"));
212        let src_token = src_token.trim_end();
213
214        if parent.kind() == "program" {
215            format!("Unexpected {}", src_token)
216        } else {
217            format!(
218                "Unexpected {} inside {}",
219                src_token,
220                user_friendly_token_name(parent.kind(), true)
221            )
222        }
223    } else {
224        "Unexpected token".to_string()
225    };
226
227    RecoverableParseError::new(message, Some(range))
228}
229
230/// Determines if an error node represents a malformed line error.
231pub fn is_malformed_line_error(node: &tree_sitter::Node, source: &str) -> bool {
232    let parent = node.parent();
233    let grandparent = parent.and_then(|n| n.parent());
234    let root = grandparent.and_then(|n| n.parent());
235
236    if let (Some(parent), Some(grandparent), Some(root)) = (parent, grandparent, root)
237        && parent.kind() == "set_comparison"
238        && grandparent.kind() == "comparison_expr"
239        && root.kind() == "program"
240    {
241        return true;
242    }
243
244    // check parent kinds to see if the error is a constraint continuation
245    let mut curr = node.parent();
246    while let Some(n) = curr {
247        let kind = n.kind();
248        if matches!(
249            kind,
250            "find_statement"
251                | "given_statement"
252                | "letting_statement"
253                | "dominance_relation"
254                | "bool_expr"
255                | "comparison_expr"
256                | "arithmetic_expr"
257                | "atom"
258        ) {
259            return false;
260        }
261        curr = n.parent();
262    }
263
264    // check for the first non-whitespace character on the line before the error node
265    let line = source.lines().nth(node.start_position().row).unwrap_or("");
266    let first_non_witespace = line
267        .as_bytes()
268        .iter()
269        .take_while(|b| b.is_ascii_whitespace())
270        .count();
271
272    // if the error node is before or at the first non-whitespace character, it's a malformed line error
273    // if the first non-whitespace character is after the error node, it could be a constraint continuation
274    if node.start_position().column <= first_non_witespace || error_node_out_of_range(node, source)
275    {
276        if first_non_witespace > 0 && is_constraint_continuation(source, node.start_position().row)
277        {
278            return false;
279        }
280        return true;
281    }
282    false
283}
284
285/// Checks if a line is a continuation of a constraint (i.e., it ends with a comma or has "such that" at the start).
286fn is_constraint_continuation(source: &str, row: usize) -> bool {
287    let lines: Vec<&str> = source.lines().collect();
288    if row == 0 {
289        return false;
290    }
291
292    let mut r = row;
293    while r > 0 {
294        r -= 1;
295        let line = lines.get(r).copied().unwrap_or("");
296        let line = line.split('$').next().unwrap_or("").trim_end();
297        if line.trim().is_empty() {
298            continue;
299        }
300        let lower = line.trim_start().to_ascii_lowercase();
301        return lower.starts_with("such that") || line.ends_with(',');
302    }
303    false
304}
305
306/// Coverts a token name into a more user-friendly format for error messages.
307/// Removes underscores, replaces certain keywords with more natural language, and adds appropriate articles.
308fn user_friendly_token_name(token: &str, article: bool) -> String {
309    let capitalized = if token.contains("atom") {
310        "Expression".to_string()
311    } else if token == "COLON" {
312        ":".to_string()
313    } else {
314        let friendly_name = token
315            .replace("literal", "")
316            .replace("int", "Integer")
317            .replace("expr", "Expression")
318            .replace('_', " ");
319        friendly_name
320            .split_whitespace()
321            .map(|word| word.capitalize())
322            .collect::<Vec<_>>()
323            .join(" ")
324    };
325
326    if !article {
327        return capitalized;
328    }
329    let first_char = capitalized.chars().next().unwrap();
330    let article = match first_char.to_ascii_lowercase() {
331        'a' | 'e' | 'i' | 'o' | 'u' => "an",
332        _ => "a",
333    };
334    format!("{} {}", article, capitalized)
335}
336
337// Generates an informative error message for malformed lines
338fn generate_malformed_line_message(line: usize, source: &str) -> String {
339    let got = source.lines().nth(line).unwrap_or("").trim();
340    let got = got.split('$').next().unwrap_or("").trim_end();
341    let got = got.replace('"', "\\\"");
342    let mut words = got.split_whitespace();
343    let first = words.next().unwrap_or("").to_ascii_lowercase();
344    let second = words.next().unwrap_or("").to_ascii_lowercase();
345
346    let expected = match first.as_str() {
347        "find" => "a find declaration statement",
348        "letting" => "a letting declaration statement",
349        "given" => "a given declaration statement",
350        "where" => "an instantiation condition",
351        "minimising" | "maximising" => "an objective statement",
352        "such" => {
353            // Check for invalid constraint statement
354            if second == "that" {
355                "a constraint statement"
356            } else {
357                "a valid top-level statement"
358            }
359        }
360
361        _ => {
362            // Default case for unrecognized starting tokens
363            "a valid top-level statement"
364        }
365    };
366    format!("Expected {}, but got '{}'", expected, got)
367}
368
369/// Returns true if the node's start or end column is out of range for its line in the source.
370fn error_node_out_of_range(node: &tree_sitter::Node, source: &str) -> bool {
371    let lines: Vec<&str> = source.lines().collect();
372    let start = node.start_position();
373    let end = node.end_position();
374
375    let start_line_len = lines.get(start.row).map_or(0, |l| l.len());
376    let end_line_len = lines.get(end.row).map_or(0, |l| l.len());
377
378    (start.column > start_line_len) || (end.column > end_line_len)
379}
380
381#[cfg(test)]
382mod test {
383
384    use super::{
385        clamp_range_before_line_comment, detect_syntactic_errors, int_domain_missing_rparen_line,
386        is_int_keyword_suffix, is_malformed_line_error, line_start_byte, point_range_at,
387        user_friendly_token_name,
388    };
389    use crate::errors::RecoverableParseError;
390    use crate::{parser::traversal::WalkDFS, util::get_tree};
391
392    /// Helper function for tests to compare the actual error with the expected one.
393    fn assert_essence_parse_error_eq(a: &RecoverableParseError, b: &RecoverableParseError) {
394        assert_eq!(a.msg, b.msg, "error messages differ");
395        assert_eq!(a.range, b.range, "error ranges differ");
396    }
397
398    #[test]
399    fn malformed_line() {
400        let source = " a,a,b: int(1..3)";
401        let (tree, _) = get_tree(source).expect("Should parse");
402        let root_node = tree.root_node();
403
404        let error_node = WalkDFS::with_retract(&root_node, &|_node| false)
405            .find(|node| node.is_error())
406            .expect("Should find an error node");
407
408        assert!(is_malformed_line_error(&error_node, source));
409    }
410
411    #[test]
412    fn malformed_find_message() {
413        let source = "find >=lex,b,c: int(1..3)";
414        let message = super::generate_malformed_line_message(0, source);
415        assert_eq!(
416            message,
417            "Expected a find declaration statement, but got 'find >=lex,b,c: int(1..3)'"
418        );
419    }
420
421    #[test]
422    fn malformed_top_level_message() {
423        let source = "a >=lex,b,c: int(1..3)";
424        let message = super::generate_malformed_line_message(0, source);
425        assert_eq!(
426            message,
427            "Expected a valid top-level statement, but got 'a >=lex,b,c: int(1..3)'"
428        );
429    }
430
431    #[test]
432    fn malformed_objective_message() {
433        let source = "maximising %x";
434        let message = super::generate_malformed_line_message(0, source);
435        assert_eq!(
436            message,
437            "Expected an objective statement, but got 'maximising %x'"
438        );
439    }
440
441    #[test]
442    fn malformed_letting_message() {
443        let source = "letting % A be 3";
444        let message = super::generate_malformed_line_message(0, source);
445        assert_eq!(
446            message,
447            "Expected a letting declaration statement, but got 'letting % A be 3'"
448        );
449    }
450
451    #[test]
452    fn malformed_constraint_message() {
453        let source = "such that % A > 3";
454        let message = super::generate_malformed_line_message(0, source);
455        assert_eq!(
456            message,
457            "Expected a constraint statement, but got 'such that % A > 3'"
458        );
459    }
460
461    #[test]
462    fn malformed_top_level_message_2() {
463        let source = "such % A > 3";
464        let message = super::generate_malformed_line_message(0, source);
465        assert_eq!(
466            message,
467            "Expected a valid top-level statement, but got 'such % A > 3'"
468        );
469    }
470
471    #[test]
472    fn malformed_given_message() {
473        let source = "given 1..3)";
474        let message = super::generate_malformed_line_message(0, source);
475        assert_eq!(
476            message,
477            "Expected a given declaration statement, but got 'given 1..3)'"
478        );
479    }
480
481    #[test]
482    fn malformed_where_message() {
483        let source = "where x>6";
484        let message = super::generate_malformed_line_message(0, source);
485        assert_eq!(
486            message,
487            "Expected an instantiation condition, but got 'where x>6'"
488        );
489    }
490
491    #[test]
492    fn user_friendly_token_name_article() {
493        assert_eq!(
494            user_friendly_token_name("int_domain", false),
495            "Integer Domain"
496        );
497        assert_eq!(
498            user_friendly_token_name("int_domain", true),
499            "an Integer Domain"
500        );
501        // assert_eq!(user_friendly_token_name("atom", true), "an Expression");
502        assert_eq!(user_friendly_token_name("COLON", false), ":");
503    }
504
505    #[test]
506    fn missing_domain() {
507        let source = "find x:";
508        let (tree, _) = get_tree(source).expect("Should parse");
509        let mut errors = vec![];
510        detect_syntactic_errors(source, &tree, &mut errors);
511        assert_eq!(errors.len(), 1, "Expected exactly one diagnostic");
512
513        let error = &errors[0];
514
515        assert_essence_parse_error_eq(
516            error,
517            &RecoverableParseError::new(
518                "Missing Domain".to_string(),
519                Some(tree_sitter::Range {
520                    start_byte: 7,
521                    end_byte: 7,
522                    start_point: tree_sitter::Point { row: 0, column: 7 },
523                    end_point: tree_sitter::Point { row: 0, column: 7 },
524                }),
525            ),
526        );
527    }
528
529    #[test]
530    fn line_start_byte_returns_correct_offsets() {
531        let source = "a\nbc\ndef";
532        let bytes = source.as_bytes();
533        assert_eq!(line_start_byte(bytes, 0), 0);
534        assert_eq!(line_start_byte(bytes, 1), 2);
535        assert_eq!(line_start_byte(bytes, 2), 5);
536    }
537
538    #[test]
539    fn point_range_at_returns_correct_zero_length_range() {
540        let source = "a\nbc\ndef";
541        let range = point_range_at(source, 1, 1); // points to 'c'
542        assert_eq!(range.start_point.row, 1);
543        assert_eq!(range.start_point.column, 1);
544        assert_eq!(range.end_point, range.start_point);
545        assert_eq!(range.start_byte, 3);
546        assert_eq!(range.end_byte, 3);
547    }
548
549    #[test]
550    fn clamp_range_before_line_comment_clamps_end_to_before_dollar() {
551        let source = "find x: int(1..3 $comment";
552        let mut range = tree_sitter::Range {
553            start_byte: 0,
554            end_byte: source.len(),
555            start_point: tree_sitter::Point { row: 0, column: 0 },
556            end_point: tree_sitter::Point {
557                row: 0,
558                column: source.len(),
559            },
560        };
561
562        clamp_range_before_line_comment(&mut range, source);
563
564        // "find x: int(1..3" ends at byte/column 16; the `$comment` suffix must be excluded.
565        assert_eq!(range.end_point.row, 0);
566        assert_eq!(range.end_point.column, 16);
567        assert_eq!(range.end_byte, 16);
568    }
569
570    #[test]
571    fn int_keyword_suffix_checks_word_boundary() {
572        assert!(is_int_keyword_suffix("find x: int"));
573        assert!(!is_int_keyword_suffix("foo"));
574        assert!(!is_int_keyword_suffix("mint"));
575    }
576
577    #[test]
578    fn int_domain_missing_rparen_line_positive_and_negative_cases() {
579        let ok = "find x: int(1..2";
580        let start = ok.find('(').unwrap();
581        assert!(int_domain_missing_rparen_line(ok, start, ok.len()));
582
583        let has_rparen = "find x: int(1..2)";
584        let start = has_rparen.find('(').unwrap();
585        assert!(!int_domain_missing_rparen_line(
586            has_rparen,
587            start,
588            has_rparen.len()
589        ));
590
591        let trailing = "find x: int(1..2 foo";
592        let start = trailing.find('(').unwrap();
593        let end = trailing.find(" foo").unwrap();
594        assert!(!int_domain_missing_rparen_line(trailing, start, end));
595
596        let print_like = "find x: print(1..2";
597        let start = print_like.find('(').unwrap();
598        assert!(!int_domain_missing_rparen_line(
599            print_like,
600            start,
601            print_like.len()
602        ));
603    }
604}