1
use tree_sitter::{Node, Parser, Tree};
2
use tree_sitter_essence::LANGUAGE;
3

            
4
use super::traversal::WalkDFS;
5
use crate::diagnostics::source_map::SourceMap;
6
use crate::errors::RecoverableParseError;
7
use conjure_cp_core::ast::SymbolTablePtr;
8

            
9
/// Context for parsing, containing shared state passed through parser functions.
10
pub struct ParseContext<'a> {
11
    pub source_code: &'a str,
12
    pub root: &'a Node<'a>,
13
    pub symbols: Option<SymbolTablePtr>,
14
    pub errors: &'a mut Vec<RecoverableParseError>,
15
    pub source_map: &'a mut SourceMap,
16
    pub typechecking_context: TypecheckingContext,
17
}
18

            
19
impl<'a> ParseContext<'a> {
20
1116
    pub fn new(
21
1116
        source_code: &'a str,
22
1116
        root: &'a Node<'a>,
23
1116
        symbols: Option<SymbolTablePtr>,
24
1116
        errors: &'a mut Vec<RecoverableParseError>,
25
1116
        source_map: &'a mut SourceMap,
26
1116
    ) -> Self {
27
1116
        Self {
28
1116
            source_code,
29
1116
            root,
30
1116
            symbols,
31
1116
            errors,
32
1116
            source_map,
33
1116
            typechecking_context: TypecheckingContext::Unknown,
34
1116
        }
35
1116
    }
36

            
37
243
    pub fn record_error(&mut self, error: RecoverableParseError) {
38
243
        self.errors.push(error);
39
243
    }
40

            
41
    /// Create a new ParseContext with different symbols but sharing source_code, root, errors, and source_map.
42
22
    pub fn with_new_symbols(&mut self, symbols: Option<SymbolTablePtr>) -> ParseContext<'_> {
43
22
        ParseContext {
44
22
            source_code: self.source_code,
45
22
            root: self.root,
46
22
            symbols,
47
22
            errors: self.errors,
48
22
            source_map: self.source_map,
49
22
            typechecking_context: self.typechecking_context,
50
22
        }
51
22
    }
52
}
53

            
54
// Used to detect type mismatches during parsing.
55
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
56
pub enum TypecheckingContext {
57
    Boolean,
58
    Arithmetic,
59
    /// Context is unknown or flexible
60
    Unknown,
61
}
62

            
63
/// Parse the given source code into a syntax tree using tree-sitter.
64
///
65
/// If successful, returns a tuple containing the syntax tree and the raw source code.
66
/// If the source code is not valid Essence, returns None.
67
///
68
/// NOTE: The new source code may be different from the original source code.
69
///       See implementation for details.
70
2085
pub fn get_tree(src: &str) -> Option<(Tree, String)> {
71
2085
    let mut parser = Parser::new();
72
2085
    parser.set_language(&LANGUAGE.into()).unwrap();
73

            
74
2085
    parser.parse(src, None).and_then(|tree| {
75
2085
        let root = tree.root_node();
76
2085
        if root.is_error() {
77
            return None;
78
2085
        }
79

            
80
2085
        let children: Vec<_> = named_children(&root).collect();
81
2085
        let first_child = children.first()?;
82

            
83
        // HACK: Tree-sitter can only parse a complete program from top to bottom, not an individual bit of syntax.
84
        // See: https://github.com/tree-sitter/tree-sitter/issues/711 and linked issues.
85
        // However, we can use a dummy _FRAGMENT_EXPRESSION prefix (which we insert as necessary)
86
        // to trick the parser into accepting an isolated expression.
87
        // This way we can parse an isolated expression and it is only slightly cursed :)
88
2085
        if first_child.is_error() {
89
599
            if src.starts_with("_FRAGMENT_EXPRESSION") {
90
                None
91
            } else {
92
599
                get_tree(&format!("_FRAGMENT_EXPRESSION {src}"))
93
            }
94
        } else {
95
1486
            Some((tree, src.to_string()))
96
        }
97
2085
    })
98
2085
}
99

            
100
/// Get the named children of a node
101
4277
pub fn named_children<'a>(node: &'a Node<'a>) -> impl Iterator<Item = Node<'a>> + 'a {
102
4277
    (0..node.named_child_count())
103
6958
        .filter_map(|i| u32::try_from(i).ok().and_then(|i| node.named_child(i)))
104
4277
}
105

            
106
2602
pub fn node_is_expression(node: &Node) -> bool {
107
1044
    matches!(
108
2602
        node.kind(),
109
2602
        "bool_expr" | "arithmetic_expr" | "comparison_expr" | "atom"
110
    )
111
2602
}
112

            
113
/// Get all top-level nodes that match the given predicate
114
519
pub fn query_toplevel<'a>(
115
519
    node: &'a Node<'a>,
116
519
    predicate: &'a dyn Fn(&Node<'a>) -> bool,
117
519
) -> impl Iterator<Item = Node<'a>> + 'a {
118
1571
    WalkDFS::with_retract(node, predicate).filter(|n| n.is_named() && predicate(n))
119
519
}
120

            
121
/// Get all meta-variable names in a node
122
1
pub fn get_metavars<'a>(node: &'a Node<'a>, src: &'a str) -> impl Iterator<Item = String> + 'a {
123
16
    query_toplevel(node, &|n| n.kind() == "metavar").filter_map(|child| {
124
1
        child
125
1
            .named_child(0)
126
1
            .map(|name| src[name.start_byte()..name.end_byte()].to_string())
127
1
    })
128
1
}
129

            
130
mod test {
131
    #[allow(unused)]
132
    use super::*;
133

            
134
    #[test]
135
1
    fn test_get_metavars() {
136
1
        let src = "such that &x = y";
137
1
        let (tree, _) = get_tree(src).unwrap();
138
1
        let root = tree.root_node();
139
1
        let metavars = get_metavars(&root, src).collect::<Vec<_>>();
140
1
        assert_eq!(metavars, vec!["x"]);
141
1
    }
142
}