1
use std::{
2
    collections::HashMap,
3
    ops::DerefMut,
4
    sync::atomic::{AtomicU64, Ordering},
5
};
6

            
7
use tree_morph::{
8
    cache::{HashMapCache, NoCache, StdHashKey},
9
    prelude::*,
10
};
11
use tree_morph_macros::named_rule;
12
use uniplate::Uniplate;
13

            
14
#[derive(Debug, Clone, PartialEq, Eq, Uniplate, Hash)]
15
#[uniplate()]
16
enum Expr {
17
    Triple(Box<Expr>, Box<Expr>, Box<Expr>),
18
    Quad(Box<Expr>, Box<Expr>, Box<Expr>, Box<Expr>),
19
    A,
20
    B,
21
    C,
22
    D,
23
}
24

            
25
#[derive(Debug)]
26
struct Meta {
27
    attempts: HashMap<String, usize>,
28
    applied: HashMap<String, usize>,
29
}
30

            
31
#[named_rule("a->b")]
32
46
fn a_to_b(cmd: &mut Commands<Expr, Meta>, expr: &Expr, _meta: &Meta) -> Option<Expr> {
33
46
    cmd.mut_meta(Box::new(|m| {
34
2
        *m.attempts.entry("a->b".into()).or_default().deref_mut() += 1
35
2
    }));
36
46
    match expr {
37
        Expr::A => {
38
2
            cmd.mut_meta(Box::new(|m| {
39
2
                *m.applied.entry("a->b".into()).or_default().deref_mut() += 1
40
2
            }));
41
2
            Some(Expr::B)
42
        }
43
44
        _ => None,
44
    }
45
46
}
46

            
47
#[named_rule("b->c")]
48
46
fn b_to_c(cmd: &mut Commands<Expr, Meta>, expr: &Expr, _meta: &Meta) -> Option<Expr> {
49
46
    cmd.mut_meta(Box::new(|m| {
50
3
        *m.attempts.entry("b->c".into()).or_default().deref_mut() += 1
51
3
    }));
52
46
    match expr {
53
        Expr::B => {
54
3
            cmd.mut_meta(Box::new(|m| {
55
3
                *m.applied.entry("b->c".into()).or_default().deref_mut() += 1
56
3
            }));
57
3
            Some(Expr::C)
58
        }
59
43
        _ => None,
60
    }
61
46
}
62

            
63
#[named_rule("c->d")]
64
46
fn c_to_d(cmd: &mut Commands<Expr, Meta>, expr: &Expr, _meta: &Meta) -> Option<Expr> {
65
46
    cmd.mut_meta(Box::new(|m| {
66
9
        *m.attempts.entry("c->d".into()).or_default().deref_mut() += 1
67
9
    }));
68
46
    match expr {
69
        Expr::C => {
70
9
            cmd.mut_meta(Box::new(|m| {
71
9
                *m.applied.entry("c->d".into()).or_default().deref_mut() += 1
72
9
            }));
73
9
            Some(Expr::D)
74
        }
75
37
        _ => None,
76
    }
77
46
}
78

            
79
2
fn setup() -> (
80
2
    Meta,
81
2
    Engine<
82
2
        Expr,
83
2
        Meta,
84
2
        NamedRule<
85
2
            for<'a, 'b, 'c> fn(
86
2
                &'a mut tree_morph::commands::Commands<Expr, Meta>,
87
2
                &'b Expr,
88
2
                &'c Meta,
89
2
            ) -> Option<Expr>,
90
2
        >,
91
2
        HashMapCache<Expr>,
92
2
    >,
93
2
) {
94
2
    let meta = Meta {
95
2
        applied: HashMap::new(),
96
2
        attempts: HashMap::new(),
97
2
    };
98
2
    let engine = EngineBuilder::new()
99
2
        .add_rule_group(vec![a_to_b, b_to_c, c_to_d])
100
2
        .add_cacher(HashMapCache::<_, StdHashKey>::new())
101
2
        .build();
102
2
    (meta, engine)
103
2
}
104

            
105
2
fn setup_nocache() -> (
106
2
    Meta,
107
2
    Engine<
108
2
        Expr,
109
2
        Meta,
110
2
        NamedRule<
111
2
            for<'a, 'b, 'c> fn(
112
2
                &'a mut tree_morph::commands::Commands<Expr, Meta>,
113
2
                &'b Expr,
114
2
                &'c Meta,
115
2
            ) -> Option<Expr>,
