Skip to main content

conjure_cp_essence_parser/parser/
util.rs

1use tree_sitter::{Node, Parser, Tree};
2use tree_sitter_essence::LANGUAGE;
3
4use super::traversal::WalkDFS;
5use crate::diagnostics::source_map::SourceMap;
6use crate::errors::RecoverableParseError;
7use conjure_cp_core::ast::SymbolTablePtr;
8
9/// Context for parsing, containing shared state passed through parser functions.
10pub struct ParseContext<'a> {
11    pub source_code: &'a str,
12    pub root: &'a Node<'a>,
13    pub symbols: Option<SymbolTablePtr>,
14    pub errors: &'a mut Vec<RecoverableParseError>,
15    pub source_map: &'a mut SourceMap,
16    pub typechecking_context: TypecheckingContext,
17}
18
19impl<'a> ParseContext<'a> {
20    pub fn new(
21        source_code: &'a str,
22        root: &'a Node<'a>,
23        symbols: Option<SymbolTablePtr>,
24        errors: &'a mut Vec<RecoverableParseError>,
25        source_map: &'a mut SourceMap,
26    ) -> Self {
27        Self {
28            source_code,
29            root,
30            symbols,
31            errors,
32            source_map,
33            typechecking_context: TypecheckingContext::Unknown,
34        }
35    }
36
37    pub fn record_error(&mut self, error: RecoverableParseError) {
38        self.errors.push(error);
39    }
40
41    /// Create a new ParseContext with different symbols but sharing source_code, root, errors, and source_map.
42    pub fn with_new_symbols(&mut self, symbols: Option<SymbolTablePtr>) -> ParseContext<'_> {
43        ParseContext {
44            source_code: self.source_code,
45            root: self.root,
46            symbols,
47            errors: self.errors,
48            source_map: self.source_map,
49            typechecking_context: self.typechecking_context,
50        }
51    }
52}
53
54// Used to detect type mismatches during parsing.
55#[derive(Debug, Clone, Copy, PartialEq, Eq)]
56pub enum TypecheckingContext {
57    Boolean,
58    Arithmetic,
59    /// Context is unknown or flexible
60    Unknown,
61}
62
63/// Parse the given source code into a syntax tree using tree-sitter.
64///
65/// If successful, returns a tuple containing the syntax tree and the raw source code.
66/// If the source code is not valid Essence, returns None.
67///
68/// NOTE: The new source code may be different from the original source code.
69///       See implementation for details.
70pub fn get_tree(src: &str) -> Option<(Tree, String)> {
71    let mut parser = Parser::new();
72    parser.set_language(&LANGUAGE.into()).unwrap();
73
74    parser.parse(src, None).and_then(|tree| {
75        let root = tree.root_node();
76        if root.is_error() {
77            return None;
78        }
79
80        let children: Vec<_> = named_children(&root).collect();
81        let first_child = children.first()?;
82
83        // HACK: Tree-sitter can only parse a complete program from top to bottom, not an individual bit of syntax.
84        // See: https://github.com/tree-sitter/tree-sitter/issues/711 and linked issues.
85        // However, we can use a dummy _FRAGMENT_EXPRESSION prefix (which we insert as necessary)
86        // to trick the parser into accepting an isolated expression.
87        // This way we can parse an isolated expression and it is only slightly cursed :)
88        if first_child.is_error() {
89            if src.starts_with("_FRAGMENT_EXPRESSION") {
90                None
91            } else {
92                get_tree(&format!("_FRAGMENT_EXPRESSION {src}"))
93            }
94        } else {
95            Some((tree, src.to_string()))
96        }
97    })
98}
99
100/// Get the named children of a node
101pub fn named_children<'a>(node: &'a Node<'a>) -> impl Iterator<Item = Node<'a>> + 'a {
102    (0..node.named_child_count())
103        .filter_map(|i| u32::try_from(i).ok().and_then(|i| node.named_child(i)))
104}
105
106pub fn node_is_expression(node: &Node) -> bool {
107    matches!(
108        node.kind(),
109        "bool_expr" | "arithmetic_expr" | "comparison_expr" | "atom"
110    )
111}
112
113/// Get all top-level nodes that match the given predicate
114pub fn query_toplevel<'a>(
115    node: &'a Node<'a>,
116    predicate: &'a dyn Fn(&Node<'a>) -> bool,
117) -> impl Iterator<Item = Node<'a>> + 'a {
118    WalkDFS::with_retract(node, predicate).filter(|n| n.is_named() && predicate(n))
119}
120
121/// Get all meta-variable names in a node
122pub fn get_metavars<'a>(node: &'a Node<'a>, src: &'a str) -> impl Iterator<Item = String> + 'a {
123    query_toplevel(node, &|n| n.kind() == "metavar").filter_map(|child| {
124        child
125            .named_child(0)
126            .map(|name| src[name.start_byte()..name.end_byte()].to_string())
127    })
128}
129
130mod test {
131    #[allow(unused)]
132    use super::*;
133
134    #[test]
135    fn test_get_metavars() {
136        let src = "such that &x = y";
137        let (tree, _) = get_tree(src).unwrap();
138        let root = tree.root_node();
139        let metavars = get_metavars(&root, src).collect::<Vec<_>>();
140        assert_eq!(metavars, vec!["x"]);
141    }
142}