Skip to main content

tree_morph/
cache.rs

1//! Define Caching behaviour for TreeMorph
2//! This should help out with repetetive and expensive tree operations as well as long chains of
3//! rules on duplicate subtrees
4
5use std::{
6    hash::{DefaultHasher, Hash, Hasher},
7    marker::PhantomData,
8};
9
10use fxhash::FxHashMap;
11
12/// Return type for RewriteCache
13/// Due to the nature of Rewriting, there may be repeated subtrees where no rule can be applied.
14/// In that case, we can compute it once and store it in cache stating no rules applicable.
15pub enum CacheResult<T> {
16    /// The Subtree does not exist in cache.
17    Unknown,
18
19    /// The Subtree exists in cache but no rule is applicable.
20    Terminal(usize),
21
22    /// The Subtree exists in cache and there is a pre computed value.
23    Rewrite(T),
24}
25
26/// Caching for Treemorph.
27///
28/// Outward facing API is simple. Given a tree and the rule application level, check the cache
29/// before attempting rule checks.
30///
31/// If successful, insert it into the cache. The next time we see the exact same subtree, we can
32/// immediately obtain the result without redoing all the hard work of recomputing.
33pub trait RewriteCache<T> {
34    /// Get the cached result
35    fn get(&self, subtree: &T, level: usize) -> CacheResult<T>;
36
37    /// Insert the results into the cache.
38    /// Note: Any powerful side effects such as changing other parts of the tree or replacing the
39    /// root should NOT be inserted into the cache.
40    fn insert(&mut self, from: &T, to: Option<T>, level: usize);
41
42    /// Invalidate any internally cached hash for the given node.
43    /// This is called on ancestors when a subtree is replaced.
44    /// The default implementation is a no-op for caches that don't use node-level hash caching.
45    fn invalidate_node(&self, _node: &T) {}
46
47    /// Invalidate cached hashes for the given node and all its descendants.
48    /// Called on replacement subtrees after rule application.
49    fn invalidate_subtree(&self, _node: &T) {}
50
51    /// Returns `false` if this cache never stores anything (e.g. [`NoCache`]).
52    /// The engine uses this to skip clones that would only feed into a no-op insert.
53    fn is_active(&self) -> bool {
54        true
55    }
56
57    /// Record the hash of an ancestor node before descending into a child.
58    /// Called by the zipper on every successful `go_down`.
59    fn push_ancestor(&mut self, _node: &T) {}
60
61    /// Discard the top ancestor hash after ascending back to a parent.
62    /// Called by the zipper on every `go_up` during normal traversal.
63    fn pop_ancestor(&mut self) {}
64
65    /// Pop the top ancestor hash and insert a mapping from the old ancestor
66    /// to the new (rebuilt) ancestor at the given level.
67    /// Called by `mark_dirty_to_root` as it walks up after a replacement.
68    fn pop_and_map_ancestor(&mut self, _new_ancestor: &T, _level: usize) {}
69}
70
71impl<T> RewriteCache<T> for Box<dyn RewriteCache<T>> {
72    fn get(&self, subtree: &T, level: usize) -> CacheResult<T> {
73        (**self).get(subtree, level)
74    }
75
76    fn insert(&mut self, from: &T, to: Option<T>, level: usize) {
77        (**self).insert(from, to, level)
78    }
79
80    fn invalidate_node(&self, node: &T) {
81        (**self).invalidate_node(node)
82    }
83
84    fn invalidate_subtree(&self, node: &T) {
85        (**self).invalidate_subtree(node)
86    }
87
88    fn is_active(&self) -> bool {
89        (**self).is_active()
90    }
91
92    fn push_ancestor(&mut self, node: &T) {
93        (**self).push_ancestor(node)
94    }
95
96    fn pop_ancestor(&mut self) {
97        (**self).pop_ancestor()
98    }
99
100    fn pop_and_map_ancestor(&mut self, new_ancestor: &T, level: usize) {
101        (**self).pop_and_map_ancestor(new_ancestor, level)
102    }
103}
104
105/// Disable Caching.
106///
107/// This should compile out if statically selected.
108pub struct NoCache;
109impl<T> RewriteCache<T> for NoCache {
110    fn get(&self, _: &T, _: usize) -> CacheResult<T> {
111        CacheResult::Unknown
112    }
113
114    fn insert(&mut self, _: &T, _: Option<T>, _: usize) {}
115
116    fn is_active(&self) -> bool {
117        false
118    }
119}
120
121/// Abstracts how a cache computes hash keys and invalidates nodes.
122///
123/// Implement this trait to plug different hashing strategies into [`HashMapCache`].
124pub trait CacheKey<T> {
125    /// Compute a level-independent hash for `term`.
126    fn node_hash(term: &T) -> u64;
127
128    /// Combine a node hash with a rule-group level to produce a cache key.
129    fn combine(node_hash: u64, level: usize) -> u64 {
130        let mut hasher = DefaultHasher::new();
131        node_hash.hash(&mut hasher);
132        level.hash(&mut hasher);
133        hasher.finish()
134    }
135
136    /// Compute a cache key for `term` at the given rule application `level`.
137    fn hash(term: &T, level: usize) -> u64 {
138        Self::combine(Self::node_hash(term), level)
139    }
140
141    /// Invalidate any internally cached hash for the given node.
142    /// The default is a no-op (used by [`StdHashKey`]).
143    fn invalidate(_node: &T) {}
144
145    /// Invalidate cached hashes for the given node and all its descendants.
146    /// The default is a no-op (used by [`StdHashKey`]).
147    fn invalidate_subtree(_node: &T) {}
148}
149
150/// Hashing strategy that delegates to the standard [`Hash`] trait.
151pub struct StdHashKey;
152
153impl<T: Hash> CacheKey<T> for StdHashKey {
154    fn node_hash(term: &T) -> u64 {
155        let mut hasher = DefaultHasher::new();
156        term.hash(&mut hasher);
157        hasher.finish()
158    }
159}
160
161/// Types with an internally cached hash value.
162///
163/// Implementors store a precomputed hash (e.g. in metadata) to avoid rehashing
164/// entire subtrees on every cache lookup. The cached hash must be invalidated
165/// whenever the node's content changes.
166///
167/// Use [`invalidate_cache`](CacheHashable::invalidate_cache) for single-node
168/// invalidation (e.g. when walking up ancestors after a replacement), and
169/// [`invalidate_cache_recursive`](CacheHashable::invalidate_cache_recursive)
170/// for full-subtree invalidation (e.g. on rule replacement subtrees that may
171/// contain cloned nodes with stale hashes from `with_children` reassembly).
172pub trait CacheHashable {
173    /// Invalidate the cached hash for this node only.
174    /// Used by `mark_dirty_to_root` when walking up ancestors after a child replacement.
175    fn invalidate_cache(&self);
176
177    /// Invalidate the cached hash for this node and all descendants.
178    /// Used on replacement subtrees after rule application to clear stale hashes
179    /// from cloned-and-reassembled nodes.
180    fn invalidate_cache_recursive(&self);
181
182    /// Return the cached hash, computing and storing it if not yet cached.
183    fn get_cached_hash(&self) -> u64;
184
185    /// Compute the hash from scratch, store it, and return it.
186    fn calculate_hash(&self) -> u64;
187}
188
189/// Hashing strategy that delegates to [`CacheHashable::get_cached_hash`],
190/// allowing types with internally cached hashes to avoid rehashing entire subtrees.
191pub struct CachedHashKey;
192
193impl<T: CacheHashable> CacheKey<T> for CachedHashKey {
194    fn node_hash(term: &T) -> u64 {
195        term.get_cached_hash()
196    }
197
198    fn invalidate(node: &T) {
199        node.invalidate_cache();
200    }
201
202    fn invalidate_subtree(node: &T) {
203        node.invalidate_cache_recursive();
204    }
205}
206
207/// RewriteCache implemented with a HashMap, generic over a [`CacheKey`] hashing strategy.
208///
209/// Use `HashMapCache<T>` (defaults to [`StdHashKey`]) for standard `Hash` types,
210/// or `HashMapCache<T, CachedHashKey>` for types implementing [`CacheHashable`].
211pub struct HashMapCache<T, K = StdHashKey>
212where
213    K: CacheKey<T>,
214    T: Clone,
215{
216    map: FxHashMap<u64, Option<T>>,
217    predecessors: FxHashMap<u64, Vec<u64>>,
218    ancestor_stack: Vec<u64>,
219    clean_levels: FxHashMap<u64, usize>,
220    _key: PhantomData<K>,
221}
222
223/// Convenience alias for a [`HashMapCache`] using [`CachedHashKey`].
224pub type CachedHashMapCache<T> = HashMapCache<T, CachedHashKey>;
225
226impl<T, K> HashMapCache<T, K>
227where
228    K: CacheKey<T>,
229    T: Clone,
230{
231    /// Creates a new HashMapCache that can be used as a RewriteCache
232    pub fn new() -> Self {
233        Self {
234            map: FxHashMap::default(),
235            predecessors: FxHashMap::default(),
236            ancestor_stack: Vec::new(),
237            clean_levels: FxHashMap::default(),
238            _key: PhantomData,
239        }
240    }
241}
242
243impl<T, K> Default for HashMapCache<T, K>
244where
245    K: CacheKey<T>,
246    T: Clone,
247{
248    fn default() -> Self {
249        Self::new()
250    }
251}
252
253impl<T, K> RewriteCache<T> for HashMapCache<T, K>
254where
255    K: CacheKey<T>,
256    T: Clone,
257{
258    fn invalidate_node(&self, node: &T) {
259        K::invalidate(node);
260    }
261
262    fn invalidate_subtree(&self, node: &T) {
263        K::invalidate_subtree(node);
264    }
265
266    fn get(&self, subtree: &T, level: usize) -> CacheResult<T> {
267        let node_hash = K::node_hash(subtree);
268        if let Some(&max_clean) = self.clean_levels.get(&node_hash)
269            && max_clean >= level
270        {
271            return CacheResult::Terminal(max_clean);
272        }
273
274        let hashed = K::combine(node_hash, level);
275
276        match self.map.get(&hashed) {
277            None => CacheResult::Unknown,
278            Some(entry) => match entry {
279                Some(res) => CacheResult::Rewrite(res.clone()),
280                None => CacheResult::Terminal(level),
281            },
282        }
283    }
284
285    fn insert(&mut self, from: &T, to: Option<T>, level: usize) {
286        let node_hash = K::node_hash(from);
287        let from_hash = K::combine(node_hash, level);
288
289        if to.is_none() {
290            self.map.insert(from_hash, None);
291            self.clean_levels
292                .entry(node_hash)
293                .and_modify(|l| *l = (*l).max(level))
294                .or_insert(level);
295            return;
296        }
297
298        let to_hash = K::hash(to.as_ref().unwrap(), level);
299
300        if from_hash == to_hash {
301            panic!("From and To have the same Hash - Cycle Detected!");
302        }
303
304        if self.map.contains_key(&from_hash) {
305            panic!("Overriding an existing mapping loses transitive closure.");
306        }
307
308        // Forward Resolution
309        let to = match self.map.get(&to_hash) {
310            Some(stored) => stored.clone(),
311            None => to,
312        };
313
314        let to_hash = match &to {
315            Some(resolved) => K::hash(resolved, level),
316            None => {
317                self.map.insert(from_hash, None);
318                return;
319            }
320        };
321
322        self.map.insert(from_hash, to.clone());
323
324        if let Some(mut predecessors) = self.predecessors.remove(&from_hash) {
325            for &dependant in &predecessors {
326                self.map.insert(dependant, to.clone());
327            }
328
329            self.predecessors
330                .entry(to_hash)
331                .or_default()
332                .append(&mut predecessors);
333        }
334
335        self.predecessors
336            .entry(to_hash)
337            .or_default()
338            .push(from_hash);
339    }
340
341    fn push_ancestor(&mut self, node: &T) {
342        self.ancestor_stack.push(K::node_hash(node));
343    }
344
345    fn pop_ancestor(&mut self) {
346        self.ancestor_stack.pop();
347    }
348
349    fn pop_and_map_ancestor(&mut self, new_ancestor: &T, level: usize) {
350        if let Some(old_node_hash) = self.ancestor_stack.pop() {
351            let old_key = K::combine(old_node_hash, level);
352            let new_key = K::hash(new_ancestor, level);
353
354            // No change at this ancestor level
355            if old_key == new_key {
356                return;
357            }
358
359            // If old_key has a rewrite mapping, don't override (preserves transitive closure).
360            // But DO override terminal entries
361            if let Some(existing) = self.map.get(&old_key) {
362                if existing.is_some() {
363                    return;
364                }
365                // Remove the stale terminal entry so we can insert the ancestor mapping
366                self.map.remove(&old_key);
367            }
368
369            // Forward resolution
370            let to = match self.map.get(&new_key) {
371                Some(stored) => stored.clone(),
372                None => Some(new_ancestor.clone()),
373            };
374
375            let to_key = match &to {
376                Some(resolved) => K::hash(resolved, level),
377                None => {
378                    self.map.insert(old_key, None);
379                    return;
380                }
381            };
382
383            self.map.insert(old_key, to.clone());
384
385            // Predecessor tracking
386            if let Some(mut preds) = self.predecessors.remove(&old_key) {
387                for &dep in &preds {
388                    self.map.insert(dep, to.clone());
389                }
390                self.predecessors
391                    .entry(to_key)
392                    .or_default()
393                    .append(&mut preds);
394            }
395
396            self.predecessors.entry(to_key).or_default().push(old_key);
397        }
398    }
399}