1use crate::errors::RecoverableParseError;
2use crate::parser::traversal::WalkDFS;
3use capitalize::Capitalize;
4use std::collections::HashSet;
5use tree_sitter::Node;
6
7fn line_start_byte(source: &[u8], row: usize) -> usize {
9 let mut current_row = 0usize;
10 let mut line_start = 0usize;
11 for (idx, b) in source.iter().enumerate() {
12 if current_row == row {
13 break;
14 }
15 if *b == b'\n' {
16 current_row += 1;
17 line_start = idx + 1;
18 }
19 }
20 line_start
21}
22
23fn point_range_at(source: &str, row: usize, column: usize) -> tree_sitter::Range {
24 let line_start = line_start_byte(source.as_bytes(), row);
25 let byte = line_start + column;
26 tree_sitter::Range {
27 start_byte: byte,
28 end_byte: byte,
29 start_point: tree_sitter::Point { row, column },
30 end_point: tree_sitter::Point { row, column },
31 }
32}
33
34fn is_int_keyword_suffix(prefix: &str) -> bool {
35 let prefix = prefix.trim_end();
36 if !prefix.ends_with("int") {
37 return false;
38 }
39 let bytes = prefix.as_bytes();
40 bytes.len() == 3 || {
41 let b = bytes[bytes.len() - 4];
42 !(b.is_ascii_alphanumeric() || b == b'_')
43 }
44}
45
46fn int_domain_missing_rparen_line(line: &str, start_col: usize, end_col: usize) -> bool {
47 line.as_bytes().get(start_col) == Some(&b'(')
48 && line[end_col..].trim().is_empty()
49 && !line[start_col..].contains(')')
50 && is_int_keyword_suffix(&line[..start_col])
51}
52
53fn clamp_range_before_line_comment(range: &mut tree_sitter::Range, source: &str) {
57 let Some(line) = source.lines().nth(range.start_point.row) else {
58 return;
59 };
60 let Some(dollar_idx) = line.find('$') else {
61 return;
62 };
63
64 let prefix = &line[..dollar_idx];
65 let clamped_col = prefix.trim_end().len();
66
67 if range.start_point.column > clamped_col {
68 range.start_point.column = clamped_col;
69 }
70 if range.end_point.row == range.start_point.row && range.end_point.column > clamped_col {
71 range.end_point.column = clamped_col;
72 }
73 if range.end_point.row > range.start_point.row {
74 range.end_point.row = range.start_point.row;
75 range.end_point.column = clamped_col;
76 }
77
78 let line_start = line_start_byte(source.as_bytes(), range.start_point.row);
79 range.start_byte = line_start + range.start_point.column;
80 range.end_byte = line_start + range.end_point.column;
81}
82
83pub fn detect_syntactic_errors(
84 source: &str,
85 tree: &tree_sitter::Tree,
86 errors: &mut Vec<RecoverableParseError>,
87) {
88 let mut malformed_lines_reported = HashSet::new();
89
90 let root_node = tree.root_node();
91 let retract: &dyn Fn(&tree_sitter::Node) -> bool = &|node: &tree_sitter::Node| {
92 node.is_missing() || node.is_error() || node.start_position() == node.end_position()
93 };
94
95 for node in WalkDFS::with_retract(&root_node, &retract) {
96 if node.start_position() == node.end_position() {
97 errors.push(classify_missing_token(node, source));
98 continue;
99 }
100 if node.is_error() {
101 let line = node.start_position().row;
102 if malformed_lines_reported.contains(&line) {
104 continue;
105 }
106 if let Some(line_str) = source.lines().nth(line)
108 && let Some(dollar_idx) = line_str.find('$')
109 && node.start_position().column >= dollar_idx
110 {
111 continue;
112 }
113
114 if is_malformed_line_error(&node, source) {
115 malformed_lines_reported.insert(line);
116 let start_byte = node.start_byte();
117 let end_byte = node.end_byte();
118
119 let last_char = source.lines().nth(line).map_or(0, |l| l.len());
120 errors.push(RecoverableParseError::new(
121 generate_malformed_line_message(line, source),
122 Some(tree_sitter::Range {
123 start_byte,
124 end_byte,
125 start_point: tree_sitter::Point {
126 row: line,
127 column: 0,
128 },
129 end_point: tree_sitter::Point {
130 row: line,
131 column: last_char,
132 },
133 }),
134 ));
135 continue;
136 } else {
137 if let Some(missing_rparen) = classify_int_domain_missing_rparen(&node, source) {
138 errors.push(missing_rparen);
139 continue;
140 }
141 errors.push(classify_unexpected_token_error(node, source));
142 }
143 continue;
144 }
145 }
146}
147
148fn classify_int_domain_missing_rparen(
152 node: &tree_sitter::Node,
153 source: &str,
154) -> Option<RecoverableParseError> {
155 let start = node.start_position();
156 let end = node.end_position();
157 let line = source.lines().nth(start.row)?;
158 let comment_col = line.find('$').unwrap_or(line.len());
159 let line = &line[..comment_col];
160 let start_col = start.column.min(line.len());
161 let end_col = end.column.min(line.len());
162 if !int_domain_missing_rparen_line(line, start_col, end_col) {
163 return None;
164 }
165 let insertion_col = line.trim_end().len();
166 Some(RecoverableParseError::new(
167 "Missing )".to_string(),
168 Some(point_range_at(source, start.row, insertion_col)),
169 ))
170}
171
172fn classify_missing_token(node: Node, source: &str) -> RecoverableParseError {
174 let mut range = tree_sitter::Range {
175 start_byte: node.start_byte(),
176 end_byte: node.end_byte(),
177 start_point: node.start_position(),
178 end_point: node.end_position(),
179 };
180 clamp_range_before_line_comment(&mut range, source);
181
182 let message = if let Some(parent) = node.parent() {
183 match parent.kind() {
184 "letting_variable_declaration" => "Missing Expression or Domain".to_string(),
185 _ => format!("Missing {}", user_friendly_token_name(node.kind(), false)),
186 }
187 } else {
188 format!("Missing {}", user_friendly_token_name(node.kind(), false))
189 };
190
191 RecoverableParseError::new(message, Some(range))
192}
193
194fn classify_unexpected_token_error(node: Node, source_code: &str) -> RecoverableParseError {
196 let mut range = tree_sitter::Range {
197 start_byte: node.start_byte().min(source_code.len()),
198 end_byte: node.end_byte().min(source_code.len()),
199 start_point: node.start_position(),
200 end_point: node.end_position(),
201 };
202 clamp_range_before_line_comment(&mut range, source_code);
203
204 let message = if let Some(parent) = node.parent() {
205 let src_token: std::borrow::Cow<'_, str> = source_code
208 .as_bytes()
209 .get(range.start_byte..range.end_byte)
210 .map(String::from_utf8_lossy)
211 .unwrap_or_else(|| std::borrow::Cow::Borrowed("<unknown>"));
212 let src_token = src_token.trim_end();
213
214 if parent.kind() == "program" {
215 format!("Unexpected {}", src_token)
216 } else {
217 format!(
218 "Unexpected {} inside {}",
219 src_token,
220 user_friendly_token_name(parent.kind(), true)
221 )
222 }
223 } else {
224 "Unexpected token".to_string()
225 };
226
227 RecoverableParseError::new(message, Some(range))
228}
229
230pub fn is_malformed_line_error(node: &tree_sitter::Node, source: &str) -> bool {
232 let parent = node.parent();
233 let grandparent = parent.and_then(|n| n.parent());
234 let root = grandparent.and_then(|n| n.parent());
235
236 if let (Some(parent), Some(grandparent), Some(root)) = (parent, grandparent, root)
237 && parent.kind() == "set_comparison"
238 && grandparent.kind() == "comparison_expr"
239 && root.kind() == "program"
240 {
241 return true;
242 }
243
244 let mut curr = node.parent();
246 while let Some(n) = curr {
247 let kind = n.kind();
248 if matches!(
249 kind,
250 "find_statement"
251 | "given_statement"
252 | "letting_statement"
253 | "dominance_relation"
254 | "bool_expr"
255 | "comparison_expr"
256 | "arithmetic_expr"
257 | "atom"
258 ) {
259 return false;
260 }
261 curr = n.parent();
262 }
263
264 let line = source.lines().nth(node.start_position().row).unwrap_or("");
266 let first_non_witespace = line
267 .as_bytes()
268 .iter()
269 .take_while(|b| b.is_ascii_whitespace())
270 .count();
271
272 if node.start_position().column <= first_non_witespace || error_node_out_of_range(node, source)
275 {
276 if first_non_witespace > 0 && is_constraint_continuation(source, node.start_position().row)
277 {
278 return false;
279 }
280 return true;
281 }
282 false
283}
284
285fn is_constraint_continuation(source: &str, row: usize) -> bool {
287 let lines: Vec<&str> = source.lines().collect();
288 if row == 0 {
289 return false;
290 }
291
292 let mut r = row;
293 while r > 0 {
294 r -= 1;
295 let line = lines.get(r).copied().unwrap_or("");
296 let line = line.split('$').next().unwrap_or("").trim_end();
297 if line.trim().is_empty() {
298 continue;
299 }
300 let lower = line.trim_start().to_ascii_lowercase();
301 return lower.starts_with("such that") || line.ends_with(',');
302 }
303 false
304}
305
306fn user_friendly_token_name(token: &str, article: bool) -> String {
309 let capitalized = if token.contains("atom") {
310 "Expression".to_string()
311 } else if token == "COLON" {
312 ":".to_string()
313 } else {
314 let friendly_name = token
315 .replace("literal", "")
316 .replace("int", "Integer")
317 .replace("expr", "Expression")
318 .replace('_', " ");
319 friendly_name
320 .split_whitespace()
321 .map(|word| word.capitalize())
322 .collect::<Vec<_>>()
323 .join(" ")
324 };
325
326 if !article {
327 return capitalized;
328 }
329 let first_char = capitalized.chars().next().unwrap();
330 let article = match first_char.to_ascii_lowercase() {
331 'a' | 'e' | 'i' | 'o' | 'u' => "an",
332 _ => "a",
333 };
334 format!("{} {}", article, capitalized)
335}
336
337fn generate_malformed_line_message(line: usize, source: &str) -> String {
339 let got = source.lines().nth(line).unwrap_or("").trim();
340 let got = got.split('$').next().unwrap_or("").trim_end();
341 let got = got.replace('"', "\\\"");
342 let mut words = got.split_whitespace();
343 let first = words.next().unwrap_or("").to_ascii_lowercase();
344 let second = words.next().unwrap_or("").to_ascii_lowercase();
345
346 let expected = match first.as_str() {
347 "find" => "a find declaration statement",
348 "letting" => "a letting declaration statement",
349 "given" => "a given declaration statement",
350 "where" => "an instantiation condition",
351 "minimising" | "maximising" => "an objective statement",
352 "such" => {
353 if second == "that" {
355 "a constraint statement"
356 } else {
357 "a valid top-level statement"
358 }
359 }
360
361 _ => {
362 "a valid top-level statement"
364 }
365 };
366 format!("Expected {}, but got '{}'", expected, got)
367}
368
369fn error_node_out_of_range(node: &tree_sitter::Node, source: &str) -> bool {
371 let lines: Vec<&str> = source.lines().collect();
372 let start = node.start_position();
373 let end = node.end_position();
374
375 let start_line_len = lines.get(start.row).map_or(0, |l| l.len());
376 let end_line_len = lines.get(end.row).map_or(0, |l| l.len());
377
378 (start.column > start_line_len) || (end.column > end_line_len)
379}
380
381#[cfg(test)]
382mod test {
383
384 use super::{
385 clamp_range_before_line_comment, detect_syntactic_errors, int_domain_missing_rparen_line,
386 is_int_keyword_suffix, is_malformed_line_error, line_start_byte, point_range_at,
387 user_friendly_token_name,
388 };
389 use crate::errors::RecoverableParseError;
390 use crate::{parser::traversal::WalkDFS, util::get_tree};
391
392 fn assert_essence_parse_error_eq(a: &RecoverableParseError, b: &RecoverableParseError) {
394 assert_eq!(a.msg, b.msg, "error messages differ");
395 assert_eq!(a.range, b.range, "error ranges differ");
396 }
397
398 #[test]
399 fn malformed_line() {
400 let source = " a,a,b: int(1..3)";
401 let (tree, _) = get_tree(source).expect("Should parse");
402 let root_node = tree.root_node();
403
404 let error_node = WalkDFS::with_retract(&root_node, &|_node| false)
405 .find(|node| node.is_error())
406 .expect("Should find an error node");
407
408 assert!(is_malformed_line_error(&error_node, source));
409 }
410
411 #[test]
412 fn malformed_find_message() {
413 let source = "find >=lex,b,c: int(1..3)";
414 let message = super::generate_malformed_line_message(0, source);
415 assert_eq!(
416 message,
417 "Expected a find declaration statement, but got 'find >=lex,b,c: int(1..3)'"
418 );
419 }
420
421 #[test]
422 fn malformed_top_level_message() {
423 let source = "a >=lex,b,c: int(1..3)";
424 let message = super::generate_malformed_line_message(0, source);
425 assert_eq!(
426 message,
427 "Expected a valid top-level statement, but got 'a >=lex,b,c: int(1..3)'"
428 );
429 }
430
431 #[test]
432 fn malformed_objective_message() {
433 let source = "maximising %x";
434 let message = super::generate_malformed_line_message(0, source);
435 assert_eq!(
436 message,
437 "Expected an objective statement, but got 'maximising %x'"
438 );
439 }
440
441 #[test]
442 fn malformed_letting_message() {
443 let source = "letting % A be 3";
444 let message = super::generate_malformed_line_message(0, source);
445 assert_eq!(
446 message,
447 "Expected a letting declaration statement, but got 'letting % A be 3'"
448 );
449 }
450
451 #[test]
452 fn malformed_constraint_message() {
453 let source = "such that % A > 3";
454 let message = super::generate_malformed_line_message(0, source);
455 assert_eq!(
456 message,
457 "Expected a constraint statement, but got 'such that % A > 3'"
458 );
459 }
460
461 #[test]
462 fn malformed_top_level_message_2() {
463 let source = "such % A > 3";
464 let message = super::generate_malformed_line_message(0, source);
465 assert_eq!(
466 message,
467 "Expected a valid top-level statement, but got 'such % A > 3'"
468 );
469 }
470
471 #[test]
472 fn malformed_given_message() {
473 let source = "given 1..3)";
474 let message = super::generate_malformed_line_message(0, source);
475 assert_eq!(
476 message,
477 "Expected a given declaration statement, but got 'given 1..3)'"
478 );
479 }
480
481 #[test]
482 fn malformed_where_message() {
483 let source = "where x>6";
484 let message = super::generate_malformed_line_message(0, source);
485 assert_eq!(
486 message,
487 "Expected an instantiation condition, but got 'where x>6'"
488 );
489 }
490
491 #[test]
492 fn user_friendly_token_name_article() {
493 assert_eq!(
494 user_friendly_token_name("int_domain", false),
495 "Integer Domain"
496 );
497 assert_eq!(
498 user_friendly_token_name("int_domain", true),
499 "an Integer Domain"
500 );
501 assert_eq!(user_friendly_token_name("COLON", false), ":");
503 }
504
505 #[test]
506 fn missing_domain() {
507 let source = "find x:";
508 let (tree, _) = get_tree(source).expect("Should parse");
509 let mut errors = vec![];
510 detect_syntactic_errors(source, &tree, &mut errors);
511 assert_eq!(errors.len(), 1, "Expected exactly one diagnostic");
512
513 let error = &errors[0];
514
515 assert_essence_parse_error_eq(
516 error,
517 &RecoverableParseError::new(
518 "Missing Domain".to_string(),
519 Some(tree_sitter::Range {
520 start_byte: 7,
521 end_byte: 7,
522 start_point: tree_sitter::Point { row: 0, column: 7 },
523 end_point: tree_sitter::Point { row: 0, column: 7 },
524 }),
525 ),
526 );
527 }
528
529 #[test]
530 fn line_start_byte_returns_correct_offsets() {
531 let source = "a\nbc\ndef";
532 let bytes = source.as_bytes();
533 assert_eq!(line_start_byte(bytes, 0), 0);
534 assert_eq!(line_start_byte(bytes, 1), 2);
535 assert_eq!(line_start_byte(bytes, 2), 5);
536 }
537
538 #[test]
539 fn point_range_at_returns_correct_zero_length_range() {
540 let source = "a\nbc\ndef";
541 let range = point_range_at(source, 1, 1); assert_eq!(range.start_point.row, 1);
543 assert_eq!(range.start_point.column, 1);
544 assert_eq!(range.end_point, range.start_point);
545 assert_eq!(range.start_byte, 3);
546 assert_eq!(range.end_byte, 3);
547 }
548
549 #[test]
550 fn clamp_range_before_line_comment_clamps_end_to_before_dollar() {
551 let source = "find x: int(1..3 $comment";
552 let mut range = tree_sitter::Range {
553 start_byte: 0,
554 end_byte: source.len(),
555 start_point: tree_sitter::Point { row: 0, column: 0 },
556 end_point: tree_sitter::Point {
557 row: 0,
558 column: source.len(),
559 },
560 };
561
562 clamp_range_before_line_comment(&mut range, source);
563
564 assert_eq!(range.end_point.row, 0);
566 assert_eq!(range.end_point.column, 16);
567 assert_eq!(range.end_byte, 16);
568 }
569
570 #[test]
571 fn int_keyword_suffix_checks_word_boundary() {
572 assert!(is_int_keyword_suffix("find x: int"));
573 assert!(!is_int_keyword_suffix("foo"));
574 assert!(!is_int_keyword_suffix("mint"));
575 }
576
577 #[test]
578 fn int_domain_missing_rparen_line_positive_and_negative_cases() {
579 let ok = "find x: int(1..2";
580 let start = ok.find('(').unwrap();
581 assert!(int_domain_missing_rparen_line(ok, start, ok.len()));
582
583 let has_rparen = "find x: int(1..2)";
584 let start = has_rparen.find('(').unwrap();
585 assert!(!int_domain_missing_rparen_line(
586 has_rparen,
587 start,
588 has_rparen.len()
589 ));
590
591 let trailing = "find x: int(1..2 foo";
592 let start = trailing.find('(').unwrap();
593 let end = trailing.find(" foo").unwrap();
594 assert!(!int_domain_missing_rparen_line(trailing, start, end));
595
596 let print_like = "find x: print(1..2";
597 let start = print_like.find('(').unwrap();
598 assert!(!int_domain_missing_rparen_line(
599 print_like,
600 start,
601 print_like.len()
602 ));
603 }
604}