diff --git a/firewood/src/db.rs b/firewood/src/db.rs index 013f24c8c..b278b564a 100644 --- a/firewood/src/db.rs +++ b/firewood/src/db.rs @@ -320,14 +320,15 @@ impl + Send + Sync> api::DbView for DbRev { impl + Send + Sync> DbRev { pub fn stream(&self) -> merkle::MerkleKeyValueStream<'_, S, Bincode> { - self.merkle.iter(self.header.kv_root) + self.merkle.key_value_iter(self.header.kv_root) } pub fn stream_from( &self, start_key: Box<[u8]>, ) -> merkle::MerkleKeyValueStream<'_, S, Bincode> { - self.merkle.iter_from(self.header.kv_root, start_key) + self.merkle + .key_value_iter_from_key(self.header.kv_root, start_key) } fn flush_dirty(&mut self) -> Option<()> { diff --git a/firewood/src/merkle.rs b/firewood/src/merkle.rs index 25a9aea33..56f89a9a0 100644 --- a/firewood/src/merkle.rs +++ b/firewood/src/merkle.rs @@ -1715,11 +1715,11 @@ impl + Send + Sync, T> Merkle { self.store.flush_dirty() } - pub(crate) fn iter(&self, root: DiskAddress) -> MerkleKeyValueStream<'_, S, T> { + pub(crate) fn key_value_iter(&self, root: DiskAddress) -> MerkleKeyValueStream<'_, S, T> { MerkleKeyValueStream::new(self, root) } - pub(crate) fn iter_from( + pub(crate) fn key_value_iter_from_key( &self, root: DiskAddress, key: Box<[u8]>, @@ -1750,8 +1750,10 @@ impl + Send + Sync, T> Merkle { let mut stream = match first_key { // TODO: fix the call-site to force the caller to do the allocation - Some(key) => self.iter_from(root, key.as_ref().to_vec().into_boxed_slice()), - None => self.iter(root), + Some(key) => { + self.key_value_iter_from_key(root, key.as_ref().to_vec().into_boxed_slice()) + } + None => self.key_value_iter(root), }; // fetch the first key from the stream diff --git a/firewood/src/merkle/stream.rs b/firewood/src/merkle/stream.rs index 6cd2c7f2a..e035ef69b 100644 --- a/firewood/src/merkle/stream.rs +++ b/firewood/src/merkle/stream.rs @@ -7,399 +7,434 @@ use crate::{ shale::{DiskAddress, ShaleStore}, v2::api, }; -use futures::{stream::FusedStream, Stream}; -use helper_types::{Either, MustUse}; +use futures::{stream::FusedStream, Stream, StreamExt}; use std::task::Poll; +use std::{cmp::Ordering, iter::once}; type Key = Box<[u8]>; type Value = Vec; -enum IteratorState<'a> { - /// Start iterating at the specified key - StartAtKey(Key), - /// Continue iterating after the last node in the `visited_node_path` - Iterating { - check_child_nibble: bool, - visited_node_path: Vec<(NodeObjRef<'a>, u8)>, +/// Represents an ongoing iteration over a node and its children. +enum IterationNode<'a> { + /// This node has not been returned yet. + Unvisited { + /// The key (as nibbles) of this node. + key: Key, + node: NodeObjRef<'a>, + }, + /// This node has been returned. Track which child to visit next. + Visited { + /// The key (as nibbles) of this node. + key: Key, + /// Returns the non-empty children of this node and their positions + /// in the node's children array. + children_iter: Box + Send>, }, } -impl IteratorState<'_> { - fn new() -> Self { - Self::StartAtKey(vec![].into_boxed_slice()) - } +enum NodeStreamState<'a> { + /// The iterator state is lazily initialized when poll_next is called + /// for the first time. The iteration start key is stored here. + StartFromKey(Key), + Iterating { + /// Each element is a node that will be visited (i.e. returned) + /// or has been visited but has unvisited children. + /// On each call to poll_next we pop the next element. + /// If it's unvisited, we visit it. + /// If it's visited, we push its next child onto this stack. + iter_stack: Vec>, + }, +} - fn with_key(key: Key) -> Self { - Self::StartAtKey(key) +impl NodeStreamState<'_> { + fn new(key: Key) -> Self { + Self::StartFromKey(key) } } -/// A MerkleKeyValueStream iterates over keys/values for a merkle trie. -pub struct MerkleKeyValueStream<'a, S, T> { - key_state: IteratorState<'a>, +pub struct MerkleNodeStream<'a, S, T> { + state: NodeStreamState<'a>, merkle_root: DiskAddress, merkle: &'a Merkle, } -impl<'a, S: ShaleStore + Send + Sync, T> FusedStream for MerkleKeyValueStream<'a, S, T> { +impl<'a, S: ShaleStore + Send + Sync, T> FusedStream for MerkleNodeStream<'a, S, T> { fn is_terminated(&self) -> bool { - matches!(&self.key_state, IteratorState::Iterating { visited_node_path, .. } if visited_node_path.is_empty()) + // The top of `iter_stack` is the next node to return. + // If `iter_stack` is empty, there are no more nodes to visit. + matches!(&self.state, NodeStreamState::Iterating { iter_stack } if iter_stack.is_empty()) } } -impl<'a, S, T> MerkleKeyValueStream<'a, S, T> { - pub(super) fn new(merkle: &'a Merkle, merkle_root: DiskAddress) -> Self { - let key_state = IteratorState::new(); - +impl<'a, S, T> MerkleNodeStream<'a, S, T> { + /// Returns a new iterator that will iterate over all the nodes in `merkle` + /// with keys greater than or equal to `key`. + pub(super) fn new(merkle: &'a Merkle, merkle_root: DiskAddress, key: Key) -> Self { Self { - merkle, - key_state, + state: NodeStreamState::new(key), merkle_root, - } - } - - pub(super) fn from_key(merkle: &'a Merkle, merkle_root: DiskAddress, key: Key) -> Self { - let key_state = IteratorState::with_key(key); - - Self { merkle, - key_state, - merkle_root, } } } -impl<'a, S: ShaleStore + Send + Sync, T> Stream for MerkleKeyValueStream<'a, S, T> { - type Item = Result<(Key, Value), api::Error>; +impl<'a, S: ShaleStore + Send + Sync, T> Stream for MerkleNodeStream<'a, S, T> { + type Item = Result<(Key, NodeObjRef<'a>), api::Error>; fn poll_next( mut self: std::pin::Pin<&mut Self>, _cx: &mut std::task::Context<'_>, ) -> Poll> { - // destructuring is necessary here because we need mutable access to `key_state` - // at the same time as immutable access to `merkle` + // destructuring is necessary here because we need mutable access to `state` + // at the same time as immutable access to `merkle`. let Self { - key_state, + state, merkle_root, merkle, } = &mut *self; - match key_state { - IteratorState::StartAtKey(key) => { - let root_node = merkle - .get_node(*merkle_root) - .map_err(|e| api::Error::InternalError(Box::new(e)))?; - - let mut check_child_nibble = false; - - // traverse the trie along each nibble until we find a node with a value - // TODO: merkle.iter_by_key(key) will simplify this entire code-block. - let (found_node, mut visited_node_path) = { - let mut visited_node_path = vec![]; - - let found_node = merkle - .get_node_by_key_with_callbacks( - root_node, - &key, - |node_addr, _| visited_node_path.push(node_addr), - |_, _| {}, - ) - .map_err(|e| api::Error::InternalError(Box::new(e)))?; - - let mut nibbles = Nibbles::<1>::new(key).into_iter(); - - let visited_node_path = visited_node_path - .into_iter() - .map(|node| merkle.get_node(node)) - .map(|node_result| { - let nibbles = &mut nibbles; - - node_result - .map(|node| match node.inner() { - NodeType::Branch(branch) => { - let mut partial_path_iter = branch.path.iter(); - let next_nibble = nibbles - .map(|nibble| (Some(nibble), partial_path_iter.next())) - .find(|(a, b)| a.as_ref() != *b); - - match next_nibble { - // this case will be hit by all but the last nodes - // unless there is a deviation between the key and the path - None | Some((None, _)) => None, - - Some((Some(key_nibble), Some(path_nibble))) => { - check_child_nibble = key_nibble < *path_nibble; - None - } - - // path is subset of the key - Some((Some(nibble), None)) => { - check_child_nibble = true; - Some((node, nibble)) - } - } - } - NodeType::Leaf(_) => Some((node, 0)), - NodeType::Extension(_) => Some((node, 0)), - }) - .transpose() - }) - .take_while(|node| node.is_some()) - .flatten() - .collect::, _>>() - .map_err(|e| api::Error::InternalError(Box::new(e)))?; - - (found_node, visited_node_path) - }; - - if let Some(found_node) = found_node { - let value = match found_node.inner() { - NodeType::Branch(branch) => { - check_child_nibble = true; - branch.value.as_ref() + match state { + NodeStreamState::StartFromKey(key) => { + self.state = get_iterator_intial_state(merkle, *merkle_root, key)?; + self.poll_next(_cx) + } + NodeStreamState::Iterating { iter_stack } => { + while let Some(mut iter_node) = iter_stack.pop() { + match iter_node { + IterationNode::Unvisited { key, node } => { + match node.inner() { + NodeType::Branch(branch) => { + // `node` is a branch node. Visit its children next. + iter_stack.push(IterationNode::Visited { + key: key.clone(), + children_iter: Box::new(as_enumerated_children_iter( + branch, + )), + }); + } + NodeType::Leaf(_) => {} + NodeType::Extension(_) => { + unreachable!("extension nodes shouldn't exist") + } + } + + let key = key_from_nibble_iter(key.iter().copied().skip(1)); + return Poll::Ready(Some(Ok((key, node)))); + } + IterationNode::Visited { + ref key, + ref mut children_iter, + } => { + // We returned `node` already. Visit its next child. + let Some((pos, child_addr)) = children_iter.next() else { + // We visited all this node's descendants. Go back to its parent. + continue; + }; + + let child = merkle.get_node(child_addr)?; + + let partial_path = match child.inner() { + NodeType::Branch(branch) => branch.path.iter().copied(), + NodeType::Leaf(leaf) => leaf.path.iter().copied(), + NodeType::Extension(_) => { + unreachable!("extension nodes shouldn't exist") + } + }; + + // The child's key is its parent's key, followed by the child's index, + // followed by the child's partial path (if any). + let child_key: Box<[u8]> = key + .iter() + .copied() + .chain(once(pos)) + .chain(partial_path) + .collect(); + + // There may be more children of this node to visit. + iter_stack.push(iter_node); + + iter_stack.push(IterationNode::Unvisited { + key: child_key, + node: child, + }); + return self.poll_next(_cx); } - NodeType::Leaf(leaf) => Some(&leaf.data), - NodeType::Extension(_) => None, - }; + } + } + Poll::Ready(None) + } + } + } +} - let next_result = value.map(|value| { - let value = value.to_vec(); +/// Returns the initial state for an iterator over the given `merkle` with root `root_node` +/// which starts at `key`. +fn get_iterator_intial_state<'a, S: ShaleStore + Send + Sync, T>( + merkle: &'a Merkle, + root_node: DiskAddress, + key: &[u8], +) -> Result, api::Error> { + // Invariant: `node`'s key is a prefix of `key`. + let mut node = merkle.get_node(root_node)?; - Ok((std::mem::take(key), value)) - }); + // Invariant: [matched_key_nibbles] is the key of `node` at the start + // of each loop iteration. + let mut matched_key_nibbles = vec![]; - visited_node_path.push((found_node, 0)); + let mut unmatched_key_nibbles = Nibbles::<1>::new(key).into_iter(); - self.key_state = IteratorState::Iterating { - check_child_nibble, - visited_node_path, - }; + let mut iter_stack: Vec = vec![]; - return Poll::Ready(next_result); + loop { + // `next_unmatched_key_nibble` is the first nibble after `matched_key_nibbles`. + let Some(next_unmatched_key_nibble) = unmatched_key_nibbles.next() else { + // The invariant tells us `node` is a prefix of `key`. + // There is no more `key` left so `node` must be at `key`. + // Visit and return `node` first. + match &node.inner { + NodeType::Branch(_) | NodeType::Leaf(_) => { + iter_stack.push(IterationNode::Unvisited { + key: Box::from(matched_key_nibbles), + node, + }); } - - let found_key = nibble_iter_from_parents(&visited_node_path); - let found_key = key_from_nibble_iter(found_key); - - if found_key > *key { - check_child_nibble = false; - visited_node_path.pop(); + NodeType::Extension(_) => { + unreachable!("extension nodes shouldn't exist") } + } - self.key_state = IteratorState::Iterating { - check_child_nibble, - visited_node_path, + return Ok(NodeStreamState::Iterating { iter_stack }); + }; + + match &node.inner { + NodeType::Branch(branch) => { + // The next nibble in `key` is `next_unmatched_key_nibble`, + // so all children of `node` with a position > `next_unmatched_key_nibble` + // should be visited since they are after `key`. + iter_stack.push(IterationNode::Visited { + key: matched_key_nibbles.iter().copied().collect(), + children_iter: Box::new( + as_enumerated_children_iter(branch) + .filter(move |(pos, _)| *pos > next_unmatched_key_nibble), + ), + }); + + // Figure out if the child at `next_unmatched_key_nibble` is a prefix of `key`. + // (i.e. if we should run this loop body again) + #[allow(clippy::indexing_slicing)] + let Some(child_addr) = branch.children[next_unmatched_key_nibble as usize] else { + // There is no child at `next_unmatched_key_nibble`. + // We'll visit `node`'s first child at index > `next_unmatched_key_nibble` + // first (if it exists). + return Ok(NodeStreamState::Iterating { iter_stack }); }; - self.poll_next(_cx) - } + matched_key_nibbles.push(next_unmatched_key_nibble); - IteratorState::Iterating { - check_child_nibble, - visited_node_path, - } => { - let next = find_next_result(merkle, visited_node_path, check_child_nibble) - .map_err(|e| api::Error::InternalError(Box::new(e))) - .transpose(); + let child = merkle.get_node(child_addr)?; - Poll::Ready(next) + let partial_key = match child.inner() { + NodeType::Branch(branch) => &branch.path, + NodeType::Leaf(leaf) => &leaf.path, + NodeType::Extension(_) => { + unreachable!("extension nodes shouldn't exist") + } + }; + + let (comparison, new_unmatched_key_nibbles) = + compare_partial_path(partial_key.iter(), unmatched_key_nibbles); + unmatched_key_nibbles = new_unmatched_key_nibbles; + + match comparison { + Ordering::Less => { + // `child` is before `key`. + return Ok(NodeStreamState::Iterating { iter_stack }); + } + Ordering::Equal => { + // `child` is a prefix of `key`. + matched_key_nibbles.extend(partial_key.iter().copied()); + node = child; + } + Ordering::Greater => { + // `child` is after `key`. + let key = matched_key_nibbles + .iter() + .chain(partial_key.iter()) + .copied() + .collect(); + iter_stack.push(IterationNode::Unvisited { key, node: child }); + + return Ok(NodeStreamState::Iterating { iter_stack }); + } + } } - } + NodeType::Leaf(leaf) => { + if compare_partial_path(leaf.path.iter(), unmatched_key_nibbles).0 + == Ordering::Greater + { + // `child` is after `key`. + let key = matched_key_nibbles + .iter() + .chain(leaf.path.iter()) + .copied() + .collect(); + iter_stack.push(IterationNode::Unvisited { key, node }); + } + return Ok(NodeStreamState::Iterating { iter_stack }); + } + NodeType::Extension(_) => { + unreachable!("extension nodes shouldn't exist") + } + }; } } -enum NodeRef<'a> { - New(NodeObjRef<'a>), - Visited(NodeObjRef<'a>), -} - -#[derive(Debug)] -enum InnerNode<'a> { - New(&'a NodeType), - Visited(&'a NodeType), +enum MerkleKeyValueStreamState<'a, S, T> { + /// The iterator state is lazily initialized when poll_next is called + /// for the first time. The iteration start key is stored here. + Uninitialized(Key), + /// The iterator works by iterating over the nodes in the merkle trie + /// and returning the key-value pairs for nodes that have values. + Initialized { + node_iter: MerkleNodeStream<'a, S, T>, + }, } -impl<'a> NodeRef<'a> { - fn inner(&self) -> InnerNode<'_> { - match self { - Self::New(node) => InnerNode::New(node.inner()), - Self::Visited(node) => InnerNode::Visited(node.inner()), - } +impl<'a, S, T> MerkleKeyValueStreamState<'a, S, T> { + /// Returns a new iterator that will iterate over all the key-value pairs in `merkle`. + fn new() -> Self { + Self::Uninitialized(Box::new([])) } - fn into_node(self) -> NodeObjRef<'a> { - match self { - Self::New(node) => node, - Self::Visited(node) => node, - } + /// Returns a new iterator that will iterate over all the key-value pairs in `merkle` + /// with keys greater than or equal to `key`. + fn with_key(key: Key) -> Self { + Self::Uninitialized(key) } } -fn find_next_result<'a, S: ShaleStore, T>( +pub struct MerkleKeyValueStream<'a, S, T> { + state: MerkleKeyValueStreamState<'a, S, T>, + merkle_root: DiskAddress, merkle: &'a Merkle, - visited_path: &mut Vec<(NodeObjRef<'a>, u8)>, - check_child_nibble: &mut bool, -) -> Result, super::MerkleError> { - let next = find_next_node_with_data(merkle, visited_path, *check_child_nibble)?.map( - |(next_node, value)| { - let partial_path = match next_node.inner() { - NodeType::Leaf(leaf) => leaf.path.iter().copied(), - NodeType::Extension(extension) => extension.path.iter().copied(), - NodeType::Branch(branch) => branch.path.iter().copied(), - }; - - // always check the child for branch nodes with data - *check_child_nibble = next_node.inner().is_branch(); - - let key = - key_from_nibble_iter(nibble_iter_from_parents(visited_path).chain(partial_path)); - - visited_path.push((next_node, 0)); - - (key, value) - }, - ); - - Ok(next) } -fn find_next_node_with_data<'a, S: ShaleStore, T>( - merkle: &'a Merkle, - visited_path: &mut Vec<(NodeObjRef<'a>, u8)>, - check_child_nibble: bool, -) -> Result, Vec)>, super::MerkleError> { - use InnerNode::*; - - let Some((visited_parent, visited_pos)) = visited_path.pop() else { - return Ok(None); - }; - - let mut node = NodeRef::Visited(visited_parent); - let mut pos = visited_pos; - let mut first_loop = true; - - loop { - match node.inner() { - New(NodeType::Leaf(leaf)) => { - let value = leaf.data.to_vec(); - return Ok(Some((node.into_node(), value))); - } - - Visited(NodeType::Leaf(_)) | Visited(NodeType::Extension(_)) => { - let Some((next_parent, next_pos)) = visited_path.pop() else { - return Ok(None); - }; - - node = NodeRef::Visited(next_parent); - pos = next_pos; - } - - New(NodeType::Extension(extension)) => { - let child = merkle.get_node(extension.chd())?; - - pos = 0; - visited_path.push((node.into_node(), pos)); +impl<'a, S: ShaleStore + Send + Sync, T> FusedStream for MerkleKeyValueStream<'a, S, T> { + fn is_terminated(&self) -> bool { + matches!(&self.state, MerkleKeyValueStreamState::Initialized { node_iter } if node_iter.is_terminated()) + } +} - node = NodeRef::New(child); - } +impl<'a, S, T> MerkleKeyValueStream<'a, S, T> { + pub(super) fn new(merkle: &'a Merkle, merkle_root: DiskAddress) -> Self { + Self { + state: MerkleKeyValueStreamState::new(), + merkle_root, + merkle, + } + } - Visited(NodeType::Branch(branch)) => { - // if the first node that we check is a visited branch, that means that the branch had a value - // and we need to visit the first child, for all other cases, we need to visit the next child - let compare_op = if first_loop && check_child_nibble { - ::ge // >= - } else { - ::gt - }; + pub(super) fn from_key(merkle: &'a Merkle, merkle_root: DiskAddress, key: Key) -> Self { + Self { + state: MerkleKeyValueStreamState::with_key(key), + merkle_root, + merkle, + } + } +} - let children = get_children_iter(branch) - .filter(move |(_, child_pos)| compare_op(child_pos, &pos)); +impl<'a, S: ShaleStore + Send + Sync, T> Stream for MerkleKeyValueStream<'a, S, T> { + type Item = Result<(Key, Value), api::Error>; - let found_next_node = - next_node(merkle, children, visited_path, &mut node, &mut pos)?; + fn poll_next( + mut self: std::pin::Pin<&mut Self>, + _cx: &mut std::task::Context<'_>, + ) -> Poll> { + // destructuring is necessary here because we need mutable access to `key_state` + // at the same time as immutable access to `merkle` + let Self { + state, + merkle_root, + merkle, + } = &mut *self; - if !found_next_node { - return Ok(None); - } + match state { + MerkleKeyValueStreamState::Uninitialized(key) => { + let iter = MerkleNodeStream::new(merkle, *merkle_root, key.clone()); + self.state = MerkleKeyValueStreamState::Initialized { node_iter: iter }; + self.poll_next(_cx) } - - New(NodeType::Branch(branch)) => { - if let Some(value) = branch.value.as_ref() { - let value = value.to_vec(); - return Ok(Some((node.into_node(), value))); - } - - let children = get_children_iter(branch); - - let found_next_node = - next_node(merkle, children, visited_path, &mut node, &mut pos)?; - - if !found_next_node { - return Ok(None); + MerkleKeyValueStreamState::Initialized { node_iter: iter } => { + match iter.poll_next_unpin(_cx) { + Poll::Ready(node) => match node { + Some(Ok((key, node))) => match node.inner() { + NodeType::Branch(branch) => { + let Some(value) = branch.value.as_ref() else { + // This node doesn't have a value to return. + // Continue to the next node. + return self.poll_next(_cx); + }; + + let value = value.to_vec(); + Poll::Ready(Some(Ok((key, value)))) + } + NodeType::Leaf(leaf) => { + let value = leaf.data.to_vec(); + Poll::Ready(Some(Ok((key, value)))) + } + NodeType::Extension(_) => { + unreachable!("extension nodes shouldn't exist") + } + }, + Some(Err(e)) => Poll::Ready(Some(Err(e))), + None => Poll::Ready(None), + }, + Poll::Pending => Poll::Pending, } } } - - first_loop = false; } } -fn get_children_iter(branch: &BranchNode) -> impl Iterator { - branch - .children - .into_iter() - .enumerate() - .filter_map(|(pos, child_addr)| child_addr.map(|child_addr| (child_addr, pos as u8))) -} - -/// This function is a little complicated because we need to be able to early return from the parent -/// when we return `false`. `MustUse` forces the caller to check the inner value of `Result::Ok`. -/// It also replaces `node` -fn next_node<'a, S, T, Iter>( - merkle: &'a Merkle, - mut children: Iter, - parents: &mut Vec<(NodeObjRef<'a>, u8)>, - node: &mut NodeRef<'a>, - pos: &mut u8, -) -> Result, super::MerkleError> +/// Takes in an iterator over a node's partial path and an iterator over the +/// unmatched portion of a key. +/// The first returned element is: +/// * [Ordering::Less] if the node is before the key. +/// * [Ordering::Equal] if the node is a prefix of the key. +/// * [Ordering::Greater] if the node is after the key. +/// The second returned element is the unmatched portion of the key after the +/// partial path has been matched. +fn compare_partial_path<'a, I1, I2>( + partial_path_iter: I1, + mut unmatched_key_nibbles_iter: I2, +) -> (Ordering, I2) where - Iter: Iterator, - S: ShaleStore, + I1: Iterator, + I2: Iterator, { - if let Some((child_addr, child_pos)) = children.next() { - let child = merkle.get_node(child_addr)?; - - *pos = child_pos; - let node = std::mem::replace(node, NodeRef::New(child)); - parents.push((node.into_node(), *pos)); - } else { - let Some((next_parent, next_pos)) = parents.pop() else { - return Ok(false.into()); + for next_partial_path_nibble in partial_path_iter { + let Some(next_key_nibble) = unmatched_key_nibbles_iter.next() else { + return (Ordering::Greater, unmatched_key_nibbles_iter); }; - *node = NodeRef::Visited(next_parent); - *pos = next_pos; + match next_partial_path_nibble.cmp(&next_key_nibble) { + Ordering::Less => return (Ordering::Less, unmatched_key_nibbles_iter), + Ordering::Greater => return (Ordering::Greater, unmatched_key_nibbles_iter), + Ordering::Equal => {} + } } - Ok(true.into()) + (Ordering::Equal, unmatched_key_nibbles_iter) } -/// create an iterator over the key-nibbles from all parents _excluding_ the sentinal node. -fn nibble_iter_from_parents<'a>(parents: &'a [(NodeObjRef, u8)]) -> impl Iterator + 'a { - parents - .iter() - .skip(1) // always skip the sentinal node - .flat_map(|(parent, child_nibble)| match parent.inner() { - NodeType::Branch(branch) => Either::Left( - branch - .path - .iter() - .copied() - .chain(std::iter::once(*child_nibble)), - ), - NodeType::Extension(extension) => Either::Right(extension.path.iter().copied()), - NodeType::Leaf(leaf) => Either::Right(leaf.path.iter().copied()), - }) +/// Returns an iterator that returns (`pos`,`child_addr`) for each non-empty child of `branch`, +/// where `pos` is the position of the child in `branch`'s children array. +fn as_enumerated_children_iter(branch: &BranchNode) -> impl Iterator { + branch + .children + .into_iter() + .enumerate() + .filter_map(|(pos, child_addr)| child_addr.map(|child_addr| (pos as u8, child_addr))) } fn key_from_nibble_iter>(mut nibbles: Iter) -> Key { @@ -412,83 +447,214 @@ fn key_from_nibble_iter>(mut nibbles: Iter) -> Key { data.into_boxed_slice() } -mod helper_types { - use std::ops::Not; +#[cfg(test)] +use super::tests::create_test_merkle; - /// Enums enable stack-based dynamic-dispatch as opposed to heap-based `Box`. - /// This helps us with match arms that return different types that implement the same trait. - /// It's possible that [rust-lang/rust#63065](https://github.com/rust-lang/rust/issues/63065) will make this unnecessary. - /// - /// And this can be replaced by the `either` crate from crates.io if we ever need more functionality. - pub(super) enum Either { - Left(T), - Right(U), - } +#[cfg(test)] +#[allow(clippy::indexing_slicing, clippy::unwrap_used)] +mod tests { + use crate::{ + merkle::Bincode, + shale::{cached::DynamicMem, compact::CompactSpace}, + }; - impl Iterator for Either - where - T: Iterator, - U: Iterator, - { - type Item = T::Item; + use super::*; + use futures::StreamExt; + use test_case::test_case; - fn next(&mut self) -> Option { - match self { - Self::Left(left) => left.next(), - Self::Right(right) => right.next(), - } + impl + Send + Sync, T> Merkle { + pub(crate) fn node_iter(&self, root: DiskAddress) -> MerkleNodeStream<'_, S, T> { + MerkleNodeStream::new(self, root, Box::new([])) + } + + pub(crate) fn node_iter_from( + &self, + root: DiskAddress, + key: Box<[u8]>, + ) -> MerkleNodeStream<'_, S, T> { + MerkleNodeStream::new(self, root, key) } } - #[must_use] - pub(super) struct MustUse(T); + #[tokio::test] + async fn key_value_iterate_empty() { + let merkle = create_test_merkle(); + let root = merkle.init_root().unwrap(); + let stream = merkle.key_value_iter_from_key(root, b"x".to_vec().into_boxed_slice()); + check_stream_is_done(stream).await; + } - impl From for MustUse { - fn from(t: T) -> Self { - Self(t) - } + #[tokio::test] + async fn node_iterate_empty() { + let merkle = create_test_merkle(); + let root = merkle.init_root().unwrap(); + let stream = merkle.node_iter(root); + check_stream_is_done(stream).await; } - impl Not for MustUse { - type Output = T::Output; + #[tokio::test] + async fn node_iterate_root_only() { + let mut merkle = create_test_merkle(); + + let root = merkle.init_root().unwrap(); - fn not(self) -> Self::Output { - self.0.not() - } + merkle.insert(vec![0x00], vec![0x00], root).unwrap(); + + let mut stream = merkle.node_iter(root); + + let (key, node) = stream.next().await.unwrap().unwrap(); + + assert_eq!(key, vec![0x00].into_boxed_slice()); + assert_eq!(node.inner().as_leaf().unwrap().data.to_vec(), vec![0x00]); + + check_stream_is_done(stream).await; } -} -// CAUTION: only use with nibble iterators -trait IntoBytes: Iterator { - fn nibbles_into_bytes(&mut self) -> Vec { - let mut data = Vec::with_capacity(self.size_hint().0 / 2); + /// Returns a new [Merkle] with the following structure: + /// sentinel + /// | 0 + /// 00 <-- branch with no value + /// 0/ D| \F + /// 00 0D0 F <-- leaf with no partial path + /// 0/ \F + /// 1 F + /// + /// Note the 0000 branch has no value and the F0F0 + /// The number next to each branch is the position of the child in the branch's children array. + fn created_populated_merkle() -> (Merkle, Bincode>, DiskAddress) + { + let mut merkle = create_test_merkle(); + let root = merkle.init_root().unwrap(); - while let (Some(hi), Some(lo)) = (self.next(), self.next()) { - data.push((hi << 4) + lo); - } + merkle + .insert(vec![0x00, 0x00, 0x00], vec![0x00, 0x00, 0x00], root) + .unwrap(); + merkle + .insert( + vec![0x00, 0x00, 0x00, 0x01], + vec![0x00, 0x00, 0x00, 0x01], + root, + ) + .unwrap(); + merkle + .insert( + vec![0x00, 0x00, 0x00, 0xFF], + vec![0x00, 0x00, 0x00, 0xFF], + root, + ) + .unwrap(); + merkle + .insert(vec![0x00, 0xD0, 0xD0], vec![0x00, 0xD0, 0xD0], root) + .unwrap(); + merkle + .insert(vec![0x00, 0xFF], vec![0x00, 0xFF], root) + .unwrap(); + (merkle, root) + } + + #[tokio::test] + async fn node_iterator_no_start_key() { + let (merkle, root) = created_populated_merkle(); + + let mut stream = merkle.node_iter(root); + + // Covers case of branch with no value + let (key, node) = stream.next().await.unwrap().unwrap(); + assert_eq!(key, vec![0x00].into_boxed_slice()); + let node = node.inner().as_branch().unwrap(); + assert!(node.value.is_none()); + assert_eq!(node.path.to_vec(), vec![0x00, 0x00]); + + // Covers case of branch with value + let (key, node) = stream.next().await.unwrap().unwrap(); + assert_eq!(key, vec![0x00, 0x00, 0x00].into_boxed_slice()); + let node = node.inner().as_branch().unwrap(); + assert_eq!(node.value.clone().unwrap().to_vec(), vec![0x00, 0x00, 0x00]); + assert_eq!(node.path.to_vec(), vec![0x00, 0x00, 0x00]); + + // Covers case of leaf with partial path + let (key, node) = stream.next().await.unwrap().unwrap(); + assert_eq!(key, vec![0x00, 0x00, 0x00, 0x01].into_boxed_slice()); + let node = node.inner().as_leaf().unwrap(); + assert_eq!(node.clone().data.to_vec(), vec![0x00, 0x00, 0x00, 0x01]); + assert_eq!(node.path.to_vec(), vec![0x01]); + + let (key, node) = stream.next().await.unwrap().unwrap(); + assert_eq!(key, vec![0x00, 0x00, 0x00, 0xFF].into_boxed_slice()); + let node = node.inner().as_leaf().unwrap(); + assert_eq!(node.clone().data.to_vec(), vec![0x00, 0x00, 0x00, 0xFF]); + assert_eq!(node.path.to_vec(), vec![0x0F]); + + let (key, node) = stream.next().await.unwrap().unwrap(); + assert_eq!(key, vec![0x00, 0xD0, 0xD0].into_boxed_slice()); + let node = node.inner().as_leaf().unwrap(); + assert_eq!(node.clone().data.to_vec(), vec![0x00, 0xD0, 0xD0]); + assert_eq!(node.path.to_vec(), vec![0x00, 0x0D, 0x00]); // 0x0D00 becomes 0xDO + + // Covers case of leaf with no partial path + let (key, node) = stream.next().await.unwrap().unwrap(); + assert_eq!(key, vec![0x00, 0xFF].into_boxed_slice()); + let node = node.inner().as_leaf().unwrap(); + assert_eq!(node.clone().data.to_vec(), vec![0x00, 0xFF]); + assert_eq!(node.path.to_vec(), vec![0x0F]); - data + check_stream_is_done(stream).await; } -} -impl> IntoBytes for T {} -#[cfg(test)] -use super::tests::create_test_merkle; + #[tokio::test] + async fn node_iterator_start_key_between_nodes() { + let (merkle, root) = created_populated_merkle(); -#[cfg(test)] -#[allow(clippy::indexing_slicing, clippy::unwrap_used)] -mod tests { - use crate::nibbles::Nibbles; + let mut stream = merkle.node_iter_from(root, vec![0x00, 0x00, 0x01].into_boxed_slice()); - use super::*; - use futures::StreamExt; - use test_case::test_case; + let (key, node) = stream.next().await.unwrap().unwrap(); + assert_eq!(key, vec![0x00, 0xD0, 0xD0].into_boxed_slice()); + assert_eq!( + node.inner().as_leaf().unwrap().clone().data.to_vec(), + vec![0x00, 0xD0, 0xD0] + ); + + // Covers case of leaf with no partial path + let (key, node) = stream.next().await.unwrap().unwrap(); + assert_eq!(key, vec![0x00, 0xFF].into_boxed_slice()); + assert_eq!( + node.inner().as_leaf().unwrap().clone().data.to_vec(), + vec![0x00, 0xFF] + ); + + check_stream_is_done(stream).await; + } #[tokio::test] - async fn iterate_empty() { - let merkle = create_test_merkle(); - let root = merkle.init_root().unwrap(); - let stream = merkle.iter_from(root, b"x".to_vec().into_boxed_slice()); + async fn node_iterator_start_key_on_node() { + let (merkle, root) = created_populated_merkle(); + + let mut stream = merkle.node_iter_from(root, vec![0x00, 0xD0, 0xD0].into_boxed_slice()); + + let (key, node) = stream.next().await.unwrap().unwrap(); + assert_eq!(key, vec![0x00, 0xD0, 0xD0].into_boxed_slice()); + assert_eq!( + node.inner().as_leaf().unwrap().clone().data.to_vec(), + vec![0x00, 0xD0, 0xD0] + ); + + // Covers case of leaf with no partial path + let (key, node) = stream.next().await.unwrap().unwrap(); + assert_eq!(key, vec![0x00, 0xFF].into_boxed_slice()); + assert_eq!( + node.inner().as_leaf().unwrap().clone().data.to_vec(), + vec![0x00, 0xFF] + ); + + check_stream_is_done(stream).await; + } + + #[tokio::test] + async fn node_iterator_start_key_after_last_key() { + let (merkle, root) = created_populated_merkle(); + + let stream = merkle.node_iter_from(root, vec![0xFF].into_boxed_slice()); + check_stream_is_done(stream).await; } @@ -497,7 +663,7 @@ mod tests { #[test_case(Some(&[128u8]); "Starting in middle")] #[test_case(Some(&[u8::MAX]); "Starting at last key")] #[tokio::test] - async fn iterate_many(start: Option<&[u8]>) { + async fn key_value_iterate_many(start: Option<&[u8]>) { let mut merkle = create_test_merkle(); let root = merkle.init_root().unwrap(); @@ -507,8 +673,8 @@ mod tests { } let mut stream = match start { - Some(start) => merkle.iter_from(root, start.to_vec().into_boxed_slice()), - None => merkle.iter(root), + Some(start) => merkle.key_value_iter_from_key(root, start.to_vec().into_boxed_slice()), + None => merkle.key_value_iter(root), }; // we iterate twice because we should get a None then start over @@ -527,14 +693,75 @@ mod tests { } #[tokio::test] - async fn fused_empty() { + async fn key_value_fused_empty() { let merkle = create_test_merkle(); let root = merkle.init_root().unwrap(); - check_stream_is_done(merkle.iter(root)).await; + check_stream_is_done(merkle.key_value_iter(root)).await; + } + + #[tokio::test] + async fn key_value_table_test() { + let mut merkle = create_test_merkle(); + let root = merkle.init_root().unwrap(); + + // Insert key-values in reverse order to ensure iterator + // doesn't just return the keys in insertion order. + for i in (0..=u8::MAX).rev() { + for j in (0..=u8::MAX).rev() { + let key = vec![i, j]; + let value = vec![i, j]; + + merkle.insert(key, value, root).unwrap(); + } + } + + // Test with no start key + let mut stream = merkle.key_value_iter(root); + for i in 0..=u8::MAX { + for j in 0..=u8::MAX { + let expected_key = vec![i, j]; + let expected_value = vec![i, j]; + + assert_eq!( + stream.next().await.unwrap().unwrap(), + (expected_key.into_boxed_slice(), expected_value), + "i: {}, j: {}", + i, + j, + ); + } + } + check_stream_is_done(stream).await; + + // Test with start key + for i in 0..=u8::MAX { + let mut stream = merkle.key_value_iter_from_key(root, vec![i].into_boxed_slice()); + for j in 0..=u8::MAX { + let expected_key = vec![i, j]; + let expected_value = vec![i, j]; + assert_eq!( + stream.next().await.unwrap().unwrap(), + (expected_key.into_boxed_slice(), expected_value), + "i: {}, j: {}", + i, + j, + ); + } + if i == u8::MAX { + check_stream_is_done(stream).await; + } else { + assert_eq!( + stream.next().await.unwrap().unwrap(), + (vec![i + 1, 0].into_boxed_slice(), vec![i + 1, 0]), + "i: {}", + i, + ); + } + } } #[tokio::test] - async fn fused_full() { + async fn key_value_fused_full() { let mut merkle = create_test_merkle(); let root = merkle.init_root().unwrap(); @@ -553,7 +780,7 @@ mod tests { merkle.insert(kv, kv.clone(), root).unwrap(); } - let mut stream = merkle.iter(root); + let mut stream = merkle.key_value_iter(root); for kv in key_values.iter() { let next = stream.next().await.unwrap().unwrap(); @@ -565,7 +792,7 @@ mod tests { } #[tokio::test] - async fn root_with_empty_data() { + async fn key_value_root_with_empty_data() { let mut merkle = create_test_merkle(); let root = merkle.init_root().unwrap(); @@ -574,13 +801,13 @@ mod tests { merkle.insert(&key, value.clone(), root).unwrap(); - let mut stream = merkle.iter(root); + let mut stream = merkle.key_value_iter(root); assert_eq!(stream.next().await.unwrap().unwrap(), (key, value)); } #[tokio::test] - async fn get_branch_and_leaf() { + async fn key_value_get_branch_and_leaf() { let mut merkle = create_test_merkle(); let root = merkle.init_root().unwrap(); @@ -597,7 +824,7 @@ mod tests { merkle.insert(branch, branch.to_vec(), root).unwrap(); - let mut stream = merkle.iter(root); + let mut stream = merkle.key_value_iter(root); assert_eq!( stream.next().await.unwrap().unwrap(), @@ -619,7 +846,7 @@ mod tests { } #[tokio::test] - async fn start_at_key_not_in_trie() { + async fn key_value_start_at_key_not_in_trie() { let mut merkle = create_test_merkle(); let root = merkle.init_root().unwrap(); @@ -640,7 +867,8 @@ mod tests { merkle.insert(key, key.to_vec(), root).unwrap(); } - let mut stream = merkle.iter_from(root, vec![intermediate].into_boxed_slice()); + let mut stream = + merkle.key_value_iter_from_key(root, vec![intermediate].into_boxed_slice()); let first_expected = key_values[1].as_slice(); let first = stream.next().await.unwrap().unwrap(); @@ -658,7 +886,7 @@ mod tests { } #[tokio::test] - async fn start_at_key_on_branch_with_no_value() { + async fn key_value_start_at_key_on_branch_with_no_value() { let sibling_path = 0x00; let branch_path = 0x0f; let children = 0..=0x0f; @@ -687,7 +915,7 @@ mod tests { let start = keys.iter().position(|key| key[0] == branch_path).unwrap(); let keys = &keys[start..]; - let mut stream = merkle.iter_from(root, vec![branch_path].into_boxed_slice()); + let mut stream = merkle.key_value_iter_from_key(root, vec![branch_path].into_boxed_slice()); for key in keys { let next = stream.next().await.unwrap().unwrap(); @@ -700,7 +928,7 @@ mod tests { } #[tokio::test] - async fn start_at_key_on_branch_with_value() { + async fn key_value_start_at_key_on_branch_with_value() { let sibling_path = 0x00; let branch_path = 0x0f; let branch_key = vec![branch_path]; @@ -736,7 +964,7 @@ mod tests { let start = keys.iter().position(|key| key == &branch_key).unwrap(); let keys = &keys[start..]; - let mut stream = merkle.iter_from(root, branch_key.into_boxed_slice()); + let mut stream = merkle.key_value_iter_from_key(root, branch_key.into_boxed_slice()); for key in keys { let next = stream.next().await.unwrap().unwrap(); @@ -749,7 +977,7 @@ mod tests { } #[tokio::test] - async fn start_at_key_on_extension() { + async fn key_value_start_at_key_on_extension() { let missing = 0x0a; let children = (0..=0x0f).filter(|x| *x != missing); let mut merkle = create_test_merkle(); @@ -767,7 +995,7 @@ mod tests { let keys = &keys[(missing as usize)..]; - let mut stream = merkle.iter_from(root, vec![missing].into_boxed_slice()); + let mut stream = merkle.key_value_iter_from_key(root, vec![missing].into_boxed_slice()); for key in keys { let next = stream.next().await.unwrap().unwrap(); @@ -780,7 +1008,7 @@ mod tests { } #[tokio::test] - async fn start_at_key_overlapping_with_extension_but_greater() { + async fn key_value_start_at_key_overlapping_with_extension_but_greater() { let start_key = 0x0a; let shared_path = 0x09; // 0x0900, 0x0901, ... 0x0a0f @@ -794,13 +1022,13 @@ mod tests { merkle.insert(&key, key.clone(), root).unwrap(); }); - let stream = merkle.iter_from(root, vec![start_key].into_boxed_slice()); + let stream = merkle.key_value_iter_from_key(root, vec![start_key].into_boxed_slice()); check_stream_is_done(stream).await; } #[tokio::test] - async fn start_at_key_overlapping_with_extension_but_smaller() { + async fn key_value_start_at_key_overlapping_with_extension_but_smaller() { let start_key = 0x00; let shared_path = 0x09; // 0x0900, 0x0901, ... 0x0a0f @@ -817,7 +1045,7 @@ mod tests { }) .collect(); - let mut stream = merkle.iter_from(root, vec![start_key].into_boxed_slice()); + let mut stream = merkle.key_value_iter_from_key(root, vec![start_key].into_boxed_slice()); for key in keys { let next = stream.next().await.unwrap().unwrap(); @@ -830,7 +1058,7 @@ mod tests { } #[tokio::test] - async fn start_at_key_between_siblings() { + async fn key_value_start_at_key_between_siblings() { let missing = 0xaa; let children = (0..=0xf) .map(|val| (val << 4) + val) // 0x00, 0x11, ... 0xff @@ -850,7 +1078,7 @@ mod tests { let keys = &keys[((missing >> 4) as usize)..]; - let mut stream = merkle.iter_from(root, vec![missing].into_boxed_slice()); + let mut stream = merkle.key_value_iter_from_key(root, vec![missing].into_boxed_slice()); for key in keys { let next = stream.next().await.unwrap().unwrap(); @@ -863,19 +1091,19 @@ mod tests { } #[tokio::test] - async fn start_at_key_greater_than_all_others_leaf() { + async fn key_value_start_at_key_greater_than_all_others_leaf() { let key = vec![0x00]; let greater_key = vec![0xff]; let mut merkle = create_test_merkle(); let root = merkle.init_root().unwrap(); merkle.insert(key.clone(), key, root).unwrap(); - let stream = merkle.iter_from(root, greater_key.into_boxed_slice()); + let stream = merkle.key_value_iter_from_key(root, greater_key.into_boxed_slice()); check_stream_is_done(stream).await; } #[tokio::test] - async fn start_at_key_greater_than_all_others_branch() { + async fn key_value_start_at_key_greater_than_all_others_branch() { let greatest = 0xff; let children = (0..=0xf) .map(|val| (val << 4) + val) // 0x00, 0x11, ... 0xff @@ -895,7 +1123,7 @@ mod tests { let keys = &keys[((greatest >> 4) as usize)..]; - let mut stream = merkle.iter_from(root, vec![greatest].into_boxed_slice()); + let mut stream = merkle.key_value_iter_from_key(root, vec![greatest].into_boxed_slice()); for key in keys { let next = stream.next().await.unwrap().unwrap(); @@ -914,21 +1142,4 @@ mod tests { assert!(stream.next().await.is_none()); assert!(stream.is_terminated()); } - - #[test] - fn remaining_bytes() { - let data = &[1]; - let nib: Nibbles<'_, 0> = Nibbles::<0>::new(data); - let mut it = nib.into_iter(); - assert_eq!(it.nibbles_into_bytes(), data.to_vec()); - } - - #[test] - fn remaining_bytes_off() { - let data = &[1]; - let nib: Nibbles<'_, 0> = Nibbles::<0>::new(data); - let mut it = nib.into_iter(); - it.next(); - assert_eq!(it.nibbles_into_bytes(), vec![]); - } } diff --git a/firewood/src/v2/api.rs b/firewood/src/v2/api.rs index 866ccadd6..a6a413d26 100644 --- a/firewood/src/v2/api.rs +++ b/firewood/src/v2/api.rs @@ -1,6 +1,7 @@ // Copyright (C) 2023, Ava Labs, Inc. All rights reserved. // See the file LICENSE.md for licensing terms. +use crate::merkle::MerkleError; pub use crate::merkle::Proof; use async_trait::async_trait; use std::{fmt::Debug, sync::Arc}; @@ -84,6 +85,13 @@ pub enum Error { RangeTooSmall, } +impl From for Error { + fn from(err: MerkleError) -> Self { + // TODO: do a better job + Error::InternalError(Box::new(err)) + } +} + /// A range proof, consisting of a proof of the first key and the last key, /// and a vector of all key/value pairs #[derive(Debug)]