diff --git a/core/store/src/trie/mem/mod.rs b/core/store/src/trie/mem/mod.rs index b44988ea885..55dc3676fd2 100644 --- a/core/store/src/trie/mem/mod.rs +++ b/core/store/src/trie/mem/mod.rs @@ -143,7 +143,7 @@ impl MemTries { pub fn update( &self, root: CryptoHash, - track_disk_changes: bool, + track_trie_changes: bool, ) -> Result { let root_id = if root == CryptoHash::default() { None @@ -163,7 +163,7 @@ impl MemTries { root_id, &self.arena.memory(), self.shard_uid.to_string(), - track_disk_changes, + track_trie_changes, )) } } diff --git a/core/store/src/trie/mem/updating.rs b/core/store/src/trie/mem/updating.rs index fe82d68a4af..6360c2a3143 100644 --- a/core/store/src/trie/mem/updating.rs +++ b/core/store/src/trie/mem/updating.rs @@ -9,6 +9,7 @@ use near_primitives::hash::{hash, CryptoHash}; use near_primitives::state::FlatStateValue; use near_primitives::types::BlockHeight; use std::collections::HashMap; +use std::sync::Arc; /// An old node means a node in the current in-memory trie. An updated node means a /// node we're going to store in the in-memory trie but have not constructed there yet. @@ -43,6 +44,28 @@ pub enum UpdatedMemTrieNode { }, } +/// Keeps values and internal nodes accessed on updating memtrie. +pub(crate) struct TrieAccesses { + /// Hashes and encoded trie nodes. + pub nodes: HashMap>, + /// Hashes of accessed values - because values themselves are not + /// necessarily present in memtrie. + pub values: HashMap, +} + +/// Tracks intermediate trie changes, final version of which is to be committed +/// to disk after finishing trie update. +struct TrieChangesTracker { + /// Changes of reference count on disk for each impacted node. + refcount_changes: TrieRefcountDeltaMap, + /// All observed values and internal nodes. + /// Needed to prepare recorded storage. + /// Note that negative `refcount_changes` does not fully cover it, as node + /// or value of the same hash can be removed and inserted for the same + /// update in different parts of trie! + accesses: TrieAccesses, +} + /// Structure to build an update to the in-memory trie. pub struct MemTrieUpdate<'a> { /// The original root before updates. It is None iff the original trie had no keys. @@ -53,8 +76,9 @@ pub struct MemTrieUpdate<'a> { /// (1) temporarily we take out the node from the slot to process it and put it back /// later; or (2) the node is deleted afterwards. pub updated_nodes: Vec>, - /// Refcount changes to on-disk trie nodes. - pub trie_refcount_changes: Option, + /// Tracks trie changes necessary to make on-disk updates and recorded + /// storage. + tracked_trie_changes: Option, } impl UpdatedMemTrieNode { @@ -97,15 +121,18 @@ impl<'a> MemTrieUpdate<'a> { root: Option, arena: &'a ArenaMemory, shard_uid: String, - track_disk_changes: bool, + track_trie_changes: bool, ) -> Self { let mut trie_update = Self { root, arena, shard_uid, updated_nodes: vec![], - trie_refcount_changes: if track_disk_changes { - Some(TrieRefcountDeltaMap::new()) + tracked_trie_changes: if track_trie_changes { + Some(TrieChangesTracker { + refcount_changes: TrieRefcountDeltaMap::new(), + accesses: TrieAccesses { nodes: HashMap::new(), values: HashMap::new() }, + }) } else { None }, @@ -145,8 +172,16 @@ impl<'a> MemTrieUpdate<'a> { match node { None => self.new_updated_node(UpdatedMemTrieNode::Empty), Some(node) => { - if let Some(trie_refcount_changes) = self.trie_refcount_changes.as_mut() { - trie_refcount_changes.subtract(node.as_ptr(self.arena).view().node_hash(), 1); + if let Some(tracked_trie_changes) = self.tracked_trie_changes.as_mut() { + let node_view = node.as_ptr(self.arena).view(); + let node_hash = node_view.node_hash(); + let raw_node_serialized = + borsh::to_vec(&node_view.to_raw_trie_node_with_size()).unwrap(); + tracked_trie_changes + .accesses + .nodes + .insert(node_hash, raw_node_serialized.into()); + tracked_trie_changes.refcount_changes.subtract(node_hash, 1); } self.new_updated_node(UpdatedMemTrieNode::from_existing_node_view( node.as_ptr(self.arena).view(), @@ -164,14 +199,16 @@ impl<'a> MemTrieUpdate<'a> { } fn add_refcount_to_value(&mut self, hash: CryptoHash, value: Option>) { - if let Some(trie_refcount_changes) = self.trie_refcount_changes.as_mut() { - trie_refcount_changes.add(hash, value.unwrap(), 1); + if let Some(tracked_node_changes) = self.tracked_trie_changes.as_mut() { + tracked_node_changes.refcount_changes.add(hash, value.unwrap(), 1); } } - fn subtract_refcount_for_value(&mut self, hash: CryptoHash) { - if let Some(trie_refcount_changes) = self.trie_refcount_changes.as_mut() { - trie_refcount_changes.subtract(hash, 1); + fn subtract_refcount_for_value(&mut self, value: FlatStateValue) { + if let Some(tracked_node_changes) = self.tracked_trie_changes.as_mut() { + let hash = value.to_value_ref().hash; + tracked_node_changes.accesses.values.insert(hash, value); + tracked_node_changes.refcount_changes.subtract(hash, 1); } } @@ -219,7 +256,7 @@ impl<'a> MemTrieUpdate<'a> { if partial.is_empty() { // This branch node is exactly where the value should be added. if let Some(value) = old_value { - self.subtract_refcount_for_value(value.to_value_ref().hash); + self.subtract_refcount_for_value(value); } self.place_node( node_id, @@ -250,7 +287,7 @@ impl<'a> MemTrieUpdate<'a> { let common_prefix = partial.common_prefix(&existing_key); if common_prefix == existing_key.len() && common_prefix == partial.len() { // We're at the exact leaf. Rewrite the value at this leaf. - self.subtract_refcount_for_value(old_value.to_value_ref().hash); + self.subtract_refcount_for_value(old_value); self.place_node( node_id, UpdatedMemTrieNode::Leaf { extension, value: flat_value }, @@ -389,7 +426,7 @@ impl<'a> MemTrieUpdate<'a> { } UpdatedMemTrieNode::Leaf { extension, value } => { if NibbleSlice::from_encoded(&extension).0 == partial { - self.subtract_refcount_for_value(value.to_value_ref().hash); + self.subtract_refcount_for_value(value); self.place_node(node_id, UpdatedMemTrieNode::Empty); break; } else { @@ -408,7 +445,7 @@ impl<'a> MemTrieUpdate<'a> { ); return; }; - self.subtract_refcount_for_value(value.unwrap().to_value_ref().hash); + self.subtract_refcount_for_value(value.unwrap()); self.place_node( node_id, UpdatedMemTrieNode::Branch { children: old_children, value: None }, @@ -779,31 +816,36 @@ impl<'a> MemTrieUpdate<'a> { } /// Converts the updates to trie changes as well as memtrie changes. - pub fn to_trie_changes(self) -> TrieChanges { - let Self { root, arena, shard_uid, trie_refcount_changes, updated_nodes } = self; - let mut trie_refcount_changes = - trie_refcount_changes.expect("Cannot to_trie_changes for memtrie changes only"); + pub(crate) fn to_trie_changes(self) -> (TrieChanges, TrieAccesses) { + let Self { root, arena, shard_uid, tracked_trie_changes, updated_nodes } = self; + let TrieChangesTracker { mut refcount_changes, accesses } = + tracked_trie_changes.expect("Cannot to_trie_changes for memtrie changes only"); let (mem_trie_changes, hashes_and_serialized) = Self::to_mem_trie_changes_internal(shard_uid, arena, updated_nodes); // We've accounted for the dereferenced nodes, as well as value addition/subtractions. // The only thing left is to increment refcount for all new nodes. for (node_hash, node_serialized) in hashes_and_serialized { - trie_refcount_changes.add(node_hash, node_serialized, 1); - } - let (insertions, deletions) = trie_refcount_changes.into_changes(); - - TrieChanges { - old_root: root.map(|root| root.as_ptr(arena).view().node_hash()).unwrap_or_default(), - new_root: mem_trie_changes - .node_ids_with_hashes - .last() - .map(|(_, hash)| *hash) - .unwrap_or_default(), - insertions, - deletions, - mem_trie_changes: Some(mem_trie_changes), + refcount_changes.add(node_hash, node_serialized, 1); } + let (insertions, deletions) = refcount_changes.into_changes(); + + ( + TrieChanges { + old_root: root + .map(|root| root.as_ptr(arena).view().node_hash()) + .unwrap_or_default(), + new_root: mem_trie_changes + .node_ids_with_hashes + .last() + .map(|(_, hash)| *hash) + .unwrap_or_default(), + insertions, + deletions, + mem_trie_changes: Some(mem_trie_changes), + }, + accesses, + ) } } @@ -917,7 +959,7 @@ mod tests { update.delete(&key); } } - update.to_trie_changes() + update.to_trie_changes().0 } fn make_memtrie_changes_only( diff --git a/core/store/src/trie/mod.rs b/core/store/src/trie/mod.rs index 99000f2825c..41bcebb4c6e 100644 --- a/core/store/src/trie/mod.rs +++ b/core/store/src/trie/mod.rs @@ -1498,7 +1498,44 @@ impl Trie { None => trie_update.delete(&key), } } - Ok(trie_update.to_trie_changes()) + let (trie_changes, trie_accesses) = trie_update.to_trie_changes(); + + // Sanity check for tests: all modified trie items must be + // present in ever accessed trie items. + #[cfg(test)] + { + for t in trie_changes.deletions.iter() { + let hash = t.trie_node_or_value_hash; + assert!( + trie_accesses.values.contains_key(&hash) + || trie_accesses.nodes.contains_key(&hash), + "Hash {} is not present in trie accesses", + hash + ); + } + } + + // Retroactively record all accessed trie items which are + // required to process trie update but were not recorded at + // processing lookups. + // The main case is a branch with two children, one of which + // got removed, so we need to read another one and squash it + // together with parent. + if let Some(recorder) = &self.recorder { + for (node_hash, serialized_node) in trie_accesses.nodes { + recorder.borrow_mut().record(&node_hash, serialized_node); + } + for (value_hash, value) in trie_accesses.values { + let value = match value { + FlatStateValue::Ref(_) => { + self.storage.retrieve_raw_bytes(&value_hash)? + } + FlatStateValue::Inlined(value) => value.into(), + }; + recorder.borrow_mut().record(&value_hash, value); + } + } + Ok(trie_changes) } None => { let mut memory = NodesStorage::new(); diff --git a/core/store/src/trie/trie_recording.rs b/core/store/src/trie/trie_recording.rs index b63c0996f82..cc411db5b3f 100644 --- a/core/store/src/trie/trie_recording.rs +++ b/core/store/src/trie/trie_recording.rs @@ -44,12 +44,13 @@ mod trie_recording_tests { use crate::trie::mem::metrics::MEM_TRIE_NUM_LOOKUPS; use crate::trie::TrieNodesCount; use crate::{DBCol, Store, Trie}; + use borsh::BorshDeserialize; use near_primitives::hash::{hash, CryptoHash}; - use near_primitives::shard_layout::{get_block_shard_uid, get_block_shard_uid_rev, ShardUId}; + use near_primitives::shard_layout::{get_block_shard_uid, ShardUId}; use near_primitives::state::ValueRef; use near_primitives::types::chunk_extra::ChunkExtra; use near_primitives::types::StateRoot; - use rand::{thread_rng, Rng}; + use rand::{random, thread_rng, Rng}; use std::collections::{HashMap, HashSet}; use std::num::NonZeroU32; @@ -66,6 +67,8 @@ mod trie_recording_tests { /// The keys that we should be using to call get_optimized_ref() on the /// trie with. keys_to_get_ref: Vec>, + /// The keys to be updated after trie reads. + updates: Vec<(Vec, Option>)>, state_root: StateRoot, } @@ -121,13 +124,26 @@ mod trie_recording_tests { } key }) - .partition::, _>(|_| thread_rng().gen()); + .partition::, _>(|_| random()); + let updates = trie_changes + .iter() + .map(|(key, _)| { + let value = if thread_rng().gen_bool(0.5) { + Some(vec![thread_rng().gen_range(0..10) as u8]) + } else { + None + }; + (key.clone(), value) + }) + .filter(|_| random()) + .collect::>(); PreparedTrie { store: tries_for_building.get_store(), shard_uid, data_in_trie, keys_to_get, keys_to_get_ref, + updates, state_root, } } @@ -146,7 +162,7 @@ mod trie_recording_tests { for result in store.iter_raw_bytes(DBCol::State) { let (key, value) = result.unwrap(); let (_, refcount) = decode_value_with_rc(&value); - let (key_hash, _) = get_block_shard_uid_rev(&key).unwrap(); + let key_hash: CryptoHash = CryptoHash::try_from_slice(&key[8..]).unwrap(); if !key_hashes_to_keep.contains(&key_hash) { update.decrement_refcount_by( DBCol::State, @@ -174,6 +190,7 @@ mod trie_recording_tests { data_in_trie, keys_to_get, keys_to_get_ref, + updates, state_root, } = prepare_trie(use_missing_keys); let tries = if use_in_memory_tries { @@ -206,6 +223,7 @@ mod trie_recording_tests { } let baseline_trie_nodes_count = trie.get_trie_nodes_count(); println!("Baseline trie nodes count: {:?}", baseline_trie_nodes_count); + trie.update(updates.iter().cloned()).unwrap(); // Now let's do this again while recording, and make sure that the counters // we get are exactly the same. @@ -223,6 +241,7 @@ mod trie_recording_tests { ); } assert_eq!(trie.get_trie_nodes_count(), baseline_trie_nodes_count); + trie.update(updates.iter().cloned()).unwrap(); // Now, let's check that when doing the same lookups with the captured partial storage, // we still get the same counters. @@ -246,6 +265,7 @@ mod trie_recording_tests { ); } assert_eq!(trie.get_trie_nodes_count(), baseline_trie_nodes_count); + trie.update(updates.iter().cloned()).unwrap(); if use_in_memory_tries { // sanity check that we did indeed use in-memory tries. @@ -310,6 +330,7 @@ mod trie_recording_tests { data_in_trie, keys_to_get, keys_to_get_ref, + updates, state_root, } = prepare_trie(use_missing_keys); let tries = if use_in_memory_tries { @@ -364,6 +385,7 @@ mod trie_recording_tests { } let baseline_trie_nodes_count = trie.get_trie_nodes_count(); println!("Baseline trie nodes count: {:?}", baseline_trie_nodes_count); + trie.update(updates.iter().cloned()).unwrap(); // Let's do this again, but this time recording reads. We'll make sure // the counters are exactly the same even when we're recording. @@ -388,6 +410,7 @@ mod trie_recording_tests { ); } assert_eq!(trie.get_trie_nodes_count(), baseline_trie_nodes_count); + trie.update(updates.iter().cloned()).unwrap(); // Now, let's check that when doing the same lookups with the captured partial storage, // we still get the same counters. @@ -411,6 +434,7 @@ mod trie_recording_tests { ); } assert_eq!(trie.get_trie_nodes_count(), baseline_trie_nodes_count); + trie.update(updates.iter().cloned()).unwrap(); if use_in_memory_tries { // sanity check that we did indeed use in-memory tries. diff --git a/core/store/src/trie/trie_tests.rs b/core/store/src/trie/trie_tests.rs index 767c2d5e8ba..7a6cf242ef4 100644 --- a/core/store/src/trie/trie_tests.rs +++ b/core/store/src/trie/trie_tests.rs @@ -419,10 +419,14 @@ mod trie_storage_tests { assert_eq!(count_delta.mem_reads, 1); } - // TODO(#10769): Make this test pass. + // Checks that when branch restructuring is triggered on updating trie, + // impacted child is still recorded. + // + // Needed when branch has two children, one of which is removed, branch + // could be converted to extension, so reading of the only remaining child + // is also required. #[test] - #[should_panic] - fn test_memtrie_discrepancy() { + fn test_memtrie_recorded_branch_restructuring() { init_test_logger(); let tries = TestTriesBuilder::new().build(); let shard_uid = ShardUId::single_shard();