1
use std::collections::BTreeMap;
2

            
3
use tree_sitter::{Node, Parser, Tree};
4
use tree_sitter_essence::LANGUAGE;
5

            
6
use super::traversal::WalkDFS;
7
use crate::diagnostics::diagnostics_api::SymbolKind;
8
use crate::diagnostics::source_map::{HoverInfo, SourceMap, SpanId, span_with_hover};
9
use crate::errors::RecoverableParseError;
10
use conjure_cp_core::ast::{Name, SymbolTablePtr};
11

            
12
/// Context for parsing, containing shared state passed through parser functions.
13
pub struct ParseContext<'a> {
14
    pub source_code: &'a str,
15
    pub root: &'a Node<'a>,
16
    pub symbols: Option<SymbolTablePtr>,
17
    pub errors: &'a mut Vec<RecoverableParseError>,
18
    pub source_map: &'a mut SourceMap,
19
    pub decl_spans: &'a mut BTreeMap<Name, SpanId>,
20
    pub typechecking_context: TypecheckingContext,
21
}
22

            
23
impl<'a> ParseContext<'a> {
24
19021
    pub fn new(
25
19021
        source_code: &'a str,
26
19021
        root: &'a Node<'a>,
27
19021
        symbols: Option<SymbolTablePtr>,
28
19021
        errors: &'a mut Vec<RecoverableParseError>,
29
19021
        source_map: &'a mut SourceMap,
30
19021
        decl_spans: &'a mut BTreeMap<Name, SpanId>,
31
19021
    ) -> Self {
32
19021
        Self {
33
19021
            source_code,
34
19021
            root,
35
19021
            symbols,
36
19021
            errors,
37
19021
            source_map,
38
19021
            decl_spans,
39
19021
            typechecking_context: TypecheckingContext::Unknown,
40
19021
        }
41
19021
    }
42

            
43
586
    pub fn record_error(&mut self, error: RecoverableParseError) {
44
586
        self.errors.push(error);
45
586
    }
46

            
47
    /// Create a new ParseContext with different symbols but sharing source_code, root, errors, and source_map.
48
3974
    pub fn with_new_symbols(&mut self, symbols: Option<SymbolTablePtr>) -> ParseContext<'_> {
49
3974
        ParseContext {
50
3974
            source_code: self.source_code,
51
3974
            root: self.root,
52
3974
            symbols,
53
3974
            errors: self.errors,
54
3974
            source_map: self.source_map,
55
3974
            decl_spans: self.decl_spans,
56
3974
            typechecking_context: self.typechecking_context,
57
3974
        }
58
3974
    }
59

            
60
47157
    pub fn save_decl_span(&mut self, name: Name, span_id: SpanId) {
61
47157
        self.decl_spans.insert(name, span_id);
62
47157
    }
63

            
64
145900
    pub fn lookup_decl_span(&self, name: &Name) -> Option<SpanId> {
65
145900
        self.decl_spans.get(name).copied()
66
145900
    }
67

            
68
65
    pub fn lookup_decl_line(&self, name: &Name) -> Option<u32> {
69
65
        let span_id = self.lookup_decl_span(name)?;
70
65
        let span = self.source_map.spans.get(span_id as usize)?;
71
65
        Some(span.start_point.line + 1)
72
65
    }
73

            
74
    /// Helper to add to span and documentation hover info into the source map
75
203632
    pub fn add_span_and_doc_hover(
76
203632
        &mut self,
77
203632
        node: &tree_sitter::Node,
78
203632
        doc_key: &str, // name of the documentation file in Bits
79
203632
        kind: SymbolKind,
80
203632
        ty: Option<String>,
81
203632
        decl_span: Option<u32>,
82
203632
    ) {
83
203632
        if let Some(description) = get_documentation(doc_key) {
84
139804
            let hover = HoverInfo {
85
139804
                description,
86
139804
                kind: Some(kind),
87
139804
                ty,
88
139804
                decl_span,
89
139804
            };
90
139804
            span_with_hover(node, self.source_code, self.source_map, hover);
91
139810
        }
92
        // If documentation is not found, do nothing (no fallback, no addition to source map)
93
203632
    }
