1
//! Define Caching behaviour for TreeMorph
2
//! This should help out with repetetive and expensive tree operations as well as long chains of
3
//! rules on duplicate subtrees
4

            
5
use std::{
6
    hash::{DefaultHasher, Hash, Hasher},
7
    marker::PhantomData,
8
};
9

            
10
use fxhash::FxHashMap;
11

            
12
/// Return type for RewriteCache
13
/// Due to the nature of Rewriting, there may be repeated subtrees where no rule can be applied.
14
/// In that case, we can compute it once and store it in cache stating no rules applicable.
15
pub enum CacheResult<T> {
16
    /// The Subtree does not exist in cache.
17
    Unknown,
18

            
19
    /// The Subtree exists in cache but no rule is applicable.
20
    Terminal(usize),
21

            
22
    /// The Subtree exists in cache and there is a pre computed value.
23
    Rewrite(T),
24
}
25

            
26
/// Caching for Treemorph.
27
///
28
/// Outward facing API is simple. Given a tree and the rule application level, check the cache
29
/// before attempting rule checks.
30
///
31
/// If successful, insert it into the cache. The next time we see the exact same subtree, we can
32
/// immediately obtain the result without redoing all the hard work of recomputing.
33
pub trait RewriteCache<T> {
34
    /// Get the cached result
35
    fn get(&self, subtree: &T, level: usize) -> CacheResult<T>;
36

            
37
    /// Insert the results into the cache.
38
    /// Note: Any powerful side effects such as changing other parts of the tree or replacing the
39
    /// root should NOT be inserted into the cache.
40
    fn insert(&mut self, from: &T, to: Option<T>, level: usize);
41

            
42
    /// Invalidate any internally cached hash for the given node.
43
    /// This is called on ancestors when a subtree is replaced.
44
    /// The default implementation is a no-op for caches that don't use node-level hash caching.
45
638181
    fn invalidate_node(&self, _node: &T) {}
46

            
47
    /// Invalidate cached hashes for the given node and all its descendants.
48
    /// Called on replacement subtrees after rule application.
49
290888
    fn invalidate_subtree(&self, _node: &T) {}
50

            
51
    /// Returns `false` if this cache never stores anything (e.g. [`NoCache`]).
52
    /// The engine uses this to skip clones that would only feed into a no-op insert.
53
31
    fn is_active(&self) -> bool {
54
31
        true
55
31
    }
56

            
57
    /// Record the hash of an ancestor node before descending into a child.
58
    /// Called by the zipper on every successful `go_down`.
59
95210599
    fn push_ancestor(&mut self, _node: &T) {}
60

            
61
    /// Discard the top ancestor hash after ascending back to a parent.
62
    /// Called by the zipper on every `go_up` during normal traversal.
63
94573787
    fn pop_ancestor(&mut self) {}
64

            
65
    /// Pop the top ancestor hash and insert a mapping from the old ancestor
66
    /// to the new (rebuilt) ancestor at the given level.
67
    /// Called by `mark_dirty_to_root` as it walks up after a replacement.
68
636811
    fn pop_and_map_ancestor(&mut self, _new_ancestor: &T, _level: usize) {}
69
}
70

            
71
impl<T> RewriteCache<T> for Box<dyn RewriteCache<T>> {
72
95501188
    fn get(&self, subtree: &T, level: usize) -> CacheResult<T> {
73
95501188
        (**self).get(subtree, level)
74
95501188
    }
75

            
76
    fn insert(&mut self, from: &T, to: Option<T>, level: usize) {
77
        (**self).insert(from, to, level)
78
    }
79

            
80
638128
    fn invalidate_node(&self, node: &T) {
81
638128
        (**self).invalidate_node(node)
82
638128
    }
83

            
84
290852
    fn invalidate_subtree(&self, node: &T) {
85
290852
        (**self).invalidate_subtree(node)
86
290852
    }
87

            
88
95501286
    fn is_active(&self) -> bool {
89
95501286
        (**self).is_active()
90
95501286
    }
91

            
92
95210538
    fn push_ancestor(&mut self, node: &T) {
93
95210538
        (**self).push_ancestor(node)
94
95210538
    }
95

            
96
94573744
    fn pop_ancestor(&mut self) {
97
94573744
        (**self).pop_ancestor()
98
94573744
    }
99

            
100
636794
    fn pop_and_map_ancestor(&mut self, new_ancestor: &T, level: usize) {
101
636794
        (**self).pop_and_map_ancestor(new_ancestor, level)
102
636794
    }
103
}
104

            
105
/// Disable Caching.
106
///
107
/// This should compile out if statically selected.
108
pub struct NoCache;
109
impl<T> RewriteCache<T> for NoCache {
110
95501286
    fn get(&self, _: &T, _: usize) -> CacheResult<T> {
111
95501286
        CacheResult::Unknown
112
95501286
    }
113

            
114
    fn insert(&mut self, _: &T, _: Option<T>, _: usize) {}
115

            
116
95501383
    fn is_active(&self) -> bool {
117
95501383
        false
118
95501383
    }
119
}
120

            
121
/// Abstracts how a cache computes hash keys and invalidates nodes.
122
///
123
/// Implement this trait to plug different hashing strategies into [`HashMapCache`].
124
pub trait CacheKey<T> {
125
    /// Compute a level-independent hash for `term`.
126
    fn node_hash(term: &T) -> u64;
127

            
128
    /// Combine a node hash with a rule-group level to produce a cache key.
129
139
    fn combine(node_hash: u64, level: usize) -> u64 {
130
139
        let mut hasher = DefaultHasher::new();
131
139
        node_hash.hash(&mut hasher);
132
139
        level.hash(&mut hasher);
133
139
        hasher.finish()
134
139
    }
135

            
136
    /// Compute a cache key for `term` at the given rule application `level`.
137
50
    fn hash(term: &T, level: usize) -> u64 {
138
50
        Self::combine(Self::node_hash(term), level)
139
50
    }
140

            
141
    /// Invalidate any internally cached hash for the given node.
142
    /// The default is a no-op (used by [`StdHashKey`]).
143
34
    fn invalidate(_node: &T) {}
144

            
145
    /// Invalidate cached hashes for the given node and all its descendants.
146
    /// The default is a no-op (used by [`StdHashKey`]).
147
7
    fn invalidate_subtree(_node: &T) {}
148
}
149

            
150
/// Hashing strategy that delegates to the standard [`Hash`] trait.
151
pub struct StdHashKey;
152

            
153
impl<T: Hash> CacheKey<T> for StdHashKey {
154
162
    fn node_hash(term: &T) -> u64 {
155
162
        let mut hasher = DefaultHasher::new();
156
162
        term.hash(&mut hasher);
157
162
        hasher.finish()
158
162
    }
159
}
160

            
161
/// Types with an internally cached hash value.
162
///
163
/// Implementors store a precomputed hash (e.g. in metadata) to avoid rehashing
164
/// entire subtrees on every cache lookup. The cached hash must be invalidated
165
/// whenever the node's content changes.
166
///
167
/// Use [`invalidate_cache`](CacheHashable::invalidate_cache) for single-node
168
/// invalidation (e.g. when walking up ancestors after a replacement), and
169
/// [`invalidate_cache_recursive`](CacheHashable::invalidate_cache_recursive)
170
/// for full-subtree invalidation (e.g. on rule replacement subtrees that may
171
/// contain cloned nodes with stale hashes from `with_children` reassembly).
172
pub trait CacheHashable {
173
    /// Invalidate the cached hash for this node only.
174
    /// Used by `mark_dirty_to_root` when walking up ancestors after a child replacement.
175
    fn invalidate_cache(&self);
176

            
177
    /// Invalidate the cached hash for this node and all descendants.
178
    /// Used on replacement subtrees after rule application to clear stale hashes
179
    /// from cloned-and-reassembled nodes.
180
    fn invalidate_cache_recursive(&self);
181

            
182
    /// Return the cached hash, computing and storing it if not yet cached.
183
    fn get_cached_hash(&self) -> u64;
184

            
185
    /// Compute the hash from scratch, store it, and return it.
186
    fn calculate_hash(&self) -> u64;
187
}
188

            
189
/// Hashing strategy that delegates to [`CacheHashable::get_cached_hash`],
190
/// allowing types with internally cached hashes to avoid rehashing entire subtrees.
191
pub struct CachedHashKey;
192

            
193
impl<T: CacheHashable> CacheKey<T> for CachedHashKey {
194
    fn node_hash(term: &T) -> u64 {
195
        term.get_cached_hash()
196
    }
197

            
198
    fn invalidate(node: &T) {
199
        node.invalidate_cache();
200
    }
201

            
202
    fn invalidate_subtree(node: &T) {
203
        node.invalidate_cache_recursive();
204
    }
205
}
206

            
207
/// RewriteCache implemented with a HashMap, generic over a [`CacheKey`] hashing strategy.
208
///
209
/// Use `HashMapCache<T>` (defaults to [`StdHashKey`]) for standard `Hash` types,
210
/// or `HashMapCache<T, CachedHashKey>` for types implementing [`CacheHashable`].
211
pub struct HashMapCache<T, K = StdHashKey>
212
where
213
    K: CacheKey<T>,
