1
use crate::errors::{FatalParseError, RecoverableParseError};
2
use crate::expression::{parse_binary_expression, parse_expression};
3
use crate::parser::abstract_literal::parse_abstract;
4
use crate::parser::comprehension::parse_comprehension;
5
use crate::util::named_children;
6
use crate::{field, named_child};
7
use conjure_cp_core::ast::{Atom, Expression, Literal, Metadata, Moo, Name, SymbolTablePtr};
8
use tree_sitter::Node;
9
use ustr::Ustr;
10

            
11
6690
pub fn parse_atom(
12
6690
    node: &Node,
13
6690
    source_code: &str,
14
13932
    root: &Node,
15
13932
    symbols_ptr: Option<SymbolTablePtr>,
16
13932
    errors: &mut Vec<RecoverableParseError>,
17
13932
) -> Result<Expression, FatalParseError> {
18
13932
    match node.kind() {
19
13932
        "atom" | "sub_atom_expr" => {
20
6962
            parse_atom(&named_child!(node), source_code, root, symbols_ptr, errors)
21
563
        }
22
3899
        "metavar" => {
23
966
            let ident = field!(node, "identifier");
24
966
            let name_str = &source_code[ident.start_byte()..ident.end_byte()];
25
966
            Ok(Expression::Metavar(Metadata::new(), Ustr::from(name_str)))
26
563
        }
27
2933
        "identifier" => parse_variable(node, source_code, symbols_ptr, errors)
28
3662
            .map(|var| Expression::Atomic(Metadata::new(), var)),
29
2977
        "from_solution" => {
30
23
            if root.kind() != "dominance_relation" {
31
                return Err(FatalParseError::syntax_error(
32
638
                    "fromSolution only allowed inside dominance relations".to_string(),
33
                    Some(node.range()),
34
2384
                ));
35
            }
36

            
37
            let inner =
38
                parse_variable(&field!(node, "variable"), source_code, symbols_ptr, errors)?;
39
            Ok(Expression::FromSolution(Metadata::new(), Moo::new(inner)))
40
        }
41
2316
        "constant" => {
42
1650
            let lit = parse_constant(node, source_code, errors)?;
43
1650
            Ok(Expression::Atomic(Metadata::new(), Atom::Literal(lit)))
44
        }
45
666
        "matrix" | "record" | "tuple" | "set_literal" => {
46
598
            parse_abstract(node, source_code, symbols_ptr, errors)
47
598
                .map(|l| Expression::AbstractLiteral(Metadata::new(), l))
48
        }
49
68
        "flatten" => parse_flatten(node, source_code, root, symbols_ptr, errors),
50
68
        "index_or_slice" => parse_index_or_slice(node, source_code, root, symbols_ptr, errors),
51
        // for now, assume is binary since powerset isn't implemented
52
        // TODO: add powerset support under "set_operation"
53
1784
        "set_operation" => parse_binary_expression(node, source_code, root, symbols_ptr, errors),
54
1718
        "comprehension" => parse_comprehension(node, source_code, root, symbols_ptr, errors),
55
1718
        _ => Err(FatalParseError::syntax_error(
56
1718
            format!("Expected atom, got: {}", node.kind()),
57
            Some(node.range()),
58
666
        )),
59
598
    }
60
6690
}
61

            
62
598
fn parse_flatten(
63
    node: &Node,
64
68
    source_code: &str,
65
68
    root: &Node,
66
    symbols_ptr: Option<SymbolTablePtr>,
67
    errors: &mut Vec<RecoverableParseError>,
68
66
) -> Result<Expression, FatalParseError> {
69
    let expr_node = field!(node, "expression");
70
    let expr = parse_atom(&expr_node, source_code, root, symbols_ptr, errors)?;
71

            
72
    if node.child_by_field_name("depth").is_some() {
73
        let depth_node = field!(node, "depth");
74
        let depth = parse_int(&depth_node, source_code, errors)?;
75
7242
        let depth_expression =
76
            Expression::Atomic(Metadata::new(), Atom::Literal(Literal::Int(depth)));
77
        Ok(Expression::Flatten(
78
            Metadata::new(),
79
            Some(Moo::new(depth_expression)),
80
            Moo::new(expr),
81
        ))
82
    } else {
83
        Ok(Expression::Flatten(Metadata::new(), None, Moo::new(expr)))
84
    }
85
}
86

            
87
2
fn parse_index_or_slice(
88
2
    node: &Node,
89
2
    source_code: &str,
90
2
    root: &Node,
91
2
    symbols_ptr: Option<SymbolTablePtr>,
92
2
    errors: &mut Vec<RecoverableParseError>,
93
2
) -> Result<Expression, FatalParseError> {
94
2
    let collection = parse_atom(
95
2
        &field!(node, "collection"),
96
2
        source_code,
97
2
        root,
98
2
        symbols_ptr.clone(),
99
2
        errors,
100
    )?;
101
2
    let mut indices = Vec::new();
102
4
    for idx_node in named_children(&field!(node, "indices")) {
103
4
        indices.push(parse_index(
104
4
            &idx_node,
105
6
            source_code,
106
6
            symbols_ptr.clone(),
107
6
            errors,
108
2
        )?);
109
2
    }
110

            
111
4
    let has_null_idx = indices.iter().any(|idx| idx.is_none());
112
    // TODO: We could check whether the slice/index is safe here
113
6
    if has_null_idx {
114
        // It's a slice
115
1
        Ok(Expression::UnsafeSlice(
116
1
            Metadata::new(),
117
5
            Moo::new(collection),
118
1
            indices,
119
3
        ))
120
    } else {
121
        // It's an index
122
3
        let idx_exprs: Vec<Expression> = indices.into_iter().map(|idx| idx.unwrap()).collect();
123
2
        Ok(Expression::UnsafeIndex(
124
2
            Metadata::new(),
125
2
            Moo::new(collection),
126
1
            idx_exprs,
127
1
        ))
128
2
    }
129
3
}
130
1

            
131
5
fn parse_index(
132
5
    node: &Node,
133
5
    source_code: &str,
134
4
    symbols_ptr: Option<SymbolTablePtr>,
135
6
    errors: &mut Vec<RecoverableParseError>,
136
4
) -> Result<Option<Expression>, FatalParseError> {
137
8
    match node.kind() {
138
8
        "arithmetic_expr" | "atom" => Ok(Some(parse_expression(
139
7
            *node,
140
6
            source_code,
141
3
            node,
142
3
            symbols_ptr,
143
6
            errors,
144
        )?)),
145
2
        "null_index" => Ok(None),
146
        _ => Err(FatalParseError::syntax_error(
147
            format!("Expected an index, got: '{}'", node.kind()),
148
            Some(node.range()),
149
        )),
150
    }
151
8
}
152

            
153
1278
fn parse_variable(
154
1278
    node: &Node,
155
1278
    source_code: &str,
156
1278
    symbols_ptr: Option<SymbolTablePtr>,
157
1278
    _errors: &mut Vec<RecoverableParseError>,
158
1255
) -> Result<Atom, FatalParseError> {
159
1255
    let raw_name = &source_code[node.start_byte()..node.end_byte()];
160
1255
    let name = Name::user(raw_name.trim());
161
1255
    if let Some(symbols) = symbols_ptr {
162
1255
        if let Some(decl) = symbols.read().lookup(&name) {
163
594
            Ok(Atom::Reference(conjure_cp_core::ast::Reference::new(decl)))
164
638
        } else {
165
661
            Err(FatalParseError::syntax_error(
166
661
                format!("Undefined variable: '{}'", raw_name),
167
661
                Some(node.range()),
168
23
            ))
169
23
        }
170
23
    } else {
171
23
        Err(FatalParseError::syntax_error(
172
            format!(
173
23
                "Found variable '{raw_name}'; Did you mean to pass a meta-variable '&{raw_name}'?\n\
174
            A symbol table is needed to resolve variable names, but none exists in this context."
175
            ),
176
            Some(node.range()),
177
        ))
178
    }
179
617
}
180

            
181
2311
fn parse_constant(
182
1650
    node: &Node,
183
3368
    source_code: &str,
184
3368
    errors: &mut Vec<RecoverableParseError>,
185
3368
) -> Result<Literal, FatalParseError> {
186
3368
    let inner = named_child!(node);
187
3368
    let raw_value = &source_code[inner.start_byte()..inner.end_byte()];
188
3321
    match inner.kind() {
189
3321
        "integer" => {
190
1607
            let value = parse_int(&inner, source_code, errors)?;
191
1654
            Ok(Literal::Int(value))
192
27
        }
193
70
        "TRUE" => Ok(Literal::Bool(true)),
194
46
        "FALSE" => Ok(Literal::Bool(false)),
195
27
        _ => Err(FatalParseError::syntax_error(
196
27
            format!(
197
27
                "'{}' (kind: '{}') is not a valid constant",
198
27
                raw_value,
199
27
                inner.kind()
200
            ),
201
20
            Some(inner.range()),
202
20
        )),
203
20
    }
204
1670
}
205
20

            
206
1627
fn parse_int(
207
1627
    node: &Node,
208
1627
    source_code: &str,
209
1627
    _errors: &mut Vec<RecoverableParseError>,
210
1607
) -> Result<i32, FatalParseError> {
211
1607
    let raw_value = &source_code[node.start_byte()..node.end_byte()];
212
1607
    raw_value.parse::<i32>().map_err(|_e| {
213
        if raw_value.is_empty() {
214
            FatalParseError::syntax_error(
215
                "Expected an integer here".to_string(),
216
                Some(node.range()),
217
            )
218
        } else {
219
            FatalParseError::syntax_error(
220
1718
                format!("'{raw_value}' is not a valid integer"),
221
                Some(node.range()),
222
1671
            )
223
1671
        }
224
1671
    })
225
1607
}