1
use crate::expression::{parse_binary_expression, parse_expression};
2
use crate::parser::abstract_literal::parse_abstract;
3
use crate::parser::comprehension::parse_comprehension;
4
use crate::util::named_children;
5
use crate::{EssenceParseError, field, named_child};
6
use conjure_cp_core::ast::{Atom, Expression, Literal, Metadata, Moo, Name, SymbolTablePtr};
7
use tree_sitter::Node;
8
use ustr::Ustr;
9

            
10
6724
pub fn parse_atom(
11
6724
    node: &Node,
12
6724
    source_code: &str,
13
6724
    root: &Node,
14
6724
    symbols_ptr: Option<SymbolTablePtr>,
15
6724
) -> Result<Expression, EssenceParseError> {
16
6724
    match node.kind() {
17
6724
        "atom" | "sub_atom_expr" => parse_atom(&named_child!(node), source_code, root, symbols_ptr),
18
3353
        "metavar" => {
19
403
            let ident = field!(node, "identifier");
20
403
            let name_str = &source_code[ident.start_byte()..ident.end_byte()];
21
403
            Ok(Expression::Metavar(Metadata::new(), Ustr::from(name_str)))
22
        }
23
2950
        "identifier" => parse_variable(node, source_code, symbols_ptr)
24
632
            .map(|var| Expression::Atomic(Metadata::new(), var)),
25
2318
        "from_solution" => {
26
            if root.kind() != "dominance_relation" {
27
                return Err(EssenceParseError::syntax_error(
28
                    "fromSolution only allowed inside dominance relations".to_string(),
29
                    Some(node.range()),
30
                ));
31
            }
32

            
33
            let inner = parse_variable(&field!(node, "variable"), source_code, symbols_ptr)?;
34
            Ok(Expression::FromSolution(Metadata::new(), Moo::new(inner)))
35
        }
36
2318
        "constant" => {
37
1652
            let lit = parse_constant(node, source_code)?;
38
1652
            Ok(Expression::Atomic(Metadata::new(), Atom::Literal(lit)))
39
        }
40
666
        "matrix" | "record" | "tuple" | "set_literal" => {
41
598
            parse_abstract(node, source_code, symbols_ptr)
42
598
                .map(|l| Expression::AbstractLiteral(Metadata::new(), l))
43
        }
44
68
        "flatten" => parse_flatten(node, source_code, root, symbols_ptr),
45
68
        "index_or_slice" => parse_index_or_slice(node, source_code, root, symbols_ptr),
46
        // for now, assume is binary since powerset isn't implemented
47
        // TODO: add powerset support under "set_operation"
48
66
        "set_operation" => parse_binary_expression(node, source_code, root, symbols_ptr),
49
        "comprehension" => parse_comprehension(node, source_code, root, symbols_ptr),
50
        _ => Err(EssenceParseError::syntax_error(
51
            format!("Expected atom, got: {}", node.kind()),
52
            Some(node.range()),
53
        )),
54
    }
55
6724
}
56

            
57
fn parse_flatten(
58
    node: &Node,
59
    source_code: &str,
60
    root: &Node,
61
    symbols_ptr: Option<SymbolTablePtr>,
62
) -> Result<Expression, EssenceParseError> {
63
    let expr_node = field!(node, "expression");
64
    let expr = parse_atom(&expr_node, source_code, root, symbols_ptr)?;
65

            
66
    if node.child_by_field_name("depth").is_some() {
67
        let depth_node = field!(node, "depth");
68
        let depth = parse_int(&depth_node, source_code)?;
69
        let depth_expression =
70
            Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Int(depth)));
71
        Ok(Expression::Flatten(
72
            Metadata::new(),
73
            Some(Moo::new(depth_expression)),
74
            Moo::new(expr),
75
        ))
76
    } else {
77
        Ok(Expression::Flatten(Metadata::new(), None, Moo::new(expr)))
78
    }
