From 509eabd0205e354a3c916bd80bf7d6a64fb1e581 Mon Sep 17 00:00:00 2001 From: Arvid Norberg Date: Fri, 13 Dec 2024 10:46:27 +0100 Subject: [PATCH] ObjectCache doesn't need the Allocator as member. Passing it in on-demand is more flexible from a borrow-checker point of view --- src/serde/de_br.rs | 4 ++-- src/serde/object_cache.rs | 48 +++++++++++++++++++++++---------------- src/serde/ser_br.rs | 8 +++---- 3 files changed, 35 insertions(+), 25 deletions(-) diff --git a/src/serde/de_br.rs b/src/serde/de_br.rs index ecb74506..2277dd9e 100644 --- a/src/serde/de_br.rs +++ b/src/serde/de_br.rs @@ -126,8 +126,8 @@ mod tests { let buf = Vec::from_hex(serialization_as_hex).unwrap(); let mut allocator = Allocator::new(); let node = node_from_bytes_backrefs(&mut allocator, &buf).unwrap(); - let mut oc = ObjectCache::new(&allocator, treehash); - let calculated_hash = oc.get_or_calculate(&node).unwrap(); + let mut oc = ObjectCache::new(treehash); + let calculated_hash = oc.get_or_calculate(&allocator, &node).unwrap(); let ch: &[u8] = calculated_hash; let expected_hash: Vec = Vec::from_hex(expected_hash_as_hex).unwrap(); assert_eq!(expected_hash, ch); diff --git a/src/serde/object_cache.rs b/src/serde/object_cache.rs index 197c3cf6..72fc78ba 100644 --- a/src/serde/object_cache.rs +++ b/src/serde/object_cache.rs @@ -11,9 +11,8 @@ use std::collections::HashMap; type CachedFunction = fn(&mut ObjectCache, &Allocator, NodePtr) -> Option; use super::bytes32::{hash_blobs, Bytes32}; -pub struct ObjectCache<'a, T> { +pub struct ObjectCache { cache: HashMap, - allocator: &'a Allocator, /// The function `f` is expected to calculate its T value recursively based /// on the T values for the left and right child for a pair. For an atom, the @@ -25,19 +24,18 @@ pub struct ObjectCache<'a, T> { f: CachedFunction, } -impl<'a, T: Clone> ObjectCache<'a, T> { - pub fn new(allocator: &'a Allocator, f: CachedFunction) -> Self { +impl ObjectCache { + pub fn new(f: CachedFunction) -> Self { Self { cache: HashMap::new(), - allocator, f, } } /// return the function value for this node, either from cache /// or by calculating it - pub fn get_or_calculate(&mut self, node: &NodePtr) -> Option<&T> { - self.calculate(node); + pub fn get_or_calculate(&mut self, allocator: &Allocator, node: &NodePtr) -> Option<&T> { + self.calculate(allocator, node); self.get_from_cache(node) } @@ -53,14 +51,14 @@ impl<'a, T: Clone> ObjectCache<'a, T> { /// calculate the function's value for the given node, traversing uncached children /// as necessary - fn calculate(&mut self, root_node: &NodePtr) { + fn calculate(&mut self, allocator: &Allocator, root_node: &NodePtr) { let mut obj_list = vec![*root_node]; while let Some(node) = obj_list.pop() { let v = self.get_from_cache(&node); match v { Some(_) => {} - None => match (self.f)(self, self.allocator, node) { - None => match self.allocator.sexp(node) { + None => match (self.f)(self, allocator, node) { + None => match allocator.sexp(node) { SExp::Pair(left, right) => { obj_list.push(node); obj_list.push(left); @@ -166,21 +164,27 @@ mod tests { let blob: Vec = Vec::from_hex(obj_as_hex).unwrap(); let mut cursor: Cursor<&[u8]> = Cursor::new(&blob); let obj = node_from_stream(&mut allocator, &mut cursor).unwrap(); - let mut oc = ObjectCache::new(&allocator, f); + let mut oc = ObjectCache::new(f); assert_eq!(oc.get_from_cache(&obj), None); - oc.calculate(&obj); + oc.calculate(&allocator, &obj); assert_eq!(oc.get_from_cache(&obj), Some(&expected_value)); - assert_eq!(oc.get_or_calculate(&obj).unwrap().clone(), expected_value); + assert_eq!( + oc.get_or_calculate(&allocator, &obj).unwrap().clone(), + expected_value + ); assert_eq!(oc.get_from_cache(&obj), Some(&expected_value)); // do it again, but the simple way - let mut oc = ObjectCache::new(&allocator, f); - assert_eq!(oc.get_or_calculate(&obj).unwrap().clone(), expected_value); + let mut oc = ObjectCache::new(f); + assert_eq!( + oc.get_or_calculate(&allocator, &obj).unwrap().clone(), + expected_value + ); } #[test] @@ -241,14 +245,20 @@ mod tests { } let expected_value = LIST_SIZE * 2 + 1; - let mut oc = ObjectCache::new(&allocator, serialized_length); - assert_eq!(oc.get_or_calculate(&top).unwrap().clone(), expected_value); + let mut oc = ObjectCache::new(serialized_length); + assert_eq!( + oc.get_or_calculate(&allocator, &top).unwrap().clone(), + expected_value + ); let expected_value = <[u8; 32]>::from_hex( "a168fce695099a30c0745075e6db3722ed7f059e0d7cc4d7e7504e215db5017b", ) .unwrap(); - let mut oc = ObjectCache::new(&allocator, treehash); - assert_eq!(oc.get_or_calculate(&top).unwrap().clone(), expected_value); + let mut oc = ObjectCache::new(treehash); + assert_eq!( + oc.get_or_calculate(&allocator, &top).unwrap().clone(), + expected_value + ); } } diff --git a/src/serde/ser_br.rs b/src/serde/ser_br.rs index 9bc63e04..d92f0a54 100644 --- a/src/serde/ser_br.rs +++ b/src/serde/ser_br.rs @@ -28,18 +28,18 @@ pub fn node_to_stream_backrefs( let mut read_cache_lookup = ReadCacheLookup::new(); - let mut thc = ObjectCache::new(allocator, treehash); - let mut slc = ObjectCache::new(allocator, serialized_length); + let mut thc = ObjectCache::new(treehash); + let mut slc = ObjectCache::new(serialized_length); while let Some(node_to_write) = write_stack.pop() { let op = read_op_stack.pop(); assert!(op == Some(ReadOp::Parse)); let node_serialized_length = *slc - .get_or_calculate(&node_to_write) + .get_or_calculate(allocator, &node_to_write) .expect("couldn't calculate serialized length"); let node_tree_hash = thc - .get_or_calculate(&node_to_write) + .get_or_calculate(allocator, &node_to_write) .expect("can't get treehash"); match read_cache_lookup.find_path(node_tree_hash, node_serialized_length) { Some(path) => {