214
    T: Clone,
215
{
216
    map: FxHashMap<u64, Option<T>>,
217
    predecessors: FxHashMap<u64, Vec<u64>>,
218
    ancestor_stack: Vec<u64>,
219
    clean_levels: FxHashMap<u64, usize>,
220
    _key: PhantomData<K>,
221
}
222

            
223
/// Convenience alias for a [`HashMapCache`] using [`CachedHashKey`].
224
pub type CachedHashMapCache<T> = HashMapCache<T, CachedHashKey>;
225

            
226
impl<T, K> HashMapCache<T, K>
227
where
228
    K: CacheKey<T>,
229
    T: Clone,
230
{
231
    /// Creates a new HashMapCache that can be used as a RewriteCache
232
3
    pub fn new() -> Self {
233
3
        Self {
234
3
            map: FxHashMap::default(),
235
3
            predecessors: FxHashMap::default(),
236
3
            ancestor_stack: Vec::new(),
237
3
            clean_levels: FxHashMap::default(),
238
3
            _key: PhantomData,
239
3
        }
240
3
    }
241
}
242

            
243
impl<T, K> Default for HashMapCache<T, K>
244
where
245
    K: CacheKey<T>,
246
    T: Clone,
247
{
248
    fn default() -> Self {
249
        Self::new()
250
    }
251
}
252

            
253
impl<T, K> RewriteCache<T> for HashMapCache<T, K>
254
where
255
    K: CacheKey<T>,
