1
use std::collections::VecDeque;
2

            
3
use tree_sitter::{Node, TreeCursor};
4

            
5
/// An iterator that traverses the syntax tree in pre-order DFS order.
6
pub struct WalkDFS<'a> {
7
    cursor: Option<TreeCursor<'a>>,
8
    retract: Option<&'a dyn Fn(&Node<'a>) -> bool>,
9
}
10

            
11
#[allow(dead_code)]
12
impl<'a> WalkDFS<'a> {
13
1
    pub fn new(node: &'a Node<'a>) -> Self {
14
1
        Self {
15
1
            cursor: Some(node.walk()),
16
1
            retract: None,
17
1
        }
18
1
    }
19

            
20
560
    pub fn with_retract(node: &'a Node<'a>, retract: &'a dyn Fn(&Node<'a>) -> bool) -> Self {
21
560
        Self {
22
560
            cursor: Some(node.walk()),
23
560
            retract: Some(retract),
24
560
        }
25
560
    }
26
}
27

            
28
impl<'a> Iterator for WalkDFS<'a> {
29
    type Item = Node<'a>;
30

            
31
8526
    fn next(&mut self) -> Option<Self::Item> {
32
8526
        let cursor = self.cursor.as_mut()?;
33
7969
        let node = cursor.node();
34

            
35
7969
        if self.retract.is_none() || !self.retract.as_ref().unwrap()(&node) {
36
            // Try to descend into the first child.
37
7264
            if cursor.goto_first_child() {
38
3809
                return Some(node);
39
3455
            }
40
705
        }
41

            
42
        // If we are at a leaf, try its next sibling instead.
43
4160
        if cursor.goto_next_sibling() {
44
2144
            return Some(node);
45
2016
        }
46

            
47
        // If neither has worked, we need to ascend until we can go to a sibling
48
        loop {
49
            // If we can't go to the parent, then that means we've reached the root, and our
50
            // iterator will be done in the next iteration
51
4360
            if !cursor.goto_parent() {
52
559
                self.cursor = None;
53
559
                break;
54
3801
            }
55

            
56
            // If we get to a sibling, then this will be the first time we touch that node,
57
            // so it'll be the next starting node
58
3801
            if cursor.goto_next_sibling() {
59
1457
                break;
60
2344
            }
61
        }
62

            
63
2016
        Some(node)
64
8526
    }
65
}
66

            
67
/// An iterator that traverses the syntax tree in breadth-first order.
68
pub struct WalkBFS<'a> {
69
    queue: VecDeque<Node<'a>>,
70
}
71

            
72
#[allow(dead_code)]
73
impl<'a> WalkBFS<'a> {
74
1
    pub fn new(root: &'a Node<'a>) -> Self {
75
1
        Self {
76
1
            queue: VecDeque::from([*root]),
77
1
        }
78
1
    }
79
}
80

            
81
impl<'a> Iterator for WalkBFS<'a> {
82
    type Item = Node<'a>;
83

            
84
8
    fn next(&mut self) -> Option<Self::Item> {
85
8
        let node = self.queue.pop_front()?;
86
8
        node.children(&mut node.walk()).for_each(|child| {
87
8
            self.queue.push_back(child);
88
8
        });
89
8
        Some(node)
90
8
    }
91
}
92

            
93
#[cfg(test)]
94
mod test {
95
    use super::super::util::get_tree;
96
    use super::*;
97

            
98
    #[test]
99
1
    pub fn test_bfs() {
100
1
        let (tree, _) = get_tree("such that x, true").unwrap();
101
1
        let root = tree.root_node();
102
8
        let mut iter = WalkBFS::new(&root).filter(|n| n.is_named());
103
1
        assert_eq!(iter.next().unwrap().kind(), "program"); // depth = 0
104
1
        assert_eq!(iter.next().unwrap().kind(), "atom"); // depth = 1
105
1
        assert_eq!(iter.next().unwrap().kind(), "atom"); // depth = 2
106
1
        assert_eq!(iter.next().unwrap().kind(), "identifier"); // depth = 2
107
1
        assert_eq!(iter.next().unwrap().kind(), "constant"); // depth = 2
108
1
        assert_eq!(iter.next().unwrap().kind(), "TRUE"); // depth = 3
109
1
    }
110

            
111
    #[test]
112
1
    pub fn test_dfs() {
113
1
        let (tree, _) = get_tree("such that x, true").unwrap();
114
1
        let root = tree.root_node();
115
8
        let mut iter = WalkDFS::new(&root).filter(|n| n.is_named());
116
1
        assert_eq!(iter.next().unwrap().kind(), "program"); // top level
117
1
        assert_eq!(iter.next().unwrap().kind(), "atom"); // first branch ("x")
118
1
        assert_eq!(iter.next().unwrap().kind(), "identifier");
119
1
        assert_eq!(iter.next().unwrap().kind(), "atom"); // second branch ("true")
120
1
        assert_eq!(iter.next().unwrap().kind(), "constant");
121
1
        assert_eq!(iter.next().unwrap().kind(), "TRUE");
122
1
    }
123

            
124
    #[test]
125
1
    pub fn test_dfs_retract() {
126
1
        let (tree, _) = get_tree("(x / 42) > (5 + y)").unwrap();
127
1
        let root = tree.root_node();
128
6
        let mut iter = WalkDFS::with_retract(&root, &|n: &Node<'_>| n.kind() == "arithmetic_expr")
129
6
            .filter(|n| n.is_named());
130
1
        assert_eq!(iter.next().unwrap().kind(), "program");
131
1
        assert_eq!(iter.next().unwrap().kind(), "comparison_expr");
132
1
        assert_eq!(iter.next().unwrap().kind(), "arithmetic_expr"); // first branch ("x / 42"). Don't descend into subexpressions.
133
1
        assert_eq!(iter.next().unwrap().kind(), "arithmetic_expr"); // second branch ("5 + y"). Don't descend into subexpressions.
134
1
    }
135
}