1
use tree_sitter::{Node, Parser, Tree};
2
use tree_sitter_essence::LANGUAGE;
3

            
4
use super::traversal::WalkDFS;
5
use crate::diagnostics::source_map::SourceMap;
6
use crate::errors::RecoverableParseError;
7
use conjure_cp_core::ast::SymbolTablePtr;
8

            
9
/// Context for parsing, containing shared state passed through parser functions.
10
pub struct ParseContext<'a> {
11
    pub source_code: &'a str,
12
    pub root: &'a Node<'a>,
13
    pub symbols: Option<SymbolTablePtr>,
14
    pub errors: &'a mut Vec<RecoverableParseError>,
15
    pub source_map: &'a mut SourceMap,
16
    pub typechecking_context: TypecheckingContext,
17
}
18

            
19
impl<'a> ParseContext<'a> {
20
1411
    pub fn new(
21
1411
        source_code: &'a str,
22
1411
        root: &'a Node<'a>,
23
1411
        symbols: Option<SymbolTablePtr>,
24
1411
        errors: &'a mut Vec<RecoverableParseError>,
25
1411
        source_map: &'a mut SourceMap,
26
1411
    ) -> Self {
27
1411
        Self {
28
1411
            source_code,
29
1411
            root,
30
1411
            symbols,
31
1411
            errors,
32
1411
            source_map,
33
1411
            typechecking_context: TypecheckingContext::Unknown,
34
1411
        }
35
1411
    }
36

            
37
310
    pub fn record_error(&mut self, error: RecoverableParseError) {
38
310
        self.errors.push(error);
39
310
    }
40

            
41
    /// Create a new ParseContext with different symbols but sharing source_code, root, errors, and source_map.
42
44
    pub fn with_new_symbols(&mut self, symbols: Option<SymbolTablePtr>) -> ParseContext<'_> {
43
44
        ParseContext {
44
44
            source_code: self.source_code,
45
44
            root: self.root,
46
44
            symbols,
47
44
            errors: self.errors,
48
44
            source_map: self.source_map,
49
44
            typechecking_context: self.typechecking_context,
50
44
        }
51
44
    }
52
}
53

            
54
// Used to detect type mismatches during parsing.
55
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
56
pub enum TypecheckingContext {
57
    Boolean,
58
    Arithmetic,
59
    /// Context is unknown or flexible
60
    Unknown,
61
}
62

            
63
/// Parse the given source code into a syntax tree using tree-sitter.
64
///
65
/// If successful, returns a tuple containing the syntax tree and the raw source code.
66
/// If the source code is not valid Essence, returns None.
67
///
68
/// NOTE: The new source code may be different from the original source code.
69
///       See implementation for details.
70
2681
pub fn get_tree(src: &str) -> Option<(Tree, String)> {
71
2681
    let mut parser = Parser::new();
72
2681
    parser.set_language(&LANGUAGE.into()).unwrap();
73

            
74
2681
    parser.parse(src, None).and_then(|tree| {
75
2681
        let root = tree.root_node();
76
2681
        if root.is_error() {
77
            return None;
78
2681
        }
79

            
80
2681
        let children: Vec<_> = named_children(&root).collect();
81
2681
        let first_child = children.first()?;
82

            
83
        // HACK: Tree-sitter can only parse a complete program from top to bottom, not an individual bit of syntax.
84
        // See: https://github.com/tree-sitter/tree-sitter/issues/711 and linked issues.
85
        // However, we can use a dummy _FRAGMENT_EXPRESSION prefix (which we insert as necessary)
86
        // to trick the parser into accepting an isolated expression.
87
        // This way we can parse an isolated expression and it is only slightly cursed :)
88
2681
        if first_child.is_error() {
89
569
            if src.starts_with("_FRAGMENT_EXPRESSION") {
90
                None
91
            } else {
92
569
                get_tree(&format!("_FRAGMENT_EXPRESSION {src}"))
93
            }
94
        } else {
95
2112
            Some((tree, src.to_string()))
96
        }
97
2681
    })
98
2681
}
99

            
100
/// Get the named children of a node
101
6647
pub fn named_children<'a>(node: &'a Node<'a>) -> impl Iterator<Item = Node<'a>> + 'a {
102
6647
    (0..node.named_child_count())
103
11652
        .filter_map(|i| u32::try_from(i).ok().and_then(|i| node.named_child(i)))
104
6647
}
105

            
106
2154
pub fn node_is_expression(node: &Node) -> bool {
107
864
    matches!(
108
2154
        node.kind(),
109
2154
        "bool_expr" | "arithmetic_expr" | "comparison_expr" | "atom"
110
    )
111
2154
}
112

            
113
/// Get all top-level nodes that match the given predicate
114
431
pub fn query_toplevel<'a>(
115
431
    node: &'a Node<'a>,
116
431
    predicate: &'a dyn Fn(&Node<'a>) -> bool,
117
431
) -> impl Iterator<Item = Node<'a>> + 'a {
118
1311
    WalkDFS::with_retract(node, predicate).filter(|n| n.is_named() && predicate(n))
119
431
}
120

            
121
/// Get all meta-variable names in a node
122
2
pub fn get_metavars<'a>(node: &'a Node<'a>, src: &'a str) -> impl Iterator<Item = String> + 'a {
123
32
    query_toplevel(node, &|n| n.kind() == "metavar").filter_map(|child| {
124
2
        child
125
2
            .named_child(0)
126
2
            .map(|name| src[name.start_byte()..name.end_byte()].to_string())
127
2
    })
128
2
}
129

            
130
mod test {
131
    #[allow(unused)]
132
    use super::*;
133

            
134
    #[test]
135
2
    fn test_get_metavars() {
136
2
        let src = "such that &x = y";
137
2
        let (tree, _) = get_tree(src).unwrap();
138
2
        let root = tree.root_node();
139
2
        let metavars = get_metavars(&root, src).collect::<Vec<_>>();
140
2
        assert_eq!(metavars, vec!["x"]);
141
2
    }
142
}