Skip to main content

conjure_cp_essence_parser/parser/
util.rs

1use std::collections::BTreeMap;
2use std::sync::{Mutex, OnceLock};
3
4use tree_sitter::{Node, Parser, Tree};
5use tree_sitter_essence::LANGUAGE;
6
7use super::traversal::WalkDFS;
8use crate::diagnostics::diagnostics_api::SymbolKind;
9use crate::diagnostics::source_map::{HoverInfo, SourceMap, SpanId, span_with_hover};
10use crate::errors::RecoverableParseError;
11use conjure_cp_core::ast::{Name, SymbolTablePtr};
12
13/// Context for parsing, containing shared state passed through parser functions.
14pub struct ParseContext<'a> {
15    pub source_code: &'a str,
16    pub root: &'a Node<'a>,
17    pub symbols: Option<SymbolTablePtr>,
18    pub errors: &'a mut Vec<RecoverableParseError>,
19    pub source_map: &'a mut SourceMap,
20    pub decl_spans: &'a mut BTreeMap<Name, SpanId>,
21    /// What type the current expression/literal itself should be
22    pub typechecking_context: TypecheckingContext,
23    /// What type the elements within a collection should be
24    pub inner_typechecking_context: TypecheckingContext,
25}
26
27impl<'a> ParseContext<'a> {
28    pub fn new(
29        source_code: &'a str,
30        root: &'a Node<'a>,
31        symbols: Option<SymbolTablePtr>,
32        errors: &'a mut Vec<RecoverableParseError>,
33        source_map: &'a mut SourceMap,
34        decl_spans: &'a mut BTreeMap<Name, SpanId>,
35    ) -> Self {
36        Self {
37            source_code,
38            root,
39            symbols,
40            errors,
41            source_map,
42            decl_spans,
43            typechecking_context: TypecheckingContext::Unknown,
44            inner_typechecking_context: TypecheckingContext::Unknown,
45        }
46    }
47
48    pub fn record_error(&mut self, error: RecoverableParseError) {
49        self.errors.push(error);
50    }
51
52    /// Create a new ParseContext with different symbols but sharing source_code, root, errors, and source_map.
53    pub fn with_new_symbols(&mut self, symbols: Option<SymbolTablePtr>) -> ParseContext<'_> {
54        ParseContext {
55            source_code: self.source_code,
56            root: self.root,
57            symbols,
58            errors: self.errors,
59            source_map: self.source_map,
60            decl_spans: self.decl_spans,
61            typechecking_context: self.typechecking_context,
62            inner_typechecking_context: self.inner_typechecking_context,
63        }
64    }
65
66    pub fn save_decl_span(&mut self, name: Name, span_id: SpanId) {
67        self.decl_spans.insert(name, span_id);
68    }
69
70    pub fn lookup_decl_span(&self, name: &Name) -> Option<SpanId> {
71        self.decl_spans.get(name).copied()
72    }
73
74    pub fn lookup_decl_line(&self, name: &Name) -> Option<u32> {
75        let span_id = self.lookup_decl_span(name)?;
76        let span = self.source_map.spans.get(span_id as usize)?;
77        Some(span.start_point.line + 1)
78    }
79
80    /// Helper to add to span and documentation hover info into the source map
81    pub fn add_span_and_doc_hover(
82        &mut self,
83        node: &tree_sitter::Node,
84        doc_key: &str, // name of the documentation file in Bits
85        kind: SymbolKind,
86        ty: Option<String>,
87        decl_span: Option<u32>,
88    ) {
89        let hover = HoverInfo {
90            description: String::new(),
91            doc_key: Some(normalise_documentation_key(doc_key)),
92            kind: Some(kind),
93            ty,
94            decl_span,
95        };
96        span_with_hover(node, self.source_code, self.source_map, hover);
97    }
98}
99
100// Used to detect type mismatches during parsing.
101#[derive(Debug, Clone, Copy, PartialEq, Eq)]
102pub enum TypecheckingContext {
103    Boolean,
104    Arithmetic,
105    Set,
106    SetOrMatrix,
107    MSet,
108    Matrix,
109    Tuple,
110    Record,
111    Partition,
112    Sequence,
113    /// Context is unknown or flexible
114    Unknown,
115}
116
117/// Parse the given source code into a syntax tree using tree-sitter.
118///
119/// If successful, returns a tuple containing the syntax tree and the raw source code.
120/// If the source code is not valid Essence, returns None.
121pub fn get_tree(src: &str) -> Option<(Tree, String)> {
122    let mut parser = Parser::new();
123    parser.set_language(&LANGUAGE.into()).unwrap();
124
125    parser.parse(src, None).and_then(|tree| {
126        let root = tree.root_node();
127        if root.is_error() {
128            return None;
129        }
130        Some((tree, src.to_string()))
131    })
132}
133
134/// Parse an expression fragment, allowing a dummy prefix for error recovery.
135///
136/// NOTE: The new source code may be different from the original source code.
137///       See implementation for details.
138pub fn get_expr_tree(src: &str) -> Option<(Tree, String)> {
139    let mut parser = Parser::new();
140    parser.set_language(&LANGUAGE.into()).unwrap();
141
142    parser.parse(src, None).and_then(|tree| {
143        let root = tree.root_node();
144        if root.is_error() {
145            return None;
146        }
147
148        let children: Vec<_> = named_children(&root).collect();
149        let first_child = children.first()?;
150
151        // HACK: Tree-sitter can only parse a complete program from top to bottom, not an individual bit of syntax.
152        // See: https://github.com/tree-sitter/tree-sitter/issues/711 and linked issues.
153        // However, we can use a dummy _FRAGMENT_EXPRESSION prefix (which we insert as necessary)
154        // to trick the parser into accepting an isolated expression.
155        // This way we can parse an isolated expression and it is only slightly cursed :)
156        if first_child.is_error() {
157            if src.starts_with("_FRAGMENT_EXPRESSION") {
158                None
159            } else {
160                get_expr_tree(&format!("_FRAGMENT_EXPRESSION {src}"))
161            }
162        } else {
163            Some((tree, src.to_string()))
164        }
165    })
166}
167
168/// Get the named children of a node
169pub fn named_children<'a>(node: &'a Node<'a>) -> impl Iterator<Item = Node<'a>> + 'a {
170    (0..node.named_child_count())
171        .filter_map(|i| u32::try_from(i).ok().and_then(|i| node.named_child(i)))
172}
173
174pub fn node_is_expression(node: &Node) -> bool {
175    matches!(
176        node.kind(),
177        "bool_expr" | "arithmetic_expr" | "comparison_expr" | "atom"
178    )
179}
180
181/// Get all top-level nodes that match the given predicate
182pub fn query_toplevel<'a>(
183    node: &'a Node<'a>,
184    predicate: &'a dyn Fn(&Node<'a>) -> bool,
185) -> impl Iterator<Item = Node<'a>> + 'a {
186    WalkDFS::with_retract(node, predicate).filter(|n| n.is_named() && predicate(n))
187}
188
189/// Get all meta-variable names in a node
190pub fn get_metavars<'a>(node: &'a Node<'a>, src: &'a str) -> impl Iterator<Item = String> + 'a {
191    query_toplevel(node, &|n| n.kind() == "metavar").filter_map(|child| {
192        child
193            .named_child(0)
194            .map(|name| src[name.start_byte()..name.end_byte()].to_string())
195    })
196}
197
198/// Fetch Essence syntax documentation from Conjure's `docs/bits/` folder on GitHub.
199///
200/// `name` is the name of the documentation file (without .md suffix). If the file is not found or an error occurs, returns None.
201pub fn get_documentation(name: &str) -> Option<String> {
202    static DOCUMENTATION_CACHE: OnceLock<Mutex<BTreeMap<String, Option<String>>>> = OnceLock::new();
203
204    let base = normalise_documentation_key(name);
205    let cache = DOCUMENTATION_CACHE.get_or_init(|| Mutex::new(BTreeMap::new()));
206
207    if let Some(cached) = cache.lock().ok()?.get(&base).cloned() {
208        return cached;
209    }
210
211    // This url is for raw Markdown bytes
212    let url =
213        format!("https://raw.githubusercontent.com/conjure-cp/conjure/main/docs/bits/{base}.md");
214
215    let output = std::process::Command::new("curl")
216        .args(["-fsSL", &url])
217        .output()
218        .ok();
219
220    let documentation = output
221        .filter(|output| output.status.success())
222        .and_then(|output| String::from_utf8(output.stdout).ok());
223
224    if let Ok(mut cache) = cache.lock() {
225        cache.insert(base, documentation.clone());
226    }
227
228    documentation
229}
230
231fn normalise_documentation_key(name: &str) -> String {
232    name.strip_suffix(".md").unwrap_or(name).to_string()
233}
234
235mod test {
236    #[allow(unused)]
237    use super::*;
238
239    #[test]
240    fn test_get_metavars() {
241        let src = "such that &x = y";
242        let (tree, _) = get_tree(src).unwrap();
243        let root = tree.root_node();
244        let metavars = get_metavars(&root, src).collect::<Vec<_>>();
245        assert_eq!(metavars, vec!["x"]);
246    }
247}