conjure_cp_essence_parser/parser/
traversal.rs

1use std::collections::VecDeque;
2
3use tree_sitter::{Node, TreeCursor};
4
5/// An iterator that traverses the syntax tree in pre-order DFS order.
6pub struct WalkDFS<'a> {
7    cursor: Option<TreeCursor<'a>>,
8    retract: Option<&'a dyn Fn(&Node<'a>) -> bool>,
9}
10
11#[allow(dead_code)]
12impl<'a> WalkDFS<'a> {
13    pub fn new(node: &'a Node<'a>) -> Self {
14        Self {
15            cursor: Some(node.walk()),
16            retract: None,
17        }
18    }
19
20    pub fn with_retract(node: &'a Node<'a>, retract: &'a dyn Fn(&Node<'a>) -> bool) -> Self {
21        Self {
22            cursor: Some(node.walk()),
23            retract: Some(retract),
24        }
25    }
26}
27
28impl<'a> Iterator for WalkDFS<'a> {
29    type Item = Node<'a>;
30
31    fn next(&mut self) -> Option<Self::Item> {
32        let cursor = self.cursor.as_mut()?;
33        let node = cursor.node();
34
35        if self.retract.is_none() || !self.retract.as_ref().unwrap()(&node) {
36            // Try to descend into the first child.
37            if cursor.goto_first_child() {
38                return Some(node);
39            }
40        }
41
42        // If we are at a leaf, try its next sibling instead.
43        if cursor.goto_next_sibling() {
44            return Some(node);
45        }
46
47        // If neither has worked, we need to ascend until we can go to a sibling
48        loop {
49            // If we can't go to the parent, then that means we've reached the root, and our
50            // iterator will be done in the next iteration
51            if !cursor.goto_parent() {
52                self.cursor = None;
53                break;
54            }
55
56            // If we get to a sibling, then this will be the first time we touch that node,
57            // so it'll be the next starting node
58            if cursor.goto_next_sibling() {
59                break;
60            }
61        }
62
63        Some(node)
64    }
65}
66
67/// An iterator that traverses the syntax tree in breadth-first order.
68pub struct WalkBFS<'a> {
69    queue: VecDeque<Node<'a>>,
70}
71
72#[allow(dead_code)]
73impl<'a> WalkBFS<'a> {
74    pub fn new(root: &'a Node<'a>) -> Self {
75        Self {
76            queue: VecDeque::from([*root]),
77        }
78    }
79}
80
81impl<'a> Iterator for WalkBFS<'a> {
82    type Item = Node<'a>;
83
84    fn next(&mut self) -> Option<Self::Item> {
85        let node = self.queue.pop_front()?;
86        node.children(&mut node.walk()).for_each(|child| {
87            self.queue.push_back(child);
88        });
89        Some(node)
90    }
91}
92
93#[cfg(test)]
94mod test {
95    use super::super::util::get_tree;
96    use super::*;
97
98    #[test]
99    pub fn test_bfs() {
100        let (tree, _) = get_tree("such that x, true").unwrap();
101        let root = tree.root_node();
102        let mut iter = WalkBFS::new(&root).filter(|n| n.is_named());
103        assert_eq!(iter.next().unwrap().kind(), "program"); // depth = 0
104        assert_eq!(iter.next().unwrap().kind(), "bool_expr"); // depth = 1
105        assert_eq!(iter.next().unwrap().kind(), "bool_expr"); // depth = 1
106        assert_eq!(iter.next().unwrap().kind(), "atom"); // depth = 2
107        assert_eq!(iter.next().unwrap().kind(), "atom"); // depth = 2
108        assert_eq!(iter.next().unwrap().kind(), "identifier"); // depth = 3
109        assert_eq!(iter.next().unwrap().kind(), "constant"); // depth = 3
110        assert_eq!(iter.next().unwrap().kind(), "TRUE"); // depth = 4
111    }
112
113    #[test]
114    pub fn test_dfs() {
115        let (tree, _) = get_tree("such that x, true").unwrap();
116        let root = tree.root_node();
117        let mut iter = WalkDFS::new(&root).filter(|n| n.is_named());
118        assert_eq!(iter.next().unwrap().kind(), "program"); // top level
119        assert_eq!(iter.next().unwrap().kind(), "bool_expr"); // first branch ("x")
120        assert_eq!(iter.next().unwrap().kind(), "atom");
121        assert_eq!(iter.next().unwrap().kind(), "identifier");
122        assert_eq!(iter.next().unwrap().kind(), "bool_expr"); // second branch ("true")
123        assert_eq!(iter.next().unwrap().kind(), "atom");
124        assert_eq!(iter.next().unwrap().kind(), "constant");
125        assert_eq!(iter.next().unwrap().kind(), "TRUE");
126    }
127
128    #[test]
129    pub fn test_dfs_retract() {
130        let (tree, _) = get_tree("(x / 42) > (5 + y)").unwrap();
131        let root = tree.root_node();
132        let mut iter = WalkDFS::with_retract(&root, &|n: &Node<'_>| n.kind() == "arithmetic_expr")
133            .filter(|n| n.is_named());
134        assert_eq!(iter.next().unwrap().kind(), "program");
135        assert_eq!(iter.next().unwrap().kind(), "comparison_expr");
136        assert_eq!(iter.next().unwrap().kind(), "arithmetic_expr"); // first branch ("x / 42"). Don't descend into subexpressions.
137        assert_eq!(iter.next().unwrap().kind(), "arithmetic_expr"); // second branch ("5 + y"). Don't descend into subexpressions.
138    }
139}