116
2
        >,
117
2
        NoCache,
118
2
    >,
119
2
) {
120
2
    let meta = Meta {
121
2
        applied: HashMap::new(),
122
2
        attempts: HashMap::new(),
123
2
    };
124
2
    let engine = EngineBuilder::new()
125
2
        .add_rule_group(vec![a_to_b, b_to_c, c_to_d])
126
2
        .build();
127
2
    (meta, engine)
128
2
}
129

            
130
#[test]
131
1
fn no_cache() {
132
1
    let expr = Expr::Quad(
133
1
        Box::new(Expr::C),
134
1
        Box::new(Expr::C),
135
1
        Box::new(Expr::C),
136
1
        Box::new(Expr::C),
137
1
    );
138

            
139
1
    let (meta, mut engine) = setup_nocache();
140

            
141
1
    let (expr, meta) = engine.morph(expr, meta);
142

            
143
1
    assert_eq!(
144
        expr,
145
1
        Expr::Quad(
146
1
            Box::new(Expr::D),
147
1
            Box::new(Expr::D),
148
1
            Box::new(Expr::D),
149
1
            Box::new(Expr::D),
150
1
        )
151
    );
152

            
153
1
    assert_eq!(meta.attempts.keys().len(), 1);
154

            
155
1
    assert_eq!(meta.applied.keys().len(), 1);
156

            
157
1
    assert_eq!(meta.applied.get("c->d"), Some(4).as_ref());
158
1
}
159

            
160
#[test]
161
1
fn basic_caching() {
162
1
    let expr = Expr::Quad(
163
1
        Box::new(Expr::C),
164
1
        Box::new(Expr::C),
165
1
        Box::new(Expr::C),
166
1
        Box::new(Expr::C),
167
1
    );
168

            
169
1
    let (meta, mut engine) = setup();
170

            
171
1
    let (expr, meta) = engine.morph(expr, meta);
172

            
173
1
    assert_eq!(
174
        expr,
175
1
        Expr::Quad(
176
1
            Box::new(Expr::D),
177
1
            Box::new(Expr::D),
178
1
            Box::new(Expr::D),
179
1
            Box::new(Expr::D),
180
1
        )
181
    );
182

            
183
1
    assert_eq!(meta.attempts.keys().len(), 1);
184

            
185
1
    assert_eq!(meta.applied.keys().len(), 1);
186

            
187
1
    assert_eq!(meta.applied.get("c->d"), Some(1).as_ref());
188
1
}
189

            
190
#[test]
191
1
fn transitive_no_caching() {
192
1
    let expr = Expr::Triple(Box::new(Expr::A), Box::new(Expr::B), Box::new(Expr::C));
193

            
194
1
    let (meta, mut engine) = setup_nocache();
195
1
    let (expr, meta) = engine.morph(expr, meta);
196

            
197
1
    assert_eq!(
198
        expr,
199
1
        Expr::Triple(Box::new(Expr::D), Box::new(Expr::D), Box::new(Expr::D),)
200
    );
201

            
202
1
    assert_eq!(meta.applied.get("a->b"), Some(1).as_ref());
203
1
    assert_eq!(meta.applied.get("b->c"), Some(2).as_ref());
204
1
    assert_eq!(meta.applied.get("c->d"), Some(3).as_ref());
205
1
}
206

            
207
#[test]
208
1
fn transitive_caching() {
209
1
    let expr = Expr::Triple(Box::new(Expr::A), Box::new(Expr::B), Box::new(Expr::C));
210

            
211
1
    let (meta, mut engine) = setup();
212
1
    let (expr, meta) = engine.morph(expr, meta);
213

            
214
1
    assert_eq!(
215
        expr,
216
1
        Expr::Triple(Box::new(Expr::D), Box::new(Expr::D), Box::new(Expr::D),)
217
    );
218

            
219
1
    assert_eq!(meta.applied.get("a->b"), Some(1).as_ref());
220
1
    assert_eq!(meta.applied.get("b->c"), Some(1).as_ref());
221
1
    assert_eq!(meta.applied.get("c->d"), Some(1).as_ref());
222
1
}
223

            
224
// --- Ancestor caching tests ---
225

            
226
#[derive(Debug, Clone, PartialEq, Eq, Uniplate, Hash)]
227
#[uniplate()]
228
enum ArithExpr {
229
    Add(Box<ArithExpr>, Box<ArithExpr>),
230
    Mul(Box<ArithExpr>, Box<ArithExpr>),
231
    Pair(Box<ArithExpr>, Box<ArithExpr>),
232
    Val(i32),
233
}
234

            
235
6
fn val(n: i32) -> ArithExpr {
236
6
    ArithExpr::Val(n)
237
6
}
238
2
fn add(a: ArithExpr, b: ArithExpr) -> ArithExpr {
239
2
    ArithExpr::Add(Box::new(a), Box::new(b))
240
2
}
241
1
fn mul(a: ArithExpr, b: ArithExpr) -> ArithExpr {
242
1
    ArithExpr::Mul(Box::new(a), Box::new(b))
243
1
}
244
2
fn pair(a: ArithExpr, b: ArithExpr) -> ArithExpr {
245
2
    ArithExpr::Pair(Box::new(a), Box::new(b))
246
2
}
247

            
248
14
fn eval_add(_: &mut Commands<ArithExpr, ()>, expr: &ArithExpr, _: &()) -> Option<ArithExpr> {
249
14
    if let ArithExpr::Add(a, b) = expr {
250
2
        if let (ArithExpr::Val(x), ArithExpr::Val(y)) = (a.as_ref(), b.as_ref()) {
251
2
            return Some(ArithExpr::Val(x + y));
252
        }
253
12
    }
254
12
    None
255
14
}
256

            
257
14
fn eval_mul(_: &mut Commands<ArithExpr, ()>, expr: &ArithExpr, _: &()) -> Option<ArithExpr> {
258
14
    if let ArithExpr::Mul(a, b) = expr {
259
3
        if let (ArithExpr::Val(x), ArithExpr::Val(y)) = (a.as_ref(), b.as_ref()) {
260
1
            return Some(ArithExpr::Val(x * y));
261
2
        }
262
11
    }
263
13
    None
264
14
}
265

            
266
/// Ancestor caching: after the first copy of `(1+2)*(3+4)` fully reduces to `Val(21)`,
267
/// the cache maps `(1+2)*(3+4) → Val(21)`. The second identical copy gets a
268
/// single cache hit at the top level — one powerful hit instead of traversing the subtree.
269
#[test]
270
1
fn ancestor_caching() {
271
    static HITS: AtomicU64 = AtomicU64::new(0);
272
    static MISSES: AtomicU64 = AtomicU64::new(0);
273

            
274
7
    fn on_hit(_: &ArithExpr, _: &mut ()) {
275
7
        HITS.fetch_add(1, Ordering::Relaxed);
276
7
    }
277
14
    fn on_miss(_: &ArithExpr, _: &mut ()) {
278
14
        MISSES.fetch_add(1, Ordering::Relaxed);
279
14
    }
280

            
281
    // Pair( (1+2)*(3+4), (1+2)*(3+4) )
282
1
    let subtree = mul(add(val(1), val(2)), add(val(3), val(4)));
283
1
    let tree = pair(subtree.clone(), subtree);
284

            
285
1
    HITS.store(0, Ordering::Relaxed);
286
1
    MISSES.store(0, Ordering::Relaxed);
287

            
288
1
    let mut engine = EngineBuilder::new()
289
1
        .add_rule_group(rule_fns![eval_add, eval_mul])
290
1
        .add_cacher(HashMapCache::<_, StdHashKey>::new())
291
1
        .add_on_cache_hit(on_hit)
292
1
        .add_on_cache_miss(on_miss)
293
1
        .build();
294

            
295
1
    let (result, _) = engine.morph(tree, ());
296
1
    assert_eq!(result, pair(val(21), val(21)));
297

            
298
1
    let hits = HITS.load(Ordering::Relaxed);
299
1
    assert!(
300
1
        hits > 0,
301
        "Expected at least one cache hit from ancestor caching"
302
    );
303
1
}