diff --git a/core/store/src/trie/iterator.rs b/core/store/src/trie/iterator.rs index 16cdf8047ae..f1de3244142 100644 --- a/core/store/src/trie/iterator.rs +++ b/core/store/src/trie/iterator.rs @@ -143,7 +143,9 @@ impl<'a> TrieIterator<'a> { TrieNode::Empty => break, TrieNode::Leaf(leaf_key, _) => { let existing_key = NibbleSlice::from_encoded(leaf_key).0; + println!("see leaf {:?} {:?}", leaf_key, existing_key); if !check_ext_key(&key, &existing_key) { + println!("yes"); self.key_nibbles.extend(existing_key.iter()); *status = CrumbStatus::Exiting; } @@ -356,6 +358,9 @@ impl<'a> TrieIterator<'a> { } IterStep::Continue => {} IterStep::Value(hash) => { + if self.key_nibbles[prefix..] >= path_end[prefix..] { + break; + } self.trie.storage.retrieve_raw_bytes(&hash)?; nodes_list.push(TrieTraversalItem { hash, @@ -368,6 +373,7 @@ impl<'a> TrieIterator<'a> { } } +#[derive(Debug)] enum IterStep { Continue, PopTrail, diff --git a/core/store/src/trie/mod.rs b/core/store/src/trie/mod.rs index 891a6d7a882..b10ef1085f9 100644 --- a/core/store/src/trie/mod.rs +++ b/core/store/src/trie/mod.rs @@ -23,6 +23,7 @@ pub use raw_node::{Children, RawTrieNode, RawTrieNodeWithSize}; use std::cell::RefCell; use std::collections::HashMap; use std::fmt::Write; +use std::hash::{Hash, Hasher}; use std::rc::Rc; use std::str; @@ -353,6 +354,12 @@ impl TrieRefcountChange { } } +impl Hash for TrieRefcountChange { + fn hash(&self, state: &mut H) { + state.write(&self.trie_node_or_value_hash.0); + state.write_u32(self.rc.into()); + } +} /// /// TrieChanges stores delta for refcount. /// Multiple versions of the state work the following way: diff --git a/core/store/src/trie/state_parts.rs b/core/store/src/trie/state_parts.rs index 41295a17b2a..4053907e0f7 100644 --- a/core/store/src/trie/state_parts.rs +++ b/core/store/src/trie/state_parts.rs @@ -68,7 +68,7 @@ impl Trie { } let root_node = self.retrieve_node(&self.root)?.1; let total_size = root_node.memory_usage; - let size_start = (total_size + num_parts - 1) / num_parts * part_id; + let size_start = total_size / num_parts * part_id + part_id.min(total_size % num_parts); self.find_node_in_dfs_order(&root_node, size_start) } @@ -115,7 +115,7 @@ impl Trie { /// Tells a child in which we should go to find a node in dfs order corresponding /// to `memory_threshold`. /// Accumulates `memory_skipped` as memory used by all skipped nodes. - /// Returns false if we already find desired node and should stop the process. + /// Returns false if we already found desired node and should stop the process. fn find_child_in_dfs_order( &self, memory_threshold: u64, @@ -123,21 +123,23 @@ impl Trie { memory_skipped: &mut u64, key_nibbles: &mut Vec, ) -> Result { - let node_size = node.node.memory_usage_direct_no_memory(); - if *memory_skipped + node_size <= memory_threshold { - *memory_skipped += node_size; - } else if node.node.has_value() { - return Ok(false); - } + *memory_skipped += node.node.memory_usage_direct_no_memory(); match &node.node { TrieNode::Empty => Ok(false), TrieNode::Leaf(key, _) => { - let (slice, _is_leaf) = NibbleSlice::from_encoded(key); + let (slice, _) = NibbleSlice::from_encoded(key); key_nibbles.extend(slice.iter()); + + // Leaf must contain value, so we found the boundary. Ok(false) } - TrieNode::Branch(children, _) => { + TrieNode::Branch(children, value_handle) => { + if *memory_skipped > memory_threshold && value_handle.is_some() { + // If we skipped enough memory and found some value, we found the boundary. + return Ok(false); + } + let mut iter = children.iter(); while let Some((index, child)) = iter.next() { let child = if let NodeHandle::Hash(h) = child { @@ -147,6 +149,7 @@ impl Trie { }; if *memory_skipped + child.memory_usage > memory_threshold { core::mem::drop(iter); + // println!("push {:?}", vec![index]); key_nibbles.push(index); *node = child; return Ok(true); @@ -174,7 +177,7 @@ impl Trie { NodeHandle::InMemory(_) => unreachable!("only possible while mutating"), NodeHandle::Hash(h) => self.retrieve_node(h)?.1, }; - let (slice, _is_leaf) = NibbleSlice::from_encoded(key); + let (slice, _) = NibbleSlice::from_encoded(key); key_nibbles.extend(slice.iter()); *node = child; Ok(true) @@ -245,6 +248,7 @@ impl Trie { contract_codes: vec![], }); } + println!("NODES: {:?}", part); let trie = Trie::from_recorded_storage(PartialStorage { nodes: part }, *state_root); let path_begin = trie.find_state_part_boundary(part_id.idx, part_id.total)?; let path_end = trie.find_state_part_boundary(part_id.idx + 1, part_id.total)?; @@ -295,9 +299,13 @@ impl Trie { } } +/// TODO (#8997): test set seems incomplete. Perhaps `get_trie_items_for_part` +/// should also belong to this file. We need to use it to check that state +/// parts are continuous and disjoint. Maybe it is checked in split_state.rs. #[cfg(test)] mod tests { - use std::collections::HashMap; + use assert_matches::assert_matches; + use std::collections::{HashMap, HashSet}; use std::sync::Arc; use rand::prelude::ThreadRng; @@ -312,6 +320,85 @@ mod tests { use super::*; use near_primitives::shard_layout::ShardUId; + fn nibbles_to_bytes(nibbles: &[u8]) -> Vec { + assert_eq!(nibbles.len() % 2, 0); + let encoded = NibbleSlice::encode_nibbles(&nibbles, false); + encoded[1..].to_vec() + } + + /// Checks that sampling state boundaries results always gives valid state + /// keys, even if trie contains intermediate nodes. + #[test] + fn boundary_is_state_key() { + // Trie should contain at least two intermediate branches, for strings + // "a" and "b". + let trie_changes = vec![ + (b"after".to_vec(), Some(vec![1])), + (b"alley".to_vec(), Some(vec![2])), + (b"berry".to_vec(), Some(vec![3])), + (b"brave".to_vec(), Some(vec![4])), + ]; + + // Number of state parts. Must be larger than number of state items to ensure + // that boundaries are nontrivial. + let num_parts = 10u64; + + let tries = create_tries(); + let state_root = + test_populate_trie(&tries, &Trie::EMPTY_ROOT, ShardUId::single_shard(), trie_changes); + let trie = tries.get_trie_for_shard(ShardUId::single_shard(), state_root); + + let nibbles_boundary = trie.find_state_part_boundary(0, num_parts).unwrap(); + assert!(nibbles_boundary.is_empty()); + + // Check that all boundaries correspond to some state key by calling `Trie::get`. + // Note that some state parts can be trivial, which is not a concern. + for part_id in 1..num_parts { + let nibbles_boundary = trie.find_state_part_boundary(part_id, num_parts).unwrap(); + let key_boundary = nibbles_to_bytes(&nibbles_boundary); + assert_matches!(trie.get(&key_boundary), Ok(Some(_))); + } + + let nibbles_boundary = trie.find_state_part_boundary(num_parts, num_parts).unwrap(); + assert_eq!(nibbles_boundary, LAST_STATE_PART_BOUNDARY); + } + + /// Checks that on degenerate case when trie is a single path state + /// parts are still distributed evenly. + #[test] + fn single_path_trie() { + // Values should be big enough to ensure that node and key overhead are + // not significant. + let value_len = 1000usize; + // Corner case when trie is a single path from empty string to "aaaa". + let trie_changes = vec![ + (b"a".to_vec(), Some(vec![1; value_len])), + (b"aa".to_vec(), Some(vec![2; value_len])), + (b"aaa".to_vec(), Some(vec![3; value_len])), + (b"aaaa".to_vec(), Some(vec![4; value_len])), + ]; + // We split state into `num_keys + 1` parts for convenience of testing, + // because right boundaries are exclusive. This way first part is + // empty and other parts contain exactly one key. + let num_parts = trie_changes.len() + 1; + + let tries = create_tries(); + let state_root = test_populate_trie( + &tries, + &Trie::EMPTY_ROOT, + ShardUId::single_shard(), + trie_changes.clone(), + ); + let trie = tries.get_trie_for_shard(ShardUId::single_shard(), state_root); + + for part_id in 1..num_parts { + let nibbles_boundary = + trie.find_state_part_boundary(part_id as u64, num_parts as u64).unwrap(); + let key_boundary = nibbles_to_bytes(&nibbles_boundary); + assert_eq!(key_boundary, trie_changes[part_id - 1].0); + } + } + impl Trie { /// Combines all parts and returns TrieChanges that can be applied to storage. /// @@ -504,6 +591,10 @@ mod tests { trie_changes } + /// Helper function checking that for given trie generator size of each + /// part is approximately bounded by `total_size / num_parts` with overhead + /// for proof and trie items irregularity. + /// TODO (#8997): run it on largest keys (2KB) and values (4MB) allowed in mainnet. fn run_test_parts_not_huge(gen_trie_changes: F, big_value_length: u64) where F: FnOnce(&mut ThreadRng, u64, u64) -> Vec<(Vec, Option>)>, @@ -511,41 +602,56 @@ mod tests { let mut rng = rand::thread_rng(); let max_key_length = 50u64; let max_key_length_in_nibbles = max_key_length * 2; - let max_node_serialized_size = 32 * 16 + 100; // DEVNOTE nodes can be pretty big let max_node_children = 16; + let max_node_serialized_size = 32 * max_node_children + 100; // Full branch node overhead. + let max_proof_overhead = + max_key_length_in_nibbles * max_node_children * max_node_serialized_size; let max_part_overhead = - max_key_length_in_nibbles * max_node_serialized_size * max_node_children * 2 - + big_value_length * 2; - println!("Max allowed overhead: {}", max_part_overhead); + big_value_length.max(max_key_length_in_nibbles * max_node_serialized_size * 2); let trie_changes = gen_trie_changes(&mut rng, max_key_length, big_value_length); - println!("Number of nodes: {}", trie_changes.len()); let tries = create_tries(); let state_root = test_populate_trie(&tries, &Trie::EMPTY_ROOT, ShardUId::single_shard(), trie_changes); let trie = tries.get_trie_for_shard(ShardUId::single_shard(), state_root); let memory_size = trie.retrieve_root_node().unwrap().memory_usage; - println!("Total memory size: {}", memory_size); for num_parts in [2, 3, 5, 10, 50].iter().cloned() { - let approximate_size_per_part = memory_size / num_parts; - let parts = (0..num_parts) - .map(|part_id| { - trie.get_trie_nodes_for_part(PartId::new(part_id, num_parts)).unwrap() - }) - .collect::>(); - let part_nodecounts_vec = - parts.iter().map(|PartialState::Nodes(nodes)| nodes.len()).collect::>(); - let sizes_vec = parts - .iter() - .map(|PartialState::Nodes(nodes)| { - nodes.iter().map(|node| node.len()).sum::() - }) - .collect::>(); + let part_size_limit = (memory_size + num_parts - 1) / num_parts; + + for part_id in 0..num_parts { + // Compute proof with size and check that it doesn't exceed theoretical boundary for + // the path with full set of left siblings of maximal possible size. + let trie_recording = trie.recording_reads(); + let left_nibbles_boundary = + trie_recording.find_state_part_boundary(part_id, num_parts).unwrap(); + let left_key_boundary = nibbles_to_bytes(&left_nibbles_boundary); + if part_id != 0 { + assert_matches!(trie.get(&left_key_boundary), Ok(Some(_))); + } + let PartialState::Nodes(proof_nodes) = + trie_recording.recorded_storage().unwrap().nodes; + let proof_size = proof_nodes.iter().map(|node| node.len()).sum::() as u64; + assert!( + proof_size <= max_proof_overhead, + "For part {}/{} left boundary proof size {} exceeds limit {}", + part_id, + num_parts, + proof_size, + max_proof_overhead + ); - println!("Node counts of parts: {:?}", part_nodecounts_vec); - println!("Sizes of parts: {:?}", sizes_vec); - println!("Max size we allow: {}", approximate_size_per_part + max_part_overhead); - for size in sizes_vec { - assert!((size as u64) < approximate_size_per_part + max_part_overhead); + let PartialState::Nodes(part_nodes) = + trie.get_trie_nodes_for_part(PartId::new(part_id, num_parts)).unwrap(); + // TODO (#8997): it's a bit weird that raw lengths are compared to + // config values. Consider better defined assertion. + let total_size = part_nodes.iter().map(|node| node.len()).sum::() as u64; + assert!( + total_size <= part_size_limit + proof_size + max_part_overhead, + "Part {}/{} is too big. Size: {}, size limit: {}", + part_id, + num_parts, + total_size, + part_size_limit + proof_size + max_part_overhead, + ); } } } @@ -555,6 +661,9 @@ mod tests { run_test_parts_not_huge(construct_trie_for_big_parts_1, 100_000); } + /// TODO (#8997): consider: + /// * adding more testcases for big and small key/value lengths, other trie structures; + /// * speeding this test up. #[test] fn test_parts_not_huge_2() { run_test_parts_not_huge(construct_trie_for_big_parts_2, 100_000); @@ -591,6 +700,7 @@ mod tests { for _ in 0..2000 { let tries = create_tries(); let trie_changes = gen_changes(&mut rng, 20); + println!("{:?}", trie_changes); let state_root = test_populate_trie( &tries, &Trie::EMPTY_ROOT, @@ -651,6 +761,27 @@ mod tests { } } + fn format_simple_trie_refcount_diff( + left: &[TrieRefcountChange], + right: &[TrieRefcountChange], + ) -> String { + let left_set: HashSet<_> = HashSet::from_iter(left.iter()); + let right_set: HashSet<_> = HashSet::from_iter(right.iter()); + format!( + "left: {:?} right: {:?}", + left_set.difference(&right_set), + right_set.difference(&left_set) + ) + } + + fn format_simple_trie_changes_diff(left: &TrieChanges, right: &TrieChanges) -> String { + format!( + "insertions diff: {}, deletions diff: {}", + format_simple_trie_refcount_diff(&left.insertions, &right.insertions), + format_simple_trie_refcount_diff(&left.deletions, &right.deletions) + ) + } + /// Helper function checking that two ways of combining state parts are identical: /// 1) Create partial storage over all nodes in state parts and traverse all /// nodes in the storage; @@ -673,12 +804,22 @@ mod tests { .trie_changes }) .collect::>(); + + println!("{:?}", changes); merge_trie_changes(changes) }; - assert_eq!(trie_changes, trie_changes_new); + assert_eq!( + trie_changes, + trie_changes_new, + "{}", + format_simple_trie_changes_diff(&trie_changes, &trie_changes_new) + ); trie_changes } + /// Check on random samples that state parts can be validated independently + /// from the entire trie. + /// TODO (#8997): add custom tests where incorrect parts don't pass validation. #[test] fn test_get_trie_nodes_for_part() { let mut rng = rand::thread_rng();