1
#![allow(clippy::legacy_numeric_constants)]
2
use std::cell::RefCell;
3
use std::collections::BTreeSet;
4
use std::rc::Rc;
5

            
6
use tree_sitter::Node;
7

            
8
use super::domain::parse_domain;
9
use super::util::named_children;
10
use crate::errors::EssenceParseError;
11
use crate::expression::parse_expression;
12
use conjure_cp_core::ast::DeclarationPtr;
13
use conjure_cp_core::ast::{Name, SymbolTable};
14

            
15
/// Parse a letting statement into a SymbolTable containing the declared symbols
16
pub fn parse_letting_statement(
17
    letting_statement: Node,
18
    source_code: &str,
19
    existing_symbols_ptr: Option<Rc<RefCell<SymbolTable>>>,
20
) -> Result<SymbolTable, EssenceParseError> {
21
    let mut symbol_table = SymbolTable::new();
22

            
23
    let mut temp_symbols = BTreeSet::new();
24

            
25
    let variable_list = letting_statement
26
        .child_by_field_name("variable_list")
27
        .expect("No variable list found");
28
    for variable in named_children(&variable_list) {
29
        let variable_name = &source_code[variable.start_byte()..variable.end_byte()];
30
        temp_symbols.insert(variable_name);
31
    }
32

            
33
    let expr_or_domain = letting_statement
34
        .child_by_field_name("expr_or_domain")
35
        .expect("No domain or expression found for letting statement");
36
    match expr_or_domain.kind() {
37
        "bool_expr" | "arithmetic_expr" => {
38
            for name in temp_symbols {
39
                symbol_table.insert(DeclarationPtr::new_value_letting(
40
                    Name::user(name),
41
                    parse_expression(
42
                        expr_or_domain,
43
                        source_code,
44
                        &letting_statement,
45
                        existing_symbols_ptr.clone(),
46
                    )?,
47
                ));
48
            }
49
        }
50
        "domain" => {
51
            for name in temp_symbols {
52
                let domain =
53
                    parse_domain(expr_or_domain, source_code, existing_symbols_ptr.clone())?;
54

            
55
                // If it's a record domain, add the field names to the symbol table
56
                if let Some(entries) = domain.as_record() {
57
                    for entry in entries {
58
                        // Add each field name as a record field declaration
59
                        symbol_table.insert(DeclarationPtr::new_record_field(entry.clone()));
60
                    }
61
                }
62

            
63
                symbol_table.insert(DeclarationPtr::new_domain_letting(Name::user(name), domain));
64
            }
65
        }
66
        _ => {
67
            return Err(EssenceParseError::syntax_error(
68
                format!(
69
                    "Expected letting expression, got '{}'",
70
                    expr_or_domain.kind()
71
                ),
72
                Some(expr_or_domain.range()),
73
            ));
74
        }
75
    }
76

            
77
    Ok(symbol_table)
78
}