Skip to content

Commit

Permalink
Merge pull request #511 from Chia-Network/object-cache-allocator
Browse files Browse the repository at this point in the history
ObjectCache doesn't need the Allocator as member
  • Loading branch information
arvidn authored Dec 14, 2024
2 parents 734207b + 509eabd commit 210d7e5
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 25 deletions.
4 changes: 2 additions & 2 deletions src/serde/de_br.rs
Original file line number Diff line number Diff line change
Expand Up @@ -126,8 +126,8 @@ mod tests {
let buf = Vec::from_hex(serialization_as_hex).unwrap();
let mut allocator = Allocator::new();
let node = node_from_bytes_backrefs(&mut allocator, &buf).unwrap();
let mut oc = ObjectCache::new(&allocator, treehash);
let calculated_hash = oc.get_or_calculate(&node).unwrap();
let mut oc = ObjectCache::new(treehash);
let calculated_hash = oc.get_or_calculate(&allocator, &node).unwrap();
let ch: &[u8] = calculated_hash;
let expected_hash: Vec<u8> = Vec::from_hex(expected_hash_as_hex).unwrap();
assert_eq!(expected_hash, ch);
Expand Down
48 changes: 29 additions & 19 deletions src/serde/object_cache.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,8 @@ use std::collections::HashMap;
type CachedFunction<T> = fn(&mut ObjectCache<T>, &Allocator, NodePtr) -> Option<T>;
use super::bytes32::{hash_blobs, Bytes32};

pub struct ObjectCache<'a, T> {
pub struct ObjectCache<T> {
cache: HashMap<NodePtr, T>,
allocator: &'a Allocator,

/// The function `f` is expected to calculate its T value recursively based
/// on the T values for the left and right child for a pair. For an atom, the
Expand All @@ -25,19 +24,18 @@ pub struct ObjectCache<'a, T> {
f: CachedFunction<T>,
}

impl<'a, T: Clone> ObjectCache<'a, T> {
pub fn new(allocator: &'a Allocator, f: CachedFunction<T>) -> Self {
impl<T: Clone> ObjectCache<T> {
pub fn new(f: CachedFunction<T>) -> Self {
Self {
cache: HashMap::new(),
allocator,
f,
}
}

/// return the function value for this node, either from cache
/// or by calculating it
pub fn get_or_calculate(&mut self, node: &NodePtr) -> Option<&T> {
self.calculate(node);
pub fn get_or_calculate(&mut self, allocator: &Allocator, node: &NodePtr) -> Option<&T> {
self.calculate(allocator, node);
self.get_from_cache(node)
}

Expand All @@ -53,14 +51,14 @@ impl<'a, T: Clone> ObjectCache<'a, T> {

/// calculate the function's value for the given node, traversing uncached children
/// as necessary
fn calculate(&mut self, root_node: &NodePtr) {
fn calculate(&mut self, allocator: &Allocator, root_node: &NodePtr) {
let mut obj_list = vec![*root_node];
while let Some(node) = obj_list.pop() {
let v = self.get_from_cache(&node);
match v {
Some(_) => {}
None => match (self.f)(self, self.allocator, node) {
None => match self.allocator.sexp(node) {
None => match (self.f)(self, allocator, node) {
None => match allocator.sexp(node) {
SExp::Pair(left, right) => {
obj_list.push(node);
obj_list.push(left);
Expand Down Expand Up @@ -166,21 +164,27 @@ mod tests {
let blob: Vec<u8> = Vec::from_hex(obj_as_hex).unwrap();
let mut cursor: Cursor<&[u8]> = Cursor::new(&blob);
let obj = node_from_stream(&mut allocator, &mut cursor).unwrap();
let mut oc = ObjectCache::new(&allocator, f);
let mut oc = ObjectCache::new(f);

assert_eq!(oc.get_from_cache(&obj), None);

oc.calculate(&obj);
oc.calculate(&allocator, &obj);

assert_eq!(oc.get_from_cache(&obj), Some(&expected_value));

assert_eq!(oc.get_or_calculate(&obj).unwrap().clone(), expected_value);
assert_eq!(
oc.get_or_calculate(&allocator, &obj).unwrap().clone(),
expected_value
);

assert_eq!(oc.get_from_cache(&obj), Some(&expected_value));

// do it again, but the simple way
let mut oc = ObjectCache::new(&allocator, f);
assert_eq!(oc.get_or_calculate(&obj).unwrap().clone(), expected_value);
let mut oc = ObjectCache::new(f);
assert_eq!(
oc.get_or_calculate(&allocator, &obj).unwrap().clone(),
expected_value
);
}

#[test]
Expand Down Expand Up @@ -241,14 +245,20 @@ mod tests {
}

let expected_value = LIST_SIZE * 2 + 1;
let mut oc = ObjectCache::new(&allocator, serialized_length);
assert_eq!(oc.get_or_calculate(&top).unwrap().clone(), expected_value);
let mut oc = ObjectCache::new(serialized_length);
assert_eq!(
oc.get_or_calculate(&allocator, &top).unwrap().clone(),
expected_value
);

let expected_value = <[u8; 32]>::from_hex(
"a168fce695099a30c0745075e6db3722ed7f059e0d7cc4d7e7504e215db5017b",
)
.unwrap();
let mut oc = ObjectCache::new(&allocator, treehash);
assert_eq!(oc.get_or_calculate(&top).unwrap().clone(), expected_value);
let mut oc = ObjectCache::new(treehash);
assert_eq!(
oc.get_or_calculate(&allocator, &top).unwrap().clone(),
expected_value
);
}
}
8 changes: 4 additions & 4 deletions src/serde/ser_br.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,18 +28,18 @@ pub fn node_to_stream_backrefs<W: io::Write>(

let mut read_cache_lookup = ReadCacheLookup::new();

let mut thc = ObjectCache::new(allocator, treehash);
let mut slc = ObjectCache::new(allocator, serialized_length);
let mut thc = ObjectCache::new(treehash);
let mut slc = ObjectCache::new(serialized_length);

while let Some(node_to_write) = write_stack.pop() {
let op = read_op_stack.pop();
assert!(op == Some(ReadOp::Parse));

let node_serialized_length = *slc
.get_or_calculate(&node_to_write)
.get_or_calculate(allocator, &node_to_write)
.expect("couldn't calculate serialized length");
let node_tree_hash = thc
.get_or_calculate(&node_to_write)
.get_or_calculate(allocator, &node_to_write)
.expect("can't get treehash");
match read_cache_lookup.find_path(node_tree_hash, node_serialized_length) {
Some(path) => {
Expand Down

0 comments on commit 210d7e5

Please sign in to comment.