1use crate::errors::RecoverableParseError;
2use crate::parser::traversal::WalkDFS;
3use capitalize::Capitalize;
4use std::collections::HashSet;
5use tree_sitter::Node;
6
7pub fn line_start_byte(source: &[u8], row: usize) -> usize {
10 let mut current_row = 0usize;
11 let mut line_start = 0usize;
12 for (idx, b) in source.iter().enumerate() {
13 if current_row == row {
14 break;
15 }
16 if *b == b'\n' {
17 current_row += 1;
18 line_start = idx + 1;
19 }
20 }
21 line_start
22}
23
24fn point_range_at(source: &str, row: usize, column: usize) -> tree_sitter::Range {
25 let line_start = line_start_byte(source.as_bytes(), row);
26 let byte = line_start + column;
27 tree_sitter::Range {
28 start_byte: byte,
29 end_byte: byte,
30 start_point: tree_sitter::Point { row, column },
31 end_point: tree_sitter::Point { row, column },
32 }
33}
34
35fn is_int_keyword_suffix(prefix: &str) -> bool {
36 let prefix = prefix.trim_end();
37 if !prefix.ends_with("int") {
38 return false;
39 }
40 let bytes = prefix.as_bytes();
41 bytes.len() == 3 || {
42 let b = bytes[bytes.len() - 4];
43 !(b.is_ascii_alphanumeric() || b == b'_')
44 }
45}
46
47fn int_domain_missing_rparen_line(line: &str, start_col: usize, end_col: usize) -> bool {
48 line.as_bytes().get(start_col) == Some(&b'(')
49 && line[end_col..].trim().is_empty()
50 && !line[start_col..].contains(')')
51 && is_int_keyword_suffix(&line[..start_col])
52}
53
54fn clamp_range_before_line_comment(range: &mut tree_sitter::Range, source: &str) {
58 let Some(line) = source.lines().nth(range.start_point.row) else {
59 return;
60 };
61 let Some(dollar_idx) = line.find('$') else {
62 return;
63 };
64
65 let prefix = &line[..dollar_idx];
66 let clamped_col = prefix.trim_end().len();
67
68 if range.start_point.column > clamped_col {
69 range.start_point.column = clamped_col;
70 }
71 if range.end_point.row == range.start_point.row && range.end_point.column > clamped_col {
72 range.end_point.column = clamped_col;
73 }
74 if range.end_point.row > range.start_point.row {
75 range.end_point.row = range.start_point.row;
76 range.end_point.column = clamped_col;
77 }
78
79 let line_start = line_start_byte(source.as_bytes(), range.start_point.row);
80 range.start_byte = line_start + range.start_point.column;
81 range.end_byte = line_start + range.end_point.column;
82}
83
84pub fn detect_syntactic_errors(
85 source: &str,
86 tree: &tree_sitter::Tree,
87 errors: &mut Vec<RecoverableParseError>,
88) {
89 let mut malformed_lines_reported = HashSet::new();
90
91 let root_node = tree.root_node();
92 let retract: &dyn Fn(&tree_sitter::Node) -> bool = &|node: &tree_sitter::Node| {
93 node.is_missing() || node.is_error() || node.start_position() == node.end_position()
94 };
95
96 for node in WalkDFS::with_retract(&root_node, &retract) {
97 if node.start_position() == node.end_position() {
98 errors.push(classify_missing_token(node, source));
99 continue;
100 }
101 if node.is_error() {
102 let line = node.start_position().row;
103 if malformed_lines_reported.contains(&line) {
105 continue;
106 }
107 if let Some(line_str) = source.lines().nth(line)
109 && let Some(dollar_idx) = line_str.find('$')
110 && node.start_position().column >= dollar_idx
111 {
112 continue;
113 }
114
115 if is_malformed_line_error(&node, source) {
116 malformed_lines_reported.insert(line);
117 let start_byte = node.start_byte();
118 let end_byte = node.end_byte();
119
120 let last_char = source.lines().nth(line).map_or(0, |l| l.len());
121 errors.push(RecoverableParseError::new(
122 generate_malformed_line_message(line, source),
123 Some(tree_sitter::Range {
124 start_byte,
125 end_byte,
126 start_point: tree_sitter::Point {
127 row: line,
128 column: 0,
129 },
130 end_point: tree_sitter::Point {
131 row: line,
132 column: last_char,
133 },
134 }),
135 ));
136 continue;
137 } else {
138 if let Some(missing_rparen) = classify_int_domain_missing_rparen(&node, source) {
139 errors.push(missing_rparen);
140 continue;
141 }
142 errors.push(classify_unexpected_token_error(node, source));
143 }
144 continue;
145 }
146 }
147}
148
149fn classify_int_domain_missing_rparen(
153 node: &tree_sitter::Node,
154 source: &str,
155) -> Option<RecoverableParseError> {
156 let start = node.start_position();
157 let end = node.end_position();
158 let line = source.lines().nth(start.row)?;
159 let comment_col = line.find('$').unwrap_or(line.len());
160 let line = &line[..comment_col];
161 let start_col = start.column.min(line.len());
162 let end_col = end.column.min(line.len());
163 if !int_domain_missing_rparen_line(line, start_col, end_col) {
164 return None;
165 }
166 let insertion_col = line.trim_end().len();
167 Some(RecoverableParseError::new(
168 "Missing )".to_string(),
169 Some(point_range_at(source, start.row, insertion_col)),
170 ))
171}
172
173fn classify_missing_token(node: Node, source: &str) -> RecoverableParseError {
175 let mut range = tree_sitter::Range {
176 start_byte: node.start_byte(),
177 end_byte: node.end_byte(),
178 start_point: node.start_position(),
179 end_point: node.end_position(),
180 };
181 clamp_range_before_line_comment(&mut range, source);
182
183 let message = if let Some(parent) = node.parent() {
184 match parent.kind() {
185 "letting_variable_declaration" => "Missing Expression or Domain".to_string(),
186 _ => format!("Missing {}", user_friendly_token_name(node.kind(), false)),
187 }
188 } else {
189 format!("Missing {}", user_friendly_token_name(node.kind(), false))
190 };
191
192 RecoverableParseError::new(message, Some(range))
193}
194
195fn classify_unexpected_token_error(node: Node, source_code: &str) -> RecoverableParseError {
197 let mut range = tree_sitter::Range {
198 start_byte: node.start_byte().min(source_code.len()),
199 end_byte: node.end_byte().min(source_code.len()),
200 start_point: node.start_position(),
201 end_point: node.end_position(),
202 };
203 clamp_range_before_line_comment(&mut range, source_code);
204
205 let message = if let Some(parent) = node.parent() {
206 let src_token: std::borrow::Cow<'_, str> = source_code
209 .as_bytes()
210 .get(range.start_byte..range.end_byte)
211 .map(String::from_utf8_lossy)
212 .unwrap_or_else(|| std::borrow::Cow::Borrowed("<unknown>"));
213 let src_token = src_token.trim_end();
214
215 if parent.kind() == "program" {
216 format!("Unexpected {}", src_token)
217 } else {
218 format!(
219 "Unexpected {} inside {}",
220 src_token,
221 user_friendly_token_name(parent.kind(), true)
222 )
223 }
224 } else {
225 "Unexpected token".to_string()
226 };
227
228 RecoverableParseError::new(message, Some(range))
229}
230
231pub fn is_malformed_line_error(node: &tree_sitter::Node, source: &str) -> bool {
233 let parent = node.parent();
234 let grandparent = parent.and_then(|n| n.parent());
235 let root = grandparent.and_then(|n| n.parent());
236
237 if let (Some(parent), Some(grandparent), Some(root)) = (parent, grandparent, root)
238 && parent.kind() == "set_comparison"
239 && grandparent.kind() == "comparison_expr"
240 && root.kind() == "program"
241 {
242 return true;
243 }
244
245 let mut curr = node.parent();
247 while let Some(n) = curr {
248 let kind = n.kind();
249 if matches!(
250 kind,
251 "find_statement"
252 | "given_statement"
253 | "letting_statement"
254 | "dominance_relation"
255 | "bool_expr"
256 | "comparison_expr"
257 | "arithmetic_expr"
258 | "atom"
259 ) {
260 return false;
261 }
262 curr = n.parent();
263 }
264
265 let line = source.lines().nth(node.start_position().row).unwrap_or("");
267 let first_non_witespace = line
268 .as_bytes()
269 .iter()
270 .take_while(|b| b.is_ascii_whitespace())
271 .count();
272
273 if node.start_position().column <= first_non_witespace || error_node_out_of_range(node, source)
276 {
277 if first_non_witespace > 0 && is_constraint_continuation(source, node.start_position().row)
278 {
279 return false;
280 }
281 return true;
282 }
283 false
284}
285
286fn is_constraint_continuation(source: &str, row: usize) -> bool {
288 let lines: Vec<&str> = source.lines().collect();
289 if row == 0 {
290 return false;
291 }
292
293 let mut r = row;
294 while r > 0 {
295 r -= 1;
296 let line = lines.get(r).copied().unwrap_or("");
297 let line = line.split('$').next().unwrap_or("").trim_end();
298 if line.trim().is_empty() {
299 continue;
300 }
301 let lower = line.trim_start().to_ascii_lowercase();
302 return lower.starts_with("such that") || line.ends_with(',');
303 }
304 false
305}
306
307fn user_friendly_token_name(token: &str, article: bool) -> String {
310 let capitalized = if token.contains("atom") {
311 "Expression".to_string()
312 } else if token == "COLON" {
313 ":".to_string()
314 } else {
315 let friendly_name = token
316 .replace("literal", "")
317 .replace("int", "Integer")
318 .replace("expr", "Expression")
319 .replace('_', " ");
320 friendly_name
321 .split_whitespace()
322 .map(|word| word.capitalize())
323 .collect::<Vec<_>>()
324 .join(" ")
325 };
326
327 if !article {
328 return capitalized;
329 }
330 let first_char = capitalized.chars().next().unwrap();
331 let article = match first_char.to_ascii_lowercase() {
332 'a' | 'e' | 'i' | 'o' | 'u' => "an",
333 _ => "a",
334 };
335 format!("{} {}", article, capitalized)
336}
337
338fn generate_malformed_line_message(line: usize, source: &str) -> String {
340 let got = source.lines().nth(line).unwrap_or("").trim();
341 let got = got.split('$').next().unwrap_or("").trim_end();
342 let got = got.replace('"', "\\\"");
343 let mut words = got.split_whitespace();
344 let first = words.next().unwrap_or("").to_ascii_lowercase();
345 let second = words.next().unwrap_or("").to_ascii_lowercase();
346
347 let expected = match first.as_str() {
348 "find" => "a find declaration statement",
349 "letting" => "a letting declaration statement",
350 "given" => "a given declaration statement",
351 "where" => "an instantiation condition",
352 "minimising" | "maximising" => "an objective statement",
353 "such" if second == "that" => "a constraint statement",
355 "such" => "a valid top-level statement",
356 _ => {
357 "a valid top-level statement"
359 }
360 };
361 format!("Expected {}, but got '{}'", expected, got)
362}
363
364fn error_node_out_of_range(node: &tree_sitter::Node, source: &str) -> bool {
366 let lines: Vec<&str> = source.lines().collect();
367 let start = node.start_position();
368 let end = node.end_position();
369
370 let start_line_len = lines.get(start.row).map_or(0, |l| l.len());
371 let end_line_len = lines.get(end.row).map_or(0, |l| l.len());
372
373 (start.column > start_line_len) || (end.column > end_line_len)
374}
375
376#[cfg(test)]
377mod test {
378
379 use super::{
380 clamp_range_before_line_comment, detect_syntactic_errors, int_domain_missing_rparen_line,
381 is_int_keyword_suffix, is_malformed_line_error, line_start_byte, point_range_at,
382 user_friendly_token_name,
383 };
384 use crate::errors::RecoverableParseError;
385 use crate::{parser::traversal::WalkDFS, util::get_tree};
386
387 fn assert_essence_parse_error_eq(a: &RecoverableParseError, b: &RecoverableParseError) {
389 assert_eq!(a.msg, b.msg, "error messages differ");
390 assert_eq!(a.range, b.range, "error ranges differ");
391 }
392
393 #[test]
394 fn malformed_line() {
395 let source = " a,a,b: int(1..3)";
396 let (tree, _) = get_tree(source).expect("Should parse");
397 let root_node = tree.root_node();
398
399 let error_node = WalkDFS::with_retract(&root_node, &|_node| false)
400 .find(|node| node.is_error())
401 .expect("Should find an error node");
402
403 assert!(is_malformed_line_error(&error_node, source));
404 }
405
406 #[test]
407 fn malformed_find_message() {
408 let source = "find >=lex,b,c: int(1..3)";
409 let message = super::generate_malformed_line_message(0, source);
410 assert_eq!(
411 message,
412 "Expected a find declaration statement, but got 'find >=lex,b,c: int(1..3)'"
413 );
414 }
415
416 #[test]
417 fn malformed_top_level_message() {
418 let source = "a >=lex,b,c: int(1..3)";
419 let message = super::generate_malformed_line_message(0, source);
420 assert_eq!(
421 message,
422 "Expected a valid top-level statement, but got 'a >=lex,b,c: int(1..3)'"
423 );
424 }
425
426 #[test]
427 fn malformed_objective_message() {
428 let source = "maximising %x";
429 let message = super::generate_malformed_line_message(0, source);
430 assert_eq!(
431 message,
432 "Expected an objective statement, but got 'maximising %x'"
433 );
434 }
435
436 #[test]
437 fn malformed_letting_message() {
438 let source = "letting % A be 3";
439 let message = super::generate_malformed_line_message(0, source);
440 assert_eq!(
441 message,
442 "Expected a letting declaration statement, but got 'letting % A be 3'"
443 );
444 }
445
446 #[test]
447 fn malformed_constraint_message() {
448 let source = "such that % A > 3";
449 let message = super::generate_malformed_line_message(0, source);
450 assert_eq!(
451 message,
452 "Expected a constraint statement, but got 'such that % A > 3'"
453 );
454 }
455
456 #[test]
457 fn malformed_top_level_message_2() {
458 let source = "such % A > 3";
459 let message = super::generate_malformed_line_message(0, source);
460 assert_eq!(
461 message,
462 "Expected a valid top-level statement, but got 'such % A > 3'"
463 );
464 }
465
466 #[test]
467 fn malformed_given_message() {
468 let source = "given 1..3)";
469 let message = super::generate_malformed_line_message(0, source);
470 assert_eq!(
471 message,
472 "Expected a given declaration statement, but got 'given 1..3)'"
473 );
474 }
475
476 #[test]
477 fn malformed_where_message() {
478 let source = "where x>6";
479 let message = super::generate_malformed_line_message(0, source);
480 assert_eq!(
481 message,
482 "Expected an instantiation condition, but got 'where x>6'"
483 );
484 }
485
486 #[test]
487 fn user_friendly_token_name_article() {
488 assert_eq!(
489 user_friendly_token_name("int_domain", false),
490 "Integer Domain"
491 );
492 assert_eq!(
493 user_friendly_token_name("int_domain", true),
494 "an Integer Domain"
495 );
496 assert_eq!(user_friendly_token_name("COLON", false), ":");
498 }
499
500 #[test]
501 fn missing_domain() {
502 let source = "find x:";
503 let (tree, _) = get_tree(source).expect("Should parse");
504 let mut errors = vec![];
505 detect_syntactic_errors(source, &tree, &mut errors);
506 assert_eq!(errors.len(), 1, "Expected exactly one diagnostic");
507
508 let error = &errors[0];
509
510 assert_essence_parse_error_eq(
511 error,
512 &RecoverableParseError::new(
513 "Missing Domain".to_string(),
514 Some(tree_sitter::Range {
515 start_byte: 7,
516 end_byte: 7,
517 start_point: tree_sitter::Point { row: 0, column: 7 },
518 end_point: tree_sitter::Point { row: 0, column: 7 },
519 }),
520 ),
521 );
522 }
523
524 #[test]
525 fn line_start_byte_returns_correct_offsets() {
526 let source = "a\nbc\ndef";
527 let bytes = source.as_bytes();
528 assert_eq!(line_start_byte(bytes, 0), 0);
529 assert_eq!(line_start_byte(bytes, 1), 2);
530 assert_eq!(line_start_byte(bytes, 2), 5);
531 }
532
533 #[test]
534 fn point_range_at_returns_correct_zero_length_range() {
535 let source = "a\nbc\ndef";
536 let range = point_range_at(source, 1, 1); assert_eq!(range.start_point.row, 1);
538 assert_eq!(range.start_point.column, 1);
539 assert_eq!(range.end_point, range.start_point);
540 assert_eq!(range.start_byte, 3);
541 assert_eq!(range.end_byte, 3);
542 }
543
544 #[test]
545 fn clamp_range_before_line_comment_clamps_end_to_before_dollar() {
546 let source = "find x: int(1..3 $comment";
547 let mut range = tree_sitter::Range {
548 start_byte: 0,
549 end_byte: source.len(),
550 start_point: tree_sitter::Point { row: 0, column: 0 },
551 end_point: tree_sitter::Point {
552 row: 0,
553 column: source.len(),
554 },
555 };
556
557 clamp_range_before_line_comment(&mut range, source);
558
559 assert_eq!(range.end_point.row, 0);
561 assert_eq!(range.end_point.column, 16);
562 assert_eq!(range.end_byte, 16);
563 }
564
565 #[test]
566 fn int_keyword_suffix_checks_word_boundary() {
567 assert!(is_int_keyword_suffix("find x: int"));
568 assert!(!is_int_keyword_suffix("foo"));
569 assert!(!is_int_keyword_suffix("mint"));
570 }
571
572 #[test]
573 fn int_domain_missing_rparen_line_positive_and_negative_cases() {
574 let ok = "find x: int(1..2";
575 let start = ok.find('(').unwrap();
576 assert!(int_domain_missing_rparen_line(ok, start, ok.len()));
577
578 let has_rparen = "find x: int(1..2)";
579 let start = has_rparen.find('(').unwrap();
580 assert!(!int_domain_missing_rparen_line(
581 has_rparen,
582 start,
583 has_rparen.len()
584 ));
585
586 let trailing = "find x: int(1..2 foo";
587 let start = trailing.find('(').unwrap();
588 let end = trailing.find(" foo").unwrap();
589 assert!(!int_domain_missing_rparen_line(trailing, start, end));
590
591 let print_like = "find x: print(1..2";
592 let start = print_like.find('(').unwrap();
593 assert!(!int_domain_missing_rparen_line(
594 print_like,
595 start,
596 print_like.len()
597 ));
598 }
599}