1
use std::collections::VecDeque;
2

            
3
use tree_sitter::{Node, TreeCursor};
4

            
5
/// An iterator that traverses the syntax tree in pre-order DFS order.
6
pub struct WalkDFS<'a> {
7
    cursor: Option<TreeCursor<'a>>,
8
    retract: Option<&'a dyn Fn(&Node<'a>) -> bool>,
9
}
10

            
11
#[allow(dead_code)]
12
impl<'a> WalkDFS<'a> {
13
    pub fn new(node: &'a Node<'a>) -> Self {
14
        Self {
15
            cursor: Some(node.walk()),
16
            retract: None,
17
        }
18
    }
19

            
20
193
    pub fn with_retract(node: &'a Node<'a>, retract: &'a dyn Fn(&Node<'a>) -> bool) -> Self {
21
193
        Self {
22
193
            cursor: Some(node.walk()),
23
193
            retract: Some(retract),
24
193
        }
25
193
    }
26
}
27

            
28
impl<'a> Iterator for WalkDFS<'a> {
29
    type Item = Node<'a>;
30

            
31
774
    fn next(&mut self) -> Option<Self::Item> {
32
774
        let cursor = self.cursor.as_mut()?;
33
581
        let node = cursor.node();
34

            
35
581
        if self.retract.is_none() || !self.retract.as_ref().unwrap()(&node) {
36
            // Try to descend into the first child.
37
387
            if cursor.goto_first_child() {
38
193
                return Some(node);
39
194
            }
40
194
        }
41

            
42
        // If we are at a leaf, try its next sibling instead.
43
388
        if cursor.goto_next_sibling() {
44
195
            return Some(node);
45
193
        }
46

            
47
        // If neither has worked, we need to ascend until we can go to a sibling
48
        loop {
49
            // If we can't go to the parent, then that means we've reached the root, and our
50
            // iterator will be done in the next iteration
51
386
            if !cursor.goto_parent() {
52
193
                self.cursor = None;
53
193
                break;
54
193
            }
55

            
56
            // If we get to a sibling, then this will be the first time we touch that node,
57
            // so it'll be the next starting node
58
193
            if cursor.goto_next_sibling() {
59
                break;
60
193
            }
61
        }
62

            
63
193
        Some(node)
64
774
    }
65
}
66

            
67
/// An iterator that traverses the syntax tree in breadth-first order.
68
pub struct WalkBFS<'a> {
69
    queue: VecDeque<Node<'a>>,
70
}
71

            
72
#[allow(dead_code)]
73
impl<'a> WalkBFS<'a> {
74
    pub fn new(root: &'a Node<'a>) -> Self {
75
        Self {
76
            queue: VecDeque::from([*root]),
77
        }
78
    }
79
}
80

            
81
impl<'a> Iterator for WalkBFS<'a> {
82
    type Item = Node<'a>;
83

            
84
    fn next(&mut self) -> Option<Self::Item> {
85
        let node = self.queue.pop_front()?;
86
        node.children(&mut node.walk()).for_each(|child| {
87
            self.queue.push_back(child);
88
        });
89
        Some(node)
90
    }
91
}
92

            
93
#[cfg(test)]
94
mod test {
95
    use super::super::util::get_tree;
96
    use super::*;
97

            
98
    #[test]
99
    pub fn test_bfs() {
100
        let (tree, _) = get_tree("such that x, true").unwrap();
101
        let root = tree.root_node();
102
        let mut iter = WalkBFS::new(&root).filter(|n| n.is_named());
103
        assert_eq!(iter.next().unwrap().kind(), "program"); // depth = 0
104
        assert_eq!(iter.next().unwrap().kind(), "bool_expr"); // depth = 1
105
        assert_eq!(iter.next().unwrap().kind(), "bool_expr"); // depth = 1
106
        assert_eq!(iter.next().unwrap().kind(), "atom"); // depth = 2
107
        assert_eq!(iter.next().unwrap().kind(), "atom"); // depth = 2
108
        assert_eq!(iter.next().unwrap().kind(), "identifier"); // depth = 3
109
        assert_eq!(iter.next().unwrap().kind(), "constant"); // depth = 3
110
        assert_eq!(iter.next().unwrap().kind(), "TRUE"); // depth = 4
111
    }
112

            
113
    #[test]
114
    pub fn test_dfs() {
115
        let (tree, _) = get_tree("such that x, true").unwrap();
116
        let root = tree.root_node();
117
        let mut iter = WalkDFS::new(&root).filter(|n| n.is_named());
118
        assert_eq!(iter.next().unwrap().kind(), "program"); // top level
119
        assert_eq!(iter.next().unwrap().kind(), "bool_expr"); // first branch ("x")
120
        assert_eq!(iter.next().unwrap().kind(), "atom");
121
        assert_eq!(iter.next().unwrap().kind(), "identifier");
122
        assert_eq!(iter.next().unwrap().kind(), "bool_expr"); // second branch ("true")
123
        assert_eq!(iter.next().unwrap().kind(), "atom");
124
        assert_eq!(iter.next().unwrap().kind(), "constant");
125
        assert_eq!(iter.next().unwrap().kind(), "TRUE");
126
    }
127

            
128
    #[test]
129
    pub fn test_dfs_retract() {
130
        let (tree, _) = get_tree("(x / 42) > (5 + y)").unwrap();
131
        let root = tree.root_node();
132
        let mut iter = WalkDFS::with_retract(&root, &|n: &Node<'_>| n.kind() == "arithmetic_expr")
133
            .filter(|n| n.is_named());
134
        assert_eq!(iter.next().unwrap().kind(), "program");
135
        assert_eq!(iter.next().unwrap().kind(), "comparison_expr");
136
        assert_eq!(iter.next().unwrap().kind(), "arithmetic_expr"); // first branch ("x / 42"). Don't descend into subexpressions.
137
        assert_eq!(iter.next().unwrap().kind(), "arithmetic_expr"); // second branch ("5 + y"). Don't descend into subexpressions.
138
    }
139
}