1
#![allow(dead_code)]
2
use std::sync::Arc;
3

            
4
use uniplate::zipper::Zipper;
5

            
6
use crate::ast::Expression;
7

            
8
/// Traverses expressions in a root expression tree, but not into nested scopes.
9
///
10
/// Same types and usage as `Biplate::contexts_bi`.
11
1445726
pub(super) fn expression_ctx(
12
1445726
    root_expression: Expression,
13
1445726
) -> impl Iterator<Item = (Expression, Arc<dyn Fn(Expression) -> Expression>)> {
14
1445726
    ExpressionCtx {
15
1445726
        zipper: SubmodelZipper {
16
1445726
            inner: Zipper::new(root_expression),
17
1445726
        },
18
1445726
        done: false,
19
1445726
    }
20
1445726
}
21

            
22
/// A zipper that traverses over the current expression tree and does not traverse into nested
23
/// scopes.
24
#[derive(Clone)]
25
#[doc(hidden)]
26
pub struct SubmodelZipper {
27
    inner: Zipper<Expression>,
28
}
29

            
30
impl SubmodelZipper {
31
    #[doc(hidden)]
32
    pub fn go_left(&mut self) -> Option<()> {
33
        self.inner.go_left()
34
    }
35

            
36
    #[doc(hidden)]
37
40272820
    pub fn go_right(&mut self) -> Option<()> {
38
40272820
        self.inner.go_right()
39
40272820
    }
40

            
41
    #[doc(hidden)]
42
18269164
    pub fn go_up(&mut self) -> Option<()> {
43
18269164
        self.inner.go_up()
44
18269164
    }
45

            
46
    #[doc(hidden)]
47
204364
    pub fn rebuild_root(self) -> Expression {
48
204364
        self.inner.rebuild_root()
49
204364
    }
50

            
51
    #[doc(hidden)]
52
40654056
    pub fn go_down(&mut self) -> Option<()> {
53
        // Do not enter comprehensions, which have their own local symbol table.
54
40654056
        if matches!(self.inner.focus(), Expression::Comprehension(_, _)) {
55
156928
            None
56
        } else {
57
40497128
            self.inner.go_down()
58
        }
59
40654056
    }
60

            
61
    #[doc(hidden)]
62
40600160
    pub fn focus(&self) -> &Expression {
63
40600160
        self.inner.focus()
64
40600160
    }
65

            
66
    #[doc(hidden)]
67
6100
    pub fn replace_focus(&mut self, new_focus: Expression) -> Expression {
68
6100
        self.inner.replace_focus(new_focus)
69
6100
    }
70

            
71
    #[doc(hidden)]
72
150468
    pub fn focus_mut(&mut self) -> &mut Expression {
73
150468
        self.inner.focus_mut()
74
150468
    }
75

            
76
    #[doc(hidden)]
77
53896
    pub fn new(root_expression: Expression) -> Self {
78
53896
        SubmodelZipper {
79
53896
            inner: Zipper::new(root_expression),
80
53896
        }
81
53896
    }
82
}
83

            
84
pub struct ExpressionCtx {
85
    zipper: SubmodelZipper,
86
    done: bool,
87
}
88

            
89
impl Iterator for ExpressionCtx {
90
    type Item = (Expression, Arc<dyn Fn(Expression) -> Expression>);
91

            
92
40548578
    fn next(&mut self) -> Option<Self::Item> {
93
40548578
        if self.done {
94
1295258
            return None;
95
39253320
        }
96
39253320
        let node = self.zipper.focus().clone();
97
39253320
        let zipper = self.zipper.clone();
98

            
99
        #[allow(clippy::arc_with_non_send_sync)]
100
39253320
        let ctx = Arc::new(move |x| {
101
150468
            let mut zipper2 = zipper.clone();
102
150468
            *zipper2.focus_mut() = x;
103
150468
            zipper2.rebuild_root()
104
150468
        });
105

            
106
        // prepare iterator for next element.
107
        // try moving down or right. if we can't move up the tree until we can move right.
108
39253320
        if self.zipper.go_down().is_none() {
109
38872084
            while self.zipper.go_right().is_none() {
110
17625114
                if self.zipper.go_up().is_none() {
111
                    // at the top again, so this will be the last time we return a node
112
1298808
                    self.done = true;
113
1298808
                    break;
114
16326306
                };
115
            }
116
16707542
        }
117

            
118
39253320
        Some((node, ctx))
119
40548578
    }
120
}