1
#![allow(dead_code)]
2
use std::sync::Arc;
3

            
4
use uniplate::zipper::Zipper;
5

            
6
use crate::ast::Expression;
7

            
8
/// Traverses expressions in a root expression tree, but not into nested scopes.
9
///
10
/// Same types and usage as `Biplate::contexts_bi`.
11
4741590
pub(super) fn expression_ctx(
12
4741590
    root_expression: Expression,
13
4741590
) -> impl Iterator<Item = (Expression, Arc<dyn Fn(Expression) -> Expression>)> {
14
4741590
    ExpressionCtx {
15
4741590
        zipper: SubmodelZipper {
16
4741590
            inner: Zipper::new(root_expression),
17
4741590
        },
18
4741590
        done: false,
19
4741590
    }
20
4741590
}
21

            
22
/// A zipper that traverses over the current expression tree and does not traverse into nested
23
/// scopes.
24
#[derive(Clone)]
25
#[doc(hidden)]
26
pub struct SubmodelZipper {
27
    inner: Zipper<Expression>,
28
}
29

            
30
impl SubmodelZipper {
31
    #[doc(hidden)]
32
    pub fn go_left(&mut self) -> Option<()> {
33
        self.inner.go_left()
34
    }
35

            
36
    #[doc(hidden)]
37
92895172
    pub fn go_right(&mut self) -> Option<()> {
38
92895172
        self.inner.go_right()
39
92895172
    }
40

            
41
    #[doc(hidden)]
42
43830834
    pub fn go_up(&mut self) -> Option<()> {
43
43830834
        self.inner.go_up()
44
43830834
    }
45

            
46
    #[doc(hidden)]
47
524522
    pub fn rebuild_root(self) -> Expression {
48
524522
        self.inner.rebuild_root()
49
524522
    }
50

            
51
    #[doc(hidden)]
52
93777160
    pub fn go_down(&mut self) -> Option<()> {
53
        // Do not enter comprehensions, which have their own local symbol table.
54
93777160
        if matches!(self.inner.focus(), Expression::Comprehension(_, _)) {
55
399614
            None
56
        } else {
57
93377546
            self.inner.go_down()
58
        }
59
93777160
    }
60

            
61
    #[doc(hidden)]
62
93586094
    pub fn focus(&self) -> &Expression {
63
93586094
        self.inner.focus()
64
93586094
    }
65

            
66
    #[doc(hidden)]
67
22866
    pub fn replace_focus(&mut self, new_focus: Expression) -> Expression {
68
22866
        self.inner.replace_focus(new_focus)
69
22866
    }
70

            
71
    #[doc(hidden)]
72
333456
    pub fn focus_mut(&mut self) -> &mut Expression {
73
333456
        self.inner.focus_mut()
74
333456
    }
75

            
76
    #[doc(hidden)]
77
191066
    pub fn new(root_expression: Expression) -> Self {
78
191066
        SubmodelZipper {
79
191066
            inner: Zipper::new(root_expression),
80
191066
        }
81
191066
    }
82
}
83

            
84
pub struct ExpressionCtx {
85
    zipper: SubmodelZipper,
86
    done: bool,
87
}
88

            
89
impl Iterator for ExpressionCtx {
90
    type Item = (Expression, Arc<dyn Fn(Expression) -> Expression>);
91

            
92
91896202
    fn next(&mut self) -> Option<Self::Item> {
93
91896202
        if self.done {
94
4408134
            return None;
95
87488068
        }
96
87488068
        let node = self.zipper.focus().clone();
97
87488068
        let zipper = self.zipper.clone();
98

            
99
        #[allow(clippy::arc_with_non_send_sync)]
100
87488068
        let ctx = Arc::new(move |x| {
101
333456
            let mut zipper2 = zipper.clone();
102
333456
            *zipper2.focus_mut() = x;
103
333456
            zipper2.rebuild_root()
104
333456
        });
105

            
106
        // prepare iterator for next element.
107
        // try moving down or right. if we can't move up the tree until we can move right.
108
87488068
        if self.zipper.go_down().is_none() {
109
86606080
            while self.zipper.go_right().is_none() {
110
41096852
                if self.zipper.go_up().is_none() {
111
                    // at the top again, so this will be the last time we return a node
112
4418158
                    self.done = true;
113
4418158
                    break;
114
36678694
                };
115
            }
116
37560682
        }
117

            
118
87488068
        Some((node, ctx))
119
91896202
    }
120
}