Skip to content

Commit

Permalink
more state part tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Longarithm committed May 11, 2023
1 parent 7f25c62 commit 362f011
Show file tree
Hide file tree
Showing 3 changed files with 192 additions and 38 deletions.
6 changes: 6 additions & 0 deletions core/store/src/trie/iterator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,9 @@ impl<'a> TrieIterator<'a> {
TrieNode::Empty => break,
TrieNode::Leaf(leaf_key, _) => {
let existing_key = NibbleSlice::from_encoded(leaf_key).0;
println!("see leaf {:?} {:?}", leaf_key, existing_key);
if !check_ext_key(&key, &existing_key) {
println!("yes");
self.key_nibbles.extend(existing_key.iter());
*status = CrumbStatus::Exiting;
}
Expand Down Expand Up @@ -356,6 +358,9 @@ impl<'a> TrieIterator<'a> {
}
IterStep::Continue => {}
IterStep::Value(hash) => {
if self.key_nibbles[prefix..] >= path_end[prefix..] {
break;
}
self.trie.storage.retrieve_raw_bytes(&hash)?;
nodes_list.push(TrieTraversalItem {
hash,
Expand All @@ -368,6 +373,7 @@ impl<'a> TrieIterator<'a> {
}
}

#[derive(Debug)]
enum IterStep {
Continue,
PopTrail,
Expand Down
7 changes: 7 additions & 0 deletions core/store/src/trie/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ pub use raw_node::{Children, RawTrieNode, RawTrieNodeWithSize};
use std::cell::RefCell;
use std::collections::HashMap;
use std::fmt::Write;
use std::hash::{Hash, Hasher};
use std::rc::Rc;
use std::str;

Expand Down Expand Up @@ -353,6 +354,12 @@ impl TrieRefcountChange {
}
}

impl Hash for TrieRefcountChange {
fn hash<H: Hasher>(&self, state: &mut H) {
state.write(&self.trie_node_or_value_hash.0);
state.write_u32(self.rc.into());
}
}
///
/// TrieChanges stores delta for refcount.
/// Multiple versions of the state work the following way:
Expand Down
217 changes: 179 additions & 38 deletions core/store/src/trie/state_parts.rs
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ impl Trie {
}
let root_node = self.retrieve_node(&self.root)?.1;
let total_size = root_node.memory_usage;
let size_start = (total_size + num_parts - 1) / num_parts * part_id;
let size_start = total_size / num_parts * part_id + part_id.min(total_size % num_parts);
self.find_node_in_dfs_order(&root_node, size_start)
}

Expand Down Expand Up @@ -115,29 +115,31 @@ impl Trie {
/// Tells a child in which we should go to find a node in dfs order corresponding
/// to `memory_threshold`.
/// Accumulates `memory_skipped` as memory used by all skipped nodes.
/// Returns false if we already find desired node and should stop the process.
/// Returns false if we already found desired node and should stop the process.
fn find_child_in_dfs_order(
&self,
memory_threshold: u64,
node: &mut TrieNodeWithSize,
memory_skipped: &mut u64,
key_nibbles: &mut Vec<u8>,
) -> Result<bool, StorageError> {
let node_size = node.node.memory_usage_direct_no_memory();
if *memory_skipped + node_size <= memory_threshold {
*memory_skipped += node_size;
} else if node.node.has_value() {
return Ok(false);
}
*memory_skipped += node.node.memory_usage_direct_no_memory();

match &node.node {
TrieNode::Empty => Ok(false),
TrieNode::Leaf(key, _) => {
let (slice, _is_leaf) = NibbleSlice::from_encoded(key);
let (slice, _) = NibbleSlice::from_encoded(key);
key_nibbles.extend(slice.iter());

// Leaf must contain value, so we found the boundary.
Ok(false)
}
TrieNode::Branch(children, _) => {
TrieNode::Branch(children, value_handle) => {
if *memory_skipped > memory_threshold && value_handle.is_some() {
// If we skipped enough memory and found some value, we found the boundary.
return Ok(false);
}

let mut iter = children.iter();
while let Some((index, child)) = iter.next() {
let child = if let NodeHandle::Hash(h) = child {
Expand All @@ -147,6 +149,7 @@ impl Trie {
};
if *memory_skipped + child.memory_usage > memory_threshold {
core::mem::drop(iter);
// println!("push {:?}", vec![index]);
key_nibbles.push(index);
*node = child;
return Ok(true);
Expand Down Expand Up @@ -174,7 +177,7 @@ impl Trie {
NodeHandle::InMemory(_) => unreachable!("only possible while mutating"),
NodeHandle::Hash(h) => self.retrieve_node(h)?.1,
};
let (slice, _is_leaf) = NibbleSlice::from_encoded(key);
let (slice, _) = NibbleSlice::from_encoded(key);
key_nibbles.extend(slice.iter());
*node = child;
Ok(true)
Expand Down Expand Up @@ -245,6 +248,7 @@ impl Trie {
contract_codes: vec![],
});
}
println!("NODES: {:?}", part);
let trie = Trie::from_recorded_storage(PartialStorage { nodes: part }, *state_root);
let path_begin = trie.find_state_part_boundary(part_id.idx, part_id.total)?;
let path_end = trie.find_state_part_boundary(part_id.idx + 1, part_id.total)?;
Expand Down Expand Up @@ -295,9 +299,13 @@ impl Trie {
}
}

/// TODO (#8997): test set seems incomplete. Perhaps `get_trie_items_for_part`
/// should also belong to this file. We need to use it to check that state
/// parts are continuous and disjoint. Maybe it is checked in split_state.rs.
#[cfg(test)]
mod tests {
use std::collections::HashMap;
use assert_matches::assert_matches;
use std::collections::{HashMap, HashSet};
use std::sync::Arc;

use rand::prelude::ThreadRng;
Expand All @@ -312,6 +320,85 @@ mod tests {
use super::*;
use near_primitives::shard_layout::ShardUId;

fn nibbles_to_bytes(nibbles: &[u8]) -> Vec<u8> {
assert_eq!(nibbles.len() % 2, 0);
let encoded = NibbleSlice::encode_nibbles(&nibbles, false);
encoded[1..].to_vec()
}

/// Checks that sampling state boundaries results always gives valid state
/// keys, even if trie contains intermediate nodes.
#[test]
fn boundary_is_state_key() {
// Trie should contain at least two intermediate branches, for strings
// "a" and "b".
let trie_changes = vec![
(b"after".to_vec(), Some(vec![1])),
(b"alley".to_vec(), Some(vec![2])),
(b"berry".to_vec(), Some(vec![3])),
(b"brave".to_vec(), Some(vec![4])),
];

// Number of state parts. Must be larger than number of state items to ensure
// that boundaries are nontrivial.
let num_parts = 10u64;

let tries = create_tries();
let state_root =
test_populate_trie(&tries, &Trie::EMPTY_ROOT, ShardUId::single_shard(), trie_changes);
let trie = tries.get_trie_for_shard(ShardUId::single_shard(), state_root);

let nibbles_boundary = trie.find_state_part_boundary(0, num_parts).unwrap();
assert!(nibbles_boundary.is_empty());

// Check that all boundaries correspond to some state key by calling `Trie::get`.
// Note that some state parts can be trivial, which is not a concern.
for part_id in 1..num_parts {
let nibbles_boundary = trie.find_state_part_boundary(part_id, num_parts).unwrap();
let key_boundary = nibbles_to_bytes(&nibbles_boundary);
assert_matches!(trie.get(&key_boundary), Ok(Some(_)));
}

let nibbles_boundary = trie.find_state_part_boundary(num_parts, num_parts).unwrap();
assert_eq!(nibbles_boundary, LAST_STATE_PART_BOUNDARY);
}

/// Checks that on degenerate case when trie is a single path state
/// parts are still distributed evenly.
#[test]
fn single_path_trie() {
// Values should be big enough to ensure that node and key overhead are
// not significant.
let value_len = 1000usize;
// Corner case when trie is a single path from empty string to "aaaa".
let trie_changes = vec![
(b"a".to_vec(), Some(vec![1; value_len])),
(b"aa".to_vec(), Some(vec![2; value_len])),
(b"aaa".to_vec(), Some(vec![3; value_len])),
(b"aaaa".to_vec(), Some(vec![4; value_len])),
];
// We split state into `num_keys + 1` parts for convenience of testing,
// because right boundaries are exclusive. This way first part is
// empty and other parts contain exactly one key.
let num_parts = trie_changes.len() + 1;

let tries = create_tries();
let state_root = test_populate_trie(
&tries,
&Trie::EMPTY_ROOT,
ShardUId::single_shard(),
trie_changes.clone(),
);
let trie = tries.get_trie_for_shard(ShardUId::single_shard(), state_root);

for part_id in 1..num_parts {
let nibbles_boundary =
trie.find_state_part_boundary(part_id as u64, num_parts as u64).unwrap();
let key_boundary = nibbles_to_bytes(&nibbles_boundary);
assert_eq!(key_boundary, trie_changes[part_id - 1].0);
}
}

impl Trie {
/// Combines all parts and returns TrieChanges that can be applied to storage.
///
Expand Down Expand Up @@ -504,48 +591,67 @@ mod tests {
trie_changes
}

/// Helper function checking that for given trie generator size of each
/// part is approximately bounded by `total_size / num_parts` with overhead
/// for proof and trie items irregularity.
/// TODO (#8997): run it on largest keys (2KB) and values (4MB) allowed in mainnet.
fn run_test_parts_not_huge<F>(gen_trie_changes: F, big_value_length: u64)
where
F: FnOnce(&mut ThreadRng, u64, u64) -> Vec<(Vec<u8>, Option<Vec<u8>>)>,
{
let mut rng = rand::thread_rng();
let max_key_length = 50u64;
let max_key_length_in_nibbles = max_key_length * 2;
let max_node_serialized_size = 32 * 16 + 100; // DEVNOTE nodes can be pretty big
let max_node_children = 16;
let max_node_serialized_size = 32 * max_node_children + 100; // Full branch node overhead.
let max_proof_overhead =
max_key_length_in_nibbles * max_node_children * max_node_serialized_size;
let max_part_overhead =
max_key_length_in_nibbles * max_node_serialized_size * max_node_children * 2
+ big_value_length * 2;
println!("Max allowed overhead: {}", max_part_overhead);
big_value_length.max(max_key_length_in_nibbles * max_node_serialized_size * 2);
let trie_changes = gen_trie_changes(&mut rng, max_key_length, big_value_length);
println!("Number of nodes: {}", trie_changes.len());
let tries = create_tries();
let state_root =
test_populate_trie(&tries, &Trie::EMPTY_ROOT, ShardUId::single_shard(), trie_changes);
let trie = tries.get_trie_for_shard(ShardUId::single_shard(), state_root);
let memory_size = trie.retrieve_root_node().unwrap().memory_usage;
println!("Total memory size: {}", memory_size);
for num_parts in [2, 3, 5, 10, 50].iter().cloned() {
let approximate_size_per_part = memory_size / num_parts;
let parts = (0..num_parts)
.map(|part_id| {
trie.get_trie_nodes_for_part(PartId::new(part_id, num_parts)).unwrap()
})
.collect::<Vec<_>>();
let part_nodecounts_vec =
parts.iter().map(|PartialState::Nodes(nodes)| nodes.len()).collect::<Vec<_>>();
let sizes_vec = parts
.iter()
.map(|PartialState::Nodes(nodes)| {
nodes.iter().map(|node| node.len()).sum::<usize>()
})
.collect::<Vec<_>>();
let part_size_limit = (memory_size + num_parts - 1) / num_parts;

for part_id in 0..num_parts {
// Compute proof with size and check that it doesn't exceed theoretical boundary for
// the path with full set of left siblings of maximal possible size.
let trie_recording = trie.recording_reads();
let left_nibbles_boundary =
trie_recording.find_state_part_boundary(part_id, num_parts).unwrap();
let left_key_boundary = nibbles_to_bytes(&left_nibbles_boundary);
if part_id != 0 {
assert_matches!(trie.get(&left_key_boundary), Ok(Some(_)));
}
let PartialState::Nodes(proof_nodes) =
trie_recording.recorded_storage().unwrap().nodes;
let proof_size = proof_nodes.iter().map(|node| node.len()).sum::<usize>() as u64;
assert!(
proof_size <= max_proof_overhead,
"For part {}/{} left boundary proof size {} exceeds limit {}",
part_id,
num_parts,
proof_size,
max_proof_overhead
);

println!("Node counts of parts: {:?}", part_nodecounts_vec);
println!("Sizes of parts: {:?}", sizes_vec);
println!("Max size we allow: {}", approximate_size_per_part + max_part_overhead);
for size in sizes_vec {
assert!((size as u64) < approximate_size_per_part + max_part_overhead);
let PartialState::Nodes(part_nodes) =
trie.get_trie_nodes_for_part(PartId::new(part_id, num_parts)).unwrap();
// TODO (#8997): it's a bit weird that raw lengths are compared to
// config values. Consider better defined assertion.
let total_size = part_nodes.iter().map(|node| node.len()).sum::<usize>() as u64;
assert!(
total_size <= part_size_limit + proof_size + max_part_overhead,
"Part {}/{} is too big. Size: {}, size limit: {}",
part_id,
num_parts,
total_size,
part_size_limit + proof_size + max_part_overhead,
);
}
}
}
Expand All @@ -555,6 +661,9 @@ mod tests {
run_test_parts_not_huge(construct_trie_for_big_parts_1, 100_000);
}

/// TODO (#8997): consider:
/// * adding more testcases for big and small key/value lengths, other trie structures;
/// * speeding this test up.
#[test]
fn test_parts_not_huge_2() {
run_test_parts_not_huge(construct_trie_for_big_parts_2, 100_000);
Expand Down Expand Up @@ -591,6 +700,7 @@ mod tests {
for _ in 0..2000 {
let tries = create_tries();
let trie_changes = gen_changes(&mut rng, 20);
println!("{:?}", trie_changes);
let state_root = test_populate_trie(
&tries,
&Trie::EMPTY_ROOT,
Expand Down Expand Up @@ -651,6 +761,27 @@ mod tests {
}
}

fn format_simple_trie_refcount_diff(
left: &[TrieRefcountChange],
right: &[TrieRefcountChange],
) -> String {
let left_set: HashSet<_> = HashSet::from_iter(left.iter());
let right_set: HashSet<_> = HashSet::from_iter(right.iter());
format!(
"left: {:?} right: {:?}",
left_set.difference(&right_set),
right_set.difference(&left_set)
)
}

fn format_simple_trie_changes_diff(left: &TrieChanges, right: &TrieChanges) -> String {
format!(
"insertions diff: {}, deletions diff: {}",
format_simple_trie_refcount_diff(&left.insertions, &right.insertions),
format_simple_trie_refcount_diff(&left.deletions, &right.deletions)
)
}

/// Helper function checking that two ways of combining state parts are identical:
/// 1) Create partial storage over all nodes in state parts and traverse all
/// nodes in the storage;
Expand All @@ -673,12 +804,22 @@ mod tests {
.trie_changes
})
.collect::<Vec<_>>();

println!("{:?}", changes);
merge_trie_changes(changes)
};
assert_eq!(trie_changes, trie_changes_new);
assert_eq!(
trie_changes,
trie_changes_new,
"{}",
format_simple_trie_changes_diff(&trie_changes, &trie_changes_new)
);
trie_changes
}

/// Check on random samples that state parts can be validated independently
/// from the entire trie.
/// TODO (#8997): add custom tests where incorrect parts don't pass validation.
#[test]
fn test_get_trie_nodes_for_part() {
let mut rng = rand::thread_rng();
Expand Down

0 comments on commit 362f011

Please sign in to comment.