256
    T: Clone,
257
{
258
34
    fn invalidate_node(&self, node: &T) {
259
34
        K::invalidate(node);
260
34
    }
261

            
262
7
    fn invalidate_subtree(&self, node: &T) {
263
7
        K::invalidate_subtree(node);
264
7
    }
265

            
266
48
    fn get(&self, subtree: &T, level: usize) -> CacheResult<T> {
267
48
        let node_hash = K::node_hash(subtree);
268
48
        if let Some(&max_clean) = self.clean_levels.get(&node_hash)
269
9
            && max_clean >= level
270
        {
271
9
            return CacheResult::Terminal(max_clean);
272
39
        }
273

            
274
39
        let hashed = K::combine(node_hash, level);
275

            
276
39
        match self.map.get(&hashed) {
277
31
            None => CacheResult::Unknown,
278
8
            Some(entry) => match entry {
279
8
                Some(res) => CacheResult::Rewrite(res.clone()),
280
                None => CacheResult::Terminal(level),
281
            },
282
        }
283
48
    }
284

            
285
31
    fn insert(&mut self, from: &T, to: Option<T>, level: usize) {
286
31
        let node_hash = K::node_hash(from);
287
31
        let from_hash = K::combine(node_hash, level);
288

            
289
31
        if to.is_none() {
290
24
            self.map.insert(from_hash, None);
291
24
            self.clean_levels
292
24
                .entry(node_hash)
293
24
                .and_modify(|l| *l = (*l).max(level))
294
24
                .or_insert(level);
295
24
            return;
296
7
        }
297

            
298
7
        let to_hash = K::hash(to.as_ref().unwrap(), level);
299

            
300
7
        if from_hash == to_hash {
301
            panic!("From and To have the same Hash - Cycle Detected!");
302
7
        }
303

            
304
7
        if self.map.contains_key(&from_hash) {
305
            panic!("Overriding an existing mapping loses transitive closure.");
306
7
        }
307

            
308
        // Forward Resolution
309
7
        let to = match self.map.get(&to_hash) {
310
            Some(stored) => stored.clone(),
311
7
            None => to,
312
        };
313

            
314
7
        let to_hash = match &to {
315
7
            Some(resolved) => K::hash(resolved, level),
316
            None => {
317
                self.map.insert(from_hash, None);
318
                return;
319
            }
320
        };
321

            
322
7
        self.map.insert(from_hash, to.clone());
323

            
324
7
        if let Some(mut predecessors) = self.predecessors.remove(&from_hash) {
325
5
            for &dependant in &predecessors {
326
5
                self.map.insert(dependant, to.clone());
327
5
            }
328

            
329
3
            self.predecessors
330
3
                .entry(to_hash)
331
3
                .or_default()
332
3
                .append(&mut predecessors);
333
4
        }
334

            
335
7
        self.predecessors
336
7
            .entry(to_hash)
337
7
            .or_default()
338
7
            .push(from_hash);
339
31
    }
340

            
341
33
    fn push_ancestor(&mut self, node: &T) {
342
33
        self.ancestor_stack.push(K::node_hash(node));
343
33
    }
344

            
345
14
    fn pop_ancestor(&mut self) {
346
14
        self.ancestor_stack.pop();
347
14
    }
348

            
349
19
    fn pop_and_map_ancestor(&mut self, new_ancestor: &T, level: usize) {
350
19
        if let Some(old_node_hash) = self.ancestor_stack.pop() {
351
19
            let old_key = K::combine(old_node_hash, level);
352
19
            let new_key = K::hash(new_ancestor, level);
353

            
354
            // No change at this ancestor level
355
19
            if old_key == new_key {
356
                return;
357
19
            }
358

            
359
            // If old_key has a rewrite mapping, don't override (preserves transitive closure).
360
            // But DO override terminal entries
361
19
            if let Some(existing) = self.map.get(&old_key) {
362
19
                if existing.is_some() {
363
2
                    return;
364
17
                }
365
                // Remove the stale terminal entry so we can insert the ancestor mapping
366
17
                self.map.remove(&old_key);
367
            }
368

            
369
            // Forward resolution
370
17
            let to = match self.map.get(&new_key) {
371
                Some(stored) => stored.clone(),
372
17
                None => Some(new_ancestor.clone()),
373
            };
374

            
375
17
            let to_key = match &to {
376
17
                Some(resolved) => K::hash(resolved, level),
377
                None => {
378
                    self.map.insert(old_key, None);
379
                    return;
380
                }
381
            };
382

            
383
17
            self.map.insert(old_key, to.clone());
384

            
385
            // Predecessor tracking
386
17
            if let Some(mut preds) = self.predecessors.remove(&old_key) {
387
32
                for &dep in &preds {
388
32
                    self.map.insert(dep, to.clone());
389
32
                }
390
13
                self.predecessors
391
13
                    .entry(to_key)
392
13
                    .or_default()
393
13
                    .append(&mut preds);
394
4
            }
395

            
396
17
            self.predecessors.entry(to_key).or_default().push(old_key);
397
        }
398
19
    }
399
}