1
use std::collections::VecDeque;
2

            
3
use tree_sitter::{Node, TreeCursor};
4

            
5
/// An iterator that traverses the syntax tree in pre-order DFS order.
6
pub struct WalkDFS<'a> {
7
    cursor: Option<TreeCursor<'a>>,
8
    retract: Option<&'a dyn Fn(&Node<'a>) -> bool>,
9
}
10

            
11
#[allow(dead_code)]
12
impl<'a> WalkDFS<'a> {
13
3
    pub fn new(node: &'a Node<'a>) -> Self {
14
3
        Self {
15
3
            cursor: Some(node.walk()),
16
3
            retract: None,
17
3
        }
18
3
    }
19

            
20
1645
    pub fn with_retract(node: &'a Node<'a>, retract: &'a dyn Fn(&Node<'a>) -> bool) -> Self {
21
1645
        Self {
22
1645
            cursor: Some(node.walk()),
23
1645
            retract: Some(retract),
24
1645
        }
25
1645
    }
26
}
27

            
28
impl<'a> Iterator for WalkDFS<'a> {
29
    type Item = Node<'a>;
30

            
31
26400
    fn next(&mut self) -> Option<Self::Item> {
32
26400
        let cursor = self.cursor.as_mut()?;
33
24761
        let node = cursor.node();
34

            
35
24761
        if self.retract.is_none() || !self.retract.as_ref().unwrap()(&node) {
36
            // Try to descend into the first child.
37
22639
            if cursor.goto_first_child() {
38
11939
                return Some(node);
39
10700
            }
40
2122
        }
41

            
42
        // If we are at a leaf, try its next sibling instead.
43
12822
        if cursor.goto_next_sibling() {
44
6623
            return Some(node);
45
6199
        }
46

            
47
        // If neither has worked, we need to ascend until we can go to a sibling
48
        loop {
49
            // If we can't go to the parent, then that means we've reached the root, and our
50
            // iterator will be done in the next iteration
51
13557
            if !cursor.goto_parent() {
52
1642
                self.cursor = None;
53
1642
                break;
54
11915
            }
55

            
56
            // If we get to a sibling, then this will be the first time we touch that node,
57
            // so it'll be the next starting node
58
11915
            if cursor.goto_next_sibling() {
59
4557
                break;
60
7358
            }
61
        }
62

            
63
6199
        Some(node)
64
26400
    }
65
}
66

            
67
/// An iterator that traverses the syntax tree in breadth-first order.
68
pub struct WalkBFS<'a> {
69
    queue: VecDeque<Node<'a>>,
70
}
71

            
72
#[allow(dead_code)]
73
impl<'a> WalkBFS<'a> {
74
3
    pub fn new(root: &'a Node<'a>) -> Self {
75
3
        Self {
76
3
            queue: VecDeque::from([*root]),
77
3
        }
78
3
    }
79
}
80

            
81
impl<'a> Iterator for WalkBFS<'a> {
82
    type Item = Node<'a>;
83

            
84
24
    fn next(&mut self) -> Option<Self::Item> {
85
24
        let node = self.queue.pop_front()?;
86
24
        node.children(&mut node.walk()).for_each(|child| {
87
24
            self.queue.push_back(child);
88
24
        });
89
24
        Some(node)
90
24
    }
91
}
92

            
93
#[cfg(test)]
94
mod test {
95
    use super::super::util::get_tree;
96
    use super::*;
97

            
98
    #[test]
99
3
    pub fn test_bfs() {
100
3
        let (tree, _) = get_tree("such that x, true").unwrap();
101
3
        let root = tree.root_node();
102
24
        let mut iter = WalkBFS::new(&root).filter(|n| n.is_named());
103
3
        assert_eq!(iter.next().unwrap().kind(), "program"); // depth = 0
104
3
        assert_eq!(iter.next().unwrap().kind(), "atom"); // depth = 1
105
3
        assert_eq!(iter.next().unwrap().kind(), "atom"); // depth = 2
106
3
        assert_eq!(iter.next().unwrap().kind(), "identifier"); // depth = 2
107
3
        assert_eq!(iter.next().unwrap().kind(), "constant"); // depth = 2
108
3
        assert_eq!(iter.next().unwrap().kind(), "TRUE"); // depth = 3
109
3
    }
110

            
111
    #[test]
112
3
    pub fn test_dfs() {
113
3
        let (tree, _) = get_tree("such that x, true").unwrap();
114
3
        let root = tree.root_node();
115
24
        let mut iter = WalkDFS::new(&root).filter(|n| n.is_named());
116
3
        assert_eq!(iter.next().unwrap().kind(), "program"); // top level
117
3
        assert_eq!(iter.next().unwrap().kind(), "atom"); // first branch ("x")
118
3
        assert_eq!(iter.next().unwrap().kind(), "identifier");
119
3
        assert_eq!(iter.next().unwrap().kind(), "atom"); // second branch ("true")
120
3
        assert_eq!(iter.next().unwrap().kind(), "constant");
121
3
        assert_eq!(iter.next().unwrap().kind(), "TRUE");
122
3
    }
123

            
124
    #[test]
125
2
    pub fn test_dfs_retract() {
126
2
        let (tree, _) = get_tree("(x / 42) > (5 + y)").unwrap();
127
2
        let root = tree.root_node();
128
14
        let mut iter = WalkDFS::with_retract(&root, &|n: &Node<'_>| n.kind() == "arithmetic_expr")
129
14
            .filter(|n| n.is_named());
130
2
        assert_eq!(iter.next().unwrap().kind(), "program");
131
2
        assert_eq!(iter.next().unwrap().kind(), "comparison_expr");
132
2
        assert_eq!(iter.next().unwrap().kind(), "arithmetic_comparison"); // intermediate node for > comparison
133
2
        assert_eq!(iter.next().unwrap().kind(), "arithmetic_expr"); // first branch ("x / 42"). Don't descend into subexpressions.
134
2
        assert_eq!(iter.next().unwrap().kind(), "arithmetic_expr"); // second branch ("5 + y"). Don't descend into subexpressions.
135
2
    }
136
}