94
}
95

            
96
// Used to detect type mismatches during parsing.
97
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
98
pub enum TypecheckingContext {
99
    Boolean,
100
    Arithmetic,
101
    Set,
102
    /// Context is unknown or flexible
103
    Unknown,
104
}
105

            
106
/// Parse the given source code into a syntax tree using tree-sitter.
107
///
108
/// If successful, returns a tuple containing the syntax tree and the raw source code.
109
/// If the source code is not valid Essence, returns None.
110
18790
pub fn get_tree(src: &str) -> Option<(Tree, String)> {
111
18790
    let mut parser = Parser::new();
112
18790
    parser.set_language(&LANGUAGE.into()).unwrap();
113

            
114
18790
    parser.parse(src, None).and_then(|tree| {
115
18790
        let root = tree.root_node();
116
18790
        if root.is_error() {
117
            return None;
118
18790
        }
119
18790
        Some((tree, src.to_string()))
120
18790
    })
121
18790
}
122

            
123
/// Parse an expression fragment, allowing a dummy prefix for error recovery.
124
///
125
/// NOTE: The new source code may be different from the original source code.
126
///       See implementation for details.
127
472
pub fn get_expr_tree(src: &str) -> Option<(Tree, String)> {
128
472
    let mut parser = Parser::new();
129
472
    parser.set_language(&LANGUAGE.into()).unwrap();
130

            
131
472
    parser.parse(src, None).and_then(|tree| {
132
472
        let root = tree.root_node();
133
472
        if root.is_error() {
134
            return None;
135
472
        }
136

            
137
472
        let children: Vec<_> = named_children(&root).collect();
138
472
        let first_child = children.first()?;
139

            
140
        // HACK: Tree-sitter can only parse a complete program from top to bottom, not an individual bit of syntax.
141
        // See: https://github.com/tree-sitter/tree-sitter/issues/711 and linked issues.
142
        // However, we can use a dummy _FRAGMENT_EXPRESSION prefix (which we insert as necessary)
143
        // to trick the parser into accepting an isolated expression.
144
        // This way we can parse an isolated expression and it is only slightly cursed :)
145
472
        if first_child.is_error() {
146
236
            if src.starts_with("_FRAGMENT_EXPRESSION") {
147
                None
148
            } else {
149
236
                get_expr_tree(&format!("_FRAGMENT_EXPRESSION {src}"))
150
            }
151
        } else {
152
236
            Some((tree, src.to_string()))
153
        }
154
472
    })
155
472
}
156

            
157
/// Get the named children of a node
158
122180
pub fn named_children<'a>(node: &'a Node<'a>) -> impl Iterator<Item = Node<'a>> + 'a {
159
122180
    (0..node.named_child_count())
160
148359
        .filter_map(|i| u32::try_from(i).ok().and_then(|i| node.named_child(i)))
161
122180
}
162

            
163
1181
pub fn node_is_expression(node: &Node) -> bool {
164
474
    matches!(
165
1181
        node.kind(),
166
1181
        "bool_expr" | "arithmetic_expr" | "comparison_expr" | "atom"
167
    )
168
1181
}
169

            
170
/// Get all top-level nodes that match the given predicate
171
236
pub fn query_toplevel<'a>(
172
236
    node: &'a Node<'a>,
173
236
    predicate: &'a dyn Fn(&Node<'a>) -> bool,
174
236
) -> impl Iterator<Item = Node<'a>> + 'a {
175
718
    WalkDFS::with_retract(node, predicate).filter(|n| n.is_named() && predicate(n))
176
236
}
177

            
178
/// Get all meta-variable names in a node
179
1
pub fn get_metavars<'a>(node: &'a Node<'a>, src: &'a str) -> impl Iterator<Item = String> + 'a {
180
16
    query_toplevel(node, &|n| n.kind() == "metavar").filter_map(|child| {
181
1
        child
182
1
            .named_child(0)
183
1
            .map(|name| src[name.start_byte()..name.end_byte()].to_string())
184
1
    })
185
1
}
186

            
187
/// Fetch Essence syntax documentation from Conjure's `docs/bits/` folder on GitHub.
188
///
189
/// `name` is the name of the documentation file (without .md suffix). If the file is not found or an error occurs, returns None.
190
203632
pub fn get_documentation(name: &str) -> Option<String> {
191
203632
    let mut base = name.to_string();
192
203632
    if let Some(stripped) = base.strip_suffix(".md") {
193
        base = stripped.to_string();
194
203632
    }
195

            
196
    // This url is for raw Markdown bytes
197
203632
    let url =
198
203632
        format!("https://raw.githubusercontent.com/conjure-cp/conjure/main/docs/bits/{base}.md");
199

            
200
203632
    let output = std::process::Command::new("curl")
201
203632
        .args(["-fsSL", &url])
202
203632
        .output()
203
203632
        .ok()?;
204

            
205
203632
    if !output.status.success() {
206
63828
        return None;
207
139804
    }
208

            
209
139804
    String::from_utf8(output.stdout).ok()
210
203632
}
211

            
212
mod test {
213
    #[allow(unused)]
214
    use super::*;
215

            
216
    #[test]
217
1
    fn test_get_metavars() {
218
1
        let src = "such that &x = y";
219
1
        let (tree, _) = get_tree(src).unwrap();
220
1
        let root = tree.root_node();
221
1
        let metavars = get_metavars(&root, src).collect::<Vec<_>>();
222
1
        assert_eq!(metavars, vec!["x"]);
223
1
    }
224
}