conjure_cp_essence_parser/parser/
letting.rs1#![allow(clippy::legacy_numeric_constants)]
2use std::collections::BTreeSet;
3
4use tree_sitter::Node;
5
6use super::domain::parse_domain;
7use super::util::named_children;
8use crate::errors::{FatalParseError, RecoverableParseError};
9use crate::expression::parse_expression;
10use conjure_cp_core::ast::DeclarationPtr;
11use conjure_cp_core::ast::{Name, SymbolTable, SymbolTablePtr};
12
13pub fn parse_letting_statement(
15 letting_statement: Node,
16 source_code: &str,
17 existing_symbols_ptr: Option<SymbolTablePtr>,
18 errors: &mut Vec<RecoverableParseError>,
19) -> Result<SymbolTable, FatalParseError> {
20 let mut symbol_table = SymbolTable::new();
21
22 let mut temp_symbols = BTreeSet::new();
23
24 let variable_list = letting_statement
25 .child_by_field_name("variable_list")
26 .expect("No variable list found");
27 for variable in named_children(&variable_list) {
28 let variable_name = &source_code[variable.start_byte()..variable.end_byte()];
29 temp_symbols.insert(variable_name);
30 }
31
32 let expr_or_domain = letting_statement
33 .child_by_field_name("expr_or_domain")
34 .expect("No domain or expression found for letting statement");
35 match expr_or_domain.kind() {
36 "bool_expr" | "arithmetic_expr" | "atom" => {
37 for name in temp_symbols {
38 symbol_table.insert(DeclarationPtr::new_value_letting(
39 Name::user(name),
40 parse_expression(
41 expr_or_domain,
42 source_code,
43 &letting_statement,
44 existing_symbols_ptr.clone(),
45 errors,
46 )?,
47 ));
48 }
49 }
50 "domain" => {
51 for name in temp_symbols {
52 let domain = parse_domain(
53 expr_or_domain,
54 source_code,
55 existing_symbols_ptr.clone(),
56 errors,
57 )?;
58
59 if let Some(entries) = domain.as_record() {
61 for entry in entries {
62 symbol_table.insert(DeclarationPtr::new_record_field(entry.clone()));
64 }
65 }
66
67 symbol_table.insert(DeclarationPtr::new_domain_letting(Name::user(name), domain));
68 }
69 }
70 _ => {
71 return Err(FatalParseError::syntax_error(
72 format!(
73 "Expected letting expression, got '{}'",
74 expr_or_domain.kind()
75 ),
76 Some(expr_or_domain.range()),
77 ));
78 }
79 }
80
81 Ok(symbol_table)
82}