1
#![allow(clippy::legacy_numeric_constants)]
2
use std::collections::BTreeSet;
3

            
4
use tree_sitter::Node;
5

            
6
use super::domain::parse_domain;
7
use super::util::named_children;
8
use crate::errors::EssenceParseError;
9
use crate::expression::parse_expression;
10
use conjure_cp_core::ast::DeclarationPtr;
11
use conjure_cp_core::ast::{Name, SymbolTable, SymbolTablePtr};
12

            
13
/// Parse a letting statement into a SymbolTable containing the declared symbols
14
276
pub fn parse_letting_statement(
15
276
    letting_statement: Node,
16
276
    source_code: &str,
17
276
    existing_symbols_ptr: Option<SymbolTablePtr>,
18
276
) -> Result<SymbolTable, EssenceParseError> {
19
276
    let mut symbol_table = SymbolTable::new();
20

            
21
276
    let mut temp_symbols = BTreeSet::new();
22

            
23
276
    let variable_list = letting_statement
24
276
        .child_by_field_name("variable_list")
25
276
        .expect("No variable list found");
26
276
    for variable in named_children(&variable_list) {
27
276
        let variable_name = &source_code[variable.start_byte()..variable.end_byte()];
28
276
        temp_symbols.insert(variable_name);
29
276
    }
30

            
31
276
    let expr_or_domain = letting_statement
32
276
        .child_by_field_name("expr_or_domain")
33
276
        .expect("No domain or expression found for letting statement");
34
276
    match expr_or_domain.kind() {
35
276
        "bool_expr" | "arithmetic_expr" | "atom" => {
36
276
            for name in temp_symbols {
37
276
                symbol_table.insert(DeclarationPtr::new_value_letting(
38
276
                    Name::user(name),
39
276
                    parse_expression(
40
276
                        expr_or_domain,
41
276
                        source_code,
42
276
                        &letting_statement,
43
276
                        existing_symbols_ptr.clone(),
44
                    )?,
45
                ));
46
            }
47
        }
48
        "domain" => {
49
            for name in temp_symbols {
50
                let domain =
51
                    parse_domain(expr_or_domain, source_code, existing_symbols_ptr.clone())?;
52

            
53
                // If it's a record domain, add the field names to the symbol table
54
                if let Some(entries) = domain.as_record() {
55
                    for entry in entries {
56
                        // Add each field name as a record field declaration
57
                        symbol_table.insert(DeclarationPtr::new_record_field(entry.clone()));
58
                    }
59
                }
60

            
61
                symbol_table.insert(DeclarationPtr::new_domain_letting(Name::user(name), domain));
62
            }
63
        }
64
        _ => {
65
            return Err(EssenceParseError::syntax_error(
66
                format!(
67
                    "Expected letting expression, got '{}'",
68
                    expr_or_domain.kind()
69
                ),
70
                Some(expr_or_domain.range()),
71
            ));
72
        }
73
    }
74

            
75
276
    Ok(symbol_table)
76
276
}