1use std::{
6 hash::{DefaultHasher, Hash, Hasher},
7 marker::PhantomData,
8};
9
10use fxhash::FxHashMap;
11
12pub enum CacheResult<T> {
16 Unknown,
18
19 Terminal(usize),
21
22 Rewrite(T),
24}
25
26pub trait RewriteCache<T> {
34 fn get(&self, subtree: &T, level: usize) -> CacheResult<T>;
36
37 fn insert(&mut self, from: &T, to: Option<T>, level: usize);
41
42 fn invalidate_node(&self, _node: &T) {}
46
47 fn invalidate_subtree(&self, _node: &T) {}
50
51 fn is_active(&self) -> bool {
54 true
55 }
56
57 fn push_ancestor(&mut self, _node: &T) {}
60
61 fn pop_ancestor(&mut self) {}
64
65 fn pop_and_map_ancestor(&mut self, _new_ancestor: &T, _level: usize) {}
69}
70
71impl<T> RewriteCache<T> for Box<dyn RewriteCache<T>> {
72 fn get(&self, subtree: &T, level: usize) -> CacheResult<T> {
73 (**self).get(subtree, level)
74 }
75
76 fn insert(&mut self, from: &T, to: Option<T>, level: usize) {
77 (**self).insert(from, to, level)
78 }
79
80 fn invalidate_node(&self, node: &T) {
81 (**self).invalidate_node(node)
82 }
83
84 fn invalidate_subtree(&self, node: &T) {
85 (**self).invalidate_subtree(node)
86 }
87
88 fn is_active(&self) -> bool {
89 (**self).is_active()
90 }
91
92 fn push_ancestor(&mut self, node: &T) {
93 (**self).push_ancestor(node)
94 }
95
96 fn pop_ancestor(&mut self) {
97 (**self).pop_ancestor()
98 }
99
100 fn pop_and_map_ancestor(&mut self, new_ancestor: &T, level: usize) {
101 (**self).pop_and_map_ancestor(new_ancestor, level)
102 }
103}
104
105pub struct NoCache;
109impl<T> RewriteCache<T> for NoCache {
110 fn get(&self, _: &T, _: usize) -> CacheResult<T> {
111 CacheResult::Unknown
112 }
113
114 fn insert(&mut self, _: &T, _: Option<T>, _: usize) {}
115
116 fn is_active(&self) -> bool {
117 false
118 }
119}
120
121pub trait CacheKey<T> {
125 fn node_hash(term: &T) -> u64;
127
128 fn combine(node_hash: u64, level: usize) -> u64 {
130 let mut hasher = DefaultHasher::new();
131 node_hash.hash(&mut hasher);
132 level.hash(&mut hasher);
133 hasher.finish()
134 }
135
136 fn hash(term: &T, level: usize) -> u64 {
138 Self::combine(Self::node_hash(term), level)
139 }
140
141 fn invalidate(_node: &T) {}
144
145 fn invalidate_subtree(_node: &T) {}
148}
149
150pub struct StdHashKey;
152
153impl<T: Hash> CacheKey<T> for StdHashKey {
154 fn node_hash(term: &T) -> u64 {
155 let mut hasher = DefaultHasher::new();
156 term.hash(&mut hasher);
157 hasher.finish()
158 }
159}
160
161pub trait CacheHashable {
173 fn invalidate_cache(&self);
176
177 fn invalidate_cache_recursive(&self);
181
182 fn get_cached_hash(&self) -> u64;
184
185 fn calculate_hash(&self) -> u64;
187}
188
189pub struct CachedHashKey;
192
193impl<T: CacheHashable> CacheKey<T> for CachedHashKey {
194 fn node_hash(term: &T) -> u64 {
195 term.get_cached_hash()
196 }
197
198 fn invalidate(node: &T) {
199 node.invalidate_cache();
200 }
201
202 fn invalidate_subtree(node: &T) {
203 node.invalidate_cache_recursive();
204 }
205}
206
207pub struct HashMapCache<T, K = StdHashKey>
212where
213 K: CacheKey<T>,
214 T: Clone,
215{
216 map: FxHashMap<u64, Option<T>>,
217 predecessors: FxHashMap<u64, Vec<u64>>,
218 ancestor_stack: Vec<u64>,
219 clean_levels: FxHashMap<u64, usize>,
220 _key: PhantomData<K>,
221}
222
223pub type CachedHashMapCache<T> = HashMapCache<T, CachedHashKey>;
225
226impl<T, K> HashMapCache<T, K>
227where
228 K: CacheKey<T>,
229 T: Clone,
230{
231 pub fn new() -> Self {
233 Self {
234 map: FxHashMap::default(),
235 predecessors: FxHashMap::default(),
236 ancestor_stack: Vec::new(),
237 clean_levels: FxHashMap::default(),
238 _key: PhantomData,
239 }
240 }
241}
242
243impl<T, K> Default for HashMapCache<T, K>
244where
245 K: CacheKey<T>,
246 T: Clone,
247{
248 fn default() -> Self {
249 Self::new()
250 }
251}
252
253impl<T, K> RewriteCache<T> for HashMapCache<T, K>
254where
255 K: CacheKey<T>,
256 T: Clone,
257{
258 fn invalidate_node(&self, node: &T) {
259 K::invalidate(node);
260 }
261
262 fn invalidate_subtree(&self, node: &T) {
263 K::invalidate_subtree(node);
264 }
265
266 fn get(&self, subtree: &T, level: usize) -> CacheResult<T> {
267 let node_hash = K::node_hash(subtree);
268 if let Some(&max_clean) = self.clean_levels.get(&node_hash)
269 && max_clean >= level
270 {
271 return CacheResult::Terminal(max_clean);
272 }
273
274 let hashed = K::combine(node_hash, level);
275
276 match self.map.get(&hashed) {
277 None => CacheResult::Unknown,
278 Some(entry) => match entry {
279 Some(res) => CacheResult::Rewrite(res.clone()),
280 None => CacheResult::Terminal(level),
281 },
282 }
283 }
284
285 fn insert(&mut self, from: &T, to: Option<T>, level: usize) {
286 let node_hash = K::node_hash(from);
287 let from_hash = K::combine(node_hash, level);
288
289 if to.is_none() {
290 self.map.insert(from_hash, None);
291 self.clean_levels
292 .entry(node_hash)
293 .and_modify(|l| *l = (*l).max(level))
294 .or_insert(level);
295 return;
296 }
297
298 let to_hash = K::hash(to.as_ref().unwrap(), level);
299
300 if from_hash == to_hash {
301 panic!("From and To have the same Hash - Cycle Detected!");
302 }
303
304 if self.map.contains_key(&from_hash) {
305 panic!("Overriding an existing mapping loses transitive closure.");
306 }
307
308 let to = match self.map.get(&to_hash) {
310 Some(stored) => stored.clone(),
311 None => to,
312 };
313
314 let to_hash = match &to {
315 Some(resolved) => K::hash(resolved, level),
316 None => {
317 self.map.insert(from_hash, None);
318 return;
319 }
320 };
321
322 self.map.insert(from_hash, to.clone());
323
324 if let Some(mut predecessors) = self.predecessors.remove(&from_hash) {
325 for &dependant in &predecessors {
326 self.map.insert(dependant, to.clone());
327 }
328
329 self.predecessors
330 .entry(to_hash)
331 .or_default()
332 .append(&mut predecessors);
333 }
334
335 self.predecessors
336 .entry(to_hash)
337 .or_default()
338 .push(from_hash);
339 }
340
341 fn push_ancestor(&mut self, node: &T) {
342 self.ancestor_stack.push(K::node_hash(node));
343 }
344
345 fn pop_ancestor(&mut self) {
346 self.ancestor_stack.pop();
347 }
348
349 fn pop_and_map_ancestor(&mut self, new_ancestor: &T, level: usize) {
350 if let Some(old_node_hash) = self.ancestor_stack.pop() {
351 let old_key = K::combine(old_node_hash, level);
352 let new_key = K::hash(new_ancestor, level);
353
354 if old_key == new_key {
356 return;
357 }
358
359 if let Some(existing) = self.map.get(&old_key) {
362 if existing.is_some() {
363 return;
364 }
365 self.map.remove(&old_key);
367 }
368
369 let to = match self.map.get(&new_key) {
371 Some(stored) => stored.clone(),
372 None => Some(new_ancestor.clone()),
373 };
374
375 let to_key = match &to {
376 Some(resolved) => K::hash(resolved, level),
377 None => {
378 self.map.insert(old_key, None);
379 return;
380 }
381 };
382
383 self.map.insert(old_key, to.clone());
384
385 if let Some(mut preds) = self.predecessors.remove(&old_key) {
387 for &dep in &preds {
388 self.map.insert(dep, to.clone());
389 }
390 self.predecessors
391 .entry(to_key)
392 .or_default()
393 .append(&mut preds);
394 }
395
396 self.predecessors.entry(to_key).or_default().push(old_key);
397 }
398 }
399}