79
}
80

            
81
2
fn parse_index_or_slice(
82
2
    node: &Node,
83
2
    source_code: &str,
84
2
    root: &Node,
85
2
    symbols_ptr: Option<SymbolTablePtr>,
86
2
) -> Result<Expression, EssenceParseError> {
87
2
    let collection = parse_atom(
88
2
        &field!(node, "collection"),
89
2
        source_code,
90
2
        root,
91
2
        symbols_ptr.clone(),
92
    )?;
93
2
    let mut indices = Vec::new();
94
4
    for idx_node in named_children(&field!(node, "indices")) {
95
4
        indices.push(parse_index(&idx_node, source_code, symbols_ptr.clone())?);
96
    }
97

            
98
4
    let has_null_idx = indices.iter().any(|idx| idx.is_none());
99
    // TODO: We could check whether the slice/index is safe here
100
2
    if has_null_idx {
101
        // It's a slice
102
1
        Ok(Expression::UnsafeSlice(
103
1
            Metadata::new(),
104
1
            Moo::new(collection),
105
1
            indices,
106
1
        ))
107
    } else {
108
        // It's an index
109
2
        let idx_exprs: Vec<Expression> = indices.into_iter().map(|idx| idx.unwrap()).collect();
110
1
        Ok(Expression::UnsafeIndex(
111
1
            Metadata::new(),
112
1
            Moo::new(collection),
113
1
            idx_exprs,
114
1
        ))
115
    }
116
2
}
117

            
118
4
fn parse_index(
119
4
    node: &Node,
120
4
    source_code: &str,
121
4
    symbols_ptr: Option<SymbolTablePtr>,
122
4
) -> Result<Option<Expression>, EssenceParseError> {
123
4
    match node.kind() {
124
4
        "arithmetic_expr" | "atom" => Ok(Some(parse_expression(
125
3
            *node,
126
3
            source_code,
127
3
            node,
128
3
            symbols_ptr,
129
        )?)),
130
1
        "null_index" => Ok(None),
131
        _ => Err(EssenceParseError::syntax_error(
132
            format!("Expected an index, got: '{}'", node.kind()),
133
            Some(node.range()),
134
        )),
135
    }
136
4
}
137

            
138
632
fn parse_variable(
139
632
    node: &Node,
140
632
    source_code: &str,
141
632
    symbols_ptr: Option<SymbolTablePtr>,
142
632
) -> Result<Atom, EssenceParseError> {
143
632
    let raw_name = &source_code[node.start_byte()..node.end_byte()];
144
632
    let name = Name::user(raw_name.trim());
145
632
    if let Some(symbols) = symbols_ptr {
146
632
        if let Some(decl) = symbols.read().lookup(&name) {
147
596
            Ok(Atom::Reference(conjure_cp_core::ast::Reference::new(decl)))
148
        } else {
149
36
            Err(EssenceParseError::syntax_error(
150
36
                format!("Undefined variable: '{}'", raw_name),
151
36
                Some(node.range()),
152
36
            ))
153
        }
154
    } else {
155
        Err(EssenceParseError::syntax_error(
156
            format!(
157
                "Found variable '{raw_name}'; Did you mean to pass a meta-variable '&{raw_name}'?\n\
158
            A symbol table is needed to resolve variable names, but none exists in this context."
159
            ),
160
            Some(node.range()),
161
        ))
162
    }
163
632
}
164

            
165
1652
fn parse_constant(node: &Node, source_code: &str) -> Result<Literal, EssenceParseError> {
166
1652
    let inner = named_child!(node);
167
1652
    let raw_value = &source_code[inner.start_byte()..inner.end_byte()];
168
1652
    match inner.kind() {
169
1652
        "integer" => {
170
1609
            let value = parse_int(&inner, source_code)?;
171
1609
            Ok(Literal::Int(value))
172
        }
173
43
        "TRUE" => Ok(Literal::Bool(true)),
174
19
        "FALSE" => Ok(Literal::Bool(false)),
175
        _ => Err(EssenceParseError::syntax_error(
176
            format!(
177
                "'{}' (kind: '{}') is not a valid constant",
178
                raw_value,
179
                inner.kind()
180
            ),
181
            Some(inner.range()),
182
        )),
183
    }
184
1652
}
185

            
186
1609
fn parse_int(node: &Node, source_code: &str) -> Result<i32, EssenceParseError> {
187
1609
    let raw_value = &source_code[node.start_byte()..node.end_byte()];
188
1609
    raw_value.parse::<i32>().map_err(|_e| {
189
        if raw_value.is_empty() {
190
            EssenceParseError::syntax_error(
191
                "Expected an integer here".to_string(),
192
                Some(node.range()),
193
            )
194
        } else {
195
            EssenceParseError::syntax_error(
196
                format!("'{raw_value}' is not a valid integer"),
197
                Some(node.range()),
198
            )
199
        }
200
    })
201
1609
}