1
use tree_sitter::{Node, Parser, Tree};
2
use tree_sitter_essence::LANGUAGE;
3

            
4
use super::traversal::WalkDFS;
5
use crate::diagnostics::source_map::SourceMap;
6
use crate::errors::RecoverableParseError;
7
use conjure_cp_core::ast::SymbolTablePtr;
8

            
9
/// Context for parsing, containing shared state passed through parser functions.
10
pub struct ParseContext<'a> {
11
    pub source_code: &'a str,
12
    pub root: &'a Node<'a>,
13
    pub symbols: Option<SymbolTablePtr>,
14
    pub errors: &'a mut Vec<RecoverableParseError>,
15
    pub source_map: &'a mut SourceMap,
16
    pub typechecking_context: TypecheckingContext,
17
}
18

            
19
impl<'a> ParseContext<'a> {
20
2817
    pub fn new(
21
2817
        source_code: &'a str,
22
2817
        root: &'a Node<'a>,
23
2817
        symbols: Option<SymbolTablePtr>,
24
2817
        errors: &'a mut Vec<RecoverableParseError>,
25
2817
        source_map: &'a mut SourceMap,
26
2817
    ) -> Self {
27
2817
        Self {
28
2817
            source_code,
29
2817
            root,
30
2817
            symbols,
31
2817
            errors,
32
2817
            source_map,
33
2817
            typechecking_context: TypecheckingContext::Unknown,
34
2817
        }
35
2817
    }
36

            
37
851
    pub fn record_error(&mut self, error: RecoverableParseError) {
38
851
        self.errors.push(error);
39
851
    }
40

            
41
    /// Create a new ParseContext with different symbols but sharing source_code, root, errors, and source_map.
42
68
    pub fn with_new_symbols(&mut self, symbols: Option<SymbolTablePtr>) -> ParseContext<'_> {
43
68
        ParseContext {
44
68
            source_code: self.source_code,
45
68
            root: self.root,
46
68
            symbols,
47
68
            errors: self.errors,
48
68
            source_map: self.source_map,
49
68
            typechecking_context: self.typechecking_context,
50
68
        }
51
68
    }
52
}
53

            
54
// Used to detect type mismatches during parsing.
55
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
56
pub enum TypecheckingContext {
57
    Boolean,
58
    Arithmetic,
59
    /// Context is unknown or flexible
60
    Unknown,
61
}
62

            
63
/// Parse the given source code into a syntax tree using tree-sitter.
64
///
65
/// If successful, returns a tuple containing the syntax tree and the raw source code.
66
/// If the source code is not valid Essence, returns None.
67
///
68
/// NOTE: The new source code may be different from the original source code.
69
///       See implementation for details.
70
5102
pub fn get_tree(src: &str) -> Option<(Tree, String)> {
71
5102
    let mut parser = Parser::new();
72
5102
    parser.set_language(&LANGUAGE.into()).unwrap();
73

            
74
5102
    parser.parse(src, None).and_then(|tree| {
75
5102
        let root = tree.root_node();
76
5102
        if root.is_error() {
77
            return None;
78
5102
        }
79

            
80
5102
        let children: Vec<_> = named_children(&root).collect();
81
5102
        let first_child = children.first()?;
82

            
83
        // HACK: Tree-sitter can only parse a complete program from top to bottom, not an individual bit of syntax.
84
        // See: https://github.com/tree-sitter/tree-sitter/issues/711 and linked issues.
85
        // However, we can use a dummy _FRAGMENT_EXPRESSION prefix (which we insert as necessary)
86
        // to trick the parser into accepting an isolated expression.
87
        // This way we can parse an isolated expression and it is only slightly cursed :)
88
5102
        if first_child.is_error() {
89
812
            if src.starts_with("_FRAGMENT_EXPRESSION") {
90
                None
91
            } else {
92
812
                get_tree(&format!("_FRAGMENT_EXPRESSION {src}"))
93
            }
94
        } else {
95
4290
            Some((tree, src.to_string()))
96
        }
97
5102
    })
98
5102
}
99

            
100
/// Get the named children of a node
101
13611
pub fn named_children<'a>(node: &'a Node<'a>) -> impl Iterator<Item = Node<'a>> + 'a {
102
13611
    (0..node.named_child_count())
103
24248
        .filter_map(|i| u32::try_from(i).ok().and_then(|i| node.named_child(i)))
104
13611
}
105

            
106
2672
pub fn node_is_expression(node: &Node) -> bool {
107
1072
    matches!(
108
2672
        node.kind(),
109
2672
        "bool_expr" | "arithmetic_expr" | "comparison_expr" | "atom"
110
    )
111
2672
}
112

            
113
/// Get all top-level nodes that match the given predicate
114
535
pub fn query_toplevel<'a>(
115
535
    node: &'a Node<'a>,
116
535
    predicate: &'a dyn Fn(&Node<'a>) -> bool,
117
535
) -> impl Iterator<Item = Node<'a>> + 'a {
118
1631
    WalkDFS::with_retract(node, predicate).filter(|n| n.is_named() && predicate(n))
119
535
}
120

            
121
/// Get all meta-variable names in a node
122
3
pub fn get_metavars<'a>(node: &'a Node<'a>, src: &'a str) -> impl Iterator<Item = String> + 'a {
123
48
    query_toplevel(node, &|n| n.kind() == "metavar").filter_map(|child| {
124
3
        child
125
3
            .named_child(0)
126
3
            .map(|name| src[name.start_byte()..name.end_byte()].to_string())
127
3
    })
128
3
}
129

            
130
mod test {
131
    #[allow(unused)]
132
    use super::*;
133

            
134
    #[test]
135
3
    fn test_get_metavars() {
136
3
        let src = "such that &x = y";
137
3
        let (tree, _) = get_tree(src).unwrap();
138
3
        let root = tree.root_node();
139
3
        let metavars = get_metavars(&root, src).collect::<Vec<_>>();
140
3
        assert_eq!(metavars, vec!["x"]);
141
3
    }
142
}