From 2d5dd968832d4df7ff9f78a7eae7c834f93c8371 Mon Sep 17 00:00:00 2001 From: Saketh Are Date: Wed, 9 Oct 2024 15:51:45 +0100 Subject: [PATCH 1/3] fix(state sync): delete get_cached_state_parts (#12197) In testnet we observed `get_cached_state_parts` consistently taking ~58 seconds to run. Returning the list of cached state parts is a relic of the old decentralized state sync design. We no longer need this at all. --- chain/chain/src/chain.rs | 40 +---------------- chain/client/src/view_client_actor.rs | 30 +------------ core/primitives/src/state_sync.rs | 5 +-- .../src/tests/client/sync_state_nodes.rs | 43 +++---------------- 4 files changed, 11 insertions(+), 107 deletions(-) diff --git a/chain/chain/src/chain.rs b/chain/chain/src/chain.rs index aa6e24b9581..e46ff494642 100644 --- a/chain/chain/src/chain.rs +++ b/chain/chain/src/chain.rs @@ -35,7 +35,6 @@ use crate::{ Provenance, }; use crate::{metrics, DoomslugThresholdMode}; -use borsh::BorshDeserialize; use crossbeam_channel::{unbounded, Receiver, Sender}; use itertools::Itertools; use lru::LruCache; @@ -69,8 +68,8 @@ use near_primitives::sharding::{ }; use near_primitives::state_part::PartId; use near_primitives::state_sync::{ - get_num_state_parts, BitArray, CachedParts, ReceiptProofResponse, RootProof, - ShardStateSyncResponseHeader, ShardStateSyncResponseHeaderV2, StateHeaderKey, StatePartKey, + get_num_state_parts, ReceiptProofResponse, RootProof, ShardStateSyncResponseHeader, + ShardStateSyncResponseHeaderV2, StateHeaderKey, StatePartKey, }; use near_primitives::stateless_validation::state_witness::{ ChunkStateWitness, ChunkStateWitnessSize, @@ -3864,41 +3863,6 @@ impl Chain { Ok((make_snapshot, delete_snapshot)) } - - /// Returns a description of state parts cached for the given shard of the given epoch. - pub fn get_cached_state_parts( - &self, - sync_hash: CryptoHash, - shard_id: ShardId, - num_parts: u64, - ) -> Result { - let _span = tracing::debug_span!(target: "chain", "get_cached_state_parts").entered(); - // DBCol::StateParts is keyed by StatePartKey: (BlockHash || ShardId || PartId (u64)). - let lower_bound = StatePartKey(sync_hash, shard_id, 0); - let lower_bound = borsh::to_vec(&lower_bound)?; - let upper_bound = StatePartKey(sync_hash, shard_id + 1, 0); - let upper_bound = borsh::to_vec(&upper_bound)?; - let mut num_cached_parts = 0; - let mut bit_array = BitArray::new(num_parts); - for item in self.chain_store.store().iter_range( - DBCol::StateParts, - Some(&lower_bound), - Some(&upper_bound), - ) { - let key = item?.0; - let key = StatePartKey::try_from_slice(&key)?; - let part_id = key.2; - num_cached_parts += 1; - bit_array.set_bit(part_id); - } - Ok(if num_cached_parts == 0 { - CachedParts::NoParts - } else if num_cached_parts == num_parts { - CachedParts::AllParts - } else { - CachedParts::BitArray(bit_array) - }) - } } /// This method calculates the congestion info for the genesis chunks. It uses diff --git a/chain/client/src/view_client_actor.rs b/chain/client/src/view_client_actor.rs index b1151aba4e5..b5f12b26b9a 100644 --- a/chain/client/src/view_client_actor.rs +++ b/chain/client/src/view_client_actor.rs @@ -1358,17 +1358,6 @@ impl Handler for ViewClientActorInner { }; let state_response = match header { Some(header) => { - let num_parts = header.num_state_parts(); - let cached_parts = match self - .chain - .get_cached_state_parts(sync_hash, shard_id, num_parts) - { - Ok(cached_parts) => Some(cached_parts), - Err(err) => { - tracing::error!(target: "sync", ?err, ?sync_hash, shard_id, "Failed to get cached state parts"); - None - } - }; let header = match header { ShardStateSyncResponseHeader::V2(inner) => inner, _ => { @@ -1381,7 +1370,7 @@ impl Handler for ViewClientActorInner { ShardStateSyncResponse::V3(ShardStateSyncResponseV3 { header: Some(header), part: None, - cached_parts, + cached_parts: None, can_generate, }) } @@ -1444,26 +1433,11 @@ impl Handler for ViewClientActorInner { None } }; - let num_parts = part.as_ref().and_then(|_| match self.chain.get_state_response_header(shard_id, sync_hash) { - Ok(header) => Some(header.num_state_parts()), - Err(err) => { - tracing::error!(target: "sync", ?err, ?sync_hash, shard_id, "Failed to get num state parts"); - None - } - }); - let cached_parts = num_parts.and_then(|num_parts| - match self.chain.get_cached_state_parts(sync_hash, shard_id, num_parts) { - Ok(cached_parts) => Some(cached_parts), - Err(err) => { - tracing::error!(target: "sync", ?err, ?sync_hash, shard_id, "Failed to get cached state parts"); - None - } - }); let can_generate = part.is_some(); let state_response = ShardStateSyncResponse::V3(ShardStateSyncResponseV3 { header: None, part, - cached_parts, + cached_parts: None, can_generate, }); let info = diff --git a/core/primitives/src/state_sync.rs b/core/primitives/src/state_sync.rs index 53f3ca22f68..ceeddf337fb 100644 --- a/core/primitives/src/state_sync.rs +++ b/core/primitives/src/state_sync.rs @@ -180,11 +180,8 @@ pub struct ShardStateSyncResponseV2 { pub struct ShardStateSyncResponseV3 { pub header: Option, pub part: Option<(u64, Vec)>, - /// Parts that can be provided **cheaply**. - // Can be `None` only if both `header` and `part` are `None`. + // TODO(saketh): deprecate unused fields cached_parts and can_generate pub cached_parts: Option, - /// Whether the node can provide parts for this epoch of this shard. - /// Assumes that a node can either provide all state parts or no state parts. pub can_generate: bool, } diff --git a/integration-tests/src/tests/client/sync_state_nodes.rs b/integration-tests/src/tests/client/sync_state_nodes.rs index 008b76b1e5c..4c842263b08 100644 --- a/integration-tests/src/tests/client/sync_state_nodes.rs +++ b/integration-tests/src/tests/client/sync_state_nodes.rs @@ -18,7 +18,7 @@ use near_o11y::testonly::{init_integration_logger, init_test_logger}; use near_o11y::WithSpanContextExt; use near_primitives::shard_layout::ShardUId; use near_primitives::state_part::PartId; -use near_primitives::state_sync::{CachedParts, StatePartKey}; +use near_primitives::state_sync::StatePartKey; use near_primitives::transaction::SignedTransaction; use near_primitives::types::{BlockId, BlockReference, EpochId, EpochReference}; use near_primitives::utils::MaybeValidated; @@ -858,7 +858,6 @@ fn test_state_sync_headers() { None => return ControlFlow::Continue(()), }; let state_response = state_response_info.take_state_response(); - let cached_parts = state_response.cached_parts().clone(); let can_generate = state_response.can_generate(); assert!(state_response.part().is_none()); if let Some(_header) = state_response.take_header() { @@ -866,27 +865,14 @@ fn test_state_sync_headers() { tracing::info!( ?sync_hash, shard_id, - ?cached_parts, can_generate, "got header but cannot generate" ); return ControlFlow::Continue(()); } - tracing::info!( - ?sync_hash, - shard_id, - ?cached_parts, - can_generate, - "got header" - ); + tracing::info!(?sync_hash, shard_id, can_generate, "got header"); } else { - tracing::info!( - ?sync_hash, - shard_id, - ?cached_parts, - can_generate, - "got no header" - ); + tracing::info!(?sync_hash, shard_id, can_generate, "got no header"); return ControlFlow::Continue(()); } @@ -915,36 +901,19 @@ fn test_state_sync_headers() { let part = state_response.part().clone(); assert!(state_response.take_header().is_none()); if let Some((part_id, _part)) = part { - if !can_generate - || cached_parts != Some(CachedParts::AllParts) - || part_id != 0 - { + if !can_generate || cached_parts != None || part_id != 0 { tracing::info!( ?sync_hash, shard_id, - ?cached_parts, can_generate, part_id, "got part but shard info is unexpected" ); return ControlFlow::Continue(()); } - tracing::info!( - ?sync_hash, - shard_id, - ?cached_parts, - can_generate, - part_id, - "got part" - ); + tracing::info!(?sync_hash, shard_id, can_generate, part_id, "got part"); } else { - tracing::info!( - ?sync_hash, - shard_id, - ?cached_parts, - can_generate, - "got no part" - ); + tracing::info!(?sync_hash, shard_id, can_generate, "got no part"); return ControlFlow::Continue(()); } } From 99ecfa4e8b97eda5e369d82ef00cfc06e5291552 Mon Sep 17 00:00:00 2001 From: Waclaw Banasik Date: Thu, 10 Oct 2024 17:43:20 +0100 Subject: [PATCH 2/3] feat(resharding) - Make shard ids non-contiguous (#12181) This is part 1 of adding support for non-contiguous shard ids. The principle idea is to make ShardId into a newtype so that it's not possible to use it to index arrays with chunk data. In addition I'm adding ShardIndex type and a mapping between shard indices and shard ids so that it's possible to covert one to another as necessary. The TLDR of this approach is to make the types right, fix compiler errors and pray to the software gods that things work out. I am now giving up on trying to make the migration in a single PR. Instead I am introducing some temporary structures and methods that are compatible with both approaches. My current plan for the migration is as follows: 1) Switch to the new ShardId definition. 2) Fix some number of compilation errors (using the temporary objects) 3) Switch back to the old definition 4) PR, review, merge 5) Repeat 1-4 until there are no more errors. 6) Cleanup the temporary objects 7) Adjust some tests to use the new ShardLayout with non-contiguous shard ids. 8) Try to get rid of the mapping wherever possible There are a few common themes in this PR: * read the shard layout and convert shard id to shard index in order to use it to index some array or chunk data * replace enumerate with reading the shard id directly from the chunk header / other chunk data * replace using shard id by adding enumerate to get the shard index * add `?` to shard id in tracing logs because the newtype ShardId doesn't work without it must-review files: * shard_layout.rs * primitives-core/src/types.rs * shard_assignment.rs good-to-review files: * state_transition_data.rs --- chain/chain-primitives/src/error.rs | 2 +- chain/chain/src/blocks_delay_tracker.rs | 13 +- chain/chain/src/chain.rs | 157 ++++++++----- chain/chain/src/chain_update.rs | 6 +- chain/chain/src/garbage_collection.rs | 10 +- chain/chain/src/migrations.rs | 2 +- chain/chain/src/runtime/migrations.rs | 6 +- chain/chain/src/runtime/mod.rs | 16 +- chain/chain/src/runtime/tests.rs | 65 ++++-- chain/chain/src/state_snapshot_actor.rs | 4 +- .../stateless_validation/chunk_endorsement.rs | 16 +- .../stateless_validation/chunk_validation.rs | 32 ++- .../state_transition_data.rs | 44 ++-- chain/chain/src/store/latest_witnesses.rs | 4 +- chain/chain/src/store/mod.rs | 26 ++- chain/chain/src/store_validator/validate.rs | 3 +- chain/chain/src/test_utils/kv_runtime.rs | 87 +++++--- chain/chain/src/types.rs | 3 +- chain/chain/src/update_shard.rs | 4 +- chain/chunks/src/chunk_cache.rs | 9 +- chain/chunks/src/client.rs | 34 +-- chain/chunks/src/logic.rs | 6 +- chain/chunks/src/shards_manager_actor.rs | 22 +- chain/chunks/src/test_utils.rs | 4 +- chain/client-primitives/src/debug.rs | 4 +- chain/client-primitives/src/types.rs | 2 +- .../client/src/chunk_distribution_network.rs | 9 +- chain/client/src/client.rs | 67 +++--- chain/client/src/client_actor.rs | 4 +- chain/client/src/debug.rs | 28 ++- chain/client/src/info.rs | 4 +- .../chunk_endorsement/tracker_v1.rs | 2 +- .../chunk_validator/mod.rs | 4 +- .../orphan_witness_handling.rs | 6 +- .../chunk_validator/orphan_witness_pool.rs | 70 +++--- .../partial_witness_tracker.rs | 2 +- .../stateless_validation/shadow_validate.rs | 12 +- .../state_witness_producer.rs | 11 +- .../state_witness_tracker.rs | 4 +- chain/client/src/sync/external.rs | 40 ++-- chain/client/src/sync/state.rs | 34 +-- chain/client/src/sync_jobs_actor.rs | 5 +- chain/client/src/test_utils/client.rs | 6 +- chain/client/src/test_utils/setup.rs | 9 +- chain/client/src/test_utils/test_env.rs | 8 +- chain/client/src/test_utils/test_loop.rs | 4 +- chain/client/src/tests/bug_repros.rs | 44 ++-- chain/client/src/tests/catching_up.rs | 7 +- chain/client/src/tests/cross_shard_tx.rs | 44 +++- chain/client/src/tests/process_blocks.rs | 9 +- chain/client/src/tests/query_client.rs | 11 +- chain/client/src/view_client_actor.rs | 28 ++- chain/epoch-manager/src/adapter.rs | 32 ++- chain/epoch-manager/src/lib.rs | 36 ++- chain/epoch-manager/src/shard_assignment.rs | 143 ++++++++---- chain/epoch-manager/src/shard_tracker.rs | 31 ++- chain/epoch-manager/src/tests/mod.rs | 101 +++++---- .../epoch-manager/src/tests/random_epochs.rs | 8 +- chain/epoch-manager/src/types.rs | 27 ++- .../epoch-manager/src/validator_selection.rs | 26 ++- chain/indexer/src/streamer/mod.rs | 22 +- chain/jsonrpc-primitives/src/types/chunks.rs | 1 + .../jsonrpc/jsonrpc-tests/tests/rpc_query.rs | 14 +- chain/jsonrpc/src/api/chunks.rs | 4 +- .../network_protocol/proto_conv/handshake.rs | 4 +- .../proto_conv/peer_message.rs | 12 +- .../network/src/network_protocol/testonly.rs | 4 +- .../src/peer_manager/peer_manager_actor.rs | 2 +- .../src/peer_manager/tests/snapshot_hosts.rs | 19 +- chain/network/src/raw/tests.rs | 17 +- chain/network/src/snapshot_hosts/tests.rs | 59 +++-- core/primitives-core/src/types.rs | 158 ++++++++++++- core/primitives/src/block.rs | 15 +- core/primitives/src/congestion_info.rs | 35 +-- core/primitives/src/epoch_info.rs | 19 +- core/primitives/src/shard_layout.rs | 211 +++++++++++++----- core/primitives/src/sharding.rs | 2 +- .../chunk_endorsements_bitmap.rs | 28 ++- core/store/benches/finalize_bench.rs | 12 +- core/store/src/flat/storage.rs | 6 +- core/store/src/genesis/initialization.rs | 21 +- .../src/trie/prefetching_trie_storage.rs | 3 +- core/store/src/trie/shard_tries.rs | 6 +- core/store/src/trie/state_parts.rs | 4 +- core/store/src/trie/trie_storage.rs | 21 +- .../src/csv_to_json_configs.rs | 13 +- genesis-tools/genesis-csv-to-json/src/main.rs | 2 +- genesis-tools/genesis-populate/src/lib.rs | 22 +- integration-tests/src/runtime_utils.rs | 4 +- integration-tests/src/test_loop/builder.rs | 4 +- .../src/test_loop/tests/in_memory_tries.rs | 20 +- .../tests/view_requests_to_archival_node.rs | 24 +- .../src/tests/client/block_corruption.rs | 14 +- .../src/tests/client/challenges.rs | 16 +- nearcore/src/config.rs | 48 ++-- nearcore/src/config_validate.rs | 8 +- nearcore/src/entity_debug.rs | 12 +- nearcore/src/metrics.rs | 3 +- nearcore/src/state_sync.rs | 34 +-- runtime/runtime/src/balance_checker.rs | 70 +++--- runtime/runtime/src/congestion_control.rs | 8 +- runtime/runtime/src/lib.rs | 8 +- runtime/runtime/src/metrics.rs | 2 +- tools/state-viewer/src/epoch_info.rs | 8 +- tools/state-viewer/src/replay_headers.rs | 10 +- 105 files changed, 1593 insertions(+), 893 deletions(-) diff --git a/chain/chain-primitives/src/error.rs b/chain/chain-primitives/src/error.rs index f394a31addf..866eb237794 100644 --- a/chain/chain-primitives/src/error.rs +++ b/chain/chain-primitives/src/error.rs @@ -199,7 +199,7 @@ pub enum Error { InvalidBlockMerkleRoot, /// Invalid split shard ids. #[error("Invalid Split Shard Ids when resharding. shard_id: {0}, parent_shard_id: {1}")] - InvalidSplitShardsIds(u64, u64), + InvalidSplitShardsIds(ShardId, ShardId), /// Someone is not a validator. Usually happens in signature verification #[error("Not A Validator: {0}")] NotAValidator(String), diff --git a/chain/chain/src/blocks_delay_tracker.rs b/chain/chain/src/blocks_delay_tracker.rs index 32bd73b236d..fa0803718fa 100644 --- a/chain/chain/src/blocks_delay_tracker.rs +++ b/chain/chain/src/blocks_delay_tracker.rs @@ -2,6 +2,7 @@ use near_async::time::{Clock, Instant, Utc}; use near_epoch_manager::EpochManagerAdapter; use near_primitives::block::{Block, Tip}; use near_primitives::hash::CryptoHash; +use near_primitives::shard_layout::ShardLayout; use near_primitives::sharding::{ChunkHash, ShardChunkHeader}; use near_primitives::types::{BlockHeight, ShardId}; use near_primitives::views::{ @@ -289,7 +290,12 @@ impl BlocksDelayTracker { } } - pub fn finish_block_processing(&mut self, block_hash: &CryptoHash, new_head: Option) { + pub fn finish_block_processing( + &mut self, + shard_layout: &ShardLayout, + block_hash: &CryptoHash, + new_head: Option, + ) { if let Some(processed_block) = self.blocks.get_mut(&block_hash) { processed_block.processed_timestamp = Some(self.clock.now()); } @@ -297,10 +303,11 @@ impl BlocksDelayTracker { if let Some(processed_block) = self.blocks.get(&block_hash) { let chunks = processed_block.chunks.clone(); self.update_block_metrics(processed_block); - for (shard_id, chunk_hash) in chunks.into_iter().enumerate() { + for (shard_index, chunk_hash) in chunks.into_iter().enumerate() { if let Some(chunk_hash) = chunk_hash { if let Some(processed_chunk) = self.chunks.get(&chunk_hash) { - self.update_chunk_metrics(processed_chunk, shard_id as ShardId); + let shard_id = shard_layout.get_shard_id(shard_index); + self.update_chunk_metrics(processed_chunk, shard_id); } } } diff --git a/chain/chain/src/chain.rs b/chain/chain/src/chain.rs index e46ff494642..5b790cf7aa0 100644 --- a/chain/chain/src/chain.rs +++ b/chain/chain/src/chain.rs @@ -612,23 +612,26 @@ impl Chain { pub fn genesis_chunk_extra( &self, + shard_layout: &ShardLayout, shard_id: ShardId, genesis_protocol_version: ProtocolVersion, congestion_info: Option, ) -> Result { - let shard_index = shard_id as usize; + let shard_index = shard_layout.get_shard_index(shard_id); let state_root = *get_genesis_state_roots(self.chain_store.store())? .ok_or_else(|| Error::Other("genesis state roots do not exist in the db".to_owned()))? .get(shard_index) .ok_or_else(|| { - Error::Other(format!("genesis state root does not exist for shard {shard_index}")) + Error::Other(format!("genesis state root does not exist for shard id {shard_id} shard index {shard_index}")) })?; let gas_limit = self .genesis .chunks() .get(shard_index) .ok_or_else(|| { - Error::Other(format!("genesis chunk does not exist for shard {shard_index}")) + Error::Other(format!( + "genesis chunk does not exist for shard {shard_id} shard index {shard_index}" + )) })? .gas_limit(); Ok(Self::create_genesis_chunk_extra( @@ -780,13 +783,16 @@ impl Chain { Ok(None) } else { debug!(target: "chain", "Downloading state for {:?}, I'm {:?}", shards_to_state_sync, me); + let epoch_id = block.header().epoch_id(); + let shard_layout = self.epoch_manager.get_shard_layout(&epoch_id)?; let state_sync_info = StateSyncInfo { epoch_tail_hash: *block.header().hash(), shards: shards_to_state_sync .iter() .map(|shard_id| { - let chunk = &prev_block.chunks()[*shard_id as usize]; + let shard_index = shard_layout.get_shard_index(*shard_id); + let chunk = &prev_block.chunks()[shard_index]; ShardInfo(*shard_id, chunk.chunk_hash()) }) .collect(), @@ -812,14 +818,18 @@ impl Chain { genesis_block: &Block, block: &Block, ) -> Result<(), Error> { - for (shard_id, chunk_header) in block.chunks().iter().enumerate() { + let epoch_id = block.header().epoch_id(); + let shard_layout = epoch_manager.get_shard_layout(&epoch_id)?; + + for (shard_index, chunk_header) in block.chunks().iter().enumerate() { + let shard_id = shard_layout.get_shard_id(shard_index); if chunk_header.height_created() == genesis_block.header().height() { // Special case: genesis chunks can be in non-genesis blocks and don't have a signature // We must verify that content matches and signature is empty. // TODO: this code will not work when genesis block has different number of chunks as the current block // https://github.com/near/nearcore/issues/4908 let chunks = genesis_block.chunks(); - let genesis_chunk = chunks.get(shard_id); + let genesis_chunk = chunks.get(shard_index); let genesis_chunk = genesis_chunk.ok_or_else(|| { Error::InvalidChunk(format!( "genesis chunk not found for shard {}, genesis block has {} chunks", @@ -841,7 +851,7 @@ impl Chain { ))); } } else if chunk_header.height_created() == block.header().height() { - if chunk_header.shard_id() != shard_id as ShardId { + if chunk_header.shard_id() != shard_id { return Err(Error::InvalidShardId(chunk_header.shard_id())); } if !epoch_manager.verify_chunk_header_signature( @@ -1301,15 +1311,19 @@ impl Chain { } let mut missing = vec![]; let block_height = block.header().height(); - for (shard_id, chunk_header) in block.chunks().iter().enumerate() { + + let epoch_id = block.header().epoch_id(); + let shard_layout = self.epoch_manager.get_shard_layout(&epoch_id)?; + + for (shard_index, chunk_header) in block.chunks().iter().enumerate() { + let shard_id = shard_layout.get_shard_id(shard_index); // Check if any chunks are invalid in this block. if let Some(encoded_chunk) = self.chain_store.is_invalid_chunk(&chunk_header.chunk_hash())? { let merkle_paths = Block::compute_chunk_headers_root(block.chunks().iter()).1; - let merkle_proof = merkle_paths - .get(shard_id) - .ok_or_else(|| Error::InvalidShardId(shard_id as ShardId))?; + let merkle_proof = + merkle_paths.get(shard_index).ok_or_else(|| Error::InvalidShardId(shard_id))?; let chunk_proof = ChunkProofs { block_header: borsh::to_vec(&block.header()).expect("Failed to serialize"), merkle_proof: merkle_proof.clone(), @@ -1319,7 +1333,6 @@ impl Chain { }; return Err(Error::InvalidChunkProofs(Box::new(chunk_proof))); } - let shard_id = shard_id as ShardId; if chunk_header.is_new_chunk(block_height) { let chunk_hash = chunk_header.chunk_hash(); @@ -1897,15 +1910,20 @@ impl Chain { height = block.header().height()) .entered(); + let epoch_id = block.header().epoch_id(); + let shard_layout = self.epoch_manager.get_shard_layout(&epoch_id)?; + let prev_head = self.chain_store.head()?; let is_caught_up = block_preprocess_info.is_caught_up; let provenance = block_preprocess_info.provenance.clone(); let block_start_processing_time = block_preprocess_info.block_start_processing_time; // TODO(#8055): this zip relies on the ordering of the apply_results. + // TODO(wacban): do the above todo for (shard_id, apply_result) in apply_results.iter() { + let shard_index = shard_layout.get_shard_index(*shard_id); if let Err(err) = apply_result { if err.is_bad_data() { - let chunk = block.chunks()[*shard_id as usize].clone(); + let chunk = block.chunks()[shard_index].clone(); block_processing_artifacts.invalid_chunks.push(chunk); } } @@ -1950,7 +1968,7 @@ impl Chain { // during catchup of this block. care_about_shard }; - tracing::debug!(target: "chain", shard_id, need_storage_update, "Updating storage"); + tracing::debug!(target: "chain", ?shard_id, need_storage_update, "Updating storage"); if need_storage_update { // TODO(#12019): consider adding to catchup flow. @@ -2004,7 +2022,12 @@ impl Chain { .as_seconds_f64() .max(0.0), ); - self.blocks_delay_tracker.finish_block_processing(&block_hash, new_head.clone()); + let shard_layout = self.epoch_manager.get_shard_layout(epoch_id)?; + self.blocks_delay_tracker.finish_block_processing( + &shard_layout, + &block_hash, + new_head.clone(), + ); timer.observe_duration(); let _timer = CryptoHashTimer::new_with_start( @@ -2449,12 +2472,16 @@ impl Chain { if sync_block_epoch_id == sync_prev_block.header().epoch_id() { return Err(sync_hash_not_first_hash(sync_hash)); } + + let epoch_id = sync_prev_block.header().epoch_id(); + let shard_layout = self.epoch_manager.get_shard_layout(&epoch_id)?; + let shard_index = shard_layout.get_shard_index(shard_id); + // Chunk header here is the same chunk header as at the `current` height. let sync_prev_hash = sync_prev_block.hash(); let chunks = sync_prev_block.chunks(); - let chunk_header = chunks - .get(shard_id as usize) - .ok_or_else(|| Error::InvalidShardId(shard_id as ShardId))?; + let chunk_header = + chunks.get(shard_index).ok_or_else(|| Error::InvalidShardId(shard_id))?; let (chunk_headers_root, chunk_proofs) = merklize( &sync_prev_block .chunks() @@ -2467,10 +2494,8 @@ impl Chain { assert_eq!(&chunk_headers_root, sync_prev_block.header().chunk_headers_root()); let chunk = self.get_chunk_clone_from_header(chunk_header)?; - let chunk_proof = chunk_proofs - .get(shard_id as usize) - .ok_or_else(|| Error::InvalidShardId(shard_id as ShardId))? - .clone(); + let chunk_proof = + chunk_proofs.get(shard_index).ok_or_else(|| Error::InvalidShardId(shard_id))?.clone(); let block_header = self.get_block_header_on_chain_by_height(&sync_hash, chunk_header.height_included())?; @@ -2481,8 +2506,8 @@ impl Chain { Ok(prev_block) => { let prev_chunk_header = prev_block .chunks() - .get(shard_id as usize) - .ok_or_else(|| Error::InvalidShardId(shard_id as ShardId))? + .get(shard_index) + .ok_or_else(|| Error::InvalidShardId(shard_id))? .clone(); let (prev_chunk_headers_root, prev_chunk_proofs) = merklize( &prev_block @@ -2496,8 +2521,8 @@ impl Chain { assert_eq!(&prev_chunk_headers_root, prev_block.header().chunk_headers_root()); let prev_chunk_proof = prev_chunk_proofs - .get(shard_id as usize) - .ok_or_else(|| Error::InvalidShardId(shard_id as ShardId))? + .get(shard_index) + .ok_or_else(|| Error::InvalidShardId(shard_id))? .clone(); let prev_chunk_height_included = prev_chunk_header.height_included(); @@ -2548,18 +2573,18 @@ impl Chain { let ReceiptProof(receipts, shard_proof) = receipt_proof; let ShardProof { from_shard_id, to_shard_id: _, proof } = shard_proof; let receipts_hash = CryptoHash::hash_borsh(ReceiptList(shard_id, receipts)); - let from_shard_id = *from_shard_id as usize; + let from_shard_index = shard_layout.get_shard_index(*from_shard_id); - let root_proof = block.chunks()[from_shard_id].prev_outgoing_receipts_root(); + let root_proof = block.chunks()[from_shard_index].prev_outgoing_receipts_root(); root_proofs_cur - .push(RootProof(root_proof, block_receipts_proofs[from_shard_id].clone())); + .push(RootProof(root_proof, block_receipts_proofs[from_shard_index].clone())); // Make sure we send something reasonable. assert_eq!(block_header.prev_chunk_outgoing_receipts_root(), &block_receipts_root); assert!(verify_path(root_proof, proof, &receipts_hash)); assert!(verify_path( block_receipts_root, - &block_receipts_proofs[from_shard_id], + &block_receipts_proofs[from_shard_index], &root_proof, )); } @@ -2632,7 +2657,7 @@ impl Chain { let _span = tracing::debug_span!( target: "sync", "get_state_response_part", - shard_id, + ?shard_id, part_id, ?sync_hash) .entered(); @@ -2647,6 +2672,7 @@ impl Chain { .log_storage_error("block has already been checked for existence")?; let header = block.header(); let epoch_id = block.header().epoch_id(); + let shard_layout = self.epoch_manager.get_shard_layout(epoch_id)?; let shard_ids = self.epoch_manager.shard_ids(epoch_id)?; if !shard_ids.contains(&shard_id) { return Err(shard_id_out_of_bounds(shard_id)); @@ -2655,10 +2681,11 @@ impl Chain { if epoch_id == prev_block.header().epoch_id() { return Err(sync_hash_not_first_hash(sync_hash)); } + let shard_index = shard_layout.get_shard_index(shard_id); let state_root = prev_block .chunks() - .get(shard_id as usize) - .ok_or_else(|| Error::InvalidShardId(shard_id as ShardId))? + .get(shard_index) + .ok_or_else(|| Error::InvalidShardId(shard_id))? .prev_state_root(); let prev_hash = *prev_block.hash(); let prev_prev_hash = *prev_block.header().prev_hash(); @@ -3105,9 +3132,14 @@ impl Chain { ) -> Result<(), Error> { if !validate_transactions_order(chunk.transactions()) { let merkle_paths = Block::compute_chunk_headers_root(block.chunks().iter()).1; + let epoch_id = block.header().epoch_id(); + let shard_layout = self.epoch_manager.get_shard_layout(&epoch_id)?; + let shard_id = chunk.shard_id(); + let shard_index = shard_layout.get_shard_index(shard_id); + let chunk_proof = ChunkProofs { block_header: borsh::to_vec(&block.header()).expect("Failed to serialize"), - merkle_proof: merkle_paths[chunk.shard_id() as usize].clone(), + merkle_proof: merkle_paths[shard_index].clone(), chunk: MaybeEncodedShardChunk::Decoded(chunk.clone()).into(), }; return Err(Error::InvalidChunkProofs(Box::new(chunk_proof))); @@ -3423,12 +3455,14 @@ impl Chain { block: &Block, chunk_header: &ShardChunkHeader, ) -> Result { - let chunk_shard_id = chunk_header.shard_id(); + let epoch_id = block.header().epoch_id(); + let shard_layout = self.epoch_manager.get_shard_layout(&epoch_id)?; + let shard_id = chunk_header.shard_id(); + let shard_index = shard_layout.get_shard_index(shard_id); let prev_merkle_proofs = Block::compute_chunk_headers_root(prev_block.chunks().iter()).1; let merkle_proofs = Block::compute_chunk_headers_root(block.chunks().iter()).1; - let prev_chunk = self - .get_chunk_clone_from_header(&prev_block.chunks()[chunk_shard_id as usize].clone()) - .unwrap(); + let prev_chunk = + self.get_chunk_clone_from_header(&prev_block.chunks()[shard_index].clone()).unwrap(); // TODO (#6316): enable storage proof generation // let prev_chunk_header = &prev_block.chunks()[chunk_shard_id as usize]; @@ -3479,8 +3513,8 @@ impl Chain { Ok(ChunkState { prev_block_header: borsh::to_vec(&prev_block.header())?, block_header: borsh::to_vec(&block.header())?, - prev_merkle_proof: prev_merkle_proofs[chunk_shard_id as usize].clone(), - merkle_proof: merkle_proofs[chunk_shard_id as usize].clone(), + prev_merkle_proof: prev_merkle_proofs[shard_index].clone(), + merkle_proof: merkle_proofs[shard_index].clone(), prev_chunk, chunk_header: chunk_header.clone(), partial_state: PartialState::TrieValues(vec![]), @@ -3590,13 +3624,17 @@ impl Chain { let prev_chunk_headers = Chain::get_prev_chunk_headers(self.epoch_manager.as_ref(), prev_block)?; + let epoch_id = block.header().epoch_id(); + let shard_layout = self.epoch_manager.get_shard_layout(&epoch_id)?; + let mut maybe_jobs = vec![]; - for (shard_id, (chunk_header, prev_chunk_header)) in + for (shard_index, (chunk_header, prev_chunk_header)) in block.chunks().iter().zip(prev_chunk_headers.iter()).enumerate() { // XXX: This is a bit questionable -- sandbox state patching works // only for a single shard. This so far has been enough. let state_patch = state_patch.take(); + let shard_id = shard_layout.get_shard_id(shard_index); let storage_context = StorageContext { storage_data_source: StorageDataSource::Db, state_patch }; @@ -3606,7 +3644,7 @@ impl Chain { prev_block, chunk_header, prev_chunk_header, - shard_id as ShardId, + shard_id, mode, incoming_receipts, storage_context, @@ -3621,10 +3659,14 @@ impl Chain { Ok(None) => {} Err(err) => { if err.is_bad_data() { + let epoch_id = block.header().epoch_id(); + let shard_layout = self.epoch_manager.get_shard_layout(&epoch_id)?; + let shard_index = shard_layout.get_shard_index(shard_id); + let chunk_header = block .chunks() - .get(shard_id) - .ok_or_else(|| Error::InvalidShardId(shard_id as ShardId))? + .get(shard_index) + .ok_or_else(|| Error::InvalidShardId(shard_id))? .clone(); invalid_chunks.push(chunk_header); } @@ -3674,7 +3716,7 @@ impl Chain { prev_chunk_header: &ShardChunkHeader, shard_id: ShardId, mode: ApplyChunksMode, - incoming_receipts: &HashMap>, + incoming_receipts: &HashMap>, storage_context: StorageContext, ) -> Result, Error> { let _span = tracing::debug_span!(target: "chain", "get_update_shard_job").entered(); @@ -3713,7 +3755,7 @@ impl Chain { ?err, prev_block_hash=?prev_hash, block_hash=?block.header().hash(), - shard_id, + ?shard_id, prev_chunk_height_included, ?prev_chunk_extra, ?chunk_header, @@ -3888,6 +3930,7 @@ fn get_genesis_congestion_infos_impl( let genesis_prev_hash = CryptoHash::default(); let genesis_epoch_id = epoch_manager.get_epoch_id_from_prev_block(&genesis_prev_hash)?; let genesis_protocol_version = epoch_manager.get_epoch_protocol_version(&genesis_epoch_id)?; + let genesis_shard_layout = epoch_manager.get_shard_layout(&genesis_epoch_id)?; // If congestion control is not enabled at the genesis block, we return None (congestion info) for each shard. if !ProtocolFeature::CongestionControl.enabled(genesis_protocol_version) { return Ok(std::iter::repeat(None).take(state_roots.len()).collect()); @@ -3900,8 +3943,8 @@ fn get_genesis_congestion_infos_impl( } let mut new_infos = vec![]; - for (shard_id, &state_root) in state_roots.iter().enumerate() { - let shard_id = shard_id as ShardId; + for (shard_index, &state_root) in state_roots.iter().enumerate() { + let shard_id = genesis_shard_layout.get_shard_id(shard_index); let congestion_info = get_genesis_congestion_info( runtime, genesis_protocol_version, @@ -4323,9 +4366,9 @@ impl Chain { let block = self.get_block(&block_hash)?; let chunks = block.chunks(); for &shard_id in shard_ids.iter() { - let chunk_header = &chunks - .get(shard_id as usize) - .ok_or_else(|| Error::InvalidShardId(shard_id as ShardId))?; + let shard_index = shard_layout.get_shard_index(shard_id); + let chunk_header = + &chunks.get(shard_index).ok_or_else(|| Error::InvalidShardId(shard_id))?; if chunk_header.height_included() == block.header().height() { return Ok(Some((block_hash, shard_id))); } @@ -4475,11 +4518,12 @@ impl Chain { ) -> Result, Error> { let epoch_id = epoch_manager.get_epoch_id_from_prev_block(prev_block.hash())?; let shard_ids = epoch_manager.shard_ids(&epoch_id)?; + let prev_shard_ids = epoch_manager.get_prev_shard_ids(prev_block.hash(), shard_ids)?; - let chunks = prev_block.chunks(); + let prev_chunks = prev_block.chunks(); Ok(prev_shard_ids .into_iter() - .map(|shard_id| chunks.get(shard_id as usize).unwrap().clone()) + .map(|(_, shard_index)| prev_chunks.get(shard_index).unwrap().clone()) .collect()) } @@ -4488,11 +4532,12 @@ impl Chain { prev_block: &Block, shard_id: ShardId, ) -> Result { - let prev_shard_id = epoch_manager.get_prev_shard_id(prev_block.hash(), shard_id)?; + let (prev_shard_id, prev_shard_index) = + epoch_manager.get_prev_shard_id(prev_block.hash(), shard_id)?; Ok(prev_block .chunks() - .get(prev_shard_id as usize) - .ok_or(Error::InvalidShardId(shard_id))? + .get(prev_shard_index) + .ok_or(Error::InvalidShardId(prev_shard_id))? .clone()) } diff --git a/chain/chain/src/chain_update.rs b/chain/chain/src/chain_update.rs index 9d0a089d767..d1fd7bbf170 100644 --- a/chain/chain/src/chain_update.rs +++ b/chain/chain/src/chain_update.rs @@ -209,7 +209,7 @@ impl<'a> ChainUpdate<'a> { let prev_hash = block.header().prev_hash(); let results = apply_chunks_results.into_iter().map(|(shard_id, x)| { if let Err(err) = &x { - warn!(target: "chain", shard_id, hash = %block.hash(), %err, "Error in applying chunk for block"); + warn!(target: "chain", ?shard_id, hash = %block.hash(), %err, "Error in applying chunk for block"); } x }).collect::, Error>>()?; @@ -465,7 +465,7 @@ impl<'a> ChainUpdate<'a> { shard_state_header: ShardStateSyncResponseHeader, ) -> Result { let _span = - tracing::debug_span!(target: "sync", "chain_update_set_state_finalize", shard_id, ?sync_hash).entered(); + tracing::debug_span!(target: "sync", "chain_update_set_state_finalize", ?shard_id, ?sync_hash).entered(); let (chunk, incoming_receipts_proofs) = match shard_state_header { ShardStateSyncResponseHeader::V1(shard_state_header) => ( ShardChunk::V1(shard_state_header.chunk), @@ -596,7 +596,7 @@ impl<'a> ChainUpdate<'a> { sync_hash: CryptoHash, ) -> Result { let _span = - tracing::debug_span!(target: "sync", "set_state_finalize_on_height", height, shard_id) + tracing::debug_span!(target: "sync", "set_state_finalize_on_height", height, ?shard_id) .entered(); let block_header_result = self.chain_store_update.get_block_header_on_chain_by_height(&sync_hash, height); diff --git a/chain/chain/src/garbage_collection.rs b/chain/chain/src/garbage_collection.rs index 876617fb357..777d83edaef 100644 --- a/chain/chain/src/garbage_collection.rs +++ b/chain/chain/src/garbage_collection.rs @@ -585,9 +585,11 @@ impl<'a> ChainStoreUpdate<'a> { let block = self.get_block(&block_hash).expect("block data is not expected to be already cleaned"); let height = block.header().height(); + let epoch_id = block.header().epoch_id(); + let shard_layout = epoch_manager.get_shard_layout(epoch_id).expect("epoch id must exist"); // 2. Delete shard_id-indexed data (Receipts, State Headers and Parts, etc.) - for shard_id in 0..block.header().chunk_mask().len() as ShardId { + for shard_id in shard_layout.shard_ids() { let block_shard_id = get_block_shard_id(&block_hash, shard_id); self.gc_outgoing_receipts(&block_hash, shard_id); self.gc_col(DBCol::IncomingReceipts, &block_shard_id); @@ -678,11 +680,11 @@ impl<'a> ChainStoreUpdate<'a> { self.get_block(&block_hash).expect("block data is not expected to be already cleaned"); let epoch_id = block.header().epoch_id(); - let head_height = block.header().height(); + let shard_layout = epoch_manager.get_shard_layout(epoch_id).expect("epoch id must exist"); // 1. Delete shard_id-indexed data (TrieChanges, Receipts, ChunkExtra, State Headers and Parts, FlatStorage data) - for shard_id in 0..block.header().chunk_mask().len() as ShardId { + for shard_id in shard_layout.shard_ids() { let shard_uid = epoch_manager.shard_id_to_uid(shard_id, epoch_id).unwrap(); let block_shard_id = get_block_shard_uid(&block_hash, &shard_uid); @@ -833,6 +835,8 @@ impl<'a> ChainStoreUpdate<'a> { for chunk_header in block.chunks().iter().filter(|h| h.height_included() == block.header().height()) { + // It is ok to use the shard id from the header because it is a new + // chunk. An old chunk may have the shard id from the parent shard. let shard_id = chunk_header.shard_id(); let outcome_ids = self.chain_store().get_outcomes_by_block_hash_and_shard_id(block_hash, shard_id)?; diff --git a/chain/chain/src/migrations.rs b/chain/chain/src/migrations.rs index f778ac4a245..afdbf7f3f43 100644 --- a/chain/chain/src/migrations.rs +++ b/chain/chain/src/migrations.rs @@ -29,7 +29,7 @@ pub fn check_if_block_is_first_with_chunk_of_version( if is_first_epoch_with_protocol_version(epoch_manager, prev_block_hash)? { // Compare only epochs because we already know that current epoch is the first one with current protocol version // convert shard id to shard id of previous epoch because number of shards may change - let shard_id = epoch_manager.get_prev_shard_ids(prev_block_hash, vec![shard_id])?[0]; + let (shard_id, _) = epoch_manager.get_prev_shard_ids(prev_block_hash, vec![shard_id])?[0]; let prev_epoch_id = chain_store.get_epoch_id_of_last_block_with_chunk( epoch_manager, prev_block_hash, diff --git a/chain/chain/src/runtime/migrations.rs b/chain/chain/src/runtime/migrations.rs index c2855ac3a59..827fc3f934f 100644 --- a/chain/chain/src/runtime/migrations.rs +++ b/chain/chain/src/runtime/migrations.rs @@ -33,6 +33,7 @@ mod tests { use near_mainnet_res::mainnet_restored_receipts; use near_mainnet_res::mainnet_storage_usage_delta; use near_primitives::hash::hash; + use near_primitives::types::new_shard_id_tmp; #[test] fn test_migration_data() { @@ -55,7 +56,10 @@ mod tests { "48ZMJukN7RzvyJSW9MJ5XmyQkQFfjy2ZxPRaDMMHqUcT" ); let mainnet_migration_data = load_migration_data(near_primitives::chains::MAINNET); - assert_eq!(mainnet_migration_data.restored_receipts.get(&0u64).unwrap().len(), 383); + assert_eq!( + mainnet_migration_data.restored_receipts.get(&new_shard_id_tmp(0)).unwrap().len(), + 383 + ); let testnet_migration_data = load_migration_data(near_primitives::chains::TESTNET); assert!(testnet_migration_data.restored_receipts.is_empty()); } diff --git a/chain/chain/src/runtime/mod.rs b/chain/chain/src/runtime/mod.rs index 3713aa3e8bf..c901466c41a 100644 --- a/chain/chain/src/runtime/mod.rs +++ b/chain/chain/src/runtime/mod.rs @@ -29,8 +29,8 @@ use near_primitives::state_part::PartId; use near_primitives::transaction::SignedTransaction; use near_primitives::trie_key::TrieKey; use near_primitives::types::{ - AccountId, Balance, BlockHeight, EpochHeight, EpochId, EpochInfoProvider, Gas, MerkleHash, - ShardId, StateChangeCause, StateRoot, StateRootNode, + new_shard_id_tmp, AccountId, Balance, BlockHeight, EpochHeight, EpochId, EpochInfoProvider, + Gas, MerkleHash, ShardId, StateChangeCause, StateRoot, StateRootNode, }; use near_primitives::version::{ProtocolFeature, ProtocolVersion}; use near_primitives::views::{ @@ -223,7 +223,7 @@ impl NightshadeRuntime { epoch_manager.get_epoch_id_from_prev_block(prev_hash).map_err(Error::from)?; let shard_version = epoch_manager.get_shard_layout(&epoch_id).map_err(Error::from)?.version(); - Ok(ShardUId { version: shard_version, shard_id: shard_id as u32 }) + Ok(ShardUId { version: shard_version, shard_id: new_shard_id_tmp(shard_id) as u32 }) } fn get_shard_uid_from_epoch_id( @@ -234,7 +234,7 @@ impl NightshadeRuntime { let epoch_manager = self.epoch_manager.read(); let shard_version = epoch_manager.get_shard_layout(epoch_id).map_err(Error::from)?.version(); - Ok(ShardUId { version: shard_version, shard_id: shard_id as u32 }) + Ok(ShardUId { version: shard_version, shard_id: new_shard_id_tmp(shard_id) as u32 }) } fn account_id_to_shard_uid( @@ -566,7 +566,7 @@ impl NightshadeRuntime { target: "runtime", "obtain_state_part", part_id = part_id.idx, - shard_id, + ?shard_id, %prev_hash, num_parts = part_id.total) .entered(); @@ -953,7 +953,7 @@ impl RuntimeAdapter for NightshadeRuntime { } } - #[instrument(target = "runtime", level = "info", skip_all, fields(shard_id = chunk.shard_id))] + #[instrument(target = "runtime", level = "info", skip_all, fields(shard_id = ?chunk.shard_id))] fn apply_chunk( &self, storage_config: RuntimeStorageConfig, @@ -1191,7 +1191,7 @@ impl RuntimeAdapter for NightshadeRuntime { target: "runtime", "obtain_state_part", part_id = part_id.idx, - shard_id, + ?shard_id, %prev_hash, ?state_root, num_parts = part_id.total) @@ -1369,7 +1369,7 @@ fn chunk_tx_gas_limit( protocol_version: u32, runtime_config: &RuntimeConfig, prev_block: &PrepareTransactionsBlockContext, - shard_id: u64, + shard_id: ShardId, gas_limit: u64, ) -> u64 { if !ProtocolFeature::CongestionControl.enabled(protocol_version) { diff --git a/chain/chain/src/runtime/tests.rs b/chain/chain/src/runtime/tests.rs index 494778e2b7d..10b2ed7f971 100644 --- a/chain/chain/src/runtime/tests.rs +++ b/chain/chain/src/runtime/tests.rs @@ -214,12 +214,13 @@ impl TestEnv { // TODO(congestion_control): pass down prev block info and read congestion info from there // For now, just use default. let prev_block_hash = self.head.last_block_hash; - let state_root = self.state_roots[shard_id as usize]; + let epoch_id = self.epoch_manager.get_epoch_id_from_prev_block(&prev_block_hash).unwrap(); + let shard_layout = self.epoch_manager.get_shard_layout(&epoch_id).unwrap(); + let shard_index = shard_layout.get_shard_index(shard_id); + let state_root = self.state_roots[shard_index]; let gas_limit = u64::MAX; let height = self.head.height + 1; let block_timestamp = 0; - let epoch_id = - self.epoch_manager.get_epoch_id_from_prev_block(&prev_block_hash).unwrap_or_default(); let protocol_version = self.epoch_manager.get_epoch_protocol_version(&epoch_id).unwrap(); let gas_price = self.runtime.genesis_config.min_gas_price; let congestion_info = if !ProtocolFeature::CongestionControl.enabled(protocol_version) { @@ -316,19 +317,21 @@ impl TestEnv { ) { let new_hash = hash(&[(self.head.height + 1) as u8]); let shard_ids = self.epoch_manager.shard_ids(&self.head.epoch_id).unwrap(); + let shard_layout = self.epoch_manager.get_shard_layout(&self.head.epoch_id).unwrap(); assert_eq!(transactions.len(), shard_ids.len()); assert_eq!(chunk_mask.len(), shard_ids.len()); let mut all_proposals = vec![]; let mut all_receipts = vec![]; for shard_id in shard_ids { + let shard_index = shard_layout.get_shard_index(shard_id); let (state_root, proposals, receipts) = self.update_runtime( shard_id, new_hash, - &transactions[shard_id as usize], + &transactions[shard_index], self.last_receipts.get(&shard_id).map_or(&[], |v| v.as_slice()), challenges_result.clone(), ); - self.state_roots[shard_id as usize] = state_root; + self.state_roots[shard_index] = state_root; all_receipts.extend(receipts); all_proposals.append(&mut proposals.clone()); self.last_shard_proposals.insert(shard_id, proposals); @@ -391,9 +394,11 @@ impl TestEnv { &self.head.epoch_id, ) .unwrap(); + let shard_layout = self.epoch_manager.get_shard_layout(&self.head.epoch_id).unwrap(); + let shard_index = shard_layout.get_shard_index(shard_id); let shard_uid = self.epoch_manager.shard_id_to_uid(shard_id, &self.head.epoch_id).unwrap(); self.runtime - .view_account(&shard_uid, self.state_roots[shard_id as usize], account_id) + .view_account(&shard_uid, self.state_roots[shard_index], account_id) .unwrap() .into() } @@ -727,9 +732,12 @@ fn test_state_sync() { let block_hash = hash(&[env.head.height as u8]); let state_part = env .runtime - .obtain_state_part(0, &block_hash, &env.state_roots[0], PartId::new(0, 1)) + .obtain_state_part(new_shard_id_tmp(0), &block_hash, &env.state_roots[0], PartId::new(0, 1)) + .unwrap(); + let root_node = env + .runtime + .get_state_root_node(new_shard_id_tmp(0), &block_hash, &env.state_roots[0]) .unwrap(); - let root_node = env.runtime.get_state_root_node(0, &block_hash, &env.state_roots[0]).unwrap(); let mut new_env = TestEnv::new(vec![validators], 2, false); for i in 1..=2 { let prev_hash = hash(&[new_env.head.height as u8]); @@ -786,7 +794,13 @@ fn test_state_sync() { let epoch_id = &new_env.head.epoch_id; new_env .runtime - .apply_state_part(0, &env.state_roots[0], PartId::new(0, 1), &state_part, epoch_id) + .apply_state_part( + new_shard_id_tmp(0), + &env.state_roots[0], + PartId::new(0, 1), + &state_part, + epoch_id, + ) .unwrap(); new_env.state_roots[0] = env.state_roots[0]; for _ in 3..=5 { @@ -827,9 +841,9 @@ fn test_get_validator_info() { let height = env.head.height; let em = env.runtime.epoch_manager.read(); let bp = em.get_block_producer_info(&epoch_id, height).unwrap(); - let cp = em.get_chunk_producer_info(&epoch_id, height, 0).unwrap(); + let cp = em.get_chunk_producer_info(&epoch_id, height, new_shard_id_tmp(0)).unwrap(); let stateless_validators = - em.get_chunk_validator_assignments(&epoch_id, 0, height).ok(); + em.get_chunk_validator_assignments(&epoch_id, new_shard_id_tmp(0), height).ok(); if let Some(vs) = stateless_validators { if vs.contains(&validators[0]) { @@ -876,7 +890,7 @@ fn test_get_validator_info() { public_key: block_producers[0].public_key(), is_slashed: false, stake: TESTING_INIT_STAKE, - shards: vec![0], + shards: vec![new_shard_id_tmp(0)], num_produced_blocks: expected_blocks[0], num_expected_blocks: expected_blocks[0], num_produced_chunks: expected_chunks[0], @@ -893,7 +907,7 @@ fn test_get_validator_info() { public_key: block_producers[1].public_key(), is_slashed: false, stake: TESTING_INIT_STAKE, - shards: vec![0], + shards: vec![new_shard_id_tmp(0)], num_produced_blocks: expected_blocks[1], num_expected_blocks: expected_blocks[1], num_produced_chunks: expected_chunks[1], @@ -911,13 +925,13 @@ fn test_get_validator_info() { account_id: "test1".parse().unwrap(), public_key: block_producers[0].public_key(), stake: TESTING_INIT_STAKE, - shards: vec![0], + shards: vec![new_shard_id_tmp(0)], }, NextEpochValidatorInfo { account_id: "test2".parse().unwrap(), public_key: block_producers[1].public_key(), stake: TESTING_INIT_STAKE, - shards: vec![0], + shards: vec![new_shard_id_tmp(0)], }, ]; let response = env @@ -988,7 +1002,7 @@ fn test_get_validator_info() { account_id: "test2".parse().unwrap(), public_key: block_producers[1].public_key(), stake: TESTING_INIT_STAKE, - shards: vec![0], + shards: vec![new_shard_id_tmp(0)], }] ); assert!(response.current_proposals.is_empty()); @@ -1461,13 +1475,13 @@ fn test_flat_state_usage() { let env = TestEnv::new(vec![vec!["test1".parse().unwrap()]], 4, false); let trie = env .runtime - .get_trie_for_shard(0, &env.head.prev_block_hash, Trie::EMPTY_ROOT, true) + .get_trie_for_shard(new_shard_id_tmp(0), &env.head.prev_block_hash, Trie::EMPTY_ROOT, true) .unwrap(); assert!(trie.has_flat_storage_chunk_view()); let trie = env .runtime - .get_view_trie_for_shard(0, &env.head.prev_block_hash, Trie::EMPTY_ROOT) + .get_view_trie_for_shard(new_shard_id_tmp(0), &env.head.prev_block_hash, Trie::EMPTY_ROOT) .unwrap(); assert!(!trie.has_flat_storage_chunk_view()); } @@ -1505,9 +1519,14 @@ fn test_trie_and_flat_state_equality() { // - using view state, which should never use flat state let head_prev_block_hash = env.head.prev_block_hash; let state_root = env.state_roots[0]; - let state = env.runtime.get_trie_for_shard(0, &head_prev_block_hash, state_root, true).unwrap(); - let view_state = - env.runtime.get_view_trie_for_shard(0, &head_prev_block_hash, state_root).unwrap(); + let state = env + .runtime + .get_trie_for_shard(new_shard_id_tmp(0), &head_prev_block_hash, state_root, true) + .unwrap(); + let view_state = env + .runtime + .get_view_trie_for_shard(new_shard_id_tmp(0), &head_prev_block_hash, state_root) + .unwrap(); let trie_key = TrieKey::Account { account_id: validators[1].clone() }; let key = trie_key.to_vec(); @@ -1651,7 +1670,7 @@ fn prepare_transactions( transaction_groups: &mut dyn TransactionGroupIterator, storage_config: RuntimeStorageConfig, ) -> Result { - let shard_id = 0; + let shard_id = new_shard_id_tmp(0); let block = chain.get_block(&env.head.prev_block_hash).unwrap(); let congestion_info = block.block_congestion_info(); @@ -1773,7 +1792,7 @@ fn test_prepare_transactions_empty_storage_proof() { #[test] #[cfg_attr(not(feature = "test_features"), ignore)] fn test_storage_proof_garbage() { - let shard_id = 0; + let shard_id = new_shard_id_tmp(0); let signer = create_test_signer("test1"); let env = TestEnv::new(vec![vec![signer.validator_id().clone()]], 100, false); let garbage_size_mb = 50usize; diff --git a/chain/chain/src/state_snapshot_actor.rs b/chain/chain/src/state_snapshot_actor.rs index 27756e4796f..c56d39016d3 100644 --- a/chain/chain/src/state_snapshot_actor.rs +++ b/chain/chain/src/state_snapshot_actor.rs @@ -5,7 +5,7 @@ use near_performance_metrics_macros::perf; use near_primitives::block::Block; use near_primitives::hash::CryptoHash; use near_primitives::shard_layout::ShardUId; -use near_primitives::types::{EpochHeight, ShardId}; +use near_primitives::types::EpochHeight; use near_store::flat::FlatStorageManager; use near_store::ShardTries; use std::sync::Arc; @@ -92,7 +92,7 @@ impl StateSnapshotActor { NetworkRequests::SnapshotHostInfo { sync_hash: prev_block_hash, epoch_height, - shards: res_shard_uids.iter().map(|uid| uid.shard_id as ShardId).collect(), + shards: res_shard_uids.iter().map(|uid| uid.shard_id.into()).collect(), }, )); } diff --git a/chain/chain/src/stateless_validation/chunk_endorsement.rs b/chain/chain/src/stateless_validation/chunk_endorsement.rs index 1531e7d65eb..d2f0c1eb28a 100644 --- a/chain/chain/src/stateless_validation/chunk_endorsement.rs +++ b/chain/chain/src/stateless_validation/chunk_endorsement.rs @@ -43,8 +43,10 @@ pub fn validate_chunk_endorsements_in_block( } let epoch_id = epoch_manager.get_epoch_id_from_prev_block(block.header().prev_hash())?; + let shard_layout = epoch_manager.get_shard_layout(&epoch_id)?; for (chunk_header, signatures) in block.chunks().iter().zip(block.chunk_endorsements()) { let shard_id = chunk_header.shard_id(); + let shard_index = shard_layout.get_shard_index(shard_id); // For old chunks, we optimize the block by not including the chunk endorsements. if chunk_header.height_included() != block.header().height() { if !signatures.is_empty() { @@ -117,14 +119,14 @@ pub fn validate_chunk_endorsements_in_block( // Validate the chunk endorsements bitmap (if present) in the block header against the endorsement signatures in the body. if let Some(endorsements_bitmap) = endorsements_bitmap { // Bitmap's length must be equal to the min bytes needed to encode one bit per validator assignment. - if endorsements_bitmap.len(shard_id).unwrap() != signatures.len().div_ceil(8) * 8 { + if endorsements_bitmap.len(shard_index).unwrap() != signatures.len().div_ceil(8) * 8 { return Err(Error::InvalidChunkEndorsementBitmap(format!( "Bitmap's length {} is inconsistent with the number of signatures {} for shard {} ", - endorsements_bitmap.len(shard_id).unwrap(), signatures.len(), shard_id, + endorsements_bitmap.len(shard_index).unwrap(), signatures.len(), shard_id, ))); } // Bits in the bitmap must match the existence of signature for the corresponding validator in the body. - for (bit, signature) in endorsements_bitmap.iter(shard_id).zip(signatures.iter()) { + for (bit, signature) in endorsements_bitmap.iter(shard_index).zip(signatures.iter()) { if bit != signature.is_some() { return Err(Error::InvalidChunkEndorsementBitmap( format!("Chunk endorsement bit in header does not match endorsement in body. shard={}, bit={}, signature={}", @@ -132,7 +134,7 @@ pub fn validate_chunk_endorsements_in_block( } } // All extra positions after the assignments must be left as false. - for value in endorsements_bitmap.iter(shard_id).skip(signatures.len()) { + for value in endorsements_bitmap.iter(shard_index).skip(signatures.len()) { if value { return Err(Error::InvalidChunkEndorsementBitmap( format!("Extra positions in the bitmap after {} validator assignments are not all false for shard {}", @@ -157,6 +159,7 @@ pub fn validate_chunk_endorsements_in_header( ))); }; let epoch_id = epoch_manager.get_epoch_id_from_prev_block(header.prev_hash())?; + let shard_layout = epoch_manager.get_shard_layout(&epoch_id)?; let shard_ids = epoch_manager.get_shard_layout(&epoch_id)?.shard_ids().collect_vec(); if chunk_endorsements.num_shards() != shard_ids.len() { return Err(Error::InvalidChunkEndorsementBitmap( @@ -165,12 +168,13 @@ pub fn validate_chunk_endorsements_in_header( } let chunk_mask = header.chunk_mask(); for shard_id in shard_ids.into_iter() { + let shard_index = shard_layout.get_shard_index(shard_id); // For old chunks, we optimize the block and its header by not including the chunk endorsements and // corresponding bitmaps. Thus, we expect that the bitmap is empty for shard with no new chunk. - if chunk_mask[shard_id as usize] != (chunk_endorsements.len(shard_id).unwrap() > 0) { + if chunk_mask[shard_index] != (chunk_endorsements.len(shard_index).unwrap() > 0) { return Err(Error::InvalidChunkEndorsementBitmap(format!( "Bitmap must be non-empty iff shard {} has new chunk in the block. Chunk mask={}, Bitmap length={}", - shard_id, chunk_mask[shard_id as usize], chunk_endorsements.len(shard_id).unwrap(), + shard_id, chunk_mask[shard_index], chunk_endorsements.len(shard_index).unwrap(), ))); } } diff --git a/chain/chain/src/stateless_validation/chunk_validation.rs b/chain/chain/src/stateless_validation/chunk_validation.rs index 8100d159e16..c796696f989 100644 --- a/chain/chain/src/stateless_validation/chunk_validation.rs +++ b/chain/chain/src/stateless_validation/chunk_validation.rs @@ -53,6 +53,8 @@ impl MainTransition { pub fn shard_id(&self) -> ShardId { match self { Self::Genesis { shard_id, .. } => *shard_id, + // It is ok to use the shard id from the header because it is a new + // chunk. An old chunk may have the shard id from the parent shard. Self::NewChunk(data) => data.chunk_header.shard_id(), } } @@ -112,7 +114,13 @@ pub fn pre_validate_chunk_state_witness( runtime_adapter: &dyn RuntimeAdapter, ) -> Result { let store = chain.chain_store(); + let epoch_id = state_witness.epoch_id; + let shard_layout = epoch_manager.get_shard_layout(&epoch_id)?; + + // It is ok to use the shard id from the header because it is a new + // chunk. An old chunk may have the shard id from the parent shard. let shard_id = state_witness.chunk_header.shard_id(); + let shard_index = shard_layout.get_shard_index(shard_id); // First, go back through the blockchain history to locate the last new chunk // and last last new chunk for the shard. @@ -128,7 +136,7 @@ pub fn pre_validate_chunk_state_witness( loop { let block = store.get_block(&block_hash)?; let chunks = block.chunks(); - let Some(chunk) = chunks.get(shard_id as usize) else { + let Some(chunk) = chunks.get(shard_index) else { return Err(Error::InvalidChunkStateWitness(format!( "Shard {} does not exist in block {:?}", shard_id, block_hash @@ -167,8 +175,7 @@ pub fn pre_validate_chunk_state_witness( let last_chunk_block = blocks_after_last_last_chunk.first().ok_or_else(|| { Error::Other("blocks_after_last_last_chunk is empty, this should be impossible!".into()) })?; - let last_new_chunk_tx_root = - last_chunk_block.chunks().get(shard_id as usize).unwrap().tx_root(); + let last_new_chunk_tx_root = last_chunk_block.chunks().get(shard_index).unwrap().tx_root(); if last_new_chunk_tx_root != tx_root_from_state_witness { return Err(Error::InvalidChunkStateWitness(format!( "Transaction root {:?} does not match expected transaction root {:?}", @@ -216,17 +223,22 @@ pub fn pre_validate_chunk_state_witness( let main_transition_params = if last_chunk_block.header().is_genesis() { let epoch_id = last_chunk_block.header().epoch_id(); + let shard_layout = epoch_manager.get_shard_layout(&epoch_id)?; let congestion_info = last_chunk_block .block_congestion_info() .get(&shard_id) .map(|info| info.congestion_info); let genesis_protocol_version = epoch_manager.get_epoch_protocol_version(&epoch_id)?; - let chunk_extra = - chain.genesis_chunk_extra(shard_id, genesis_protocol_version, congestion_info)?; + let chunk_extra = chain.genesis_chunk_extra( + &shard_layout, + shard_id, + genesis_protocol_version, + congestion_info, + )?; MainTransition::Genesis { chunk_extra, block_hash: *last_chunk_block.hash(), shard_id } } else { MainTransition::NewChunk(NewChunkData { - chunk_header: last_chunk_block.chunks().get(shard_id as usize).unwrap().clone(), + chunk_header: last_chunk_block.chunks().get(shard_index).unwrap().clone(), transactions: state_witness.transactions.clone(), receipts: receipts_to_apply, block: Chain::get_apply_chunk_block_context( @@ -527,7 +539,7 @@ impl Chain { let height_created = witness.chunk_header.height_created(); let chunk_hash = witness.chunk_header.chunk_hash(); let parent_span = tracing::debug_span!( - target: "chain", "shadow_validate", shard_id, height_created); + target: "chain", "shadow_validate", ?shard_id, height_created); let (encoded_witness, raw_witness_size) = { let shard_id_label = shard_id.to_string(); let encode_timer = @@ -554,7 +566,7 @@ impl Chain { pre_validate_chunk_state_witness(&witness, &self, epoch_manager, runtime_adapter)?; tracing::debug!( parent: &parent_span, - shard_id, + ?shard_id, ?chunk_hash, witness_size = encoded_witness.size_bytes(), raw_witness_size, @@ -580,7 +592,7 @@ impl Chain { Ok(()) => { tracing::debug!( parent: &parent_span, - shard_id, + ?shard_id, ?chunk_hash, validation_elapsed = ?validation_start.elapsed(), "completed shadow chunk validation" @@ -592,7 +604,7 @@ impl Chain { tracing::error!( parent: &parent_span, ?err, - shard_id, + ?shard_id, ?chunk_hash, "shadow chunk validation failed" ); diff --git a/chain/chain/src/stateless_validation/state_transition_data.rs b/chain/chain/src/stateless_validation/state_transition_data.rs index 9bf05e23753..f57d0f35c47 100644 --- a/chain/chain/src/stateless_validation/state_transition_data.rs +++ b/chain/chain/src/stateless_validation/state_transition_data.rs @@ -23,8 +23,11 @@ impl Chain { return Ok(()); } let final_block = chain_store.get_block(&final_block_hash)?; - let final_block_chunk_created_heights = - final_block.chunks().iter().map(|chunk| chunk.height_created()).collect::>(); + let final_block_chunk_created_heights = final_block + .chunks() + .iter() + .map(|chunk| (chunk.shard_id(), chunk.height_created())) + .collect::>(); clear_before_last_final_block(chain_store, &final_block_chunk_created_heights)?; Ok(()) } @@ -38,7 +41,7 @@ impl Chain { /// TODO(resharding): this doesn't work after shard layout change fn clear_before_last_final_block( chain_store: &ChainStore, - last_final_block_chunk_created_heights: &[BlockHeight], + last_final_block_chunk_created_heights: &[(ShardId, BlockHeight)], ) -> Result<(), Error> { let mut start_heights = if let Some(start_heights) = chain_store @@ -56,10 +59,7 @@ fn clear_before_last_final_block( "garbage collecting state transition data" ); let mut store_update = chain_store.store().store_update(); - for (shard_index, &last_final_block_height) in - last_final_block_chunk_created_heights.iter().enumerate() - { - let shard_id = shard_index as ShardId; + for &(shard_id, last_final_block_height) in last_final_block_chunk_created_heights.iter() { let start_height = *start_heights.get(&shard_id).unwrap_or(&last_final_block_height); let mut potentially_deleted_count = 0; for height in start_height..last_final_block_height { @@ -72,7 +72,7 @@ fn clear_before_last_final_block( } tracing::debug!( target: "state_transition_data", - shard_id, + ?shard_id, start_height, potentially_deleted_count, "garbage collected state transition data for shard" @@ -116,7 +116,7 @@ mod tests { use near_primitives::block_header::{BlockHeader, BlockHeaderInnerLite, BlockHeaderV4}; use near_primitives::hash::{hash, CryptoHash}; use near_primitives::stateless_validation::stored_chunk_state_transition_data::StoredChunkStateTransitionData; - use near_primitives::types::{BlockHeight, EpochId, ShardId}; + use near_primitives::types::{new_shard_id_tmp, BlockHeight, EpochId, ShardId}; use near_primitives::utils::{get_block_shard_id, get_block_shard_id_rev, index_to_bytes}; use near_store::db::STATE_TRANSITION_START_HEIGHTS; use near_store::test_utils::create_test_store; @@ -127,7 +127,7 @@ mod tests { #[test] fn initial_state_transition_data_gc() { - let shard_id = 0; + let shard_id = new_shard_id_tmp(0); let block_at_1 = hash(&[1]); let block_at_2 = hash(&[2]); let block_at_3 = hash(&[3]); @@ -136,8 +136,9 @@ mod tests { for (hash, height) in [(block_at_1, 1), (block_at_2, 2), (block_at_3, 3)] { save_state_transition_data(&store, hash, height, shard_id); } - clear_before_last_final_block(&create_chain_store(&store), &[final_height]).unwrap(); - check_start_heights(&store, vec![final_height]); + clear_before_last_final_block(&create_chain_store(&store), &[(shard_id, final_height)]) + .unwrap(); + check_start_heights(&store, vec![(shard_id, final_height)]); check_existing_state_transition_data( &store, vec![(block_at_2, shard_id), (block_at_3, shard_id)], @@ -145,34 +146,27 @@ mod tests { } #[test] fn multiple_state_transition_data_gc() { - let shard_id = 0; + let shard_id = new_shard_id_tmp(0); let store = create_test_store(); let chain_store = create_chain_store(&store); save_state_transition_data(&store, hash(&[1]), 1, shard_id); save_state_transition_data(&store, hash(&[2]), 2, shard_id); - clear_before_last_final_block(&chain_store, &[2]).unwrap(); + clear_before_last_final_block(&chain_store, &[(shard_id, 2)]).unwrap(); let block_at_3 = hash(&[3]); let final_height = 3; save_state_transition_data(&store, block_at_3, final_height, shard_id); - clear_before_last_final_block(&chain_store, &[3]).unwrap(); - check_start_heights(&store, vec![final_height]); + clear_before_last_final_block(&chain_store, &[(shard_id, 3)]).unwrap(); + check_start_heights(&store, vec![(shard_id, final_height)]); check_existing_state_transition_data(&store, vec![(block_at_3, shard_id)]); } #[track_caller] - fn check_start_heights(store: &Store, expected: Vec) { + fn check_start_heights(store: &Store, expected: Vec<(ShardId, BlockHeight)>) { let start_heights = store .get_ser::(DBCol::Misc, STATE_TRANSITION_START_HEIGHTS) .unwrap() .unwrap(); - assert_eq!( - start_heights, - expected - .into_iter() - .enumerate() - .map(|(i, h)| (i as ShardId, h)) - .collect::>() - ); + assert_eq!(start_heights, expected.into_iter().collect::>()); } #[track_caller] diff --git a/chain/chain/src/store/latest_witnesses.rs b/chain/chain/src/store/latest_witnesses.rs index 13020e86145..2b61626b749 100644 --- a/chain/chain/src/store/latest_witnesses.rs +++ b/chain/chain/src/store/latest_witnesses.rs @@ -113,7 +113,7 @@ impl ChainStore { target: "client", "save_latest_chunk_state_witness", witness_height = witness.chunk_header.height_created(), - witness_shard = witness.chunk_header.shard_id(), + witness_shard = ?witness.chunk_header.shard_id(), ) .entered(); @@ -173,7 +173,7 @@ impl ChainStore { OsRng.fill_bytes(&mut random_uuid); let key = LatestWitnessesKey { height: witness.chunk_header.height_created(), - shard_id: witness.chunk_header.shard_id(), + shard_id: witness.chunk_header.shard_id().into(), epoch_id: witness.epoch_id, witness_size: serialized_witness_size, random_uuid, diff --git a/chain/chain/src/store/mod.rs b/chain/chain/src/store/mod.rs index 877b0c1f375..9ccfbf373ee 100644 --- a/chain/chain/src/store/mod.rs +++ b/chain/chain/src/store/mod.rs @@ -243,8 +243,8 @@ pub trait ChainStoreAccess { target: "chain", version = shard_layout.version(), prev_version = prev_shard_layout.version(), - shard_id, - parent_shard_id, + ?shard_id, + ?parent_shard_id, "crossing epoch boundary with shard layout change, updating shard id" ); shard_id = parent_shard_id; @@ -349,18 +349,22 @@ pub trait ChainStoreAccess { shard_id: ShardId, ) -> Result { let mut candidate_hash = *hash; + let block_header = self.get_block_header(&candidate_hash)?; + let shard_layout = epoch_manager.get_shard_layout(block_header.epoch_id())?; let mut shard_id = shard_id; + let mut shard_index = shard_layout.get_shard_index(shard_id); loop { let block_header = self.get_block_header(&candidate_hash)?; if *block_header .chunk_mask() - .get(shard_id as usize) - .ok_or_else(|| Error::InvalidShardId(shard_id as ShardId))? + .get(shard_index) + .ok_or_else(|| Error::InvalidShardId(shard_id))? { break Ok(*block_header.epoch_id()); } candidate_hash = *block_header.prev_hash(); - shard_id = epoch_manager.get_prev_shard_ids(&candidate_hash, vec![shard_id])?[0]; + (shard_id, shard_index) = + epoch_manager.get_prev_shard_ids(&candidate_hash, vec![shard_id])?[0]; } } } @@ -370,7 +374,7 @@ pub trait ChainStoreAccess { /// incoming receipts and the shard layout changed. fn filter_incoming_receipts_for_shard( target_shard_layout: &ShardLayout, - target_shard_id: u64, + target_shard_id: ShardId, receipt_proofs: Arc>, ) -> Vec { let mut filtered_receipt_proofs = vec![]; @@ -586,10 +590,10 @@ impl ChainStore { receipts: &mut Vec, protocol_version: ProtocolVersion, shard_layout: &ShardLayout, - shard_id: u64, - receipts_shard_id: u64, + shard_id: ShardId, + receipts_shard_id: ShardId, ) -> Result<(), Error> { - tracing::trace!(target: "resharding", ?protocol_version, shard_id, receipts_shard_id, "reassign_outgoing_receipts_for_resharding"); + tracing::trace!(target: "resharding", ?protocol_version, ?shard_id, ?receipts_shard_id, "reassign_outgoing_receipts_for_resharding"); // If simple nightshade v2 is enabled and stable use that. // Same reassignment of outgoing receipts works for simple nightshade v3 if checked_feature!("stable", SimpleNightshadeV2, protocol_version) { @@ -2173,9 +2177,9 @@ impl<'a> ChainStoreUpdate<'a> { source_store.get_chunk_extra(block_hash, &shard_uid)?.clone(), ); } - for (shard_id, chunk_header) in block.chunks().iter().enumerate() { + for (shard_index, chunk_header) in block.chunks().iter().enumerate() { + let shard_id = shard_layout.get_shard_id(shard_index); let chunk_hash = chunk_header.chunk_hash(); - let shard_id = shard_id as u64; chain_store_update .chain_store_cache_update .chunks diff --git a/chain/chain/src/store_validator/validate.rs b/chain/chain/src/store_validator/validate.rs index 743a293e499..61163f7f10c 100644 --- a/chain/chain/src/store_validator/validate.rs +++ b/chain/chain/src/store_validator/validate.rs @@ -578,8 +578,9 @@ pub(crate) fn trie_changes_chunk_extra_exists( // 5. There should be ShardChunk with ShardId `shard_id` let shard_id = shard_uid.shard_id(); + let shard_index = shard_layout.get_shard_index(shard_id); let chunks = block.chunks(); - if let Some(chunk_header) = chunks.get(shard_id as usize) { + if let Some(chunk_header) = chunks.get(shard_index) { // if the chunk is not a new chunk, skip the check if chunk_header.height_included() != block.header().height() { return Ok(()); diff --git a/chain/chain/src/test_utils/kv_runtime.rs b/chain/chain/src/test_utils/kv_runtime.rs index cc2e76bdca6..a4ed03ac276 100644 --- a/chain/chain/src/test_utils/kv_runtime.rs +++ b/chain/chain/src/test_utils/kv_runtime.rs @@ -42,8 +42,8 @@ use near_primitives::transaction::{ }; use near_primitives::types::validator_stake::ValidatorStake; use near_primitives::types::{ - AccountId, ApprovalStake, Balance, BlockHeight, EpochHeight, EpochId, Nonce, NumShards, - ShardId, StateRoot, StateRootNode, ValidatorInfoIdentifier, + shard_id_as_u32, AccountId, ApprovalStake, Balance, BlockHeight, EpochHeight, EpochId, Nonce, + NumShards, ShardId, ShardIndex, StateRoot, StateRootNode, ValidatorInfoIdentifier, }; use near_primitives::version::{ProtocolFeature, ProtocolVersion, PROTOCOL_VERSION}; use near_primitives::views::{ @@ -170,14 +170,15 @@ impl MockEpochManager { }) .collect(); - let validators_per_shard = block_producers.len() as ShardId / vs.validator_groups; - let coef = block_producers.len() as ShardId / vs.num_shards; + let validators_per_shard = block_producers.len() / vs.validator_groups as usize; + let coef = block_producers.len() / vs.num_shards as usize; let chunk_producers: Vec> = (0..vs.num_shards) - .map(|shard_id| { - let offset = (shard_id * coef / validators_per_shard * validators_per_shard) - as usize; - block_producers[offset..offset + validators_per_shard as usize].to_vec() + .map(|shard_index| { + let shard_index = shard_index as usize; + let offset = + shard_index * coef / validators_per_shard * validators_per_shard; + block_producers[offset..offset + validators_per_shard].to_vec() }) .collect(); @@ -289,8 +290,8 @@ impl MockEpochManager { &self.validators_by_valset[valset].block_producers } - fn get_chunk_producers(&self, valset: usize, shard_id: ShardId) -> Vec { - self.validators_by_valset[valset].chunk_producers[shard_id as usize].clone() + fn get_chunk_producers(&self, valset: usize, shard_index: ShardIndex) -> Vec { + self.validators_by_valset[valset].chunk_producers[shard_index].clone() } fn get_valset_for_epoch(&self, epoch_id: &EpochId) -> Result { @@ -423,8 +424,8 @@ impl EpochManagerAdapter for MockEpochManager { self.hash_to_valset.write().unwrap().contains_key(epoch_id) } - fn shard_ids(&self, _epoch_id: &EpochId) -> Result, EpochError> { - Ok((0..self.num_shards).collect()) + fn shard_ids(&self, epoch_id: &EpochId) -> Result, EpochError> { + Ok(self.get_shard_layout(epoch_id)?.shard_ids().collect()) } fn num_total_parts(&self) -> usize { @@ -463,7 +464,7 @@ impl EpochManagerAdapter for MockEpochManager { shard_id: ShardId, _epoch_id: &EpochId, ) -> Result { - Ok(ShardUId { version: 0, shard_id: shard_id as u32 }) + Ok(ShardUId { version: 0, shard_id: shard_id_as_u32(shard_id) }) } fn get_block_info(&self, _hash: &CryptoHash) -> Result, EpochError> { @@ -587,18 +588,33 @@ impl EpochManagerAdapter for MockEpochManager { fn get_prev_shard_ids( &self, - _prev_hash: &CryptoHash, + prev_hash: &CryptoHash, shard_ids: Vec, - ) -> Result, Error> { - Ok(shard_ids) + ) -> Result, Error> { + let mut prev_shard_ids = vec![]; + let shard_layout = self.get_shard_layout_from_prev_block(prev_hash)?; + for shard_id in shard_ids { + // This is not correct if there was a resharding event in between + // the previous and current block. + let prev_shard_id = shard_id; + let prev_shard_index = shard_layout.get_shard_index(prev_shard_id); + prev_shard_ids.push((prev_shard_id, prev_shard_index)); + } + + Ok(prev_shard_ids) } fn get_prev_shard_id( &self, - _prev_hash: &CryptoHash, + prev_hash: &CryptoHash, shard_id: ShardId, - ) -> Result { - Ok(shard_id) + ) -> Result<(ShardId, ShardIndex), Error> { + let shard_layout = self.get_shard_layout_from_prev_block(prev_hash)?; + // This is not correct if there was a resharding event in between + // the previous and current block. + let prev_shard_id = shard_id; + let prev_shard_index = shard_layout.get_shard_index(prev_shard_id); + Ok((prev_shard_id, prev_shard_index)) } fn get_shard_layout_from_prev_block( @@ -728,8 +744,10 @@ impl EpochManagerAdapter for MockEpochManager { shard_id: ShardId, ) -> Result { let valset = self.get_valset_for_epoch(epoch_id)?; - let chunk_producers = self.get_chunk_producers(valset, shard_id); - let index = (shard_id + height + 1) as usize % chunk_producers.len(); + let shard_layout = self.get_shard_layout(epoch_id)?; + let shard_index = shard_layout.get_shard_index(shard_id); + let chunk_producers = self.get_chunk_producers(valset, shard_index); + let index = (shard_index + height as usize + 1) % chunk_producers.len(); Ok(chunk_producers[index].account_id().clone()) } @@ -977,7 +995,9 @@ impl EpochManagerAdapter for MockEpochManager { // we check if we care about a shard. Please do not remove the unwrap, fix the logic of // the calling function. let epoch_valset = self.get_valset_for_epoch(&epoch_id).unwrap(); - let chunk_producers = self.get_chunk_producers(epoch_valset, shard_id); + let shard_layout = self.get_shard_layout(&epoch_id)?; + let shard_index = shard_layout.get_shard_index(shard_id); + let chunk_producers = self.get_chunk_producers(epoch_valset, shard_index); for validator in chunk_producers { if validator.account_id() == account_id { return Ok(true); @@ -996,7 +1016,9 @@ impl EpochManagerAdapter for MockEpochManager { // we check if we care about a shard. Please do not remove the unwrap, fix the logic of // the calling function. let epoch_valset = self.get_epoch_and_valset(*parent_hash).unwrap(); - let chunk_producers = self.get_chunk_producers(epoch_valset.1, shard_id); + let shard_layout = self.get_shard_layout_from_prev_block(parent_hash)?; + let shard_index = shard_layout.get_shard_index(shard_id); + let chunk_producers = self.get_chunk_producers(epoch_valset.1, shard_index); for validator in chunk_producers { if validator.account_id() == account_id { return Ok(true); @@ -1015,8 +1037,12 @@ impl EpochManagerAdapter for MockEpochManager { // we check if we care about a shard. Please do not remove the unwrap, fix the logic of // the calling function. let epoch_valset = self.get_epoch_and_valset(*parent_hash).unwrap(); - let chunk_producers = self - .get_chunk_producers((epoch_valset.1 + 1) % self.validators_by_valset.len(), shard_id); + let shard_layout = self.get_shard_layout_from_prev_block(parent_hash)?; + let shard_index = shard_layout.get_shard_index(shard_id); + let chunk_producers = self.get_chunk_producers( + (epoch_valset.1 + 1) % self.validators_by_valset.len(), + shard_index, + ); for validator in chunk_producers { if validator.account_id() == account_id { return Ok(true); @@ -1077,9 +1103,10 @@ impl RuntimeAdapter for KeyValueRuntime { state_root: StateRoot, _use_flat_storage: bool, ) -> Result { - Ok(self - .tries - .get_trie_for_shard(ShardUId { version: 0, shard_id: shard_id as u32 }, state_root)) + Ok(self.tries.get_trie_for_shard( + ShardUId { version: 0, shard_id: shard_id_as_u32(shard_id) }, + state_root, + )) } fn get_flat_storage_manager(&self) -> near_store::flat::FlatStorageManager { @@ -1093,7 +1120,7 @@ impl RuntimeAdapter for KeyValueRuntime { state_root: StateRoot, ) -> Result { Ok(self.tries.get_view_trie_for_shard( - ShardUId { version: 0, shard_id: shard_id as u32 }, + ShardUId { version: 0, shard_id: shard_id_as_u32(shard_id) }, state_root, )) } @@ -1278,7 +1305,7 @@ impl RuntimeAdapter for KeyValueRuntime { Ok(ApplyChunkResult { trie_changes: WrappedTrieChanges::new( self.get_tries(), - ShardUId { version: 0, shard_id: shard_id as u32 }, + ShardUId { version: 0, shard_id: shard_id_as_u32(shard_id) }, TrieChanges::empty(state_root), Default::default(), block.block_hash, diff --git a/chain/chain/src/types.rs b/chain/chain/src/types.rs index 7f18ea65c61..1063ec140d2 100644 --- a/chain/chain/src/types.rs +++ b/chain/chain/src/types.rs @@ -540,6 +540,7 @@ mod tests { use near_primitives::merkle::verify_path; use near_primitives::test_utils::{create_test_signer, TestBlockBuilder}; use near_primitives::transaction::{ExecutionMetadata, ExecutionOutcome, ExecutionStatus}; + use near_primitives::types::new_shard_id_tmp; use near_primitives::version::PROTOCOL_VERSION; use std::sync::Arc; @@ -547,7 +548,7 @@ mod tests { #[test] fn test_block_produce() { - let shard_ids: Vec<_> = (0..32).collect(); + let shard_ids: Vec<_> = (0..32).map(new_shard_id_tmp).collect(); let genesis_chunks = genesis_chunks( vec![Trie::EMPTY_ROOT], vec![Default::default(); shard_ids.len()], diff --git a/chain/chain/src/update_shard.rs b/chain/chain/src/update_shard.rs index b1d760c9547..5f171bcc7db 100644 --- a/chain/chain/src/update_shard.rs +++ b/chain/chain/src/update_shard.rs @@ -133,7 +133,7 @@ pub fn apply_new_chunk( target: "chain", parent: parent_span, "apply_new_chunk", - shard_id, + ?shard_id, ?apply_reason) .entered(); let gas_limit = chunk_header.gas_limit(); @@ -182,7 +182,7 @@ pub fn apply_old_chunk( target: "chain", parent: parent_span, "apply_old_chunk", - shard_id, + ?shard_id, ?apply_reason) .entered(); diff --git a/chain/chunks/src/chunk_cache.rs b/chain/chunks/src/chunk_cache.rs index 2ac1f0b0e70..4c830378857 100644 --- a/chain/chunks/src/chunk_cache.rs +++ b/chain/chunks/src/chunk_cache.rs @@ -274,12 +274,13 @@ mod tests { use near_crypto::KeyType; use near_primitives::hash::CryptoHash; use near_primitives::sharding::{PartialEncodedChunkV2, ShardChunkHeader, ShardChunkHeaderV2}; + use near_primitives::types::{new_shard_id_tmp, ShardId}; use near_primitives::validator_signer::InMemoryValidatorSigner; use crate::chunk_cache::EncodedChunksCache; use crate::shards_manager_actor::ChunkRequestInfo; - fn create_chunk_header(height: u64, shard_id: u64) -> ShardChunkHeader { + fn create_chunk_header(height: u64, shard_id: ShardId) -> ShardChunkHeader { let signer = InMemoryValidatorSigner::from_random("test".parse().unwrap(), KeyType::ED25519); ShardChunkHeader::V2(ShardChunkHeaderV2::new( @@ -303,8 +304,8 @@ mod tests { #[test] fn test_incomplete_chunks() { let mut cache = EncodedChunksCache::new(); - let header0 = create_chunk_header(1, 0); - let header1 = create_chunk_header(1, 1); + let header0 = create_chunk_header(1, new_shard_id_tmp(0)); + let header1 = create_chunk_header(1, new_shard_id_tmp(1)); cache.get_or_insert_from_header(&header0); cache.merge_in_partial_encoded_chunk(&PartialEncodedChunkV2 { header: header1.clone(), @@ -327,7 +328,7 @@ mod tests { #[test] fn test_cache_removal() { let mut cache = EncodedChunksCache::new(); - let header = create_chunk_header(1, 0); + let header = create_chunk_header(1, new_shard_id_tmp(0)); let partial_encoded_chunk = PartialEncodedChunkV2 { header: header, parts: vec![], prev_outgoing_receipts: vec![] }; cache.merge_in_partial_encoded_chunk(&partial_encoded_chunk); diff --git a/chain/chunks/src/client.rs b/chain/chunks/src/client.rs index 8e00394519f..a898bacabdd 100644 --- a/chain/chunks/src/client.rs +++ b/chain/chunks/src/client.rs @@ -161,7 +161,7 @@ mod tests { hash::CryptoHash, shard_layout::{account_id_to_shard_uid, ShardLayout}, transaction::SignedTransaction, - types::AccountId, + types::{new_shard_id_tmp, shard_id_as_u32, AccountId, ShardId}, }; use near_store::ShardUId; use rand::{rngs::StdRng, seq::SliceRandom, SeedableRng}; @@ -171,11 +171,12 @@ mod tests { #[test] fn test_random_seed_with_shard_id() { - let seed0 = ShardedTransactionPool::random_seed(&TEST_SEED, 0); - let seed10 = ShardedTransactionPool::random_seed(&TEST_SEED, 10); - let seed256 = ShardedTransactionPool::random_seed(&TEST_SEED, 256); - let seed1000 = ShardedTransactionPool::random_seed(&TEST_SEED, 1000); - let seed1000000 = ShardedTransactionPool::random_seed(&TEST_SEED, 1_000_000); + let seed0 = ShardedTransactionPool::random_seed(&TEST_SEED, new_shard_id_tmp(0)); + let seed10 = ShardedTransactionPool::random_seed(&TEST_SEED, new_shard_id_tmp(10)); + let seed256 = ShardedTransactionPool::random_seed(&TEST_SEED, new_shard_id_tmp(256)); + let seed1000 = ShardedTransactionPool::random_seed(&TEST_SEED, new_shard_id_tmp(1000)); + let seed1000000 = + ShardedTransactionPool::random_seed(&TEST_SEED, new_shard_id_tmp(1_000_000)); assert_ne!(seed0, seed10); assert_ne!(seed0, seed256); assert_ne!(seed0, seed1000); @@ -196,12 +197,13 @@ mod tests { let mut pool = ShardedTransactionPool::new(TEST_SEED, None); - let mut shard_id_to_accounts = HashMap::new(); - shard_id_to_accounts.insert(0, vec!["aaa", "abcd", "a-a-a-a-a"]); - shard_id_to_accounts.insert(1, vec!["aurora"]); - shard_id_to_accounts.insert(2, vec!["aurora-0", "bob", "kkk"]); + let mut shard_id_to_accounts: HashMap = HashMap::new(); + shard_id_to_accounts.insert(new_shard_id_tmp(0), vec!["aaa", "abcd", "a-a-a-a-a"]); + shard_id_to_accounts.insert(new_shard_id_tmp(1), vec!["aurora"]); + shard_id_to_accounts.insert(new_shard_id_tmp(2), vec!["aurora-0", "bob", "kkk"]); // this shard is split, make sure there are accounts for both shards 3' and 4' - shard_id_to_accounts.insert(3, vec!["mmm", "rrr", "sweat", "ttt", "www", "zzz"]); + shard_id_to_accounts + .insert(new_shard_id_tmp(3), vec!["mmm", "rrr", "sweat", "ttt", "www", "zzz"]); let deposit = 222; @@ -234,8 +236,10 @@ mod tests { CryptoHash::default(), ); - let shard_uid = - ShardUId { shard_id: signer_shard_id as u32, version: old_shard_layout.version() }; + let shard_uid = ShardUId { + shard_id: shard_id_as_u32(signer_shard_id), + version: old_shard_layout.version(), + }; pool.insert_transaction(shard_uid, tx); } @@ -250,7 +254,7 @@ mod tests { { let shard_ids: Vec<_> = new_shard_layout.shard_ids().collect(); for &shard_id in shard_ids.iter() { - let shard_id = shard_id as u32; + let shard_id = shard_id_as_u32(shard_id); let shard_uid = ShardUId { shard_id, version: new_shard_layout.version() }; let pool = pool.pool_for_shard(shard_uid); let pool_len = pool.len(); @@ -260,7 +264,7 @@ mod tests { let mut total = 0; for shard_id in shard_ids { - let shard_id = shard_id as u32; + let shard_id = shard_id_as_u32(shard_id); let shard_uid = ShardUId { shard_id, version: new_shard_layout.version() }; let mut pool_iter = pool.get_pool_iterator(shard_uid).unwrap(); while let Some(group) = pool_iter.next() { diff --git a/chain/chunks/src/logic.rs b/chain/chunks/src/logic.rs index 5f483094501..daf09dcfa24 100644 --- a/chain/chunks/src/logic.rs +++ b/chain/chunks/src/logic.rs @@ -110,8 +110,8 @@ pub fn make_outgoing_receipts_proofs( let mut receipts_by_shard = Chain::group_receipts_by_shard(outgoing_receipts.to_vec(), &shard_layout); - let it = proofs.into_iter().enumerate().map(move |(proof_shard_id, proof)| { - let proof_shard_id = proof_shard_id as u64; + let it = proofs.into_iter().enumerate().map(move |(proof_shard_index, proof)| { + let proof_shard_id = shard_layout.get_shard_id(proof_shard_index); let receipts = receipts_by_shard.remove(&proof_shard_id).unwrap_or_else(Vec::new); let shard_proof = ShardProof { from_shard_id: shard_id, to_shard_id: proof_shard_id, proof }; @@ -174,7 +174,7 @@ pub fn decode_encoded_chunk( target: "chunks", "decode_encoded_chunk", height_included = encoded_chunk.cloned_header().height_included(), - shard_id = encoded_chunk.cloned_header().shard_id(), + shard_id = ?encoded_chunk.cloned_header().shard_id(), ?chunk_hash) .entered(); diff --git a/chain/chunks/src/shards_manager_actor.rs b/chain/chunks/src/shards_manager_actor.rs index 6829b9f3ee2..700261f9a06 100644 --- a/chain/chunks/src/shards_manager_actor.rs +++ b/chain/chunks/src/shards_manager_actor.rs @@ -524,7 +524,7 @@ impl ShardsManagerActor { debug!( target: "chunks", ?part_ords, - shard_id, + ?shard_id, ?target_account, prefer_peer, "Requesting parts", @@ -684,18 +684,18 @@ impl ShardsManagerActor { target: "chunks", "request_chunk_single", ?chunk_hash, - shard_id, + ?shard_id, height_created = height) .entered(); if self.requested_partial_encoded_chunks.contains_key(&chunk_hash) { - debug!(target: "chunks", height, shard_id, ?chunk_hash, "Not requesting chunk, already being requested."); + debug!(target: "chunks", height, ?shard_id, ?chunk_hash, "Not requesting chunk, already being requested."); return; } if let Some(entry) = self.encoded_chunks.get(&chunk_header.chunk_hash()) { if entry.complete { - debug!(target: "chunks", height, shard_id, ?chunk_hash, "Not requesting chunk, already complete."); + debug!(target: "chunks", height, ?shard_id, ?chunk_hash, "Not requesting chunk, already complete."); return; } } else { @@ -703,7 +703,7 @@ impl ShardsManagerActor { // However, if the chunk had just been processed and marked as complete, it might have // been removed from the cache if it is out of horizon. So in this case, the chunk is // already complete and we don't need to request anything. - debug!(target: "chunks", height, shard_id, ?chunk_hash, "Not requesting chunk, already complete and GC-ed."); + debug!(target: "chunks", height, ?shard_id, ?chunk_hash, "Not requesting chunk, already complete and GC-ed."); return; } @@ -721,7 +721,7 @@ impl ShardsManagerActor { ); if mark_only { - debug!(target: "chunks", height, shard_id, ?chunk_hash, "Marked the chunk as being requested but did not send the request yet."); + debug!(target: "chunks", height, ?shard_id, ?chunk_hash, "Marked the chunk as being requested but did not send the request yet."); return; } @@ -749,7 +749,7 @@ impl ShardsManagerActor { // we want to give some time for any `PartialEncodedChunkForward` messages to arrive // before we send requests. if !should_wait_for_chunk_forwarding || fetch_from_archival || old_block { - debug!(target: "chunks", height, shard_id, ?chunk_hash, "Requesting."); + debug!(target: "chunks", height, ?shard_id, ?chunk_hash, "Requesting."); let request_result = self.request_partial_encoded_chunk( height, &ancestor_hash, @@ -1108,7 +1108,7 @@ impl ShardsManagerActor { target: "chunks", "check_chunk_complete", height_included = chunk.cloned_header().height_included(), - shard_id = chunk.cloned_header().shard_id(), + shard_id = ?chunk.cloned_header().shard_id(), chunk_hash = ?chunk.chunk_hash()) .entered(); @@ -1471,7 +1471,7 @@ impl ShardsManagerActor { target: "chunks", "process_partial_encoded_chunk", ?chunk_hash, - shard_id = header.shard_id(), + shard_id = ?header.shard_id(), height_created = header.height_created(), height_included = header.height_included()) .entered(); @@ -2263,7 +2263,7 @@ mod test { use near_network::types::NetworkRequests; use near_primitives::block::Tip; use near_primitives::hash::{hash, CryptoHash}; - use near_primitives::types::EpochId; + use near_primitives::types::{new_shard_id_tmp, EpochId}; use near_primitives::validator_signer::EmptyValidatorSigner; use near_store::test_utils::create_test_store; use std::sync::Arc; @@ -2322,7 +2322,7 @@ mod test { height: 0, ancestor_hash: Default::default(), prev_block_hash: Default::default(), - shard_id: 0, + shard_id: new_shard_id_tmp(0), added, last_requested: added, }, diff --git a/chain/chunks/src/test_utils.rs b/chain/chunks/src/test_utils.rs index 4415b467a6d..8d659afa4c7 100644 --- a/chain/chunks/src/test_utils.rs +++ b/chain/chunks/src/test_utils.rs @@ -15,7 +15,7 @@ use near_primitives::sharding::{ ShardChunkHeader, }; use near_primitives::test_utils::create_test_signer; -use near_primitives::types::MerkleHash; +use near_primitives::types::{new_shard_id_tmp, MerkleHash}; use near_primitives::types::{AccountId, EpochId, ShardId}; use near_primitives::version::{ProtocolFeature, PROTOCOL_VERSION}; use near_store::adapter::chunk_store::ChunkStoreAdapter; @@ -92,7 +92,7 @@ impl ChunkTestFixture { let (mock_parent_hash, mock_height) = if orphan_chunk { (CryptoHash::hash_bytes(&[]), 2) } else { (mock_ancestor_hash, 1) }; // setting this to 2 instead of 0 so that when chunk producers - let mock_shard_id: ShardId = 0; + let mock_shard_id: ShardId = new_shard_id_tmp(0); let mock_epoch_id = epoch_manager.get_epoch_id_from_prev_block(&mock_ancestor_hash).unwrap(); let mock_chunk_producer = diff --git a/chain/client-primitives/src/debug.rs b/chain/client-primitives/src/debug.rs index aacf1128e06..29b7bd314cb 100644 --- a/chain/client-primitives/src/debug.rs +++ b/chain/client-primitives/src/debug.rs @@ -2,7 +2,7 @@ //! without backwards compatibility of JSON encoding. use crate::types::StatusError; use near_primitives::congestion_info::CongestionInfo; -use near_primitives::types::EpochId; +use near_primitives::types::{EpochId, ShardId}; use near_primitives::views::{ CatchupStatusView, ChainProcessingInfo, EpochValidatorInfo, RequestedStatePartsView, SyncStatusView, @@ -143,7 +143,7 @@ pub struct ProductionAtHeight { // None if we are not responsible for producing this block. pub block_production: Option, // Map from shard_id to chunk that we are responsible to produce at this height - pub chunk_production: HashMap, + pub chunk_production: HashMap, } // Information about the approvals that we received. diff --git a/chain/client-primitives/src/types.rs b/chain/client-primitives/src/types.rs index 2baeb99161d..b2d97aacd58 100644 --- a/chain/client-primitives/src/types.rs +++ b/chain/client-primitives/src/types.rs @@ -479,7 +479,7 @@ pub enum GetChunkError { #[error("Block either has never been observed on the node or has been garbage collected: {error_message}")] UnknownBlock { error_message: String }, #[error("Shard ID {shard_id} is invalid")] - InvalidShardId { shard_id: u64 }, + InvalidShardId { shard_id: ShardId }, #[error("Chunk with hash {chunk_hash:?} has never been observed on this node")] UnknownChunk { chunk_hash: ChunkHash }, // NOTE: Currently, the underlying errors are too broad, and while we tried to handle diff --git a/chain/client/src/chunk_distribution_network.rs b/chain/client/src/chunk_distribution_network.rs index b66b0dc2dce..e27fa436ab8 100644 --- a/chain/client/src/chunk_distribution_network.rs +++ b/chain/client/src/chunk_distribution_network.rs @@ -226,6 +226,7 @@ mod tests { PartialEncodedChunkV2, ShardChunkHeaderInner, ShardChunkHeaderInnerV3, ShardChunkHeaderV3, }, + types::new_shard_id_tmp, validator_signer::EmptyValidatorSigner, }; use std::{collections::HashMap, convert::Infallible, future::Future}; @@ -235,7 +236,7 @@ mod tests { fn test_request_chunks() { let (mock_sender, mut message_receiver) = mpsc::unbounded_channel(); let mut client = MockClient::default(); - let missing_chunk = mock_shard_chunk(0, 0); + let missing_chunk = mock_shard_chunk(0, 0u64.into()); let mut blocks_delay_tracker = BlocksDelayTracker::new(Clock::real()); let shards_manager = MockSender::new(mock_sender); let shards_manager_adapter = shards_manager.into_sender(); @@ -309,8 +310,8 @@ mod tests { // When chunks are known by the client, the shards manager // is told to process the chunk directly - let known_chunk_1 = mock_shard_chunk(1, 0); - let known_chunk_2 = mock_shard_chunk(2, 0); + let known_chunk_1 = mock_shard_chunk(1, new_shard_id_tmp(0)); + let known_chunk_2 = mock_shard_chunk(2, new_shard_id_tmp(0)); client.publish_chunk(&known_chunk_1).now_or_never(); client.publish_chunk(&known_chunk_2).now_or_never(); let blocks_missing_chunks = vec![BlockMissingChunks { @@ -392,7 +393,7 @@ mod tests { }); } - fn mock_shard_chunk(height: u64, shard_id: u64) -> PartialEncodedChunk { + fn mock_shard_chunk(height: u64, shard_id: ShardId) -> PartialEncodedChunk { let prev_block_hash = hash(&[height.to_le_bytes().as_slice(), shard_id.to_le_bytes().as_slice()].concat()); let mut mock_hashes = MockHashes::new(prev_block_hash); diff --git a/chain/client/src/client.rs b/chain/client/src/client.rs index 9a13abeba08..6bbff7d0994 100644 --- a/chain/client/src/client.rs +++ b/chain/client/src/client.rs @@ -150,7 +150,7 @@ pub struct Client { /// A mapping from a block for which a state sync is underway for the next epoch, and the object /// storing the current status of the state sync and blocks catch up pub catchup_state_syncs: - HashMap, BlocksCatchUpState)>, + HashMap, BlocksCatchUpState)>, /// Keeps track of information needed to perform the initial Epoch Sync pub epoch_sync: EpochSync, /// Keeps track of syncing headers. @@ -428,8 +428,9 @@ impl Client { block: &Block, ) -> Result<(), Error> { let epoch_id = self.epoch_manager.get_epoch_id(block.hash())?; - for (shard_id, chunk_header) in block.chunks().iter().enumerate() { - let shard_id = shard_id as ShardId; + let shard_layout = self.epoch_manager.get_shard_layout(&epoch_id)?; + for (shard_index, chunk_header) in block.chunks().iter().enumerate() { + let shard_id = shard_layout.get_shard_id(shard_index); let shard_uid = self.epoch_manager.shard_id_to_uid(shard_id, &epoch_id)?; if block.header().height() == chunk_header.height_included() { if cares_about_shard_this_or_next_epoch( @@ -458,8 +459,10 @@ impl Client { block: &Block, ) -> Result<(), Error> { let epoch_id = self.epoch_manager.get_epoch_id(block.hash())?; - for (shard_id, chunk_header) in block.chunks().iter().enumerate() { - let shard_id = shard_id as ShardId; + let shard_layout = self.epoch_manager.get_shard_layout(&epoch_id)?; + + for (shard_index, chunk_header) in block.chunks().iter().enumerate() { + let shard_id = shard_layout.get_shard_id(shard_index); let shard_uid = self.epoch_manager.shard_id_to_uid(shard_id, &epoch_id)?; if block.header().height() == chunk_header.height_included() { @@ -726,7 +729,7 @@ impl Client { BlockProductionTracker::construct_chunk_collection_info( height, &epoch_id, - chunk_headers.len() as ShardId, + chunk_headers.len(), &new_chunks, self.epoch_manager.as_ref(), &self.chunk_inclusion_tracker, @@ -734,16 +737,18 @@ impl Client { ); // Collect new chunk headers and endorsements. + let shard_layout = self.epoch_manager.get_shard_layout(&epoch_id)?; for (shard_id, chunk_hash) in new_chunks { + let shard_index = shard_layout.get_shard_index(shard_id); let (mut chunk_header, chunk_endorsement) = self.chunk_inclusion_tracker.get_chunk_header_and_endorsements(&chunk_hash)?; *chunk_header.height_included_mut() = height; *chunk_headers - .get_mut(shard_id as usize) + .get_mut(shard_index) .ok_or_else(|| near_chain_primitives::Error::InvalidShardId(shard_id))? = chunk_header; *chunk_endorsements - .get_mut(shard_id as usize) + .get_mut(shard_index) .ok_or_else(|| near_chain_primitives::Error::InvalidShardId(shard_id))? = chunk_endorsement; } @@ -831,7 +836,7 @@ impl Client { me = ?signer.as_ref().validator_id(), ?chunk_proposer, next_height, - shard_id, + ?shard_id, "Not producing chunk. Not chunk producer for next chunk."); return Ok(None); } @@ -863,7 +868,7 @@ impl Client { let prev_prev_hash = *self.chain.get_block_header(&prev_block_hash)?.prev_hash(); if !self.chain.prev_block_is_caught_up(&prev_prev_hash, &prev_block_hash)? { // See comment in similar snipped in `produce_block` - debug!(target: "client", shard_id, next_height, "Produce chunk: prev block is not caught up"); + debug!(target: "client", ?shard_id, next_height, "Produce chunk: prev block is not caught up"); return Err(Error::ChunkProducer( "State for the epoch is not downloaded yet, skipping chunk production" .to_string(), @@ -871,7 +876,7 @@ impl Client { } } - debug!(target: "client", me = ?validator_signer.validator_id(), next_height, shard_id, "Producing chunk"); + debug!(target: "client", me = ?validator_signer.validator_id(), next_height, ?shard_id, "Producing chunk"); let shard_uid = self.epoch_manager.shard_id_to_uid(shard_id, epoch_id)?; let chunk_extra = self @@ -879,9 +884,10 @@ impl Client { .get_chunk_extra(&prev_block_hash, &shard_uid) .map_err(|err| Error::ChunkProducer(format!("No chunk extra available: {}", err)))?; - let prev_shard_id = self.epoch_manager.get_prev_shard_id(prev_block.hash(), shard_id)?; + let (prev_shard_id, prev_shard_index) = + self.epoch_manager.get_prev_shard_id(prev_block.hash(), shard_id)?; let last_chunk_header = - prev_block.chunks().get(prev_shard_id as usize).cloned().ok_or_else(|| { + prev_block.chunks().get(prev_shard_index).cloned().ok_or_else(|| { Error::ChunkProducer(format!( "No last chunk in prev_block_hash {:?}, prev_shard_id: {}", prev_block_hash, prev_shard_id @@ -1022,7 +1028,7 @@ impl Client { chunk_extra: &ChunkExtra, ) -> Result { let Self { chain, sharded_tx_pool, runtime_adapter: runtime, .. } = self; - let shard_id = shard_uid.shard_id as ShardId; + let shard_id = shard_uid.shard_id(); let prepared_transactions = if let Some(mut iter) = sharded_tx_pool.get_pool_iterator(shard_uid) { @@ -1453,8 +1459,18 @@ impl Client { ) { let chunk_header = partial_chunk.cloned_header(); self.chain.blocks_delay_tracker.mark_chunk_completed(&chunk_header); + + // TODO(#10569) We would like a proper error handling here instead of `expect`. + let parent_hash = *chunk_header.prev_block_hash(); + let shard_layout = self + .epoch_manager + .get_shard_layout_from_prev_block(&parent_hash) + .expect("Could not obtain shard layout"); + + let shard_id = partial_chunk.shard_id(); + let shard_index = shard_layout.get_shard_index(shard_id); self.block_production_info - .record_chunk_collected(partial_chunk.height_created(), partial_chunk.shard_id()); + .record_chunk_collected(partial_chunk.height_created(), shard_index); // TODO(#10569) We would like a proper error handling here instead of `expect`. persist_chunk(partial_chunk, shard_chunk, self.chain.mut_chain_store()) @@ -2210,7 +2226,7 @@ impl Client { validators.remove(account_id); } for validator in validators { - trace!(target: "client", me = ?signer.as_ref().map(|bp| bp.validator_id()), ?tx, ?validator, shard_id, "Routing a transaction"); + trace!(target: "client", me = ?signer.as_ref().map(|bp| bp.validator_id()), ?tx, ?validator, ?shard_id, "Routing a transaction"); // Send message to network to actually forward transaction. self.network_adapter.send(PeerManagerMessageRequest::NetworkRequests( @@ -2392,7 +2408,7 @@ impl Client { // forward to current epoch validators, // possibly forward to next epoch validators if self.active_validator(shard_id, signer)? { - trace!(target: "client", account = ?me, shard_id, tx_hash = ?tx.get_hash(), is_forwarded, "Recording a transaction."); + trace!(target: "client", account = ?me, ?shard_id, tx_hash = ?tx.get_hash(), is_forwarded, "Recording a transaction."); metrics::TRANSACTION_RECEIVED_VALIDATOR.inc(); if !is_forwarded { @@ -2400,12 +2416,12 @@ impl Client { } Ok(ProcessTxResponse::ValidTx) } else if !is_forwarded { - trace!(target: "client", shard_id, tx_hash = ?tx.get_hash(), "Forwarding a transaction."); + trace!(target: "client", ?shard_id, tx_hash = ?tx.get_hash(), "Forwarding a transaction."); metrics::TRANSACTION_RECEIVED_NON_VALIDATOR.inc(); self.forward_tx(&epoch_id, tx, signer)?; Ok(ProcessTxResponse::RequestRouted) } else { - trace!(target: "client", shard_id, tx_hash = ?tx.get_hash(), "Non-validator received a forwarded transaction, dropping it."); + trace!(target: "client", ?shard_id, tx_hash = ?tx.get_hash(), "Non-validator received a forwarded transaction, dropping it."); metrics::TRANSACTION_RECEIVED_NON_VALIDATOR_FORWARDED.inc(); Ok(ProcessTxResponse::NoResponse) } @@ -2414,7 +2430,7 @@ impl Client { Ok(ProcessTxResponse::DoesNotTrackShard) } else if is_forwarded { // Received forwarded transaction but we are not tracking the shard - debug!(target: "client", ?me, shard_id, tx_hash = ?tx.get_hash(), "Received forwarded transaction but no tracking shard"); + debug!(target: "client", ?me, ?shard_id, tx_hash = ?tx.get_hash(), "Received forwarded transaction but no tracking shard"); Ok(ProcessTxResponse::NoResponse) } else { // We are not tracking this shard, so there is no way to validate this tx. Just rerouting. @@ -2497,7 +2513,7 @@ impl Client { debug!(target: "catchup", ?me, ?sync_hash, progress_per_shard = ?format_shard_sync_phase_per_shard(&shards_to_split, false), "Catchup"); let use_colour = matches!(self.config.log_summary_style, LogSummaryStyle::Colored); - let tracking_shards: Vec = + let tracking_shards: Vec = state_sync_info.shards.iter().map(|tuple| tuple.0).collect(); // Notify each shard to sync. if notify_state_sync { @@ -2583,7 +2599,7 @@ impl Client { sync_hash: CryptoHash, state_sync_info: &StateSyncInfo, me: &Option, - ) -> Result, Error> { + ) -> Result, Error> { let prev_hash = *self.chain.get_block(&sync_hash)?.header().prev_hash(); let need_to_reshard = self.epoch_manager.will_shard_layout_change(&prev_hash)?; @@ -2596,7 +2612,7 @@ impl Client { let shards_to_split = state_sync_info .shards .iter() - .filter_map(|ShardInfo(shard_id, _)| self.should_split_shard(shard_id, me, prev_hash)) + .filter_map(|ShardInfo(shard_id, _)| self.should_split_shard(*shard_id, me, prev_hash)) .collect(); Ok(shards_to_split) } @@ -2605,11 +2621,10 @@ impl Client { /// track it. fn should_split_shard( &mut self, - shard_id: &u64, + shard_id: ShardId, me: &Option, prev_hash: CryptoHash, - ) -> Option<(u64, ShardSyncDownload)> { - let shard_id = *shard_id; + ) -> Option<(ShardId, ShardSyncDownload)> { if self.shard_tracker.care_about_shard(me.as_ref(), &prev_hash, shard_id, true) { let shard_sync_download = ShardSyncDownload { downloads: vec![], diff --git a/chain/client/src/client_actor.rs b/chain/client/src/client_actor.rs index f43121d4a3d..ea909642fac 100644 --- a/chain/client/src/client_actor.rs +++ b/chain/client/src/client_actor.rs @@ -65,7 +65,7 @@ use near_primitives::block::Tip; use near_primitives::block_header::ApprovalType; use near_primitives::hash::CryptoHash; use near_primitives::network::{AnnounceAccount, PeerId}; -use near_primitives::types::{AccountId, BlockHeight, EpochId}; +use near_primitives::types::{AccountId, BlockHeight, EpochId, ShardId}; use near_primitives::unwrap_or_return; use near_primitives::utils::MaybeValidated; use near_primitives::validator_signer::ValidatorSigner; @@ -1912,7 +1912,7 @@ impl ClientActorInner { &mut self, epoch_id: EpochId, sync_hash: CryptoHash, - shards_to_sync: &Vec, + shards_to_sync: &Vec, ) { let shard_layout = self.client.epoch_manager.get_shard_layout(&epoch_id).expect("Cannot get shard layout"); diff --git a/chain/client/src/debug.rs b/chain/client/src/debug.rs index 47ca3fc1e4d..c3c59579848 100644 --- a/chain/client/src/debug.rs +++ b/chain/client/src/debug.rs @@ -21,7 +21,9 @@ use near_performance_metrics_macros::perf; use near_primitives::congestion_info::CongestionControl; use near_primitives::state_sync::get_num_state_parts; use near_primitives::stateless_validation::chunk_endorsement::ChunkEndorsement; -use near_primitives::types::{AccountId, BlockHeight, NumShards, ShardId, ValidatorInfoIdentifier}; +use near_primitives::types::{ + AccountId, BlockHeight, NumShards, ShardId, ShardIndex, ValidatorInfoIdentifier, +}; use near_primitives::{ hash::CryptoHash, state_sync::{ShardStateSyncResponseHeader, StateHeaderKey}, @@ -104,11 +106,11 @@ impl BlockProductionTracker { /// Record chunk collected after a block is produced if the block didn't include a chunk for the shard. /// If called before the block was produced, nothing happens. - pub(crate) fn record_chunk_collected(&mut self, height: BlockHeight, shard_id: ShardId) { + pub(crate) fn record_chunk_collected(&mut self, height: BlockHeight, shard_index: ShardIndex) { if let Some(block_production) = self.0.get_mut(&height) { let chunk_collections = &mut block_production.chunks_collection_time; // Check that chunk_collection is set and we haven't received this chunk yet. - if let Some(chunk_collection) = chunk_collections.get_mut(shard_id as usize) { + if let Some(chunk_collection) = chunk_collections.get_mut(shard_index) { if chunk_collection.received_time.is_none() { chunk_collection.received_time = Some(Clock::real().now_utc()); } @@ -121,13 +123,15 @@ impl BlockProductionTracker { pub(crate) fn construct_chunk_collection_info( block_height: BlockHeight, epoch_id: &EpochId, - num_shards: ShardId, + num_shards: usize, new_chunks: &HashMap, epoch_manager: &dyn EpochManagerAdapter, chunk_inclusion_tracker: &ChunkInclusionTracker, ) -> Result, Error> { let mut chunk_collection_info = vec![]; - for shard_id in 0..num_shards { + for shard_index in 0..num_shards { + let shard_layout = epoch_manager.get_shard_layout(epoch_id)?; + let shard_id = shard_layout.get_shard_id(shard_index); if let Some(chunk_hash) = new_chunks.get(&shard_id) { let (chunk_producer, received_time) = chunk_inclusion_tracker.get_chunk_producer_and_received_time(chunk_hash)?; @@ -228,6 +232,8 @@ impl ClientActorInner { let block = self.client.chain.get_block_by_height(epoch_start_height)?; let epoch_id = block.header().epoch_id(); + let shard_layout = self.client.epoch_manager.get_shard_layout(&epoch_id)?; + let (validators, chunk_only_producers) = self.get_producers_for_epoch(&epoch_id, ¤t_block)?; @@ -235,9 +241,10 @@ impl ClientActorInner { .chunks() .iter() .enumerate() - .map(|(shard_id, chunk)| { + .map(|(shard_index, chunk)| { + let shard_id = shard_layout.get_shard_id(shard_index); let state_root_node = self.client.runtime_adapter.get_state_root_node( - shard_id as u64, + shard_id, block.hash(), &chunk.prev_state_root(), ); @@ -252,9 +259,10 @@ impl ClientActorInner { }) .collect(); - let state_header_exists: Vec = (0..block.chunks().len()) + let state_header_exists: Vec = shard_layout + .shard_ids() .map(|shard_id| { - let key = borsh::to_vec(&StateHeaderKey(shard_id as u64, *block.hash())); + let key = borsh::to_vec(&StateHeaderKey(shard_id, *block.hash())); match key { Ok(key) => { matches!( @@ -490,7 +498,7 @@ impl ClientActorInner { }); DebugChunkStatus { - shard_id: chunk.shard_id(), + shard_id: chunk.shard_id().into(), chunk_hash: chunk.chunk_hash(), chunk_producer: self .client diff --git a/chain/client/src/info.rs b/chain/client/src/info.rs index 3f3d3d8e10e..ab62c3b9d3d 100644 --- a/chain/client/src/info.rs +++ b/chain/client/src/info.rs @@ -210,6 +210,7 @@ impl InfoHelper { let epoch_info = client.epoch_manager.get_epoch_info(&head.epoch_id); let blocks_in_epoch = client.config.epoch_length; let shard_ids = client.epoch_manager.shard_ids(&head.epoch_id).unwrap_or_default(); + let shard_layout = client.epoch_manager.get_shard_layout(&head.epoch_id).unwrap(); if let Ok(epoch_info) = epoch_info { metrics::VALIDATORS_CHUNKS_EXPECTED_IN_EPOCH.reset(); metrics::VALIDATORS_BLOCKS_EXPECTED_IN_EPOCH.reset(); @@ -250,10 +251,11 @@ impl InfoHelper { }); for shard_id in shard_ids { + let shard_index = shard_layout.get_shard_index(shard_id); let mut stake_per_cp = HashMap::::new(); stake_sum = 0; let chunk_producers_settlement = &epoch_info.chunk_producers_settlement(); - let chunk_producers = chunk_producers_settlement.get(shard_id as usize); + let chunk_producers = chunk_producers_settlement.get(shard_index); let Some(chunk_producers) = chunk_producers else { tracing::warn!(target: "stats", ?shard_id, ?chunk_producers_settlement, "invalid shard id, not found in the shard settlement"); continue; diff --git a/chain/client/src/stateless_validation/chunk_endorsement/tracker_v1.rs b/chain/client/src/stateless_validation/chunk_endorsement/tracker_v1.rs index 30dc76f41ff..44b2954c147 100644 --- a/chain/client/src/stateless_validation/chunk_endorsement/tracker_v1.rs +++ b/chain/client/src/stateless_validation/chunk_endorsement/tracker_v1.rs @@ -90,7 +90,7 @@ impl ChunkEndorsementTracker { chunk_header: &ShardChunkHeader, endorsement: ChunkEndorsementV1, ) -> Result<(), Error> { - let _span = tracing::debug_span!(target: "client", "process_chunk_endorsement", chunk_hash=?chunk_header.chunk_hash(), shard_id=chunk_header.shard_id()).entered(); + let _span = tracing::debug_span!(target: "client", "process_chunk_endorsement", chunk_hash=?chunk_header.chunk_hash(), shard_id=?chunk_header.shard_id()).entered(); // Validate the endorsement before locking the mutex to improve performance. if !self.epoch_manager.verify_chunk_endorsement(&chunk_header, &endorsement)? { tracing::error!(target: "client", ?endorsement, "Invalid chunk endorsement."); diff --git a/chain/client/src/stateless_validation/chunk_validator/mod.rs b/chain/client/src/stateless_validation/chunk_validator/mod.rs index 6395464efca..23e6fdf43bb 100644 --- a/chain/client/src/stateless_validation/chunk_validator/mod.rs +++ b/chain/client/src/stateless_validation/chunk_validator/mod.rs @@ -214,7 +214,7 @@ pub(crate) fn send_chunk_endorsement_to_block_producers( tracing::debug!( target: "client", chunk_hash=?chunk_hash, - shard_id=chunk_header.shard_id(), + shard_id=?chunk_header.shard_id(), ?block_producers, "send_chunk_endorsement", ); @@ -243,7 +243,7 @@ impl Client { tracing::debug!( target: "client", chunk_hash=?witness.chunk_header.chunk_hash(), - shard_id=witness.chunk_header.shard_id(), + shard_id=?witness.chunk_header.shard_id(), "process_chunk_state_witness", ); diff --git a/chain/client/src/stateless_validation/chunk_validator/orphan_witness_handling.rs b/chain/client/src/stateless_validation/chunk_validator/orphan_witness_handling.rs index ed68fa977ad..96a55131d2b 100644 --- a/chain/client/src/stateless_validation/chunk_validator/orphan_witness_handling.rs +++ b/chain/client/src/stateless_validation/chunk_validator/orphan_witness_handling.rs @@ -35,7 +35,7 @@ impl Client { let _span = tracing::debug_span!(target: "client", "handle_orphan_state_witness", witness_height, - witness_shard, + ?witness_shard, witness_chunk = ?chunk_header.chunk_hash(), witness_prev_block = ?chunk_header.prev_block_hash(), ) @@ -63,7 +63,7 @@ impl Client { tracing::warn!( target: "client", witness_height, - witness_shard, + ?witness_shard, witness_chunk = ?chunk_header.chunk_hash(), witness_prev_block = ?chunk_header.prev_block_hash(), witness_size, @@ -87,7 +87,7 @@ impl Client { tracing::debug!( target: "client", witness_height = header.height_created(), - witness_shard = header.shard_id(), + witness_shard = ?header.shard_id(), witness_chunk = ?header.chunk_hash(), witness_prev_block = ?header.prev_block_hash(), "Processing an orphaned ChunkStateWitness, its previous block has arrived." diff --git a/chain/client/src/stateless_validation/chunk_validator/orphan_witness_pool.rs b/chain/client/src/stateless_validation/chunk_validator/orphan_witness_pool.rs index 0429875b794..0a15da403e5 100644 --- a/chain/client/src/stateless_validation/chunk_validator/orphan_witness_pool.rs +++ b/chain/client/src/stateless_validation/chunk_validator/orphan_witness_pool.rs @@ -56,7 +56,7 @@ impl OrphanStateWitnessPool { tracing::debug!( target: "client", ejected_witness_height = header.height_created(), - ejected_witness_shard = header.shard_id(), + ejected_witness_shard = ?header.shard_id(), ejected_witness_chunk = ?header.chunk_hash(), ejected_witness_prev_block = ?header.prev_block_hash(), "Ejecting an orphaned ChunkStateWitness from the cache due to capacity limit. It will not be processed." @@ -101,7 +101,7 @@ impl OrphanStateWitnessPool { target: "client", final_height, ejected_witness_height = witness_height, - ejected_witness_shard = cache_key.shard_id, + ejected_witness_shard = ?cache_key.shard_id, ejected_witness_chunk = ?header.chunk_hash(), ejected_witness_prev_block = ?header.prev_block_hash(), "Ejecting an orphaned ChunkStateWitness from the cache because it's below \ @@ -180,7 +180,7 @@ mod tests { use near_primitives::hash::{hash, CryptoHash}; use near_primitives::sharding::{ShardChunkHeader, ShardChunkHeaderInner}; use near_primitives::stateless_validation::state_witness::ChunkStateWitness; - use near_primitives::types::{BlockHeight, ShardId}; + use near_primitives::types::{new_shard_id_tmp, BlockHeight, ShardId}; use super::OrphanStateWitnessPool; @@ -253,10 +253,10 @@ mod tests { fn basic() { let mut pool = OrphanStateWitnessPool::new(10); - let witness1 = make_witness(100, 1, block(99), 0); - let witness2 = make_witness(100, 2, block(99), 0); - let witness3 = make_witness(101, 1, block(100), 0); - let witness4 = make_witness(101, 2, block(100), 0); + let witness1 = make_witness(100, new_shard_id_tmp(1), block(99), 0); + let witness2 = make_witness(100, new_shard_id_tmp(2), block(99), 0); + let witness3 = make_witness(101, new_shard_id_tmp(1), block(100), 0); + let witness4 = make_witness(101, new_shard_id_tmp(2), block(100), 0); pool.add_orphan_state_witness(witness1.clone(), 0); pool.add_orphan_state_witness(witness2.clone(), 0); @@ -280,8 +280,8 @@ mod tests { // The old witness is replaced when the awaited block is the same { - let witness1 = make_witness(100, 1, block(99), 0); - let witness2 = make_witness(100, 1, block(99), 1); + let witness1 = make_witness(100, new_shard_id_tmp(1), block(99), 0); + let witness2 = make_witness(100, new_shard_id_tmp(1), block(99), 1); pool.add_orphan_state_witness(witness1, 0); pool.add_orphan_state_witness(witness2.clone(), 0); @@ -291,8 +291,8 @@ mod tests { // The old witness is replaced when the awaited block is different, waiting_for_block is cleaned as expected { - let witness3 = make_witness(102, 1, block(100), 0); - let witness4 = make_witness(102, 1, block(101), 0); + let witness3 = make_witness(102, new_shard_id_tmp(1), block(100), 0); + let witness4 = make_witness(102, new_shard_id_tmp(1), block(101), 0); pool.add_orphan_state_witness(witness3, 0); pool.add_orphan_state_witness(witness4.clone(), 0); @@ -311,9 +311,9 @@ mod tests { fn limited_capacity() { let mut pool = OrphanStateWitnessPool::new(2); - let witness1 = make_witness(102, 1, block(101), 0); - let witness2 = make_witness(101, 1, block(100), 0); - let witness3 = make_witness(101, 2, block(100), 0); + let witness1 = make_witness(102, new_shard_id_tmp(1), block(101), 0); + let witness2 = make_witness(101, new_shard_id_tmp(1), block(100), 0); + let witness3 = make_witness(101, new_shard_id_tmp(2), block(100), 0); pool.add_orphan_state_witness(witness1, 0); pool.add_orphan_state_witness(witness2.clone(), 0); @@ -337,7 +337,7 @@ mod tests { let mut pool = OrphanStateWitnessPool::new(10); let large_shard_id = ShardId::MAX; - let witness = make_witness(101, large_shard_id, block(99), 0); + let witness = make_witness(101, large_shard_id.into(), block(99), 0); pool.add_orphan_state_witness(witness.clone(), 0); let waiting_for_99 = pool.take_state_witnesses_waiting_for_block(&block(99)); @@ -351,10 +351,10 @@ mod tests { fn remove_below_height() { let mut pool = OrphanStateWitnessPool::new(10); - let witness1 = make_witness(100, 1, block(99), 0); - let witness2 = make_witness(101, 1, block(100), 0); - let witness3 = make_witness(102, 1, block(101), 0); - let witness4 = make_witness(103, 1, block(102), 0); + let witness1 = make_witness(100, new_shard_id_tmp(1), block(99), 0); + let witness2 = make_witness(101, new_shard_id_tmp(1), block(100), 0); + let witness3 = make_witness(102, new_shard_id_tmp(1), block(101), 0); + let witness4 = make_witness(103, new_shard_id_tmp(1), block(102), 0); pool.add_orphan_state_witness(witness1, 0); pool.add_orphan_state_witness(witness2.clone(), 0); @@ -382,10 +382,10 @@ mod tests { #[test] fn destructor_doesnt_crash() { let mut pool = OrphanStateWitnessPool::new(10); - pool.add_orphan_state_witness(make_witness(100, 0, block(99), 0), 0); - pool.add_orphan_state_witness(make_witness(100, 2, block(99), 0), 0); - pool.add_orphan_state_witness(make_witness(100, 2, block(99), 0), 1); - pool.add_orphan_state_witness(make_witness(101, 0, block(100), 0), 0); + pool.add_orphan_state_witness(make_witness(100, new_shard_id_tmp(0), block(99), 0), 0); + pool.add_orphan_state_witness(make_witness(100, new_shard_id_tmp(2), block(99), 0), 0); + pool.add_orphan_state_witness(make_witness(100, new_shard_id_tmp(2), block(99), 0), 1); + pool.add_orphan_state_witness(make_witness(101, new_shard_id_tmp(0), block(100), 0), 0); std::mem::drop(pool); } @@ -395,24 +395,24 @@ mod tests { let mut pool = OrphanStateWitnessPool::new(5); // Witnesses for shards 0, 1, 2, 3 at height 1000, looking for block 99 - let witness0 = make_witness(100, 0, block(99), 0); - let witness1 = make_witness(100, 1, block(99), 0); - let witness2 = make_witness(100, 2, block(99), 0); - let witness3 = make_witness(100, 3, block(99), 0); + let witness0 = make_witness(100, new_shard_id_tmp(0), block(99), 0); + let witness1 = make_witness(100, new_shard_id_tmp(1), block(99), 0); + let witness2 = make_witness(100, new_shard_id_tmp(2), block(99), 0); + let witness3 = make_witness(100, new_shard_id_tmp(3), block(99), 0); pool.add_orphan_state_witness(witness0, 0); pool.add_orphan_state_witness(witness1, 0); pool.add_orphan_state_witness(witness2, 0); pool.add_orphan_state_witness(witness3, 0); // Another witness on shard 1, height 100. Should replace witness1 - let witness5 = make_witness(100, 1, block(99), 1); + let witness5 = make_witness(100, new_shard_id_tmp(1), block(99), 1); pool.add_orphan_state_witness(witness5.clone(), 0); // Witnesses for shards 0, 1, 2, 3 at height 101, looking for block 100 - let witness6 = make_witness(101, 0, block(100), 0); - let witness7 = make_witness(101, 1, block(100), 0); - let witness8 = make_witness(101, 2, block(100), 0); - let witness9 = make_witness(101, 3, block(100), 0); + let witness6 = make_witness(101, new_shard_id_tmp(0), block(100), 0); + let witness7 = make_witness(101, new_shard_id_tmp(1), block(100), 0); + let witness8 = make_witness(101, new_shard_id_tmp(2), block(100), 0); + let witness9 = make_witness(101, new_shard_id_tmp(3), block(100), 0); pool.add_orphan_state_witness(witness6, 0); pool.add_orphan_state_witness(witness7.clone(), 0); pool.add_orphan_state_witness(witness8.clone(), 0); @@ -424,9 +424,9 @@ mod tests { assert_contents(looking_for_99, vec![witness5]); // Let's add a few more witnesses - let witness10 = make_witness(102, 1, block(101), 0); - let witness11 = make_witness(102, 4, block(100), 0); - let witness12 = make_witness(102, 1, block(77), 0); + let witness10 = make_witness(102, new_shard_id_tmp(1), block(101), 0); + let witness11 = make_witness(102, new_shard_id_tmp(4), block(100), 0); + let witness12 = make_witness(102, new_shard_id_tmp(1), block(77), 0); pool.add_orphan_state_witness(witness10, 0); pool.add_orphan_state_witness(witness11.clone(), 0); pool.add_orphan_state_witness(witness12.clone(), 0); diff --git a/chain/client/src/stateless_validation/partial_witness/partial_witness_tracker.rs b/chain/client/src/stateless_validation/partial_witness/partial_witness_tracker.rs index 9b608f1a033..4b3ff3a8e08 100644 --- a/chain/client/src/stateless_validation/partial_witness/partial_witness_tracker.rs +++ b/chain/client/src/stateless_validation/partial_witness/partial_witness_tracker.rs @@ -158,7 +158,7 @@ impl PartialEncodedStateWitnessTracker { tracing::error!( target: "client", ?err, - shard_id = key.shard_id, + shard_id = ?key.shard_id, height_created = key.height_created, "Failed to reed solomon decode witness parts. Maybe malicious or corrupt data." ); diff --git a/chain/client/src/stateless_validation/shadow_validate.rs b/chain/client/src/stateless_validation/shadow_validate.rs index d5c240f8a07..33df5655de8 100644 --- a/chain/client/src/stateless_validation/shadow_validate.rs +++ b/chain/client/src/stateless_validation/shadow_validate.rs @@ -17,11 +17,15 @@ impl Client { tracing::debug!(target: "client", ?block_hash, "shadow validation for block chunks"); let prev_block = self.chain.get_block(block.header().prev_hash())?; let prev_block_chunks = prev_block.chunks(); - for chunk in - block.chunks().iter().filter(|chunk| chunk.is_new_chunk(block.header().height())) + for (shard_index, chunk) in block + .chunks() + .iter() + .enumerate() + .filter(|(_, chunk)| chunk.is_new_chunk(block.header().height())) { let chunk = self.chain.get_chunk_clone_from_header(chunk)?; - let prev_chunk_header = prev_block_chunks.get(chunk.shard_id() as usize).unwrap(); + // TODO(resharding) This doesn't work if shard layout changes. + let prev_chunk_header = prev_block_chunks.get(shard_index).unwrap(); if let Err(err) = self.shadow_validate_chunk(prev_block.header(), prev_chunk_header, &chunk) { @@ -30,7 +34,7 @@ impl Client { tracing::error!( target: "client", ?err, - shard_id = chunk.shard_id(), + shard_id = ?chunk.shard_id(), ?block_hash, "shadow chunk validation failed" ); diff --git a/chain/client/src/stateless_validation/state_witness_producer.rs b/chain/client/src/stateless_validation/state_witness_producer.rs index 94ade12c855..ce714155d29 100644 --- a/chain/client/src/stateless_validation/state_witness_producer.rs +++ b/chain/client/src/stateless_validation/state_witness_producer.rs @@ -256,15 +256,14 @@ impl Client { let mut source_receipt_proofs = HashMap::new(); for receipt_proof_response in incoming_receipt_proofs { let from_block = self.chain.chain_store().get_block(&receipt_proof_response.0)?; + let shard_layout = + self.epoch_manager.get_shard_layout(from_block.header().epoch_id())?; for proof in receipt_proof_response.1.iter() { - let from_shard_id: usize = proof - .1 - .from_shard_id - .try_into() - .map_err(|_| Error::Other("Couldn't convert u64 to usize!".into()))?; + let from_shard_id = proof.1.from_shard_id; + let from_shard_index = shard_layout.get_shard_index(from_shard_id); let from_chunk_hash = from_block .chunks() - .get(from_shard_id) + .get(from_shard_index) .ok_or_else(|| Error::InvalidShardId(proof.1.from_shard_id))? .chunk_hash(); let insert_res = diff --git a/chain/client/src/stateless_validation/state_witness_tracker.rs b/chain/client/src/stateless_validation/state_witness_tracker.rs index 2297e98a6d6..c1d8face71f 100644 --- a/chain/client/src/stateless_validation/state_witness_tracker.rs +++ b/chain/client/src/stateless_validation/state_witness_tracker.rs @@ -153,7 +153,7 @@ mod state_witness_tracker_tests { use near_async::time::{Duration, FakeClock, Utc}; use near_primitives::hash::hash; use near_primitives::stateless_validation::state_witness::ChunkStateWitness; - use near_primitives::types::ShardId; + use near_primitives::types::new_shard_id_tmp; const NUM_VALIDATORS: usize = 3; @@ -205,7 +205,7 @@ mod state_witness_tracker_tests { } fn dummy_witness() -> ChunkStateWitness { - ChunkStateWitness::new_dummy(100, 2 as ShardId, hash("fake hash".as_bytes())) + ChunkStateWitness::new_dummy(100, new_shard_id_tmp(2), hash("fake hash".as_bytes())) } fn dummy_clock() -> FakeClock { diff --git a/chain/client/src/sync/external.rs b/chain/client/src/sync/external.rs index af8c08ec510..451a95e43a2 100644 --- a/chain/client/src/sync/external.rs +++ b/chain/client/src/sync/external.rs @@ -143,7 +143,7 @@ impl ExternalConnection { match self { ExternalConnection::S3 { bucket } => { bucket.put_object(&location, data).await?; - tracing::debug!(target: "state_sync_dump", shard_id, part_length = data.len(), ?location, ?file_type, "Wrote a state part to S3"); + tracing::debug!(target: "state_sync_dump", ?shard_id, part_length = data.len(), ?location, ?file_type, "Wrote a state part to S3"); Ok(()) } ExternalConnection::Filesystem { root_dir } => { @@ -157,7 +157,7 @@ impl ExternalConnection { .truncate(true) .open(&path)?; file.write_all(data)?; - tracing::debug!(target: "state_sync_dump", shard_id, part_length = data.len(), ?location, ?file_type, "Wrote a state part to a file"); + tracing::debug!(target: "state_sync_dump", ?shard_id, part_length = data.len(), ?location, ?file_type, "Wrote a state part to a file"); Ok(()) } ExternalConnection::GCS { gcs_client, bucket, .. } => { @@ -165,7 +165,7 @@ impl ExternalConnection { .object() .create(bucket, data.to_vec(), location, "application/octet-stream") .await?; - tracing::debug!(target: "state_sync_dump", shard_id, part_length = data.len(), ?location, ?file_type, "Wrote a state part to GCS"); + tracing::debug!(target: "state_sync_dump", ?shard_id, part_length = data.len(), ?location, ?file_type, "Wrote a state part to GCS"); Ok(()) } } @@ -194,7 +194,7 @@ impl ExternalConnection { ExternalConnection::S3 { bucket } => { let prefix = format!("{}/", directory_path); let list_results = bucket.list(prefix.clone(), Some("/".to_string())).await?; - tracing::debug!(target: "state_sync_dump", shard_id, ?directory_path, "List state parts in s3"); + tracing::debug!(target: "state_sync_dump", ?shard_id, ?directory_path, "List state parts in s3"); let mut file_names = vec![]; for res in list_results { for obj in res.contents { @@ -205,7 +205,7 @@ impl ExternalConnection { } ExternalConnection::Filesystem { root_dir } => { let path = root_dir.join(directory_path); - tracing::debug!(target: "state_sync_dump", shard_id, ?path, "List state parts in local directory"); + tracing::debug!(target: "state_sync_dump", ?shard_id, ?path, "List state parts in local directory"); std::fs::create_dir_all(&path)?; let mut file_names = vec![]; let files = std::fs::read_dir(&path)?; @@ -217,7 +217,7 @@ impl ExternalConnection { } ExternalConnection::GCS { gcs_client, bucket, .. } => { let prefix = format!("{}/", directory_path); - tracing::debug!(target: "state_sync_dump", shard_id, ?directory_path, "List state parts in GCS"); + tracing::debug!(target: "state_sync_dump", ?shard_id, ?directory_path, "List state parts in GCS"); Ok(gcs_client .object() .list( @@ -277,7 +277,7 @@ pub fn external_storage_location( chain_id: &str, epoch_id: &EpochId, epoch_height: u64, - shard_id: u64, + shard_id: ShardId, file_type: &StateFileType, ) -> String { format!( @@ -291,7 +291,7 @@ pub fn external_storage_location_directory( chain_id: &str, epoch_id: &EpochId, epoch_height: u64, - shard_id: u64, + shard_id: ShardId, obj_type: &StateFileType, ) -> String { location_prefix(chain_id, epoch_height, epoch_id, shard_id, obj_type) @@ -301,7 +301,7 @@ pub fn location_prefix( chain_id: &str, epoch_height: u64, epoch_id: &EpochId, - shard_id: u64, + shard_id: ShardId, obj_type: &StateFileType, ) -> String { match obj_type { @@ -410,6 +410,7 @@ mod test { ExternalConnection, StateFileType, }; use near_o11y::testonly::init_test_logger; + use near_primitives::types::new_shard_id_tmp; use rand::distributions::{Alphanumeric, DistString}; fn random_string(rand_len: usize) -> String { @@ -460,31 +461,38 @@ mod test { let file_type = StateFileType::StatePart { part_id: 0, num_parts: 1 }; // Before uploading we shouldn't see filename in the list of files. - let files = rt.block_on(async { connection.list_objects(0, &dir).await.unwrap() }); + let files = rt + .block_on(async { connection.list_objects(new_shard_id_tmp(0), &dir).await.unwrap() }); tracing::debug!("Files before upload: {:?}", files); assert_eq!(files.into_iter().filter(|x| *x == filename).collect::>().len(), 0); // Uploading the file. rt.block_on(async { - connection.put_file(file_type.clone(), &data, 0, &full_filename).await.unwrap() + connection + .put_file(file_type.clone(), &data, new_shard_id_tmp(0), &full_filename) + .await + .unwrap() }); // After uploading we should see filename in the list of files. - let files = rt.block_on(async { connection.list_objects(0, &dir).await.unwrap() }); + let files = rt + .block_on(async { connection.list_objects(new_shard_id_tmp(0), &dir).await.unwrap() }); tracing::debug!("Files after upload: {:?}", files); assert_eq!(files.into_iter().filter(|x| *x == filename).collect::>().len(), 1); // And the data should match generates data. - let download_data = rt - .block_on(async { connection.get_file(0, &full_filename, &file_type).await.unwrap() }); + let download_data = rt.block_on(async { + connection.get_file(new_shard_id_tmp(0), &full_filename, &file_type).await.unwrap() + }); assert_eq!(download_data, data); // Also try to download some data at nonexistent location and expect to fail. let filename = random_string(8); let full_filename = format!("{}/{}", dir, filename); - let download_data = - rt.block_on(async { connection.get_file(0, &full_filename, &file_type).await }); + let download_data = rt.block_on(async { + connection.get_file(new_shard_id_tmp(0), &full_filename, &file_type).await + }); assert!(download_data.is_err(), "{:?}", download_data); } } diff --git a/chain/client/src/sync/state.rs b/chain/client/src/sync/state.rs index d4dccac3e7e..48d74ff5236 100644 --- a/chain/client/src/sync/state.rs +++ b/chain/client/src/sync/state.rs @@ -49,7 +49,9 @@ use near_primitives::state_part::PartId; use near_primitives::state_sync::{ ShardStateSyncResponse, ShardStateSyncResponseHeader, StatePartKey, }; -use near_primitives::types::{AccountId, EpochHeight, EpochId, ShardId, StateRoot}; +use near_primitives::types::{ + shard_id_as_u32, AccountId, EpochHeight, EpochId, ShardId, StateRoot, +}; use near_store::DBCol; use rand::seq::SliceRandom; use rand::thread_rng; @@ -204,7 +206,7 @@ impl StateSync { &mut self, me: &Option, sync_hash: CryptoHash, - sync_status: &mut HashMap, + sync_status: &mut HashMap, chain: &mut Chain, epoch_manager: &dyn EpochManagerAdapter, highest_height_peers: &[HighestHeightPeerInfo], @@ -232,7 +234,7 @@ impl StateSync { for shard_id in tracking_shards { let version = prev_shard_layout.version(); - let shard_uid = ShardUId { version, shard_id: shard_id as u32 }; + let shard_uid = ShardUId { version, shard_id: shard_id_as_u32(shard_id) }; let mut download_timeout = false; let mut run_shard_state_download = false; let shard_sync_download = sync_status.entry(shard_id).or_insert_with(|| { @@ -344,7 +346,7 @@ impl StateSync { &mut self, chain: &mut Chain, sync_hash: CryptoHash, - shard_sync: &mut HashMap, + shard_sync: &mut HashMap, ) { for StateSyncGetFileResult { sync_hash: msg_sync_hash, shard_id, part_id, result } in self.state_parts_mpsc_rx.try_iter() @@ -539,7 +541,7 @@ impl StateSync { // Currently it is assumed that one of the direct peers of the node is able to generate // the shard header. let peer_id = possible_targets.choose(&mut thread_rng()).cloned().unwrap(); - tracing::debug!(target: "sync", ?peer_id, shard_id, ?sync_hash, ?possible_targets, "request_shard_header"); + tracing::debug!(target: "sync", ?peer_id, ?shard_id, ?sync_hash, ?possible_targets, "request_shard_header"); assert!(header_download.run_me.load(Ordering::SeqCst)); header_download.run_me.store(false, Ordering::SeqCst); header_download.state_requests_count += 1; @@ -668,7 +670,7 @@ impl StateSync { &mut self, me: &Option, sync_hash: CryptoHash, - sync_status: &mut HashMap, + sync_status: &mut HashMap, chain: &mut Chain, epoch_manager: &dyn EpochManagerAdapter, highest_height_peers: &[HighestHeightPeerInfo], @@ -722,7 +724,7 @@ impl StateSync { &mut self, shard_sync_download: &mut ShardSyncDownload, hash: CryptoHash, - shard_id: u64, + shard_id: ShardId, state_response: ShardStateSyncResponse, chain: &mut Chain, ) { @@ -1314,6 +1316,7 @@ mod test { use near_primitives::state_sync::{ CachedParts, ShardStateSyncResponseHeader, ShardStateSyncResponseV3, }; + use near_primitives::types::new_shard_id_tmp; use near_primitives::{test_utils::TestBlockBuilder, types::EpochId}; #[test] @@ -1356,7 +1359,8 @@ mod test { } let request_hash = &chain.head().unwrap().last_block_hash; - let state_sync_header = chain.get_state_response_header(0, *request_hash).unwrap(); + let state_sync_header = + chain.get_state_response_header(new_shard_id_tmp(0), *request_hash).unwrap(); let state_sync_header = match state_sync_header { ShardStateSyncResponseHeader::V1(_) => panic!("Invalid header"), ShardStateSyncResponseHeader::V2(internal) => internal, @@ -1370,7 +1374,7 @@ mod test { genesis_id: Default::default(), highest_block_height: chain.epoch_length + 10, highest_block_hash: Default::default(), - tracked_shards: vec![0], + tracked_shards: vec![new_shard_id_tmp(0)], archival: false, }; @@ -1383,7 +1387,7 @@ mod test { &mut chain, kv.as_ref(), &[highest_height_peer_info], - vec![0], + vec![new_shard_id_tmp(0)], &noop().into_sender(), &noop().into_sender(), &ActixArbiterHandleFutureSpawner(Arbiter::new().handle()), @@ -1398,7 +1402,7 @@ mod test { assert_eq!( NetworkRequests::StateRequestHeader { - shard_id: 0, + shard_id: new_shard_id_tmp(0), sync_hash: *request_hash, peer_id: peer_id.clone(), }, @@ -1406,7 +1410,7 @@ mod test { ); assert_eq!(1, new_shard_sync.len()); - let download = new_shard_sync.get(&0).unwrap(); + let download = new_shard_sync.get(&new_shard_id_tmp(0)).unwrap(); assert_eq!(download.status, ShardSyncStatus::StateDownloadHeader); @@ -1430,14 +1434,14 @@ mod test { }); state_sync.update_download_on_state_response_message( - &mut new_shard_sync.get_mut(&0).unwrap(), + &mut new_shard_sync.get_mut(&new_shard_id_tmp(0)).unwrap(), *request_hash, - 0, + new_shard_id_tmp(0), state_response, &mut chain, ); - let download = new_shard_sync.get(&0).unwrap(); + let download = new_shard_sync.get(&new_shard_id_tmp(0)).unwrap(); assert_eq!(download.status, ShardSyncStatus::StateDownloadHeader); // Download should be marked as done. assert_eq!(download.downloads[0].done, true); diff --git a/chain/client/src/sync_jobs_actor.rs b/chain/client/src/sync_jobs_actor.rs index 176151823ad..5d7e854fb3a 100644 --- a/chain/client/src/sync_jobs_actor.rs +++ b/chain/client/src/sync_jobs_actor.rs @@ -9,7 +9,6 @@ use near_chain::chain::{ use near_performance_metrics_macros::perf; use near_primitives::state_part::PartId; use near_primitives::state_sync::StatePartKey; -use near_primitives::types::ShardId; use near_store::adapter::StoreUpdateAdapter; use near_store::DBCol; @@ -73,7 +72,7 @@ impl SyncJobsActor { tracing::debug_span!(target: "sync", "apply_parts").entered(); let store = msg.runtime_adapter.store(); - let shard_id = msg.shard_uid.shard_id as ShardId; + let shard_id = msg.shard_uid.shard_id(); for part_id in 0..msg.num_parts { let key = borsh::to_vec(&StatePartKey(msg.sync_hash, shard_id, part_id))?; let part = store.get(DBCol::StateParts, &key)?.unwrap(); @@ -124,7 +123,7 @@ impl SyncJobsActor { // Unload mem-trie (in case it is still loaded) before we apply state parts. msg.runtime_adapter.get_tries().unload_mem_trie(&msg.shard_uid); - let shard_id = msg.shard_uid.shard_id as ShardId; + let shard_id = msg.shard_uid.shard_id(); match self.clear_flat_state(&msg) { Err(err) => { self.client_sender.send(ApplyStatePartsResponse { diff --git a/chain/client/src/test_utils/client.rs b/chain/client/src/test_utils/client.rs index f8e8063738b..639e946e116 100644 --- a/chain/client/src/test_utils/client.rs +++ b/chain/client/src/test_utils/client.rs @@ -22,7 +22,7 @@ use near_primitives::merkle::{merklize, PartialMerkleTree}; use near_primitives::sharding::{EncodedShardChunk, ShardChunk}; use near_primitives::stateless_validation::chunk_endorsement::ChunkEndorsementV1; use near_primitives::transaction::SignedTransaction; -use near_primitives::types::{BlockHeight, ShardId}; +use near_primitives::types::{new_shard_id_tmp, BlockHeight, ShardId}; use near_primitives::utils::MaybeValidated; use near_primitives::version::PROTOCOL_VERSION; use num_rational::Ratio; @@ -159,7 +159,7 @@ fn create_chunk_on_height_for_shard( } pub fn create_chunk_on_height(client: &mut Client, next_height: BlockHeight) -> ProduceChunkResult { - create_chunk_on_height_for_shard(client, next_height, 0) + create_chunk_on_height_for_shard(client, next_height, new_shard_id_tmp(0)) } pub fn create_chunk_with_transactions( @@ -190,7 +190,7 @@ pub fn create_chunk( last_block.header().epoch_id(), last_block.chunks()[0].clone(), next_height, - 0, + new_shard_id_tmp(0), signer.as_ref(), ) .unwrap() diff --git a/chain/client/src/test_utils/setup.rs b/chain/client/src/test_utils/setup.rs index c6c669dc8c5..c8490c9070d 100644 --- a/chain/client/src/test_utils/setup.rs +++ b/chain/client/src/test_utils/setup.rs @@ -56,7 +56,9 @@ use near_primitives::epoch_info::RngSeed; use near_primitives::hash::{hash, CryptoHash}; use near_primitives::network::PeerId; use near_primitives::test_utils::create_test_signer; -use near_primitives::types::{AccountId, BlockHeightDelta, EpochId, NumBlocks, NumSeats}; +use near_primitives::types::{ + new_shard_id_tmp, AccountId, BlockHeightDelta, EpochId, NumBlocks, NumSeats, +}; use near_primitives::validator_signer::{EmptyValidatorSigner, ValidatorSigner}; use near_primitives::version::PROTOCOL_VERSION; use near_store::adapter::StoreAdapter; @@ -448,7 +450,10 @@ fn process_peer_manager_message_default( height: last_height[i], hash: CryptoHash::default(), }), - tracked_shards: vec![0, 1, 2, 3], + tracked_shards: vec![0, 1, 2, 3] + .into_iter() + .map(new_shard_id_tmp) + .collect(), archival: true, }, }, diff --git a/chain/client/src/test_utils/test_env.rs b/chain/client/src/test_utils/test_env.rs index 4d612566f48..770d8f65c15 100644 --- a/chain/client/src/test_utils/test_env.rs +++ b/chain/client/src/test_utils/test_env.rs @@ -524,8 +524,10 @@ impl TestEnv { let last_block = client.chain.get_block(&head.last_block_hash).unwrap(); let shard_id = client.epoch_manager.account_id_to_shard_id(&account_id, &head.epoch_id).unwrap(); + let shard_layout = client.epoch_manager.get_shard_layout(&head.epoch_id).unwrap(); + let shard_index = shard_layout.get_shard_index(shard_id); let shard_uid = client.epoch_manager.shard_id_to_uid(shard_id, &head.epoch_id).unwrap(); - let last_chunk_header = &last_block.chunks()[shard_id as usize]; + let last_chunk_header = &last_block.chunks()[shard_index]; for i in 0..self.clients.len() { let tracks_shard = self.clients[i] @@ -582,7 +584,9 @@ impl TestEnv { let shard_id = client.epoch_manager.account_id_to_shard_id(&account_id, &head.epoch_id).unwrap(); let shard_uid = client.epoch_manager.shard_id_to_uid(shard_id, &head.epoch_id).unwrap(); - let last_chunk_header = &last_block.chunks()[shard_id as usize]; + let shard_layout = client.epoch_manager.get_shard_layout(&head.epoch_id).unwrap(); + let shard_index = shard_layout.get_shard_index(shard_id); + let last_chunk_header = &last_block.chunks()[shard_index]; let response = client .runtime_adapter .query( diff --git a/chain/client/src/test_utils/test_loop.rs b/chain/client/src/test_utils/test_loop.rs index 0dd8b8c7767..cb0c461a4da 100644 --- a/chain/client/src/test_utils/test_loop.rs +++ b/chain/client/src/test_utils/test_loop.rs @@ -58,7 +58,9 @@ where let shard_id = client.epoch_manager.account_id_to_shard_id(&account_id, &head.epoch_id).unwrap(); let shard_uid = client.epoch_manager.shard_id_to_uid(shard_id, &head.epoch_id).unwrap(); - let last_chunk_header = &last_block.chunks()[shard_id as usize]; + let shard_layout = client.epoch_manager.get_shard_layout(&head.epoch_id).unwrap(); + let shard_index = shard_layout.get_shard_index(shard_id); + let last_chunk_header = &last_block.chunks()[shard_index]; client .runtime_adapter diff --git a/chain/client/src/tests/bug_repros.rs b/chain/client/src/tests/bug_repros.rs index b239d69c965..5f14ff54e87 100644 --- a/chain/client/src/tests/bug_repros.rs +++ b/chain/client/src/tests/bug_repros.rs @@ -109,28 +109,30 @@ fn repro_1183() { for from in ["test1", "test2", "test3", "test4"].iter() { for to in ["test1", "test2", "test3", "test4"].iter() { let (from, to) = (from.parse().unwrap(), to.parse().unwrap()); - connectors1.write().unwrap()[account_id_to_shard_id(&from, 4) as usize] - .client_actor - .do_send( - ProcessTxRequest { - transaction: SignedTransaction::send_money( - block.header().height() * 16 + nonce_delta, + // This test uses the V0 shard layout so it's ok to + // cast ShardId to ShardIndex. + let shard_id = account_id_to_shard_id(&from, 4); + let shard_index = shard_id as usize; + connectors1.write().unwrap()[shard_index].client_actor.do_send( + ProcessTxRequest { + transaction: SignedTransaction::send_money( + block.header().height() * 16 + nonce_delta, + from.clone(), + to, + &InMemorySigner::from_seed( from.clone(), - to, - &InMemorySigner::from_seed( - from.clone(), - KeyType::ED25519, - from.as_ref(), - ) - .into(), - 1, - *block.header().prev_hash(), - ), - is_forwarded: false, - check_only: false, - } - .with_span_context(), - ); + KeyType::ED25519, + from.as_ref(), + ) + .into(), + 1, + *block.header().prev_hash(), + ), + is_forwarded: false, + check_only: false, + } + .with_span_context(), + ); nonce_delta += 1 } } diff --git a/chain/client/src/tests/catching_up.rs b/chain/client/src/tests/catching_up.rs index f49db6989ec..16f78a624d8 100644 --- a/chain/client/src/tests/catching_up.rs +++ b/chain/client/src/tests/catching_up.rs @@ -23,7 +23,7 @@ use near_primitives::network::PeerId; use near_primitives::receipt::Receipt; use near_primitives::sharding::ChunkHash; use near_primitives::transaction::SignedTransaction; -use near_primitives::types::{AccountId, BlockHeight, BlockHeightDelta, BlockReference}; +use near_primitives::types::{AccountId, BlockHeight, BlockHeightDelta, BlockReference, ShardId}; use near_primitives::views::QueryRequest; use near_primitives::views::QueryResponseKind::ViewAccount; @@ -99,7 +99,7 @@ enum ReceiptsSyncPhases { #[derive(Debug, Clone, PartialEq, Eq, BorshSerialize, BorshDeserialize)] pub struct StateRequestStruct { - pub shard_id: u64, + pub shard_id: ShardId, pub sync_hash: CryptoHash, pub sync_prev_prev_hash: Option, pub part_id: Option, @@ -714,7 +714,8 @@ fn test_all_chunks_accepted_common( let verbose = false; - let seen_chunk_same_sender = Arc::new(RwLock::new(HashSet::<(AccountId, u64, u64)>::new())); + let seen_chunk_same_sender = + Arc::new(RwLock::new(HashSet::<(AccountId, u64, ShardId)>::new())); let requested = Arc::new(RwLock::new(HashSet::<(AccountId, Vec, ChunkHash)>::new())); let responded = Arc::new(RwLock::new(HashSet::<(CryptoHash, Vec, ChunkHash)>::new())); diff --git a/chain/client/src/tests/cross_shard_tx.rs b/chain/client/src/tests/cross_shard_tx.rs index 7467758d81f..8a26ad41f54 100644 --- a/chain/client/src/tests/cross_shard_tx.rs +++ b/chain/client/src/tests/cross_shard_tx.rs @@ -189,8 +189,11 @@ fn test_cross_shard_tx_callback( let balances1 = balances; let observed_balances1 = observed_balances; let presumable_epoch1 = presumable_epoch.clone(); - let actor = &connectors_[account_id_to_shard_id(&account_id, 8) as usize - + (*presumable_epoch.read().unwrap() * 8) % 24] + // This test uses the V0 shard layout so it's ok to cast ShardId to + // ShardIndex. + let shard_id = account_id_to_shard_id(&account_id, 8); + let shard_index = shard_id as usize; + let actor = &connectors_[shard_index + (*presumable_epoch.read().unwrap() * 8) % 24] .view_client_actor; let actor = actor.send( Query::new( @@ -254,10 +257,15 @@ fn test_cross_shard_tx_callback( let amount = (5 + iteration_local) as u128; let next_nonce = nonce.fetch_add(1, Ordering::Relaxed); + // This test uses the V0 shard layout so it's ok to cast ShardId to + // ShardIndex. + let shard_id = account_id_to_shard_id(&validators[from], 8); + let shard_index = shard_id as usize; + send_tx( validators.len(), connectors.clone(), - account_id_to_shard_id(&validators[from], 8) as usize, + shard_index, validators[from].clone(), validators[to].clone(), amount, @@ -287,8 +295,14 @@ fn test_cross_shard_tx_callback( let presumable_epoch1 = presumable_epoch.clone(); let account_id1 = validators[i].clone(); let block_stats1 = block_stats.clone(); - let actor = &connectors_[account_id_to_shard_id(&validators[i], 8) as usize - + (*presumable_epoch.read().unwrap() * 8) % 24] + + // This test uses the V0 shard layout so it's ok to cast ShardId to + // ShardIndex. + let shard_id = account_id_to_shard_id(&validators[i], 8); + let shard_index = shard_id as usize; + + let actor = &connectors_ + [shard_index + (*presumable_epoch.read().unwrap() * 8) % 24] .view_client_actor; let actor = actor.send( Query::new( @@ -341,8 +355,13 @@ fn test_cross_shard_tx_callback( let connectors_ = connectors.write().unwrap(); let connectors1 = connectors.clone(); let presumable_epoch1 = presumable_epoch.clone(); - let actor = &connectors_[account_id_to_shard_id(&account_id, 8) as usize - + (*presumable_epoch.read().unwrap() * 8) % 24] + + // This test uses the V0 shard layout so it's ok to cast ShardId to + // ShardIndex. + let shard_id = account_id_to_shard_id(&account_id, 8); + let shard_index = shard_id as usize; + + let actor = &connectors_[shard_index + (*presumable_epoch.read().unwrap() * 8) % 24] .view_client_actor; let actor = actor.send( Query::new( @@ -498,9 +517,14 @@ fn test_cross_shard_tx_common( let presumable_epoch1 = presumable_epoch.clone(); let account_id1 = validators[i].clone(); let block_stats1 = block_stats.clone(); - let actor = &connectors_[account_id_to_shard_id(&validators[i], 8) as usize - + *presumable_epoch.read().unwrap() * 8] - .view_client_actor; + + // This test uses the V0 shard layout so it's ok to cast ShardId to + // ShardIndex. + let shard_id = account_id_to_shard_id(&validators[i], 8); + let shard_index = shard_id as usize; + + let actor = + &connectors_[shard_index + *presumable_epoch.read().unwrap() * 8].view_client_actor; let actor = actor.send( Query::new( BlockReference::latest(), diff --git a/chain/client/src/tests/process_blocks.rs b/chain/client/src/tests/process_blocks.rs index 2419c0fa7d2..481e244206f 100644 --- a/chain/client/src/tests/process_blocks.rs +++ b/chain/client/src/tests/process_blocks.rs @@ -11,6 +11,7 @@ use near_primitives::network::PeerId; use near_primitives::sharding::ShardChunkHeader; use near_primitives::sharding::ShardChunkHeaderV3; use near_primitives::test_utils::create_test_signer; +use near_primitives::types::new_shard_id_tmp; use near_primitives::types::validator_stake::ValidatorStake; use near_primitives::utils::MaybeValidated; use near_primitives::version::{ProtocolFeature, PROTOCOL_VERSION}; @@ -78,7 +79,7 @@ fn test_bad_shard_id() { chunk.encoded_merkle_root(), chunk.encoded_length(), 2, - 1, + new_shard_id_tmp(1), chunk.prev_gas_used(), chunk.gas_limit(), chunk.prev_balance_burnt(), @@ -102,7 +103,11 @@ fn test_bad_shard_id() { let err = env.clients[0] .process_block_test(MaybeValidated::from(block), Provenance::NONE) .unwrap_err(); - assert_matches!(err, near_chain::Error::InvalidShardId(1)); + if let near_chain::Error::InvalidShardId(shard_id) = err { + assert!(shard_id == new_shard_id_tmp(1)); + } else { + panic!("Expected InvalidShardId error, got {:?}", err); + } } /// Test that if a block's content (vrf_value) is corrupted, the invalid block will not affect the node's block processing diff --git a/chain/client/src/tests/query_client.rs b/chain/client/src/tests/query_client.rs index ac4a5bbbfb2..bc2214e202b 100644 --- a/chain/client/src/tests/query_client.rs +++ b/chain/client/src/tests/query_client.rs @@ -20,7 +20,7 @@ use near_primitives::block::{Block, BlockHeader}; use near_primitives::merkle::PartialMerkleTree; use near_primitives::test_utils::create_test_signer; use near_primitives::transaction::SignedTransaction; -use near_primitives::types::{BlockId, BlockReference, EpochId}; +use near_primitives::types::{new_shard_id_tmp, BlockId, BlockReference, EpochId}; use near_primitives::version::PROTOCOL_VERSION; use near_primitives::views::{QueryRequest, QueryResponseKind}; use num_rational::Ratio; @@ -210,7 +210,7 @@ fn test_execution_outcome_for_chunk() { .unwrap() .unwrap(); assert_eq!(execution_outcomes_in_block.len(), 1); - let outcomes = execution_outcomes_in_block.remove(&0).unwrap(); + let outcomes = execution_outcomes_in_block.remove(&new_shard_id_tmp(0)).unwrap(); assert_eq!(outcomes[0].id, tx_hash); System::current().stop(); }); @@ -249,7 +249,7 @@ fn test_state_request() { for _ in 0..30 { let res = view_client .send( - StateRequestHeader { shard_id: 0, sync_hash: block_hash } + StateRequestHeader { shard_id: new_shard_id_tmp(0), sync_hash: block_hash } .with_span_context(), ) .await @@ -258,14 +258,15 @@ fn test_state_request() { } // immediately query again, should be rejected + let shard_id = new_shard_id_tmp(0); let res = view_client - .send(StateRequestHeader { shard_id: 0, sync_hash: block_hash }.with_span_context()) + .send(StateRequestHeader { shard_id, sync_hash: block_hash }.with_span_context()) .await .unwrap(); assert!(res.is_none()); actix::clock::sleep(std::time::Duration::from_secs(40)).await; let res = view_client - .send(StateRequestHeader { shard_id: 0, sync_hash: block_hash }.with_span_context()) + .send(StateRequestHeader { shard_id, sync_hash: block_hash }.with_span_context()) .await .unwrap(); assert!(res.is_some()); diff --git a/chain/client/src/view_client_actor.rs b/chain/client/src/view_client_actor.rs index b5f12b26b9a..33f66d08d72 100644 --- a/chain/client/src/view_client_actor.rs +++ b/chain/client/src/view_client_actor.rs @@ -299,6 +299,7 @@ impl ViewClientActorInner { let head = self.chain.head()?; let epoch_id = self.epoch_manager.get_epoch_id(&head.last_block_hash)?; let epoch_info: Arc = self.epoch_manager.get_epoch_info(&epoch_id)?; + let shard_layout = self.epoch_manager.get_shard_layout(&epoch_id)?; let shard_ids = self.epoch_manager.shard_ids(&epoch_id)?; let cur_block_info = self.epoch_manager.get_block_info(&head.last_block_hash)?; let next_epoch_start_height = @@ -309,13 +310,17 @@ impl ViewClientActorInner { let mut start_block_of_window: Option = None; let last_block_of_epoch = next_epoch_start_height - 1; + // This loop does not go beyond the current epoch so it is valid to use + // the EpochInfo and ShardLayout from the current epoch. for block_height in head.height..next_epoch_start_height { let bp = epoch_info.sample_block_producer(block_height); let bp = epoch_info.get_validator(bp).account_id().clone(); let cps: Vec = shard_ids .iter() .map(|&shard_id| { - let cp = epoch_info.sample_chunk_producer(block_height, shard_id).unwrap(); + let cp = epoch_info + .sample_chunk_producer(&shard_layout, shard_id, block_height) + .unwrap(); let cp = epoch_info.get_validator(cp).account_id().clone(); cp }) @@ -751,9 +756,12 @@ fn get_chunk_from_block( shard_id: ShardId, chain: &Chain, ) -> Result { + let epoch_id = block.header().epoch_id(); + let shard_layout = chain.epoch_manager.get_shard_layout(epoch_id)?; + let shard_index = shard_layout.get_shard_index(shard_id); let chunk_header = block .chunks() - .get(shard_id as usize) + .get(shard_index) .ok_or_else(|| near_chain::Error::InvalidShardId(shard_id))? .clone(); let chunk_hash = chunk_header.chunk_hash(); @@ -1076,10 +1084,13 @@ impl Handler for ViewClientActorInner { let mut outcome_proof = outcome; let epoch_id = *self.chain.get_block(&outcome_proof.block_hash)?.header().epoch_id(); + let shard_layout = + self.epoch_manager.get_shard_layout(&epoch_id).into_chain_error()?; let target_shard_id = self .epoch_manager .account_id_to_shard_id(&account_id, &epoch_id) .into_chain_error()?; + let target_shard_index = shard_layout.get_shard_index(target_shard_id); let res = self.chain.get_next_block_hash_with_new_chunk( &outcome_proof.block_hash, target_shard_id, @@ -1095,7 +1106,7 @@ impl Handler for ViewClientActorInner { .iter() .map(|header| header.prev_outcome_root()) .collect::>(); - if target_shard_id >= (outcome_roots.len() as u64) { + if target_shard_index >= outcome_roots.len() { return Err(GetExecutionOutcomeError::InconsistentState { number_or_shards: outcome_roots.len(), execution_outcome_shard_id: target_shard_id, @@ -1103,8 +1114,7 @@ impl Handler for ViewClientActorInner { } Ok(GetExecutionOutcomeResponse { outcome_proof: outcome_proof.into(), - outcome_root_proof: merklize(&outcome_roots).1[target_shard_id as usize] - .clone(), + outcome_root_proof: merklize(&outcome_roots).1[target_shard_index].clone(), }) } else { Err(GetExecutionOutcomeError::NotConfirmed { transaction_or_receipt_id: id }) @@ -1361,7 +1371,7 @@ impl Handler for ViewClientActorInner { let header = match header { ShardStateSyncResponseHeader::V2(inner) => inner, _ => { - tracing::error!(target: "sync", ?sync_hash, shard_id, "Invalid state sync header format"); + tracing::error!(target: "sync", ?sync_hash, ?shard_id, "Invalid state sync header format"); return None; } }; @@ -1409,16 +1419,16 @@ impl Handler for ViewClientActorInner { let part = match self.chain.get_state_response_part(shard_id, part_id, sync_hash) { Ok(part) => Some((part_id, part)), Err(err) => { - error!(target: "sync", ?err, ?sync_hash, shard_id, part_id, "Cannot build state part"); + error!(target: "sync", ?err, ?sync_hash, ?shard_id, part_id, "Cannot build state part"); None } }; - tracing::trace!(target: "sync", ?sync_hash, shard_id, part_id, "Finished computation for state request part"); + tracing::trace!(target: "sync", ?sync_hash, ?shard_id, part_id, "Finished computation for state request part"); part } Ok(false) => { - warn!(target: "sync", ?sync_hash, shard_id, "sync_hash didn't pass validation, possible malicious behavior"); + warn!(target: "sync", ?sync_hash, ?shard_id, "sync_hash didn't pass validation, possible malicious behavior"); // Do not respond, possible malicious behavior. return None; } diff --git a/chain/epoch-manager/src/adapter.rs b/chain/epoch-manager/src/adapter.rs index 1dbe8de16a7..12cf5ac2860 100644 --- a/chain/epoch-manager/src/adapter.rs +++ b/chain/epoch-manager/src/adapter.rs @@ -18,7 +18,7 @@ use near_primitives::stateless_validation::validator_assignment::ChunkValidatorA use near_primitives::stateless_validation::ChunkProductionKey; use near_primitives::types::validator_stake::ValidatorStake; use near_primitives::types::{ - AccountId, ApprovalStake, Balance, BlockHeight, EpochHeight, EpochId, ShardId, + AccountId, ApprovalStake, Balance, BlockHeight, EpochHeight, EpochId, ShardId, ShardIndex, ValidatorInfoIdentifier, }; use near_primitives::version::ProtocolVersion; @@ -118,11 +118,13 @@ pub trait EpochManagerAdapter: Send + Sync { /// resharding happened and some shards were split. /// If there was no resharding, it just returns `shard_ids` as is, without any validation. /// The resulting Vec will always be of the same length as the `shard_ids` argument. + /// + /// TODO(wacban) - rename to reflect the new return type fn get_prev_shard_ids( &self, prev_hash: &CryptoHash, shard_ids: Vec, - ) -> Result, Error>; + ) -> Result, Error>; /// For a `ShardId` in the current block, returns its parent `ShardId` /// from previous block. @@ -130,11 +132,13 @@ pub trait EpochManagerAdapter: Send + Sync { /// Most of the times parent of the shard is the shard itself, unless a /// resharding happened and some shards were split. /// If there was no resharding, it just returns the `shard_id` as is, without any validation. + /// + /// TODO(wacban) - rename to reflect the new return type fn get_prev_shard_id( &self, prev_hash: &CryptoHash, shard_id: ShardId, - ) -> Result; + ) -> Result<(ShardId, ShardIndex), Error>; /// Get shard layout given hash of previous block. fn get_shard_layout_from_prev_block( @@ -596,9 +600,9 @@ impl EpochManagerAdapter for EpochManagerHandle { &self, prev_hash: &CryptoHash, shard_ids: Vec, - ) -> Result, Error> { + ) -> Result, Error> { + let shard_layout = self.get_shard_layout_from_prev_block(prev_hash)?; if self.is_next_block_epoch_start(prev_hash)? { - let shard_layout = self.get_shard_layout_from_prev_block(prev_hash)?; let prev_shard_layout = self.get_shard_layout(&self.get_epoch_id(prev_hash)?)?; if prev_shard_layout != shard_layout { return Ok(shard_ids @@ -611,22 +615,27 @@ impl EpochManagerAdapter for EpochManagerHandle { shard_layout, parent_shard_id ); - parent_shard_id + let parent_shard_index = prev_shard_layout.get_shard_index(parent_shard_id); + (parent_shard_id, parent_shard_index) }) }) .collect::>()?); } } - Ok(shard_ids) + + Ok(shard_ids + .iter() + .map(|&shard_id| (shard_id, shard_layout.get_shard_index(shard_id))) + .collect()) } fn get_prev_shard_id( &self, prev_hash: &CryptoHash, shard_id: ShardId, - ) -> Result { + ) -> Result<(ShardId, ShardIndex), Error> { + let shard_layout = self.get_shard_layout_from_prev_block(prev_hash)?; if self.is_next_block_epoch_start(prev_hash)? { - let shard_layout = self.get_shard_layout_from_prev_block(prev_hash)?; let prev_shard_layout = self.get_shard_layout(&self.get_epoch_id(prev_hash)?)?; if prev_shard_layout != shard_layout { let parent_shard_id = shard_layout.get_parent_shard_id(shard_id)?; @@ -636,10 +645,11 @@ impl EpochManagerAdapter for EpochManagerHandle { shard_layout, parent_shard_id ); - return Ok(parent_shard_id); + let parent_shard_index = prev_shard_layout.get_shard_index(parent_shard_id); + return Ok((parent_shard_id, parent_shard_index)); } } - Ok(shard_id) + Ok((shard_id, shard_layout.get_shard_index(shard_id))) } fn get_shard_layout_from_prev_block( diff --git a/chain/epoch-manager/src/lib.rs b/chain/epoch-manager/src/lib.rs index c0c5036247d..92cdbe3b1e4 100644 --- a/chain/epoch-manager/src/lib.rs +++ b/chain/epoch-manager/src/lib.rs @@ -1087,15 +1087,17 @@ impl EpochManager { } let epoch_info = self.get_epoch_info(epoch_id)?; + let shard_layout = self.get_shard_layout(epoch_id)?; let chunk_validators_per_shard = epoch_info.sample_chunk_validators(height); - for (shard_id, chunk_validators) in chunk_validators_per_shard.into_iter().enumerate() { + for (shard_index, chunk_validators) in chunk_validators_per_shard.into_iter().enumerate() { let chunk_validators = chunk_validators .into_iter() .map(|(validator_id, assignment_weight)| { (epoch_info.get_validator(validator_id).take_account_id(), assignment_weight) }) .collect(); - let cache_key = (*epoch_id, shard_id as ShardId, height); + let shard_id = shard_layout.get_shard_id(shard_index); + let cache_key = (*epoch_id, shard_id, height); self.chunk_validators_cache .put(cache_key, Arc::new(ChunkValidatorAssignments::new(chunk_validators))); } @@ -1180,7 +1182,9 @@ impl EpochManager { shard_id: ShardId, ) -> Result { let epoch_info = self.get_epoch_info(epoch_id)?; - let validator_id = Self::chunk_producer_from_info(&epoch_info, height, shard_id)?; + let shard_layout = self.get_shard_layout(epoch_id)?; + let validator_id = + Self::chunk_producer_from_info(&epoch_info, &shard_layout, shard_id, height)?; Ok(epoch_info.get_validator(validator_id)) } @@ -1239,9 +1243,13 @@ impl EpochManager { shard_id: ShardId, ) -> Result { let epoch_info = self.get_epoch_info(&epoch_id)?; + + let shard_layout = self.get_shard_layout(&epoch_id)?; + let shard_index = shard_layout.get_shard_index(shard_id); + let chunk_producers_settlement = epoch_info.chunk_producers_settlement(); let chunk_producers = chunk_producers_settlement - .get(shard_id as usize) + .get(shard_index) .ok_or_else(|| EpochError::ShardingError(format!("invalid shard id {shard_id}")))?; for validator_id in chunk_producers.iter() { if epoch_info.validator_account_id(*validator_id) == account_id { @@ -1458,16 +1466,18 @@ impl EpochManager { ValidatorInfoIdentifier::BlockHash(ref b) => self.get_epoch_id(b)?, }; let cur_epoch_info = self.get_epoch_info(&epoch_id)?; + let cur_shard_layout = self.get_shard_layout(&epoch_id)?; let epoch_height = cur_epoch_info.epoch_height(); let epoch_start_height = self.get_epoch_start_from_epoch_id(&epoch_id)?; let mut validator_to_shard = (0..cur_epoch_info.validators_len()) .map(|_| HashSet::default()) .collect::>>(); - for (shard_id, validators) in + for (shard_index, validators) in cur_epoch_info.chunk_producers_settlement().into_iter().enumerate() { + let shard_id = cur_shard_layout.get_shard_id(shard_index); for validator_id in validators { - validator_to_shard[*validator_id as usize].insert(shard_id as ShardId); + validator_to_shard[*validator_id as usize].insert(shard_id); } } @@ -1630,14 +1640,16 @@ impl EpochManager { }; let next_epoch_info = self.get_epoch_info(&next_epoch_id)?; + let next_shard_layout = self.get_shard_layout(&next_epoch_id)?; let mut next_validator_to_shard = (0..next_epoch_info.validators_len()) .map(|_| HashSet::default()) .collect::>>(); - for (shard_id, validators) in + for (shard_index, validators) in next_epoch_info.chunk_producers_settlement().iter().enumerate() { + let shard_id = next_shard_layout.get_shard_id(shard_index); for validator_id in validators { - next_validator_to_shard[*validator_id as usize].insert(shard_id as u64); + next_validator_to_shard[*validator_id as usize].insert(shard_id); } } let next_validators = next_epoch_info @@ -1742,10 +1754,11 @@ impl EpochManager { #[inline] pub(crate) fn chunk_producer_from_info( epoch_info: &EpochInfo, - height: BlockHeight, + shard_layout: &ShardLayout, shard_id: ShardId, + height: BlockHeight, ) -> Result { - epoch_info.sample_chunk_producer(height, shard_id).ok_or_else(|| { + epoch_info.sample_chunk_producer(shard_layout, shard_id, height).ok_or_else(|| { EpochError::ChunkProducerSelectionError(format!( "Invalid shard {shard_id} for height {height}" )) @@ -2034,6 +2047,7 @@ impl EpochManager { let epoch_id = *self.get_block_info(block_hash)?.epoch_id(); let epoch_info = self.get_epoch_info(&epoch_id)?; + let shard_layout = self.get_shard_layout(&epoch_id)?; let mut aggregator = EpochInfoAggregator::new(epoch_id, *block_hash); let mut cur_hash = *block_hash; @@ -2089,7 +2103,7 @@ impl EpochManager { }; let block_info = self.get_block_info(&cur_hash)?; - aggregator.update_tail(&block_info, &epoch_info, prev_height); + aggregator.update_tail(&block_info, &epoch_info, &shard_layout, prev_height); if prev_hash == self.epoch_info_aggregator.last_block_hash { // We’ve reached sync point of the old aggregator. If old diff --git a/chain/epoch-manager/src/shard_assignment.rs b/chain/epoch-manager/src/shard_assignment.rs index 39e88f64d21..53e421571a4 100644 --- a/chain/epoch-manager/src/shard_assignment.rs +++ b/chain/epoch-manager/src/shard_assignment.rs @@ -1,7 +1,8 @@ use crate::EpochInfo; use crate::RngSeed; use near_primitives::types::validator_stake::ValidatorStake; -use near_primitives::types::{Balance, NumShards, ShardId}; +use near_primitives::types::ShardIndex; +use near_primitives::types::{Balance, NumShards}; use near_primitives::utils::min_heap::{MinHeap, PeekMut}; use rand::Rng; use std::collections::{BTreeSet, HashMap, HashSet}; @@ -22,21 +23,48 @@ impl HasStake for ValidatorStake { } } +/// A helper struct to maintain the shard assignment sorted by the number of +/// validators assigned to each shard. +#[derive(Debug, PartialEq, Eq, PartialOrd, Ord)] +struct ValidatorsFirstShardAssignmentItem { + validators: usize, + stake: Balance, + shard_index: ShardIndex, +} + +type ValidatorsFirstShardAssignment = MinHeap; + +/// A helper struct to maintain the shard assignment sorted by the stake +/// assigned to each shard. +#[derive(Debug, PartialEq, Eq, PartialOrd, Ord)] +struct StakeFirstShardAssignmentItem { + stake: Balance, + validators: usize, + shard_index: ShardIndex, +} + +type StakeFirstShardAssignment = MinHeap; + +impl From for StakeFirstShardAssignmentItem { + fn from(v: ValidatorsFirstShardAssignmentItem) -> Self { + Self { validators: v.validators, stake: v.stake, shard_index: v.shard_index } + } +} + fn assign_to_satisfy_shards_inner>( - shard_index: &mut MinHeap<(usize, Balance, ShardId)>, + shard_assignment: &mut ValidatorsFirstShardAssignment, result: &mut Vec>, cp_iter: &mut I, min_validators_per_shard: usize, ) { - let mut buffer = Vec::with_capacity(shard_index.len()); - // Stores (shard_id, cp_index) meaning that cp at cp_index has already been - // added to shard shard_id. Used to make sure we don’t add a cp to the same + let mut buffer = Vec::with_capacity(shard_assignment.len()); + // Stores (shard_index, cp_index) meaning that cp at cp_index has already been + // added to shard shard_index. Used to make sure we don’t add a cp to the same // shard multiple times. - let mut seen = std::collections::HashSet::<(ShardId, usize)>::with_capacity( - result.len() * min_validators_per_shard, - ); + let seen_capacity = result.len() * min_validators_per_shard; + let mut seen = HashSet::<(ShardIndex, usize)>::with_capacity(seen_capacity); - while shard_index.peek().unwrap().0 < min_validators_per_shard { + while shard_assignment.peek().unwrap().validators < min_validators_per_shard { // cp_iter is an infinite cycle iterator so getting next value can never // fail. cp_index is index of each element in the iterator but the // indexing is done before cycling thus the same cp always gets the same @@ -45,26 +73,26 @@ fn assign_to_satisfy_shards_inner { // No shards left which don’t already contain this chunk // producer. Skip it and move to another producer. break; } - Some(top) if top.0 >= min_validators_per_shard => { - // `shard_index` is sorted by number of chunk producers, + Some(top) if top.validators >= min_validators_per_shard => { + // `shard_assignment` is sorted by number of chunk producers, // thus all remaining shards have min_validators_per_shard // producers already assigned to them. Don’t assign current // one to any shard and move to next cp. break; } - Some(mut top) if seen.insert((top.2, cp_index)) => { + Some(mut top) if seen.insert((top.shard_index, cp_index)) => { // Chunk producer is not yet assigned to the shard and the // shard still needs more producers. Assign `cp` to it and // move to next one. - top.0 += 1; - top.1 += cp.get_stake(); - result[usize::try_from(top.2).unwrap()].push(cp); + top.validators += 1; + top.stake += cp.get_stake(); + result[top.shard_index].push(cp); break; } Some(top) => { @@ -78,7 +106,7 @@ fn assign_to_satisfy_shards_inner( let mut result: Vec> = (0..num_shards).map(|_| Vec::new()).collect(); // Initially, sort by number of validators first so we fill shards up. - let mut shard_index: MinHeap<(usize, Balance, ShardId)> = - (0..num_shards).map(|s| (0, 0, s)).collect(); + let mut shard_assignment: ValidatorsFirstShardAssignment = (0..num_shards) + .map(|shard_index| shard_index as usize) + .map(|shard_index| ValidatorsFirstShardAssignmentItem { + validators: 0, + stake: 0, + shard_index, + }) + .collect(); // Distribute chunk producers until all shards have at least the // minimum requested number. If there are not enough validators to satisfy // that requirement, assign some of the validators to multiple shards. let mut chunk_producers = chunk_producers.into_iter().enumerate().cycle(); assign_to_satisfy_shards_inner( - &mut shard_index, + &mut shard_assignment, &mut result, &mut chunk_producers, min_validators_per_shard, @@ -153,7 +187,11 @@ struct ShardSetItem { /// /// Caller must guarantee that `min_validators_per_shard` is achievable and /// `prev_chunk_producers_assignment` corresponds to the same number of shards. +/// /// TODO(resharding) - implement shard assignment +/// The current shard assignment works fully based on the ShardIndex. During +/// resharding those indices will change and the assignment will move many +/// validators to different shards. This should be avoided. fn assign_to_balance_shards( chunk_producers: Vec, num_shards: NumShards, @@ -304,8 +342,12 @@ pub(crate) fn assign_chunk_producers_to_shards( pub(crate) mod old_validator_selection { use crate::shard_assignment::{assign_to_satisfy_shards_inner, HasStake, NotEnoughValidators}; - use near_primitives::types::{Balance, NumShards, ShardId}; - use near_primitives::utils::min_heap::MinHeap; + use near_primitives::types::NumShards; + + use super::{ + StakeFirstShardAssignment, StakeFirstShardAssignmentItem, ValidatorsFirstShardAssignment, + ValidatorsFirstShardAssignmentItem, + }; /// Assign chunk producers (a.k.a. validators) to shards. The i-th element /// of the output corresponds to the validators assigned to the i-th shard. @@ -344,15 +386,21 @@ pub(crate) mod old_validator_selection { let mut result: Vec> = (0..num_shards).map(|_| Vec::new()).collect(); // Initially, sort by number of validators first so we fill shards up. - let mut shard_index: MinHeap<(usize, Balance, ShardId)> = - (0..num_shards).map(|s| (0, 0, s)).collect(); + let mut shard_assignment: ValidatorsFirstShardAssignment = (0..num_shards) + .map(|shard_index| shard_index as usize) + .map(|shard_index| ValidatorsFirstShardAssignmentItem { + validators: 0, + stake: 0, + shard_index, + }) + .collect(); // First, distribute chunk producers until all shards have at least the // minimum requested number. If there are not enough validators to satisfy // that requirement, assign some of the validators to multiple shards. let mut chunk_producers = chunk_producers.into_iter().enumerate().cycle(); assign_to_satisfy_shards_inner( - &mut shard_index, + &mut shard_assignment, &mut result, &mut chunk_producers, min_validators_per_shard, @@ -364,20 +412,21 @@ pub(crate) mod old_validator_selection { num_chunk_producers.saturating_sub(num_shards as usize * min_validators_per_shard); if remaining_producers > 0 { // Re-index shards to favour lowest stake first. - let mut shard_index: MinHeap<(Balance, usize, ShardId)> = shard_index - .into_iter() - .map(|(count, stake, shard_id)| (stake, count, shard_id)) - .collect(); + let mut shard_assignment: StakeFirstShardAssignment = + shard_assignment.into_iter().map(Into::into).collect(); for (_, cp) in chunk_producers.take(remaining_producers) { - let (least_stake, least_validator_count, shard_id) = - shard_index.pop().expect("shard_index should never be empty"); - shard_index.push(( - least_stake + cp.get_stake(), - least_validator_count + 1, - shard_id, - )); - result[usize::try_from(shard_id).unwrap()].push(cp); + let StakeFirstShardAssignmentItem { + stake: least_stake, + validators: least_validator_count, + shard_index, + } = shard_assignment.pop().expect("shard_assignment should never be empty"); + shard_assignment.push(StakeFirstShardAssignmentItem { + stake: least_stake + cp.get_stake(), + validators: least_validator_count + 1, + shard_index, + }); + result[shard_index].push(cp); } } @@ -390,7 +439,7 @@ mod tests { use crate::shard_assignment::{assign_chunk_producers_to_shards, NotEnoughValidators}; use crate::RngSeed; use near_primitives::types::validator_stake::ValidatorStake; - use near_primitives::types::{AccountId, Balance, NumShards}; + use near_primitives::types::{AccountId, Balance, NumShards, ShardIndex}; use std::collections::{HashMap, HashSet}; const EXPONENTIAL_STAKES: [Balance; 12] = [100, 90, 81, 73, 66, 59, 53, 48, 43, 39, 35, 31]; @@ -486,12 +535,12 @@ mod tests { let mut assignments = assignments .into_iter() .enumerate() - .map(|(shard_id, cps)| { + .map(|(shard_index, cps)| { // All shards must have at least min_validators_per_shard validators. assert!( cps.len() >= min_validators_per_shard, "Shard {} has only {} chunk producers; expected at least {}", - shard_id, + shard_index, cps.len(), min_validators_per_shard ); @@ -500,7 +549,7 @@ mod tests { cps.len(), cps.iter().map(|cp| cp.0).collect::>().len(), "Shard {} contains duplicate chunk producers: {:?}", - shard_id, + shard_index, cps ); // If all is good, aggregate as (cps_count, total_stake) pair. @@ -520,12 +569,12 @@ mod tests { / (stakes.len() as Balance); let assignment = assign_shards(stakes, num_shards, min_validators_per_shard) .expect("There should have been enough validators"); - for (shard_id, &cps) in assignment.iter().enumerate() { + for (shard_index, &cps) in assignment.iter().enumerate() { // Validator distribution should be even. assert_eq!( validators_per_shard, cps.0, "Shard {} has {} validators, expected {}", - shard_id, cps.0, validators_per_shard + shard_index, cps.0, validators_per_shard ); // Stake distribution should be even @@ -533,7 +582,7 @@ mod tests { assert!( diff.abs() < diff_tolerance, "Shard {}'s stake {} is {} away from average; expected less than {} away", - shard_id, + shard_index, cps.1, diff.abs(), diff_tolerance @@ -724,12 +773,12 @@ mod tests { assert_eq!(assignment, target_assignment); } - fn validator_to_shard(assignment: &[Vec]) -> HashMap { + fn validator_to_shard(assignment: &[Vec]) -> HashMap { assignment .iter() .enumerate() - .flat_map(|(shard_id, cps)| { - cps.iter().map(move |cp| (cp.account_id().clone(), shard_id)) + .flat_map(|(shard_index, cps)| { + cps.iter().map(move |cp| (cp.account_id().clone(), shard_index)) }) .collect() } diff --git a/chain/epoch-manager/src/shard_tracker.rs b/chain/epoch-manager/src/shard_tracker.rs index 49cc1abf95c..59b2a8f593a 100644 --- a/chain/epoch-manager/src/shard_tracker.rs +++ b/chain/epoch-manager/src/shard_tracker.rs @@ -81,11 +81,13 @@ impl ShardTracker { shard_layout.shard_ids().map(|_| false).collect(); for account_id in tracked_accounts { let shard_id = account_id_to_shard_id(account_id, &shard_layout); - tracking_mask[shard_id as usize] = true; + let shard_index = shard_layout.get_shard_index(shard_id); + tracking_mask[shard_index] = true; } tracking_mask }); - Ok(tracking_mask.get(shard_id as usize).copied().unwrap_or(false)) + let shard_index = shard_layout.get_shard_index(shard_id); + Ok(tracking_mask.get(shard_index).copied().unwrap_or(false)) } TrackedConfig::AllShards => Ok(true), TrackedConfig::Schedule(schedule) => { @@ -209,13 +211,16 @@ mod tests { use crate::shard_tracker::TrackedConfig; use crate::test_utils::hash_range; use crate::{EpochManager, EpochManagerAdapter, EpochManagerHandle, RewardCalculator}; + use itertools::Itertools; use near_crypto::{KeyType, PublicKey}; use near_primitives::epoch_block_info::BlockInfo; use near_primitives::epoch_manager::{AllEpochConfig, EpochConfig}; use near_primitives::hash::CryptoHash; use near_primitives::shard_layout::ShardLayout; use near_primitives::types::validator_stake::ValidatorStake; - use near_primitives::types::{BlockHeight, EpochId, NumShards, ProtocolVersion, ShardId}; + use near_primitives::types::{ + new_shard_id_tmp, BlockHeight, EpochId, NumShards, ProtocolVersion, ShardId, + }; use near_primitives::version::ProtocolFeature::SimpleNightshade; use near_primitives::version::PROTOCOL_VERSION; use near_store::test_utils::create_test_store; @@ -334,7 +339,7 @@ mod tests { #[test] fn test_track_accounts() { - let shard_ids: Vec<_> = (0..4).collect(); + let shard_ids = (0..4).map(new_shard_id_tmp).collect_vec(); let epoch_manager = get_epoch_manager(PROTOCOL_VERSION, shard_ids.len() as NumShards, false); let shard_layout = epoch_manager.read().get_shard_layout(&EpochId::default()).unwrap(); @@ -359,7 +364,7 @@ mod tests { #[test] fn test_track_all_shards() { - let shard_ids: Vec<_> = (0..4).collect(); + let shard_ids = (0..4).map(new_shard_id_tmp).collect_vec(); let epoch_manager = get_epoch_manager(PROTOCOL_VERSION, shard_ids.len() as NumShards, false); let tracker = ShardTracker::new(TrackedConfig::AllShards, Arc::new(epoch_manager)); @@ -378,17 +383,21 @@ mod tests { #[test] fn test_track_schedule() { // Creates a ShardTracker that changes every epoch tracked shards. - let shard_ids: Vec<_> = (0..4).collect(); + let shard_ids = (0..4).map(new_shard_id_tmp).collect_vec(); + let epoch_manager = Arc::new(get_epoch_manager(PROTOCOL_VERSION, shard_ids.len() as NumShards, false)); - let subset1 = HashSet::from([0, 1]); - let subset2 = HashSet::from([1, 2]); - let subset3 = HashSet::from([2, 3]); + let subset1: HashSet = + HashSet::from([0, 1]).into_iter().map(new_shard_id_tmp).collect(); + let subset2: HashSet = + HashSet::from([1, 2]).into_iter().map(new_shard_id_tmp).collect(); + let subset3: HashSet = + HashSet::from([2, 3]).into_iter().map(new_shard_id_tmp).collect(); let tracker = ShardTracker::new( TrackedConfig::Schedule(vec![ subset1.clone().into_iter().collect(), - subset2.clone().into_iter().collect(), - subset3.clone().into_iter().collect(), + subset2.clone().into_iter().map(Into::into).collect(), + subset3.clone().into_iter().map(Into::into).collect(), ]), epoch_manager.clone(), ); diff --git a/chain/epoch-manager/src/tests/mod.rs b/chain/epoch-manager/src/tests/mod.rs index 6bc992de46c..80542627ec5 100644 --- a/chain/epoch-manager/src/tests/mod.rs +++ b/chain/epoch-manager/src/tests/mod.rs @@ -26,6 +26,7 @@ use near_primitives::stateless_validation::partial_witness::PartialEncodedStateW use near_primitives::types::ValidatorKickoutReason::{ NotEnoughBlocks, NotEnoughChunkEndorsements, NotEnoughChunks, }; +use near_primitives::types::{new_shard_id_tmp, ShardIndex}; use near_primitives::validator_signer::ValidatorSigner; use near_primitives::version::ProtocolFeature::{self, SimpleNightshade}; use near_primitives::version::PROTOCOL_VERSION; @@ -882,12 +883,13 @@ fn test_reward_multiple_shards() { for height in 1..(2 * epoch_length) { let i = height as usize; let epoch_id = epoch_manager.get_epoch_id_from_prev_block(&h[i - 1]).unwrap(); + let shard_layout = epoch_manager.get_shard_layout(&epoch_id).unwrap(); // test1 skips its chunks in the first epoch let chunk_mask = (0..num_shards) .map(|shard_index| { - let expected_chunk_producer = epoch_manager - .get_chunk_producer_info(&epoch_id, height, shard_index as u64) - .unwrap(); + let shard_id = shard_layout.get_shard_id(shard_index as ShardIndex); + let expected_chunk_producer = + epoch_manager.get_chunk_producer_info(&epoch_id, height, shard_id).unwrap(); if expected_chunk_producer.account_id() == "test1" && epoch_id == init_epoch_id { expected_chunks += 1; false @@ -1092,11 +1094,17 @@ fn test_expected_chunks_prev_block_not_produced() { let height = i as u64; let epoch_id = epoch_manager.get_epoch_id_from_prev_block(&prev_block).unwrap(); let epoch_info = epoch_manager.get_epoch_info(&epoch_id).unwrap().clone(); + let shard_layout = epoch_manager.get_shard_layout(&epoch_id).unwrap(); let block_producer = EpochManager::block_producer_from_info(&epoch_info, height); let prev_block_info = epoch_manager.get_block_info(&prev_block).unwrap(); let prev_height = prev_block_info.height(); - let expected_chunk_producer = - EpochManager::chunk_producer_from_info(&epoch_info, prev_height + 1, 0).unwrap(); + let expected_chunk_producer = EpochManager::chunk_producer_from_info( + &epoch_info, + &shard_layout, + new_shard_id_tmp(0), + prev_height + 1, + ) + .unwrap(); // test1 does not produce blocks during first epoch if block_producer == 0 && epoch_id == initial_epoch_id { expected += 1; @@ -1491,15 +1499,20 @@ fn test_chunk_producer_kickout() { let height = height as u64; let epoch_id = em.get_epoch_id_from_prev_block(prev_block).unwrap(); let epoch_info = em.get_epoch_info(&epoch_id).unwrap().clone(); + let shard_layout = em.get_shard_layout(&epoch_id).unwrap(); let chunk_mask = (0..4) - .map(|shard_id| { + .map(|shard_index| { if height >= epoch_length { return true; } - - let chunk_producer = - EpochManager::chunk_producer_from_info(&epoch_info, height, shard_id as u64) - .unwrap(); + let shard_id = shard_layout.get_shard_id(shard_index); + let chunk_producer = EpochManager::chunk_producer_from_info( + &epoch_info, + &shard_layout, + shard_id, + height, + ) + .unwrap(); // test1 skips chunks if chunk_producer == 0 { expected += 1; @@ -1636,20 +1649,24 @@ fn test_chunk_validator_kickout_using_endorsement_stats() { for (prev_block, (height, curr_block)) in hashes.iter().zip(hashes.iter().enumerate().skip(1)) { let height = height as u64; let epoch_id = em.get_epoch_id_from_prev_block(prev_block).unwrap(); + let shard_layout = em.get_shard_layout(&epoch_id).unwrap(); // All chunks are produced. let chunk_mask = vec![true; num_shards as usize]; // Prepare the chunk endorsements so that "test2" misses some of the endorsements. let mut bitmap = ChunkEndorsementsBitmap::new(num_shards as usize); - for shard_id in 0..num_shards { + for shard_id in shard_layout.shard_ids() { let chunk_validators = em .get_chunk_validator_assignments(&epoch_id, shard_id, height) .unwrap() .ordered_chunk_validators(); + let shard_index = shard_layout.get_shard_index(shard_id); bitmap.add_endorsements( - shard_id, + shard_index, chunk_validators .iter() - .map(|account| account.as_str() != "test2" || (height + shard_id) % 2 == 0) + .map(|account| { + account.as_str() != "test2" || (height + shard_index as u64) % 2 == 0 + }) .collect(), ) } @@ -2584,13 +2601,13 @@ fn test_validator_kickout_determinism() { (4, ChunkStats::new_with_endorsement(89, 100)), ]); let chunk_stats_tracker1 = HashMap::from([ - (0, chunk_stats0.clone().into_iter().collect()), - (1, chunk_stats1.clone().into_iter().collect()), + (new_shard_id_tmp(0), chunk_stats0.clone().into_iter().collect()), + (new_shard_id_tmp(1), chunk_stats1.clone().into_iter().collect()), ]); let chunk_stats0: Vec<_> = chunk_stats0.into_iter().rev().collect(); let chunk_stats_tracker2 = HashMap::from([ - (0, chunk_stats0.into_iter().collect()), - (1, chunk_stats1.into_iter().collect()), + (new_shard_id_tmp(0), chunk_stats0.into_iter().collect()), + (new_shard_id_tmp(1), chunk_stats1.into_iter().collect()), ]); let (_validator_stats, kickouts1) = EpochManager::compute_validators_to_reward_and_kickout( &epoch_config, @@ -2653,8 +2670,8 @@ fn test_chunk_validators_with_different_endorsement_ratio() { (3, ChunkStats::new_with_endorsement(60, 100)), ]); let chunk_stats_tracker = HashMap::from([ - (0, chunk_stats0.into_iter().collect()), - (1, chunk_stats1.into_iter().collect()), + (new_shard_id_tmp(0), chunk_stats0.into_iter().collect()), + (new_shard_id_tmp(1), chunk_stats1.into_iter().collect()), ]); let (_validator_stats, kickouts) = EpochManager::compute_validators_to_reward_and_kickout( &epoch_config, @@ -2715,8 +2732,8 @@ fn test_chunk_validators_with_same_endorsement_ratio_and_different_stake() { (3, ChunkStats::new_with_endorsement(65, 100)), ]); let chunk_stats_tracker = HashMap::from([ - (0, chunk_stats0.into_iter().collect()), - (1, chunk_stats1.into_iter().collect()), + (new_shard_id_tmp(0), chunk_stats0.into_iter().collect()), + (new_shard_id_tmp(1), chunk_stats1.into_iter().collect()), ]); let (_validator_stats, kickouts) = EpochManager::compute_validators_to_reward_and_kickout( &epoch_config, @@ -2777,8 +2794,8 @@ fn test_chunk_validators_with_same_endorsement_ratio_and_stake() { (3, ChunkStats::new_with_endorsement(65, 100)), ]); let chunk_stats_tracker = HashMap::from([ - (0, chunk_stats0.into_iter().collect()), - (1, chunk_stats1.into_iter().collect()), + (new_shard_id_tmp(0), chunk_stats0.into_iter().collect()), + (new_shard_id_tmp(1), chunk_stats1.into_iter().collect()), ]); let (_validator_stats, kickouts) = EpochManager::compute_validators_to_reward_and_kickout( &epoch_config, @@ -2826,7 +2843,7 @@ fn test_validator_kickout_sanity() { ]); let chunk_stats_tracker = HashMap::from([ ( - 0, + new_shard_id_tmp(0), HashMap::from([ (0, ChunkStats::new_with_production(100, 100)), ( @@ -2844,7 +2861,7 @@ fn test_validator_kickout_sanity() { ]), ), ( - 1, + new_shard_id_tmp(1), HashMap::from([ (0, ChunkStats::new_with_production(70, 100)), ( @@ -2964,7 +2981,7 @@ fn test_chunk_endorsement_stats() { ]), &HashMap::from([ ( - 0, + new_shard_id_tmp(0), HashMap::from([ (0, ChunkStats::new(100, 100, 100, 100)), (1, ChunkStats::new(90, 100, 100, 100)), @@ -2973,7 +2990,7 @@ fn test_chunk_endorsement_stats() { ]), ), ( - 1, + new_shard_id_tmp(1), HashMap::from([ (0, ChunkStats::new(95, 100, 100, 100)), (1, ChunkStats::new(95, 100, 90, 100)), @@ -3043,16 +3060,16 @@ fn test_max_kickout_stake_ratio() { // validator 3 doesn't need to produce any block or chunk (3, ValidatorStats { produced: 0, expected: 0 }), ]); - let chunk_stats = HashMap::from([ + let chunk_stats_tracker = HashMap::from([ ( - 0, + new_shard_id_tmp(0), HashMap::from([ (0, ChunkStats::new_with_production(0, 100)), (1, ChunkStats::new_with_production(0, 100)), ]), ), ( - 1, + new_shard_id_tmp(1), HashMap::from([ (2, ChunkStats::new_with_production(100, 100)), (4, ChunkStats::new_with_production(50, 100)), @@ -3065,7 +3082,7 @@ fn test_max_kickout_stake_ratio() { &epoch_config, &epoch_info, &block_stats, - &chunk_stats, + &chunk_stats_tracker, &HashMap::new(), &prev_validator_kickout, ); @@ -3125,7 +3142,7 @@ fn test_max_kickout_stake_ratio() { &epoch_config, &epoch_info, &block_stats, - &chunk_stats, + &chunk_stats_tracker, &HashMap::new(), &prev_validator_kickout, ); @@ -3173,9 +3190,9 @@ fn test_chunk_validator_kickout( (2, ValidatorStats { produced: 90, expected: 100 }), (3, ValidatorStats { produced: 0, expected: 0 }), ]); - let chunk_stats = HashMap::from([ + let chunk_stats_tracker = HashMap::from([ ( - 0, + new_shard_id_tmp(0), HashMap::from([ (0, ChunkStats::new_with_production(90, 100)), (1, ChunkStats::new_with_production(90, 100)), @@ -3185,7 +3202,7 @@ fn test_chunk_validator_kickout( ]), ), ( - 1, + new_shard_id_tmp(1), HashMap::from([ (0, ChunkStats::new_with_production(90, 100)), (2, ChunkStats::new_with_production(90, 100)), @@ -3204,7 +3221,7 @@ fn test_chunk_validator_kickout( &epoch_config, &epoch_info, &block_stats, - &chunk_stats, + &chunk_stats_tracker, &HashMap::new(), &prev_validator_kickout, ); @@ -3251,9 +3268,9 @@ fn test_block_and_chunk_producer_not_kicked_out_for_low_endorsements() { (1, ValidatorStats { produced: 90, expected: 100 }), (2, ValidatorStats { produced: 90, expected: 100 }), ]); - let chunk_stats = HashMap::from([ + let chunk_stats_tracker = HashMap::from([ ( - 0, + new_shard_id_tmp(0), HashMap::from([ (0, ChunkStats::new(90, 100, 10, 100)), (1, ChunkStats::new(90, 100, 10, 100)), @@ -3261,7 +3278,7 @@ fn test_block_and_chunk_producer_not_kicked_out_for_low_endorsements() { ]), ), ( - 1, + new_shard_id_tmp(1), HashMap::from([ (0, ChunkStats::new(90, 100, 10, 100)), (1, ChunkStats::new(90, 100, 10, 100)), @@ -3276,7 +3293,7 @@ fn test_block_and_chunk_producer_not_kicked_out_for_low_endorsements() { &epoch_config, &epoch_info, &block_stats, - &chunk_stats, + &chunk_stats_tracker, &HashMap::new(), &HashMap::new(), ); @@ -3295,7 +3312,7 @@ fn test_chunk_header(h: &[CryptoHash], signer: &ValidatorSigner) -> ShardChunkHe h[2], 0, 1, - 0, + new_shard_id_tmp(0), 0, 0, 0, @@ -3329,7 +3346,7 @@ fn test_verify_chunk_endorsements() { // verify if we have one chunk validator let chunk_validator_assignments = - &epoch_manager.get_chunk_validator_assignments(&epoch_id, 0, 1).unwrap(); + &epoch_manager.get_chunk_validator_assignments(&epoch_id, new_shard_id_tmp(0), 1).unwrap(); assert_eq!(chunk_validator_assignments.ordered_chunk_validators().len(), 1); assert!(chunk_validator_assignments.contains(&account_id)); diff --git a/chain/epoch-manager/src/tests/random_epochs.rs b/chain/epoch-manager/src/tests/random_epochs.rs index 8d8bdc4f329..ac0a46b0687 100644 --- a/chain/epoch-manager/src/tests/random_epochs.rs +++ b/chain/epoch-manager/src/tests/random_epochs.rs @@ -325,7 +325,8 @@ fn verify_block_stats( { let aggregator = epoch_manager.get_epoch_info_aggregator_upto_last(&block_hashes[i]).unwrap(); - let epoch_info = epoch_manager.get_epoch_info(block_infos[i].epoch_id()).unwrap(); + let epoch_id = block_infos[i].epoch_id(); + let epoch_info = epoch_manager.get_epoch_info(epoch_id).unwrap(); for key in aggregator.block_tracker.keys().copied() { assert!(key < epoch_info.validators_iter().len() as u64); } @@ -340,7 +341,10 @@ fn verify_block_stats( aggregator.block_tracker.values().map(|value| value.expected).sum::(); assert_eq!(sum_produced, blocks_in_epoch); assert_eq!(sum_expected, blocks_in_epoch_expected); - for shard_id in 0..(aggregator.shard_tracker.len() as u64) { + // TODO: The following sophisticated check doesn't do anything. The + // shard tracker is empty because the chunk mask in all block infos + // is empty. + for &shard_id in aggregator.shard_tracker.keys() { let sum_produced = aggregator .shard_tracker .get(&shard_id) diff --git a/chain/epoch-manager/src/types.rs b/chain/epoch-manager/src/types.rs index 22cda704587..aebac696fb7 100644 --- a/chain/epoch-manager/src/types.rs +++ b/chain/epoch-manager/src/types.rs @@ -3,6 +3,7 @@ use itertools::Itertools; use near_primitives::epoch_block_info::BlockInfo; use near_primitives::epoch_info::EpochInfo; use near_primitives::hash::CryptoHash; +use near_primitives::shard_layout::ShardLayout; use near_primitives::types::validator_stake::ValidatorStake; use near_primitives::types::{ AccountId, BlockHeight, ChunkStats, EpochId, ShardId, ValidatorId, ValidatorStats, @@ -69,6 +70,7 @@ impl EpochInfoAggregator { &mut self, block_info: &BlockInfo, epoch_info: &EpochInfo, + shard_layout: &ShardLayout, prev_block_height: BlockHeight, ) { let _span = @@ -105,12 +107,13 @@ impl EpochInfoAggregator { // TODO(#11900): Call EpochManager::get_chunk_validator_assignments to access the cached validator assignments. let chunk_validator_assignment = epoch_info.sample_chunk_validators(prev_block_height + 1); - for (i, mask) in block_info.chunk_mask().iter().enumerate() { - let shard_id: ShardId = i as ShardId; + for (shard_index, mask) in block_info.chunk_mask().iter().enumerate() { + let shard_id = shard_layout.get_shard_id(shard_index); let chunk_producer_id = EpochManager::chunk_producer_from_info( epoch_info, + shard_layout, + shard_id, prev_block_height + 1, - i as ShardId, ) .unwrap(); let tracker = self.shard_tracker.entry(shard_id).or_insert_with(HashMap::new); @@ -123,7 +126,7 @@ impl EpochInfoAggregator { debug!( target: "epoch_tracker", chunk_validator = ?epoch_info.validator_account_id(chunk_producer_id), - shard_id = i, + ?shard_id, block_height = prev_block_height + 1, "Missed chunk"); } @@ -132,7 +135,7 @@ impl EpochInfoAggregator { .or_insert_with(|| ChunkStats::new_with_production(u64::from(*mask), 1)); let chunk_validators = chunk_validator_assignment - .get(i) + .get(shard_index) .map_or::<&[(u64, u128)], _>(&[], Vec::as_slice) .iter() .map(|(id, _)| *id) @@ -148,14 +151,14 @@ impl EpochInfoAggregator { // For old chunks, we optimize the block and its header by not including the chunk endorsements and // corresponding bitmaps. Thus, we expect that the bitmap is non-empty for new chunks only. if *mask { - debug_assert!(chunk_endorsements.len(shard_id).unwrap() == chunk_validators.len().div_ceil(8) * 8, - "Chunk endorsement bitmap length is inconsistent with number of chunk validators. Bitmap length={}, num validators={}, shard_id={}", - chunk_endorsements.len(shard_id).unwrap(), chunk_validators.len(), shard_id); - chunk_endorsements.iter(shard_id) + debug_assert!(chunk_endorsements.len(shard_index).unwrap() == chunk_validators.len().div_ceil(8) * 8, + "Chunk endorsement bitmap length is inconsistent with number of chunk validators. Bitmap length={}, num validators={}, shard_index={}", + chunk_endorsements.len(shard_index).unwrap(), chunk_validators.len(), shard_index); + chunk_endorsements.iter(shard_index) } else { - debug_assert_eq!(chunk_endorsements.len(shard_id).unwrap(), 0, - "Chunk endorsement bitmap must be empty for missing chunk. Bitmap length={}, shard_id={}", - chunk_endorsements.len(shard_id).unwrap(), shard_id); + debug_assert_eq!(chunk_endorsements.len(shard_index).unwrap(), 0, + "Chunk endorsement bitmap must be empty for missing chunk. Bitmap length={}, shard_index={}", + chunk_endorsements.len(shard_index).unwrap(), shard_index); Box::new(std::iter::repeat(false).take(chunk_validators.len())) } } else { diff --git a/chain/epoch-manager/src/validator_selection.rs b/chain/epoch-manager/src/validator_selection.rs index dc9716ff71b..0e58c0221b9 100644 --- a/chain/epoch-manager/src/validator_selection.rs +++ b/chain/epoch-manager/src/validator_selection.rs @@ -591,12 +591,12 @@ mod old_validator_selection { all_validators.push(bp.clone()); } - let shard_ids: Vec<_> = epoch_config.shard_layout.shard_ids().collect(); + let num_shards = epoch_config.shard_layout.shard_ids().count(); if chunk_producers.is_empty() { // All validators tried to unstake? return Err(EpochError::NotEnoughValidators { num_validators: 0u64, - num_shards: shard_ids.len() as NumShards, + num_shards: num_shards as u64, }); } @@ -605,11 +605,9 @@ mod old_validator_selection { // each validator as even as possible). Note that in prod configuration number of seats // per shard is the same as maximal number of block producers, so normally all // validators would be assigned to all chunks - let chunk_producers_settlement = shard_ids - .iter() - .map(|&shard_id| shard_id as usize) - .map(|shard_id| { - (0..epoch_config.num_block_producer_seats_per_shard[shard_id] + let chunk_producers_settlement = (0..num_shards) + .map(|shard_index| { + (0..epoch_config.num_block_producer_seats_per_shard[shard_index] .min(block_producers_settlement.len() as u64)) .map(|_| { let res = block_producers_settlement[id]; @@ -637,6 +635,7 @@ mod tests { use near_primitives::epoch_manager::ValidatorSelectionConfig; use near_primitives::shard_layout::ShardLayout; use near_primitives::types::validator_stake::ValidatorStake; + use near_primitives::types::ShardIndex; use near_primitives::version::PROTOCOL_VERSION; use num_rational::Ratio; @@ -964,12 +963,15 @@ mod tests { ) .unwrap(); - for shard_id in 0..num_shards { + let shard_layout = &epoch_config.shard_layout; + for shard_index in 0..num_shards { + let shard_index = shard_index as ShardIndex; + let shard_id = shard_layout.get_shard_id(shard_index); for h in 0..100_000 { - let cp = epoch_info.sample_chunk_producer(h, shard_id); + let cp = epoch_info.sample_chunk_producer(shard_layout, shard_id, h); // Don't read too much into this. The reason the ValidatorId always // equals the ShardId is because the validators are assigned to shards in order. - assert_eq!(cp, Some(shard_id)) + assert_eq!(cp, Some(shard_index as u64)) } } @@ -992,10 +994,10 @@ mod tests { ) .unwrap(); - for shard_id in 0..num_shards { + for shard_id in shard_layout.shard_ids() { let mut counts: [i32; 2] = [0, 0]; for h in 0..100_000 { - let cp = epoch_info.sample_chunk_producer(h, shard_id).unwrap(); + let cp = epoch_info.sample_chunk_producer(shard_layout, shard_id, h).unwrap(); // if ValidatorId is in the second half then it is the lower // stake validator (because they are sorted by decreasing stake). let index = if cp >= num_shards { 1 } else { 0 }; diff --git a/chain/indexer/src/streamer/mod.rs b/chain/indexer/src/streamer/mod.rs index 3e2c2dbb0cc..45b01ccd042 100644 --- a/chain/indexer/src/streamer/mod.rs +++ b/chain/indexer/src/streamer/mod.rs @@ -79,8 +79,7 @@ pub async fn build_streamer_message( let chunks = fetch_block_chunks(&client, &block).await?; let protocol_config_view = fetch_protocol_config(&client, block.header.hash).await?; - let num_shards = protocol_config_view.num_block_producer_seats_per_shard.len() - as near_primitives::types::NumShards; + let shard_ids = protocol_config_view.shard_layout.shard_ids(); let runtime_config_store = near_parameters::RuntimeConfigStore::new(None); let runtime_config = runtime_config_store.get_config(protocol_config_view.protocol_version); @@ -92,7 +91,7 @@ pub async fn build_streamer_message( near_primitives::types::EpochId(block.header.epoch_id), ) .await?; - let mut indexer_shards = (0..num_shards) + let mut indexer_shards = shard_ids .map(|shard_id| IndexerShard { shard_id, chunk: None, @@ -101,12 +100,10 @@ pub async fn build_streamer_message( }) .collect::>(); - for chunk in chunks { + for (shard_index, chunk) in chunks.into_iter().enumerate() { let views::ChunkView { transactions, author, header, receipts: chunk_non_local_receipts } = chunk; - let shard_id = header.shard_id as usize; - let mut outcomes = shards_outcomes .remove(&header.shard_id) .expect("Execution outcomes for given shard should be present"); @@ -236,9 +233,9 @@ pub async fn build_streamer_message( chunk_receipts.extend(chunk_non_local_receipts); - indexer_shards[shard_id].receipt_execution_outcomes = receipt_execution_outcomes; + indexer_shards[shard_index].receipt_execution_outcomes = receipt_execution_outcomes; // Put the chunk into corresponding indexer shard - indexer_shards[shard_id].chunk = Some(IndexerChunkView { + indexer_shards[shard_index].chunk = Some(IndexerChunkView { author, header, transactions: indexer_transactions, @@ -250,12 +247,13 @@ pub async fn build_streamer_message( // chunks and we end up with non-empty `shards_outcomes` we want to be sure we put them into IndexerShard // That might happen before the fix https://github.com/near/nearcore/pull/4228 for (shard_id, outcomes) in shards_outcomes { - indexer_shards[shard_id as usize].receipt_execution_outcomes.extend( - outcomes.into_iter().map(|outcome| IndexerExecutionOutcomeWithReceipt { + let shard_index = protocol_config_view.shard_layout.get_shard_index(shard_id); + indexer_shards[shard_index].receipt_execution_outcomes.extend(outcomes.into_iter().map( + |outcome| IndexerExecutionOutcomeWithReceipt { execution_outcome: outcome.execution_outcome, receipt: outcome.receipt.expect("`receipt` must be present at this moment"), - }), - ) + }, + )) } Ok(StreamerMessage { block, shards: indexer_shards }) diff --git a/chain/jsonrpc-primitives/src/types/chunks.rs b/chain/jsonrpc-primitives/src/types/chunks.rs index d571e5ea32f..de5ebcff138 100644 --- a/chain/jsonrpc-primitives/src/types/chunks.rs +++ b/chain/jsonrpc-primitives/src/types/chunks.rs @@ -34,6 +34,7 @@ pub enum RpcChunkError { #[serde(skip_serializing)] error_message: String, }, + // TODO Should use ShardId instead of u64 #[error("Shard id {shard_id} does not exist")] InvalidShardId { shard_id: u64 }, #[error("Chunk with hash {chunk_hash:?} has never been observed on this node")] diff --git a/chain/jsonrpc/jsonrpc-tests/tests/rpc_query.rs b/chain/jsonrpc/jsonrpc-tests/tests/rpc_query.rs index dc34436601a..727febd5987 100644 --- a/chain/jsonrpc/jsonrpc-tests/tests/rpc_query.rs +++ b/chain/jsonrpc/jsonrpc-tests/tests/rpc_query.rs @@ -15,7 +15,7 @@ use near_network::test_utils::wait_or_timeout; use near_o11y::testonly::init_test_logger; use near_primitives::account::{AccessKey, AccessKeyPermission}; use near_primitives::hash::CryptoHash; -use near_primitives::types::{BlockId, BlockReference, EpochId, SyncCheckpoint}; +use near_primitives::types::{new_shard_id_tmp, BlockId, BlockReference, EpochId, SyncCheckpoint}; use near_primitives::views::QueryRequest; use near_time::Clock; @@ -90,7 +90,10 @@ fn test_block_query() { #[test] fn test_chunk_by_hash() { test_with_client!(test_utils::NodeType::NonValidator, client, async move { - let chunk = client.chunk(ChunkId::BlockShardId(BlockId::Height(0), 0u64)).await.unwrap(); + let chunk = client + .chunk(ChunkId::BlockShardId(BlockId::Height(0), new_shard_id_tmp(0))) + .await + .unwrap(); assert_eq!(chunk.author, "test1"); assert_eq!(chunk.header.balance_burnt, 0); assert_eq!(chunk.header.chunk_hash.as_ref().len(), 32); @@ -104,7 +107,7 @@ fn test_chunk_by_hash() { assert_eq!(chunk.header.prev_block_hash.as_ref().len(), 32); assert_eq!(chunk.header.prev_state_root.as_ref().len(), 32); assert_eq!(chunk.header.rent_paid, 0); - assert_eq!(chunk.header.shard_id, 0); + assert_eq!(chunk.header.shard_id, new_shard_id_tmp(0)); assert!(if let Signature::ED25519(_) = chunk.header.signature { true } else { false }); assert_eq!(chunk.header.tx_root.as_ref(), &[0; 32]); assert_eq!(chunk.header.validator_proposals, vec![]); @@ -118,7 +121,8 @@ fn test_chunk_by_hash() { #[test] fn test_chunk_invalid_shard_id() { test_with_client!(test_utils::NodeType::NonValidator, client, async move { - let chunk = client.chunk(ChunkId::BlockShardId(BlockId::Height(0), 100)).await; + let chunk = + client.chunk(ChunkId::BlockShardId(BlockId::Height(0), new_shard_id_tmp(100))).await; match chunk { Ok(_) => panic!("should result in an error"), Err(e) => { @@ -649,7 +653,7 @@ fn test_get_chunk_with_object_in_params() { assert_eq!(chunk.header.prev_block_hash.as_ref().len(), 32); assert_eq!(chunk.header.prev_state_root.as_ref().len(), 32); assert_eq!(chunk.header.rent_paid, 0); - assert_eq!(chunk.header.shard_id, 0); + assert_eq!(chunk.header.shard_id, new_shard_id_tmp(0)); assert!(if let Signature::ED25519(_) = chunk.header.signature { true } else { false }); assert_eq!(chunk.header.tx_root.as_ref(), &[0; 32]); assert_eq!(chunk.header.validator_proposals, vec![]); diff --git a/chain/jsonrpc/src/api/chunks.rs b/chain/jsonrpc/src/api/chunks.rs index badf0bf4654..b4034c01ba1 100644 --- a/chain/jsonrpc/src/api/chunks.rs +++ b/chain/jsonrpc/src/api/chunks.rs @@ -57,7 +57,9 @@ impl RpcFrom for RpcChunkError { match error { GetChunkError::IOError { error_message } => Self::InternalError { error_message }, GetChunkError::UnknownBlock { error_message } => Self::UnknownBlock { error_message }, - GetChunkError::InvalidShardId { shard_id } => Self::InvalidShardId { shard_id }, + GetChunkError::InvalidShardId { shard_id } => { + Self::InvalidShardId { shard_id: shard_id.into() } + } GetChunkError::UnknownChunk { chunk_hash } => Self::UnknownChunk { chunk_hash }, GetChunkError::Unreachable { ref error_message } => { tracing::warn!(target: "jsonrpc", "Unreachable error occurred: {}", error_message); diff --git a/chain/network/src/network_protocol/proto_conv/handshake.rs b/chain/network/src/network_protocol/proto_conv/handshake.rs index 11003597f14..7dc67077de8 100644 --- a/chain/network/src/network_protocol/proto_conv/handshake.rs +++ b/chain/network/src/network_protocol/proto_conv/handshake.rs @@ -42,7 +42,7 @@ impl From<&PeerChainInfoV2> for proto::PeerChainInfo { Self { genesis_id: MF::some((&x.genesis_id).into()), height: x.height, - tracked_shards: x.tracked_shards.clone(), + tracked_shards: x.tracked_shards.clone().into_iter().map(Into::into).collect(), archival: x.archival, ..Self::default() } @@ -55,7 +55,7 @@ impl TryFrom<&proto::PeerChainInfo> for PeerChainInfoV2 { Ok(Self { genesis_id: try_from_required(&p.genesis_id).map_err(Self::Error::GenesisId)?, height: p.height, - tracked_shards: p.tracked_shards.clone(), + tracked_shards: p.tracked_shards.clone().into_iter().map(Into::into).collect(), archival: p.archival, }) } diff --git a/chain/network/src/network_protocol/proto_conv/peer_message.rs b/chain/network/src/network_protocol/proto_conv/peer_message.rs index b73a66d7966..78e732c4a63 100644 --- a/chain/network/src/network_protocol/proto_conv/peer_message.rs +++ b/chain/network/src/network_protocol/proto_conv/peer_message.rs @@ -179,7 +179,7 @@ impl From<&SnapshotHostInfo> for proto::SnapshotHostInfo { peer_id: MF::some((&x.peer_id).into()), sync_hash: MF::some((&x.sync_hash).into()), epoch_height: x.epoch_height, - shards: x.shards.clone(), + shards: x.shards.clone().into_iter().map(Into::into).collect(), signature: MF::some((&x.signature).into()), ..Default::default() } @@ -193,7 +193,7 @@ impl TryFrom<&proto::SnapshotHostInfo> for SnapshotHostInfo { peer_id: try_from_required(&x.peer_id).map_err(Self::Error::PeerId)?, sync_hash: try_from_required(&x.sync_hash).map_err(Self::Error::SyncHash)?, epoch_height: x.epoch_height, - shards: x.shards.clone(), + shards: x.shards.clone().into_iter().map(Into::into).collect(), signature: try_from_required(&x.signature).map_err(Self::Error::Signature)?, }) } @@ -313,14 +313,14 @@ impl From<&PeerMessage> for proto::PeerMessage { PeerMessage::SyncSnapshotHosts(ssh) => ProtoMT::SyncSnapshotHosts(ssh.into()), PeerMessage::StateRequestHeader(shard_id, sync_hash) => { ProtoMT::StateRequestHeader(proto::StateRequestHeader { - shard_id: *shard_id, + shard_id: (*shard_id).into(), sync_hash: MF::some(sync_hash.into()), ..Default::default() }) } PeerMessage::StateRequestPart(shard_id, sync_hash, part_id) => { ProtoMT::StateRequestPart(proto::StateRequestPart { - shard_id: *shard_id, + shard_id: (*shard_id).into(), sync_hash: MF::some(sync_hash.into()), part_id: *part_id, ..Default::default() @@ -477,11 +477,11 @@ impl TryFrom<&proto::PeerMessage> for PeerMessage { Challenge::try_from_slice(&c.borsh).map_err(Self::Error::Challenge)?, ), ProtoMT::StateRequestHeader(srh) => PeerMessage::StateRequestHeader( - srh.shard_id, + srh.shard_id.into(), try_from_required(&srh.sync_hash).map_err(Self::Error::BlockRequest)?, ), ProtoMT::StateRequestPart(srp) => PeerMessage::StateRequestPart( - srp.shard_id, + srp.shard_id.into(), try_from_required(&srp.sync_hash).map_err(Self::Error::BlockRequest)?, srp.part_id, ), diff --git a/chain/network/src/network_protocol/testonly.rs b/chain/network/src/network_protocol/testonly.rs index 7a2762c61b6..6a34083e3ef 100644 --- a/chain/network/src/network_protocol/testonly.rs +++ b/chain/network/src/network_protocol/testonly.rs @@ -18,7 +18,7 @@ use near_primitives::sharding::{ ChunkHash, EncodedShardChunkBody, PartialEncodedChunkPart, ShardChunk, }; use near_primitives::transaction::SignedTransaction; -use near_primitives::types::{AccountId, BlockHeight, EpochId, StateRoot}; +use near_primitives::types::{new_shard_id_tmp, AccountId, BlockHeight, EpochId, StateRoot}; use near_primitives::validator_signer::{InMemoryValidatorSigner, ValidatorSigner}; use near_primitives::version; use rand::distributions::Standard; @@ -211,7 +211,7 @@ impl ChunkSet { Self { chunks: HashMap::default() } } pub fn make(&mut self) -> Vec { - let shard_ids: Vec<_> = (0..4).collect(); + let shard_ids: Vec<_> = (0..4).into_iter().map(new_shard_id_tmp).collect(); // TODO: these are always genesis chunks. // Consider making this more realistic. let chunks = genesis_chunks( diff --git a/chain/network/src/peer_manager/peer_manager_actor.rs b/chain/network/src/peer_manager/peer_manager_actor.rs index 9bfbf775f1c..b4739b85fa1 100644 --- a/chain/network/src/peer_manager/peer_manager_actor.rs +++ b/chain/network/src/peer_manager/peer_manager_actor.rs @@ -1373,7 +1373,7 @@ impl actix::Handler for PeerManagerActor { peer_id: h.peer_id.clone(), sync_hash: h.sync_hash, epoch_height: h.epoch_height, - shards: h.shards.clone(), + shards: h.shards.clone().into_iter().map(Into::into).collect(), }) .collect::>(), }), diff --git a/chain/network/src/peer_manager/tests/snapshot_hosts.rs b/chain/network/src/peer_manager/tests/snapshot_hosts.rs index 1ea2afbfcba..09c56a18816 100644 --- a/chain/network/src/peer_manager/tests/snapshot_hosts.rs +++ b/chain/network/src/peer_manager/tests/snapshot_hosts.rs @@ -10,12 +10,14 @@ use crate::types::NetworkRequests; use crate::types::PeerManagerMessageRequest; use crate::types::PeerMessage; use crate::{network_protocol::testonly as data, peer::testonly::PeerHandle}; +use itertools::Itertools; use near_async::time; use near_crypto::SecretKey; use near_o11y::testonly::init_test_logger; use near_o11y::WithSpanContextExt; use near_primitives::hash::CryptoHash; use near_primitives::network::PeerId; +use near_primitives::types::new_shard_id_tmp; use near_primitives::types::EpochHeight; use near_primitives::types::ShardId; use peer_manager::testonly::FDS_PER_PEER; @@ -32,10 +34,10 @@ fn make_snapshot_host_info( rng: &mut impl Rng, ) -> Arc { let epoch_height: EpochHeight = rng.gen::(); - let max_shard_id: ShardId = 32; + let max_shard_id = 32; let shards_num: usize = rng.gen_range(1..16); - let mut shards: Vec = (0..max_shard_id).choose_multiple(rng, shards_num); - shards.sort(); + let shards = (0..max_shard_id).choose_multiple(rng, shards_num); + let shards = shards.into_iter().sorted().map(new_shard_id_tmp).collect(); let sync_hash = CryptoHash::hash_borsh(epoch_height); Arc::new(SnapshotHostInfo::new(peer_id.clone(), sync_hash, epoch_height, shards, secret_key)) } @@ -251,7 +253,7 @@ async fn too_many_shards_not_broadcast() { tracing::info!(target:"test", "Send an invalid SyncSnapshotHosts message from peer1. One of the host infos has more shard ids than allowed."); let too_many_shards: Vec = - (0..(MAX_SHARDS_PER_SNAPSHOT_HOST_INFO as u64 + 1)).collect(); + (0..(MAX_SHARDS_PER_SNAPSHOT_HOST_INFO as u64 + 1)).map(Into::into).collect(); let invalid_info = Arc::new(SnapshotHostInfo::new( peer1_config.node_id(), CryptoHash::hash_borsh(rng.gen::()), @@ -369,11 +371,12 @@ async fn large_shard_id_in_cache() { let peer1 = pm.start_inbound(chain.clone(), peer1_config.clone()).await.handshake(clock).await; tracing::info!(target:"test", "Send a SnapshotHostInfo message with very large shard ids."); + let max_shard_id: ShardId = ShardId::MAX; let big_shard_info = Arc::new(SnapshotHostInfo::new( peer1_config.node_id(), CryptoHash::hash_borsh(1234_u64), 1234, - vec![0, 1232232, ShardId::MAX - 1, ShardId::MAX], + vec![0, 1232232, max_shard_id - 1, max_shard_id].into_iter().map(Into::into).collect(), &peer1_config.node_key, )); @@ -419,7 +422,7 @@ async fn too_many_shards_truncate() { tracing::info!(target:"test", "Ask peer manager to send out an invalid SyncSnapshotHosts message. The info has more shard ids than allowed."); // Create a list of shards with twice as many shard ids as is allowed let too_many_shards: Vec = - (0..(2 * MAX_SHARDS_PER_SNAPSHOT_HOST_INFO as u64)).collect(); + (0..(2 * MAX_SHARDS_PER_SNAPSHOT_HOST_INFO as u64)).map(Into::into).collect(); let sync_hash = CryptoHash::hash_borsh(rng.gen::()); let epoch_height: EpochHeight = rng.gen(); @@ -442,9 +445,9 @@ async fn too_many_shards_truncate() { // The list of shards should contain MAX_SHARDS_PER_SNAPSHOT_HOST_INFO randomly sampled, unique shard ids taken from too_many_shards assert_eq!(info.shards.len(), MAX_SHARDS_PER_SNAPSHOT_HOST_INFO); - for shard_id in &info.shards { + for &shard_id in &info.shards { // Shard ids are taken from the original vector - assert!(*shard_id < 2 * MAX_SHARDS_PER_SNAPSHOT_HOST_INFO as u64); + assert!(shard_id < 2 * MAX_SHARDS_PER_SNAPSHOT_HOST_INFO as u64); } // The shard_ids are sorted and unique (no two elements are equal, hence the < condition instead of <=) assert!(info.shards.windows(2).all(|twoelems| twoelems[0] < twoelems[1])); diff --git a/chain/network/src/raw/tests.rs b/chain/network/src/raw/tests.rs index f7601b06e5a..49b5adb00df 100644 --- a/chain/network/src/raw/tests.rs +++ b/chain/network/src/raw/tests.rs @@ -8,6 +8,7 @@ use near_crypto::{KeyType, SecretKey}; use near_o11y::testonly::init_test_logger; use near_primitives::hash::CryptoHash; use near_primitives::network::PeerId; +use near_primitives::types::new_shard_id_tmp; use std::sync::Arc; #[tokio::test] @@ -38,7 +39,7 @@ async fn test_raw_conn_pings() { &genesis_id.chain_id, genesis_id.hash, 0, - vec![0], + vec![new_shard_id_tmp(0)], time::Duration::SECOND, ) .await @@ -99,7 +100,7 @@ async fn test_raw_conn_state_parts() { &genesis_id.chain_id, genesis_id.hash, 0, - vec![0], + vec![new_shard_id_tmp(0)], time::Duration::SECOND, ) .await @@ -110,9 +111,13 @@ async fn test_raw_conn_state_parts() { // But the fake node simply ignores the block hash. let block_hash = CryptoHash::new(); for part_id in 0..num_parts { - conn.send_message(raw::DirectMessage::StateRequestPart(0, block_hash, part_id)) - .await - .unwrap(); + conn.send_message(raw::DirectMessage::StateRequestPart( + new_shard_id_tmp(0), + block_hash, + part_id, + )) + .await + .unwrap(); } let mut part_id_received = -1i64; @@ -174,7 +179,7 @@ async fn test_listener() { &genesis_id.chain_id, genesis_id.hash, 0, - vec![0], + vec![new_shard_id_tmp(0)], false, time::Duration::SECOND, ) diff --git a/chain/network/src/snapshot_hosts/tests.rs b/chain/network/src/snapshot_hosts/tests.rs index 79b010938a5..0ab4fda763d 100644 --- a/chain/network/src/snapshot_hosts/tests.rs +++ b/chain/network/src/snapshot_hosts/tests.rs @@ -5,12 +5,13 @@ use crate::snapshot_hosts::{priority_score, Config, SnapshotHostInfoError, Snaps use crate::testonly::assert_is_superset; use crate::testonly::{make_rng, AsSet as _}; use crate::types::SnapshotHostInfo; +use itertools::Itertools; use near_crypto::SecretKey; use near_o11y::testonly::init_test_logger; use near_primitives::hash::CryptoHash; use near_primitives::network::PeerId; -use near_primitives::types::EpochHeight; use near_primitives::types::ShardId; +use near_primitives::types::{new_shard_id_tmp, EpochHeight}; use rand::Rng; use std::collections::HashSet; use std::sync::Arc; @@ -52,17 +53,19 @@ async fn happy_path() { let cache = SnapshotHostsCache::new(config); assert_eq!(cache.get_hosts().len(), 0); // initially empty + let sid_vec = |v: &[u64]| v.iter().cloned().map(Into::into).collect_vec(); + // initial insert - let info0 = Arc::new(make_snapshot_host_info(&peer0, 123, vec![0, 1, 2, 3], &key0)); - let info1 = Arc::new(make_snapshot_host_info(&peer1, 123, vec![2], &key1)); + let info0 = Arc::new(make_snapshot_host_info(&peer0, 123, sid_vec(&[0, 1, 2, 3]), &key0)); + let info1 = Arc::new(make_snapshot_host_info(&peer1, 123, sid_vec(&[2]), &key1)); let res = cache.insert(vec![info0.clone(), info1.clone()]).await; assert_eq!([&info0, &info1].as_set(), unwrap(&res).as_set()); assert_eq!([&info0, &info1].as_set(), cache.get_hosts().iter().collect::>()); // second insert with various types of updates - let info0new = Arc::new(make_snapshot_host_info(&peer0, 124, vec![1, 3], &key0)); - let info1old = Arc::new(make_snapshot_host_info(&peer1, 122, vec![0, 1, 2, 3], &key1)); - let info2 = Arc::new(make_snapshot_host_info(&peer2, 123, vec![2], &key2)); + let info0new = Arc::new(make_snapshot_host_info(&peer0, 124, sid_vec(&[1, 3]), &key0)); + let info1old = Arc::new(make_snapshot_host_info(&peer1, 122, sid_vec(&[0, 1, 2, 3]), &key1)); + let info2 = Arc::new(make_snapshot_host_info(&peer2, 123, sid_vec(&[2]), &key2)); let res = cache.insert(vec![info0new.clone(), info1old.clone(), info2.clone()]).await; assert_eq!([&info0new, &info2].as_set(), unwrap(&res).as_set()); assert_eq!( @@ -86,8 +89,11 @@ async fn invalid_signature() { let config = Config { snapshot_hosts_cache_size: 100, part_selection_cache_batch_size: 1 }; let cache = SnapshotHostsCache::new(config); - let info0_invalid_sig = Arc::new(make_snapshot_host_info(&peer0, 1, vec![0, 1, 2, 3], &key1)); - let info1 = Arc::new(make_snapshot_host_info(&peer1, 1, vec![0, 1, 2, 3], &key1)); + let sid_vec = |v: &[u64]| v.iter().cloned().map(Into::into).collect_vec(); + + let shards = sid_vec(&[0, 1, 2, 3]); + let info0_invalid_sig = Arc::new(make_snapshot_host_info(&peer0, 1, shards.clone(), &key1)); + let info1 = Arc::new(make_snapshot_host_info(&peer1, 1, shards, &key1)); let res = cache.insert(vec![info0_invalid_sig.clone(), info1.clone()]).await; // invalid signature => InvalidSignature assert_eq!( @@ -119,12 +125,14 @@ async fn too_many_shards() { let config = Config { snapshot_hosts_cache_size: 100, part_selection_cache_batch_size: 1 }; let cache = SnapshotHostsCache::new(config); + let sid_vec = |v: &[u64]| v.iter().cloned().map(Into::into).collect_vec(); + // info0 is valid - let info0 = Arc::new(make_snapshot_host_info(&peer0, 1, vec![0, 1, 2, 3], &key0)); + let info0 = Arc::new(make_snapshot_host_info(&peer0, 1, sid_vec(&[0, 1, 2, 3]), &key0)); // info1 is invalid - it has more shard ids than MAX_SHARDS_PER_SNAPSHOT_HOST_INFO let too_many_shards: Vec = - (0..(MAX_SHARDS_PER_SNAPSHOT_HOST_INFO as u64 + 1)).collect(); + (0..(MAX_SHARDS_PER_SNAPSHOT_HOST_INFO as u64 + 1)).into_iter().map(Into::into).collect(); let info1 = Arc::new(make_snapshot_host_info(&peer1, 1, too_many_shards, &key1)); // info1.verify() should fail @@ -155,8 +163,10 @@ async fn duplicate_peer_id() { let config = Config { snapshot_hosts_cache_size: 100, part_selection_cache_batch_size: 1 }; let cache = SnapshotHostsCache::new(config); - let info00 = Arc::new(make_snapshot_host_info(&peer0, 1, vec![0, 1, 2, 3], &key0)); - let info01 = Arc::new(make_snapshot_host_info(&peer0, 2, vec![0, 3], &key0)); + let sid_vec = |v: &[u64]| v.iter().cloned().map(Into::into).collect_vec(); + + let info00 = Arc::new(make_snapshot_host_info(&peer0, 1, sid_vec(&[0, 1, 2, 3]), &key0)); + let info01 = Arc::new(make_snapshot_host_info(&peer0, 2, sid_vec(&[0, 3]), &key0)); let res = cache.insert(vec![info00.clone(), info01.clone()]).await; // duplicate peer ids => DuplicatePeerId assert_eq!(Some(SnapshotHostInfoError::DuplicatePeerId), res.1); @@ -182,19 +192,21 @@ async fn test_lru_eviction() { let config = Config { snapshot_hosts_cache_size: 2, part_selection_cache_batch_size: 1 }; let cache = SnapshotHostsCache::new(config); + let sid_vec = |v: &[u64]| v.iter().cloned().map(Into::into).collect_vec(); + // initial inserts to capacity - let info0 = Arc::new(make_snapshot_host_info(&peer0, 123, vec![0, 1, 2, 3], &key0)); + let info0 = Arc::new(make_snapshot_host_info(&peer0, 123, sid_vec(&[0, 1, 2, 3]), &key0)); let res = cache.insert(vec![info0.clone()]).await; assert_eq!([&info0].as_set(), unwrap(&res).as_set()); assert_eq!([&info0].as_set(), cache.get_hosts().iter().collect::>()); - let info1 = Arc::new(make_snapshot_host_info(&peer1, 123, vec![2], &key1)); + let info1 = Arc::new(make_snapshot_host_info(&peer1, 123, sid_vec(&[2]), &key1)); let res = cache.insert(vec![info1.clone()]).await; assert_eq!([&info1].as_set(), unwrap(&res).as_set()); assert_eq!([&info0, &info1].as_set(), cache.get_hosts().iter().collect::>()); // insert past capacity - let info2 = Arc::new(make_snapshot_host_info(&peer2, 123, vec![1, 3], &key2)); + let info2 = Arc::new(make_snapshot_host_info(&peer2, 123, sid_vec(&[1, 3]), &key2)); let res = cache.insert(vec![info2.clone()]).await; // check that the new data is accepted assert_eq!([&info2].as_set(), unwrap(&res).as_set()); @@ -318,7 +330,7 @@ async fn run_select_peer_test( assert!(err.is_none()); } SelectPeerAction::CallSelect(wanted) => { - let peer = cache.select_host_for_part(sync_hash, 0, part_id); + let peer = cache.select_host_for_part(sync_hash, new_shard_id_tmp(0), part_id); let wanted = match wanted { Some(idx) => Some(&peers[*idx].peer_id), None => None, @@ -326,9 +338,10 @@ async fn run_select_peer_test( assert!(peer.as_ref() == wanted, "got: {:?} want: {:?}", &peer, &wanted); } SelectPeerAction::PartReceived => { - assert!(cache.has_selector(0, part_id)); - cache.part_received(0, part_id); - assert!(!cache.has_selector(0, part_id)); + let shard_id = new_shard_id_tmp(0); + assert!(cache.has_selector(shard_id, part_id)); + cache.part_received(shard_id, part_id); + assert!(!cache.has_selector(shard_id, part_id)); } } } @@ -342,11 +355,15 @@ async fn test_select_peer() { let part_id = 0; let num_peers = SELECT_PEER_CASES.iter().map(|t| t.num_peers).max().unwrap(); let mut peers = Vec::with_capacity(num_peers); + + let sid_vec = |v: &[u64]| v.iter().cloned().map(Into::into).collect_vec(); + for _ in 0..num_peers { let key = data::make_secret_key(&mut rng); let peer_id = PeerId::new(key.public_key()); - let score = priority_score(&peer_id, 0u64, part_id); - let info = Arc::new(SnapshotHostInfo::new(peer_id, sync_hash, 123, vec![0, 1, 2, 3], &key)); + let score = priority_score(&peer_id, new_shard_id_tmp(0), part_id); + let info = + Arc::new(SnapshotHostInfo::new(peer_id, sync_hash, 123, sid_vec(&[0, 1, 2, 3]), &key)); peers.push((info, score)); } peers.sort_by(|(_linfo, lscore), (_rinfo, rscore)| { diff --git a/core/primitives-core/src/types.rs b/core/primitives-core/src/types.rs index 7cc5e279fff..04267a9c491 100644 --- a/core/primitives-core/src/types.rs +++ b/core/primitives-core/src/types.rs @@ -18,8 +18,6 @@ pub type Nonce = u64; pub type BlockHeight = u64; /// Height of the epoch. pub type EpochHeight = u64; -/// Shard index, from 0 to NUM_SHARDS - 1. -pub type ShardId = u64; /// Balance is type for storing amounts of tokens. pub type Balance = u128; /// Gas is a type for storing amount of gas. @@ -45,3 +43,159 @@ pub type ReceiptIndex = usize; pub type PromiseId = Vec; pub type ProtocolVersion = u32; + +/// The shard identifier. The ShardId is currently being migrated to a newtype - +/// please see the new ShardId definition below. +pub type ShardId = u64; + +/// The ShardIndex is the index of the shard in an array of shard data. +/// Historically the ShardId was always in the range 0..NUM_SHARDS and was used +/// as the shard index. This is no longer the case, and the ShardIndex should be +/// used instead. +pub type ShardIndex = usize; + +// TODO(wacban) This is a temporary solution to aid the transition to having +// ShardId as a newtype. It should be replaced / removed / inlined once the +// transition is complete. +pub const fn new_shard_id_tmp(id: u64) -> ShardId { + id +} + +// TODO(wacban) This is a temporary solution to aid the transition to having +// ShardId as a newtype. It should be replaced / removed / inlined once the +// transition is complete. +pub fn new_shard_id_vec_tmp(vec: &[u64]) -> Vec { + vec.iter().copied().map(new_shard_id_tmp).collect() +} + +// TODO(wacban) This is a temporary solution to aid the transition to having +// ShardId as a newtype. It should be replaced / removed / inlined once the +// transition is complete. +pub const fn shard_id_as_u32(id: ShardId) -> u32 { + id as u32 +} + +// TODO(wacban) Complete the transition to ShardId as a newtype. +// /// The shard identifier. It may be a arbitrary number - it does not need to be +// /// a number in the range 0..NUM_SHARDS. The shard ids do not need to be +// /// sequential or contiguous. +// /// +// /// The shard id is wrapped in a newtype to prevent the old pattern of using +// /// indices in range 0..NUM_SHARDS and casting to ShardId. Once the transition +// /// if fully complete it potentially may be simplified to a regular type alias. +// #[derive( +// arbitrary::Arbitrary, +// borsh::BorshSerialize, +// borsh::BorshDeserialize, +// serde::Serialize, +// serde::Deserialize, +// Hash, +// Clone, +// Copy, +// Debug, +// PartialEq, +// Eq, +// PartialOrd, +// Ord, +// )] +// pub struct ShardId(u64); + +// impl ShardId { +// /// Create a new shard id. Please note that this function should not be used +// /// to convert a shard index (a number in 0..num_shards range) to ShardId. +// /// Instead the ShardId should be obtained from the shard_layout. +// /// +// /// ``` +// /// // BAD USAGE: +// /// for shard_index in 1..num_shards { +// /// let shard_id = ShardId::new(shard_index); // Incorrect!!! +// /// } +// /// ``` +// /// ``` +// /// // GOOD USAGE 1: +// /// for shard_index in 1..num_shards { +// /// let shard_id = shard_layout.get_shard_id(shard_index); +// /// } +// /// // GOOD USAGE 2: +// /// for shard_id in shard_layout.shard_ids() { +// /// let shard_id = shard_layout.get_shard_id(shard_index); +// /// } +// /// ``` +// pub const fn new(id: u64) -> Self { +// Self(id) +// } + +// /// Get the numerical value of the shard id. This should not be used as an +// /// index into an array, as the shard id may be any arbitrary number. +// pub fn get(self) -> u64 { +// self.0 +// } + +// pub fn to_le_bytes(self) -> [u8; 8] { +// self.0.to_le_bytes() +// } + +// pub fn from_le_bytes(bytes: [u8; 8]) -> Self { +// Self(u64::from_le_bytes(bytes)) +// } + +// // TODO This is not great, in ShardUId shard_id is u32. +// // Currently used for some metrics so kinda ok. +// pub fn max() -> Self { +// Self(u64::MAX) +// } +// } + +// impl Display for ShardId { +// fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { +// write!(f, "{}", self.0) +// } +// } + +// impl From for ShardId { +// fn from(id: u64) -> Self { +// Self(id) +// } +// } + +// impl Into for ShardId { +// fn into(self) -> u64 { +// self.0 +// } +// } + +// impl From for ShardId { +// fn from(id: u32) -> Self { +// Self(id as u64) +// } +// } + +// impl Into for ShardId { +// fn into(self) -> u32 { +// self.0 as u32 +// } +// } + +// impl From for ShardId { +// fn from(id: i32) -> Self { +// Self(id as u64) +// } +// } + +// impl From for ShardId { +// fn from(id: usize) -> Self { +// Self(id as u64) +// } +// } + +// impl From for ShardId { +// fn from(id: u16) -> Self { +// Self(id as u64) +// } +// } + +// impl Into for ShardId { +// fn into(self) -> u16 { +// self.0 as u16 +// } +// } diff --git a/core/primitives/src/block.rs b/core/primitives/src/block.rs index 826801ab67e..8816c28c155 100644 --- a/core/primitives/src/block.rs +++ b/core/primitives/src/block.rs @@ -14,6 +14,7 @@ use crate::sharding::{ChunkHashHeight, ShardChunkHeader, ShardChunkHeaderV1}; use crate::types::{Balance, BlockHeight, EpochId, Gas}; use crate::version::{ProtocolVersion, SHARD_CHUNK_HEADER_UPGRADE_VERSION}; use borsh::{BorshDeserialize, BorshSerialize}; +use near_primitives_core::types::{ShardId, ShardIndex}; use near_schema_checker_lib::ProtocolSchema; use near_time::Utc; use primitive_types::U256; @@ -88,6 +89,7 @@ pub enum Block { #[cfg(feature = "solomon")] type ShardChunkReedSolomon = reed_solomon_erasure::galois_8::ReedSolomon; +/// The shard_ids, state_roots and congestion_infos must be in the same order. #[cfg(feature = "solomon")] pub fn genesis_chunks( state_roots: Vec, @@ -110,10 +112,9 @@ pub fn genesis_chunks( let num = shard_ids.len(); assert_eq!(state_roots.len(), num); - for shard_id in 0..num { - let state_root = state_roots[shard_id]; - let congestion_info = congestion_infos[shard_id]; - let shard_id = shard_id as crate::types::ShardId; + for (shard_index, &shard_id) in shard_ids.iter().enumerate() { + let state_root = state_roots[shard_index]; + let congestion_info = congestion_infos[shard_index]; let encoded_chunk = genesis_chunk( &rs, @@ -140,7 +141,7 @@ fn genesis_chunk( genesis_protocol_version: u32, genesis_height: u64, initial_gas_limit: u64, - shard_id: u64, + shard_id: ShardId, state_root: CryptoHash, congestion_info: Option, ) -> crate::sharding::EncodedShardChunk { @@ -781,7 +782,7 @@ impl<'a> ExactSizeIterator for VersionedChunksIter<'a> { } } -impl<'a> Index for ChunksCollection<'a> { +impl<'a> Index for ChunksCollection<'a> { type Output = ShardChunkHeader; /// Deprecated. Please use get instead, it's safer. @@ -808,7 +809,7 @@ impl<'a> ChunksCollection<'a> { } } - pub fn get(&self, index: usize) -> Option<&ShardChunkHeader> { + pub fn get(&self, index: ShardIndex) -> Option<&ShardChunkHeader> { match self { ChunksCollection::V1(chunks) => chunks.get(index), ChunksCollection::V2(chunks) => chunks.get(index), diff --git a/core/primitives/src/congestion_info.rs b/core/primitives/src/congestion_info.rs index 319774fa6f5..5488d5059b4 100644 --- a/core/primitives/src/congestion_info.rs +++ b/core/primitives/src/congestion_info.rs @@ -499,7 +499,9 @@ impl ShardAcceptsTransactions { #[cfg(test)] mod tests { + use itertools::Itertools; use near_parameters::RuntimeConfigStore; + use near_primitives_core::types::new_shard_id_tmp; use near_primitives_core::version::{ProtocolFeature, PROTOCOL_VERSION}; use super::*; @@ -576,7 +578,12 @@ mod tests { assert_eq!(0.0, congestion_control.outgoing_congestion()); assert_eq!(0.0, congestion_control.congestion_level()); - assert!(config.max_outgoing_gas.abs_diff(congestion_control.outgoing_gas_limit(0)) <= 1); + assert!( + config + .max_outgoing_gas + .abs_diff(congestion_control.outgoing_gas_limit(new_shard_id_tmp(0))) + <= 1 + ); assert!(config.max_tx_gas.abs_diff(congestion_control.process_tx_limit()) <= 1); assert!(congestion_control.shard_accepts_transactions().is_yes()); @@ -599,7 +606,7 @@ mod tests { let control = CongestionControl::new(config, info, 0); assert_eq!(1.0, control.congestion_level()); // fully congested, no more forwarding allowed - assert_eq!(0, control.outgoing_gas_limit(1)); + assert_eq!(0, control.outgoing_gas_limit(new_shard_id_tmp(1))); assert!(control.shard_accepts_transactions().is_no()); // processing to other shards is not restricted by memory congestion assert_eq!(config.max_tx_gas, control.process_tx_limit()); @@ -613,7 +620,7 @@ mod tests { assert_eq!( (0.5 * config.min_outgoing_gas as f64 + 0.5 * config.max_outgoing_gas as f64) as u64, - control.outgoing_gas_limit(1) + control.outgoing_gas_limit(new_shard_id_tmp(1)) ); // at 50%, still no new transactions are allowed assert!(control.shard_accepts_transactions().is_no()); @@ -627,7 +634,7 @@ mod tests { assert_eq!( (0.125 * config.min_outgoing_gas as f64 + 0.875 * config.max_outgoing_gas as f64) as u64, - control.outgoing_gas_limit(1) + control.outgoing_gas_limit(new_shard_id_tmp(1)) ); // at 12.5%, new transactions are allowed (threshold is 0.25) assert!(control.shard_accepts_transactions().is_yes()); @@ -651,7 +658,7 @@ mod tests { let control = CongestionControl::new(config, info, 0); assert_eq!(1.0, control.congestion_level()); // fully congested, no more forwarding allowed - assert_eq!(0, control.outgoing_gas_limit(1)); + assert_eq!(0, control.outgoing_gas_limit(new_shard_id_tmp(1))); assert!(control.shard_accepts_transactions().is_no()); // processing to other shards is restricted by own incoming congestion assert_eq!(config.min_tx_gas, control.process_tx_limit()); @@ -665,7 +672,7 @@ mod tests { assert_eq!( (0.5 * config.min_outgoing_gas as f64 + 0.5 * config.max_outgoing_gas as f64) as u64, - control.outgoing_gas_limit(1) + control.outgoing_gas_limit(new_shard_id_tmp(1)) ); // at 50%, still no new transactions to us are allowed assert!(control.shard_accepts_transactions().is_no()); @@ -684,7 +691,7 @@ mod tests { assert_eq!( (0.125 * config.min_outgoing_gas as f64 + 0.875 * config.max_outgoing_gas as f64) as u64, - control.outgoing_gas_limit(1) + control.outgoing_gas_limit(new_shard_id_tmp(1)) ); // at 12.5%, new transactions are allowed (threshold is 0.25) assert!(control.shard_accepts_transactions().is_yes()); @@ -711,7 +718,7 @@ mod tests { let control = CongestionControl::new(config, info, 0); assert_eq!(1.0, control.congestion_level()); // fully congested, no more forwarding allowed - assert_eq!(0, control.outgoing_gas_limit(1)); + assert_eq!(0, control.outgoing_gas_limit(new_shard_id_tmp(1))); assert!(control.shard_accepts_transactions().is_no()); // processing to other shards is not restricted by own outgoing congestion assert_eq!(config.max_tx_gas, control.process_tx_limit()); @@ -722,7 +729,7 @@ mod tests { assert_eq!(0.5, control.congestion_level()); assert_eq!( (0.5 * config.min_outgoing_gas as f64 + 0.5 * config.max_outgoing_gas as f64) as u64, - control.outgoing_gas_limit(1) + control.outgoing_gas_limit(new_shard_id_tmp(1)) ); // at 50%, still no new transactions to us are allowed assert!(control.shard_accepts_transactions().is_no()); @@ -734,7 +741,7 @@ mod tests { assert_eq!( (0.125 * config.min_outgoing_gas as f64 + 0.875 * config.max_outgoing_gas as f64) as u64, - control.outgoing_gas_limit(1) + control.outgoing_gas_limit(new_shard_id_tmp(1)) ); // at 12.5%, new transactions are allowed (threshold is 0.25) assert!(control.shard_accepts_transactions().is_yes()); @@ -802,8 +809,8 @@ mod tests { let mut info = CongestionInfo::default(); info.add_buffered_receipt_gas(config.max_congestion_outgoing_gas / 2).unwrap(); - let shard = 2; - let all_shards = [0, 1, 2, 3, 4]; + let shard = new_shard_id_tmp(2); + let all_shards = [0, 1, 2, 3, 4].into_iter().map(new_shard_id_tmp).collect_vec(); // Test without missed chunks congestion. @@ -813,7 +820,7 @@ mod tests { let expected_outgoing_limit = 0.5 * config.min_outgoing_gas as f64 + 0.5 * config.max_outgoing_gas as f64; - for shard in all_shards { + for &shard in &all_shards { assert_eq!(control.outgoing_gas_limit(shard), expected_outgoing_limit as u64); } @@ -825,7 +832,7 @@ mod tests { let expected_outgoing_limit = mix(config.max_outgoing_gas, config.min_outgoing_gas, 0.8) as f64; - for shard in all_shards { + for &shard in &all_shards { assert_eq!(control.outgoing_gas_limit(shard), expected_outgoing_limit as u64); } diff --git a/core/primitives/src/epoch_info.rs b/core/primitives/src/epoch_info.rs index bcc14c12c49..47e6277d247 100644 --- a/core/primitives/src/epoch_info.rs +++ b/core/primitives/src/epoch_info.rs @@ -3,6 +3,7 @@ use smart_default::SmartDefault; use std::collections::{BTreeMap, HashMap}; use crate::rand::WeightedIndex; +use crate::shard_layout::ShardLayout; use crate::types::validator_stake::{ValidatorStake, ValidatorStakeIter}; use crate::types::{AccountId, ValidatorKickoutReason, ValidatorStakeV1}; use crate::validator_mandates::ValidatorMandates; @@ -601,35 +602,35 @@ impl EpochInfo { pub fn sample_chunk_producer( &self, - height: BlockHeight, + shard_layout: &ShardLayout, shard_id: ShardId, + height: BlockHeight, ) -> Option { + let shard_index = shard_layout.get_shard_index(shard_id); match &self { Self::V1(v1) => { let cp_settlement = &v1.chunk_producers_settlement; - let shard_cps = cp_settlement.get(shard_id as usize)?; + let shard_cps = cp_settlement.get(shard_index)?; shard_cps.get((height as u64 % (shard_cps.len() as u64)) as usize).copied() } Self::V2(v2) => { let cp_settlement = &v2.chunk_producers_settlement; - let shard_cps = cp_settlement.get(shard_id as usize)?; + let shard_cps = cp_settlement.get(shard_index)?; shard_cps.get((height as u64 % (shard_cps.len() as u64)) as usize).copied() } Self::V3(v3) => { let protocol_version = self.protocol_version(); let seed = Self::chunk_produce_seed(protocol_version, &v3.rng_seed, height, shard_id); - let shard_id = shard_id as usize; - let sample = v3.chunk_producers_sampler.get(shard_id)?.sample(seed); - v3.chunk_producers_settlement.get(shard_id)?.get(sample).copied() + let sample = v3.chunk_producers_sampler.get(shard_index)?.sample(seed); + v3.chunk_producers_settlement.get(shard_index)?.get(sample).copied() } Self::V4(v4) => { let protocol_version = self.protocol_version(); let seed = Self::chunk_produce_seed(protocol_version, &v4.rng_seed, height, shard_id); - let shard_id = shard_id as usize; - let sample = v4.chunk_producers_sampler.get(shard_id)?.sample(seed); - v4.chunk_producers_settlement.get(shard_id)?.get(sample).copied() + let sample = v4.chunk_producers_sampler.get(shard_index)?.sample(seed); + v4.chunk_producers_settlement.get(shard_index)?.get(sample).copied() } } } diff --git a/core/primitives/src/shard_layout.rs b/core/primitives/src/shard_layout.rs index 1a15b88161a..717b462c8db 100644 --- a/core/primitives/src/shard_layout.rs +++ b/core/primitives/src/shard_layout.rs @@ -2,9 +2,9 @@ use crate::hash::CryptoHash; use crate::types::{AccountId, NumShards}; use borsh::{BorshDeserialize, BorshSerialize}; use itertools::Itertools; -use near_primitives_core::types::ShardId; +use near_primitives_core::types::{ShardId, ShardIndex}; use near_schema_checker_lib::ProtocolSchema; -use std::collections::{BTreeMap, HashMap}; +use std::collections::BTreeMap; use std::{fmt, str}; /// This file implements two data structure `ShardLayout` and `ShardUId` @@ -88,6 +88,18 @@ type ShardsSplitMapV2 = BTreeMap>; /// A mapping from the child shard to the parent shard. type ShardsParentMapV2 = BTreeMap; +fn new_shard_ids_vec(shard_ids: Vec) -> Vec { + shard_ids.into_iter().map(Into::into).collect() +} + +fn new_shards_split_map(shards_split_map: Vec>) -> ShardsSplitMap { + shards_split_map.into_iter().map(new_shard_ids_vec).collect() +} + +fn new_shards_split_map_v2(shards_split_map: BTreeMap>) -> ShardsSplitMapV2 { + shards_split_map.into_iter().map(|(k, v)| (k.into(), new_shard_ids_vec(v))).collect() +} + #[derive(serde::Serialize, serde::Deserialize, Clone, Debug, PartialEq, Eq)] pub struct ShardLayoutV1 { /// The boundary accounts are the accounts on boundaries between shards. @@ -110,14 +122,14 @@ impl ShardLayoutV1 { // In this shard layout the accounts are divided into ranges, each range is // mapped to a shard. The shards are contiguous and start from 0. fn account_id_to_shard_id(&self, account_id: &AccountId) -> ShardId { - let mut shard_id: ShardId = 0; + let mut shard_id: u64 = 0; for boundary_account in &self.boundary_accounts { if account_id < boundary_account { break; } shard_id += 1; } - shard_id + shard_id.into() } } @@ -138,9 +150,15 @@ pub struct ShardLayoutV2 { /// /// The shard id at index i corresponds to the shard with account range: /// [boundary_accounts[i -1], boundary_accounts[i]). - /// shard_ids: Vec, + /// The mapping from shard id to shard index. + id_to_index_map: BTreeMap, + + /// The mapping from shard index to shard id. + /// TODO(wacban) this is identical to the shard_ids, remove it. + index_to_id_map: BTreeMap, + /// A mapping from the parent shard to child shards. Maps shards from the /// previous shard layout to shards that they split to in this shard layout. shards_split_map: Option, @@ -192,16 +210,17 @@ impl ShardLayout { version: ShardVersion, ) -> Self { let to_parent_shard_map = if let Some(shards_split_map) = &shards_split_map { - let mut to_parent_shard_map = HashMap::new(); + let mut to_parent_shard_map = BTreeMap::new(); let num_shards = (boundary_accounts.len() + 1) as NumShards; for (parent_shard_id, shard_ids) in shards_split_map.iter().enumerate() { + let parent_shard_id = parent_shard_id as u64; for &shard_id in shard_ids { - let prev = to_parent_shard_map.insert(shard_id, parent_shard_id as ShardId); + let prev = to_parent_shard_map.insert(shard_id, parent_shard_id); assert!(prev.is_none(), "no shard should appear in the map twice"); assert!(shard_id < num_shards, "shard id should be valid"); } } - Some((0..num_shards).map(|shard_id| to_parent_shard_map[&shard_id]).collect()) + Some((0..num_shards).map(|shard_id| to_parent_shard_map[&shard_id.into()]).collect()) } else { None }; @@ -225,10 +244,19 @@ impl ShardLayout { assert_eq!(boundary_accounts.len() + 1, shard_ids.len()); assert_eq!(boundary_accounts, boundary_accounts.iter().sorted().cloned().collect_vec()); + let mut id_to_index_map = BTreeMap::new(); + let mut index_to_id_map = BTreeMap::new(); + for (shard_index, &shard_id) in shard_ids.iter().enumerate() { + id_to_index_map.insert(shard_id, shard_index); + index_to_id_map.insert(shard_index, shard_id); + } + let Some(shards_split_map) = shards_split_map else { return Self::V2(ShardLayoutV2 { boundary_accounts, shard_ids, + id_to_index_map, + index_to_id_map, shards_split_map: None, shards_parent_map: None, version: VERSION, @@ -253,6 +281,8 @@ impl ShardLayout { Self::V2(ShardLayoutV2 { boundary_accounts, shard_ids, + id_to_index_map, + index_to_id_map, shards_split_map, shards_parent_map, version: VERSION, @@ -263,7 +293,7 @@ impl ShardLayout { pub fn v1_test() -> Self { ShardLayout::v1( vec!["abc", "foo", "test0"].into_iter().map(|s| s.parse().unwrap()).collect(), - Some(vec![vec![0, 1, 2, 3]]), + Some(new_shards_split_map(vec![vec![0, 1, 2, 3]])), 1, ) } @@ -275,7 +305,7 @@ impl ShardLayout { .into_iter() .map(|s| s.parse().unwrap()) .collect(), - Some(vec![vec![0, 1, 2, 3]]), + Some(new_shards_split_map(vec![vec![0, 1, 2, 3]])), 1, ) } @@ -287,7 +317,7 @@ impl ShardLayout { .into_iter() .map(|s| s.parse().unwrap()) .collect(), - Some(vec![vec![0], vec![1], vec![2], vec![3, 4]]), + Some(new_shards_split_map(vec![vec![0], vec![1], vec![2], vec![3, 4]])), 2, ) } @@ -305,7 +335,7 @@ impl ShardLayout { .into_iter() .map(|s| s.parse().unwrap()) .collect(), - Some(vec![vec![0], vec![1], vec![2, 3], vec![4], vec![5]]), + Some(new_shards_split_map(vec![vec![0], vec![1], vec![2, 3], vec![4], vec![5]])), 3, ) } @@ -331,6 +361,7 @@ impl ShardLayout { ]; let shard_ids = vec![0, 1, 6, 7, 3, 4, 5]; + let shard_ids = new_shard_ids_vec(shard_ids); let shards_split_map = BTreeMap::from([ (0, vec![0]), @@ -340,6 +371,7 @@ impl ShardLayout { (4, vec![4]), (5, vec![5]), ]); + let shards_split_map = new_shards_split_map_v2(shards_split_map); let shards_split_map = Some(shards_split_map); ShardLayout::v2(boundary_accounts, shard_ids, shards_split_map) @@ -361,7 +393,14 @@ impl ShardLayout { .into_iter() .map(|s| s.parse().unwrap()) .collect(), - Some(vec![vec![0], vec![1], vec![2], vec![3], vec![4, 5], vec![6]]), + Some(new_shards_split_map(vec![ + vec![0], + vec![1], + vec![2], + vec![3], + vec![4, 5], + vec![6], + ])), 4, ) } @@ -380,7 +419,10 @@ impl ShardLayout { match self { Self::V0(_) => None, Self::V1(v1) => match &v1.shards_split_map { - Some(shards_split_map) => shards_split_map.get(parent_shard_id as usize).cloned(), + Some(shards_split_map) => { + let parent_shard_index = parent_shard_id as usize; + shards_split_map.get(parent_shard_index).cloned() + } None => None, }, Self::V2(v2) => match &v2.shards_split_map { @@ -403,7 +445,10 @@ impl ShardLayout { Self::V1(v1) => match &v1.to_parent_shard_map { // we can safely unwrap here because the construction of to_parent_shard_map guarantees // that every shard has a parent shard - Some(to_parent_shard_map) => *to_parent_shard_map.get(shard_id as usize).unwrap(), + Some(to_parent_shard_map) => { + let shard_index = self.get_shard_index(shard_id); + *to_parent_shard_map.get(shard_index).unwrap() + } None => panic!("shard_layout has no parent shard"), }, Self::V2(v2) => match &v2.shards_parent_map { @@ -441,8 +486,8 @@ impl ShardLayout { pub fn shard_ids(&self) -> impl Iterator + '_ { match self { - Self::V0(_) => (0..self.num_shards()).collect_vec().into_iter(), - Self::V1(_) => (0..self.num_shards()).collect_vec().into_iter(), + Self::V0(_) => (0..self.num_shards()).map(Into::into).collect_vec().into_iter(), + Self::V1(_) => (0..self.num_shards()).map(Into::into).collect_vec().into_iter(), Self::V2(v2) => v2.shard_ids.clone().into_iter(), } } @@ -452,6 +497,26 @@ impl ShardLayout { pub fn shard_uids(&self) -> impl Iterator + '_ { self.shard_ids().map(|shard_id| ShardUId::from_shard_id_and_layout(shard_id, self)) } + + /// Returns the shard index for a given shard id. The shard index should be + /// used when indexing into an array of chunk data. + pub fn get_shard_index(&self, shard_id: ShardId) -> ShardIndex { + match self { + Self::V0(_) => shard_id as ShardIndex, + Self::V1(_) => shard_id as ShardIndex, + Self::V2(v2) => v2.id_to_index_map[&shard_id], + } + } + + /// Get the shard id for a given shard index. The shard id should be used to + /// identify the shard and starting from the ShardLayoutV2 it is unique. + pub fn get_shard_id(&self, shard_index: usize) -> ShardId { + match self { + Self::V0(_) => shard_index as ShardId, + Self::V1(_) => shard_index as ShardId, + Self::V2(v2) => v2.index_to_id_map[&shard_index], + } + } } /// Maps an account to the shard that it belongs to given a shard_layout @@ -464,7 +529,8 @@ pub fn account_id_to_shard_id(account_id: &AccountId, shard_layout: &ShardLayout ShardLayout::V0(ShardLayoutV0 { num_shards, .. }) => { let hash = CryptoHash::hash_bytes(account_id.as_bytes()); let (bytes, _) = stdx::split_array::<32, 8, 24>(hash.as_bytes()); - u64::from_le_bytes(*bytes) % num_shards + let shard_id = u64::from_le_bytes(*bytes) % num_shards; + shard_id.into() } ShardLayout::V1(v1) => v1.account_id_to_shard_id(account_id), ShardLayout::V2(v2) => v2.account_id_to_shard_id(account_id), @@ -536,7 +602,7 @@ impl ShardUId { /// Returns shard id pub fn shard_id(&self) -> ShardId { - ShardId::from(self.shard_id) + self.shard_id.into() } } @@ -679,9 +745,12 @@ impl<'de> serde::de::Visitor<'de> for ShardUIdVisitor { #[cfg(test)] mod tests { use crate::epoch_manager::{AllEpochConfig, EpochConfig, ValidatorSelectionConfig}; - use crate::shard_layout::{account_id_to_shard_id, ShardLayout, ShardLayoutV1, ShardUId}; + use crate::shard_layout::{ + account_id_to_shard_id, new_shard_ids_vec, new_shards_split_map, ShardLayout, + ShardLayoutV1, ShardUId, + }; use itertools::Itertools; - use near_primitives_core::types::ProtocolVersion; + use near_primitives_core::types::{new_shard_id_tmp, ProtocolVersion}; use near_primitives_core::types::{AccountId, ShardId}; use near_primitives_core::version::{ProtocolFeature, PROTOCOL_VERSION}; use rand::distributions::Alphanumeric; @@ -689,7 +758,7 @@ mod tests { use rand::{Rng, SeedableRng}; use std::collections::{BTreeMap, HashMap}; - use super::{ShardVersion, ShardsSplitMap}; + use super::{new_shards_split_map_v2, ShardVersion, ShardsSplitMap}; // The old ShardLayoutV1, before fixed shards were removed. tests only #[derive(serde::Serialize, serde::Deserialize, Clone, Debug, PartialEq, Eq)] @@ -750,8 +819,8 @@ mod tests { fn test_shard_layout_v0() { let num_shards = 4; let shard_layout = ShardLayout::v0(num_shards, 0); - let mut shard_id_distribution: HashMap<_, _> = - shard_layout.shard_ids().map(|shard_id| (shard_id, 0)).collect(); + let mut shard_id_distribution: HashMap = + shard_layout.shard_ids().map(|shard_id| (shard_id.into(), 0)).collect(); let mut rng = StdRng::from_seed([0; 32]); for _i in 0..1000 { let s: Vec = (&mut rng).sample_iter(&Alphanumeric).take(10).collect(); @@ -759,7 +828,7 @@ mod tests { let account_id = s.to_lowercase().parse().unwrap(); let shard_id = account_id_to_shard_id(&account_id, &shard_layout); assert!(shard_id < num_shards); - *shard_id_distribution.get_mut(&shard_id).unwrap() += 1; + *shard_id_distribution.get_mut(&shard_id.into()).unwrap() += 1; } let expected_distribution: HashMap<_, _> = [(0, 247), (1, 268), (2, 233), (3, 252)].into_iter().collect(); @@ -770,38 +839,39 @@ mod tests { fn test_shard_layout_v1() { let shard_layout = ShardLayout::v1( parse_account_ids(&["aurora", "bar", "foo", "foo.baz", "paz"]), - Some(vec![vec![0, 1, 2], vec![3, 4, 5]]), + Some(new_shards_split_map(vec![vec![0, 1, 2], vec![3, 4, 5]])), 1, ); assert_eq!( - shard_layout.get_children_shards_uids(0).unwrap(), + shard_layout.get_children_shards_uids(new_shard_id_tmp(0)).unwrap(), (0..3).map(|x| ShardUId { version: 1, shard_id: x }).collect::>() ); assert_eq!( - shard_layout.get_children_shards_uids(1).unwrap(), + shard_layout.get_children_shards_uids(new_shard_id_tmp(1)).unwrap(), (3..6).map(|x| ShardUId { version: 1, shard_id: x }).collect::>() ); for x in 0..3 { - assert_eq!(shard_layout.get_parent_shard_id(x).unwrap(), 0); - assert_eq!(shard_layout.get_parent_shard_id(x + 3).unwrap(), 1); + assert_eq!(shard_layout.get_parent_shard_id(new_shard_id_tmp(x)).unwrap(), 0); + assert_eq!(shard_layout.get_parent_shard_id(new_shard_id_tmp(x + 3)).unwrap(), 1); } - assert_eq!(account_id_to_shard_id(&"aurora".parse().unwrap(), &shard_layout), 1); - assert_eq!(account_id_to_shard_id(&"foo.aurora".parse().unwrap(), &shard_layout), 3); - assert_eq!(account_id_to_shard_id(&"bar.foo.aurora".parse().unwrap(), &shard_layout), 2); - assert_eq!(account_id_to_shard_id(&"bar".parse().unwrap(), &shard_layout), 2); - assert_eq!(account_id_to_shard_id(&"bar.bar".parse().unwrap(), &shard_layout), 2); - assert_eq!(account_id_to_shard_id(&"foo".parse().unwrap(), &shard_layout), 3); - assert_eq!(account_id_to_shard_id(&"baz.foo".parse().unwrap(), &shard_layout), 2); - assert_eq!(account_id_to_shard_id(&"foo.baz".parse().unwrap(), &shard_layout), 4); - assert_eq!(account_id_to_shard_id(&"a.foo.baz".parse().unwrap(), &shard_layout), 0); - - assert_eq!(account_id_to_shard_id(&"aaa".parse().unwrap(), &shard_layout), 0); - assert_eq!(account_id_to_shard_id(&"abc".parse().unwrap(), &shard_layout), 0); - assert_eq!(account_id_to_shard_id(&"bbb".parse().unwrap(), &shard_layout), 2); - assert_eq!(account_id_to_shard_id(&"foo.goo".parse().unwrap(), &shard_layout), 4); - assert_eq!(account_id_to_shard_id(&"goo".parse().unwrap(), &shard_layout), 4); - assert_eq!(account_id_to_shard_id(&"zoo".parse().unwrap(), &shard_layout), 5); + let aid = |s: &str| s.parse().unwrap(); + assert_eq!(account_id_to_shard_id(&aid("aurora"), &shard_layout), 1); + assert_eq!(account_id_to_shard_id(&aid("foo.aurora"), &shard_layout), 3); + assert_eq!(account_id_to_shard_id(&aid("bar.foo.aurora"), &shard_layout), 2); + assert_eq!(account_id_to_shard_id(&aid("bar"), &shard_layout), 2); + assert_eq!(account_id_to_shard_id(&aid("bar.bar"), &shard_layout), 2); + assert_eq!(account_id_to_shard_id(&aid("foo"), &shard_layout), 3); + assert_eq!(account_id_to_shard_id(&aid("baz.foo"), &shard_layout), 2); + assert_eq!(account_id_to_shard_id(&aid("foo.baz"), &shard_layout), 4); + assert_eq!(account_id_to_shard_id(&aid("a.foo.baz"), &shard_layout), 0); + + assert_eq!(account_id_to_shard_id(&aid("aaa"), &shard_layout), 0); + assert_eq!(account_id_to_shard_id(&aid("abc"), &shard_layout), 0); + assert_eq!(account_id_to_shard_id(&aid("bbb"), &shard_layout), 2); + assert_eq!(account_id_to_shard_id(&aid("foo.goo"), &shard_layout), 4); + assert_eq!(account_id_to_shard_id(&aid("goo"), &shard_layout), 4); + assert_eq!(account_id_to_shard_id(&aid("zoo"), &shard_layout), 5); } // check that after removing the fixed shards from the shard layout v1 @@ -812,8 +882,8 @@ mod tests { let old = OldShardLayoutV1 { fixed_shards: vec![], boundary_accounts: parse_account_ids(&["aaa", "bbb"]), - shards_split_map: Some(vec![vec![0, 1, 2]]), - to_parent_shard_map: Some(vec![0, 0, 0]), + shards_split_map: Some(new_shards_split_map(vec![vec![0, 1, 2]])), + to_parent_shard_map: Some(new_shard_ids_vec(vec![0, 0, 0])), version: 1, }; let json = serde_json::to_string_pretty(&old).unwrap(); @@ -847,7 +917,7 @@ mod tests { assert_eq!(account_id_to_shard_id(&"ppp".parse().unwrap(), &shard_layout), 7); // check shard ids - assert_eq!(shard_layout.shard_ids().collect_vec(), vec![3, 8, 4, 7]); + assert_eq!(shard_layout.shard_ids().collect_vec(), new_shard_ids_vec(vec![3, 8, 4, 7])); // check shard uids let version = 3; @@ -855,15 +925,24 @@ mod tests { assert_eq!(shard_layout.shard_uids().collect_vec(), vec![u(3), u(8), u(4), u(7)]); // check parent - assert_eq!(shard_layout.get_parent_shard_id(3).unwrap(), 3); - assert_eq!(shard_layout.get_parent_shard_id(8).unwrap(), 1); - assert_eq!(shard_layout.get_parent_shard_id(4).unwrap(), 4); - assert_eq!(shard_layout.get_parent_shard_id(7).unwrap(), 1); + assert_eq!(shard_layout.get_parent_shard_id(new_shard_id_tmp(3)).unwrap(), 3); + assert_eq!(shard_layout.get_parent_shard_id(new_shard_id_tmp(8)).unwrap(), 1); + assert_eq!(shard_layout.get_parent_shard_id(new_shard_id_tmp(4)).unwrap(), 4); + assert_eq!(shard_layout.get_parent_shard_id(new_shard_id_tmp(7)).unwrap(), 1); // check child - assert_eq!(shard_layout.get_children_shards_ids(1).unwrap(), vec![7, 8]); - assert_eq!(shard_layout.get_children_shards_ids(3).unwrap(), vec![3]); - assert_eq!(shard_layout.get_children_shards_ids(4).unwrap(), vec![4]); + assert_eq!( + shard_layout.get_children_shards_ids(new_shard_id_tmp(1)).unwrap(), + new_shard_ids_vec(vec![7, 8]) + ); + assert_eq!( + shard_layout.get_children_shards_ids(new_shard_id_tmp(3)).unwrap(), + new_shard_ids_vec(vec![3]) + ); + assert_eq!( + shard_layout.get_children_shards_ids(new_shard_id_tmp(4)).unwrap(), + new_shard_ids_vec(vec![4]) + ); } fn get_test_shard_layout_v2() -> ShardLayout { @@ -873,10 +952,12 @@ mod tests { let boundary_accounts = vec![b0, b1, b2]; let shard_ids = vec![3, 8, 4, 7]; + let shard_ids = new_shard_ids_vec(shard_ids); // the mapping from parent to the child // shard 1 is split into shards 7 & 8 while other shards stay the same let shards_split_map = BTreeMap::from([(1, vec![7, 8]), (3, vec![3]), (4, vec![4])]); + let shards_split_map = new_shards_split_map_v2(shards_split_map); let shards_split_map = Some(shards_split_map); ShardLayout::v2(boundary_accounts, shard_ids, shards_split_map) @@ -1020,6 +1101,24 @@ mod tests { 4, 5 ], + "id_to_index_map": { + "0": 0, + "1": 1, + "3": 4, + "4": 5, + "5": 6, + "6": 2, + "7": 3 + }, + "index_to_id_map": { + "0": 0, + "1": 1, + "2": 6, + "3": 7, + "4": 3, + "5": 4, + "6": 5 + }, "shards_split_map": { "0": [ 0 diff --git a/core/primitives/src/sharding.rs b/core/primitives/src/sharding.rs index a9a1c7fa6e9..4e5e26cd867 100644 --- a/core/primitives/src/sharding.rs +++ b/core/primitives/src/sharding.rs @@ -1213,7 +1213,7 @@ impl EncodedShardChunk { "decode_chunk", data_parts, height_included = self.cloned_header().height_included(), - shard_id = self.cloned_header().shard_id(), + shard_id = ?self.cloned_header().shard_id(), chunk_hash = ?self.chunk_hash()) .entered(); diff --git a/core/primitives/src/stateless_validation/chunk_endorsements_bitmap.rs b/core/primitives/src/stateless_validation/chunk_endorsements_bitmap.rs index 26874117ecd..ae28fa9efb5 100644 --- a/core/primitives/src/stateless_validation/chunk_endorsements_bitmap.rs +++ b/core/primitives/src/stateless_validation/chunk_endorsements_bitmap.rs @@ -1,6 +1,5 @@ use bitvec::prelude::*; use borsh::{BorshDeserialize, BorshSerialize}; -use near_primitives_core::types::ShardId; use near_schema_checker_lib::ProtocolSchema; /// Represents a collection of bitmaps, one per shard, to store whether the endorsements from the chunk validators has been received. @@ -56,21 +55,21 @@ impl ChunkEndorsementsBitmap { // Creates an endorsement bitmap for all the shards. pub fn from_endorsements(shards_to_endorsements: Vec>) -> Self { let mut bitmap = ChunkEndorsementsBitmap::new(shards_to_endorsements.len()); - for (shard_id, endorsements) in shards_to_endorsements.into_iter().enumerate() { - bitmap.add_endorsements(shard_id as ShardId, endorsements); + for (shard_index, endorsements) in shards_to_endorsements.into_iter().enumerate() { + bitmap.add_endorsements(shard_index, endorsements); } bitmap } /// Adds the provided endorsements to the bitmap for the specified shard. - pub fn add_endorsements(&mut self, shard_id: ShardId, endorsements: Vec) { + pub fn add_endorsements(&mut self, shard_index: usize, endorsements: Vec) { let bitvec: BitVecType = endorsements.iter().collect(); - self.inner[shard_id as usize] = bitvec.into(); + self.inner[shard_index] = bitvec.into(); } /// Returns an iterator over the endorsements (yields true if the endorsement for the respective position was received). - pub fn iter(&self, shard_id: ShardId) -> Box> { - let bitvec = BitVecType::from_vec(self.inner[shard_id as usize].clone()); + pub fn iter(&self, shard_index: usize) -> Box> { + let bitvec = BitVecType::from_vec(self.inner[shard_index].clone()); Box::new(bitvec.into_iter()) } @@ -81,8 +80,8 @@ impl ChunkEndorsementsBitmap { /// Returns the full length of the bitmap for a given shard. /// Note that the size may be greater than the number of validator assignments. - pub fn len(&self, shard_id: ShardId) -> Option { - self.inner.get(shard_id as usize).map(|v| v.len() * 8) + pub fn len(&self, shard_index: usize) -> Option { + self.inner.get(shard_index).map(|v| v.len() * 8) } } @@ -90,7 +89,6 @@ impl ChunkEndorsementsBitmap { mod tests { use super::ChunkEndorsementsBitmap; use itertools::Itertools; - use near_primitives_core::types::ShardId; use rand::Rng; const NUM_SHARDS: usize = 4; @@ -102,9 +100,9 @@ mod tests { expected_endorsements: &Vec>, ) { // Endorsements from the bitmap iterator must match the endorsements given previously. - for (shard_id, endorsements) in expected_endorsements.iter().enumerate() { - let num_bits = bitmap.len(shard_id as ShardId).unwrap(); - let bits = bitmap.iter(shard_id as ShardId).collect_vec(); + for (shard_index, endorsements) in expected_endorsements.iter().enumerate() { + let num_bits = bitmap.len(shard_index).unwrap(); + let bits = bitmap.iter(shard_index).collect_vec(); // Number of bits must be equal to the size of the bit iterator for the corresponding shard. assert_eq!(num_bits, bits.len()); // Bitmap must contain the minimal number of bits to represent the endorsements. @@ -121,13 +119,13 @@ mod tests { let mut rng = rand::thread_rng(); let mut bitmap = ChunkEndorsementsBitmap::new(NUM_SHARDS); let mut expected_endorsements = vec![]; - for shard_id in 0..NUM_SHARDS { + for shard_index in 0..NUM_SHARDS { let mut endorsements = vec![false; num_assignments]; for _ in 0..num_produced { endorsements[rng.gen_range(0..num_assignments)] = true; } expected_endorsements.push(endorsements.clone()); - bitmap.add_endorsements(shard_id as ShardId, endorsements); + bitmap.add_endorsements(shard_index, endorsements); } // Check before serialization. assert_bitmap(&bitmap, num_assignments, &expected_endorsements); diff --git a/core/store/benches/finalize_bench.rs b/core/store/benches/finalize_bench.rs index aedce9232ec..49d7915ada4 100644 --- a/core/store/benches/finalize_bench.rs +++ b/core/store/benches/finalize_bench.rs @@ -30,7 +30,7 @@ use near_primitives::sharding::{ ShardChunkV2, ShardProof, }; use near_primitives::transaction::{Action, FunctionCallAction, SignedTransaction}; -use near_primitives::types::AccountId; +use near_primitives::types::{new_shard_id_tmp, AccountId, ShardId}; use near_primitives::validator_signer::InMemoryValidatorSigner; use near_primitives::version::{ProtocolFeature, PROTOCOL_VERSION}; use near_store::DBCol; @@ -115,7 +115,7 @@ fn create_benchmark_receipts() -> Vec { ] } -fn create_chunk_header(height: u64, shard_id: u64) -> ShardChunkHeader { +fn create_chunk_header(height: u64, shard_id: ShardId) -> ShardChunkHeader { let congestion_info = ProtocolFeature::CongestionControl .enabled(PROTOCOL_VERSION) .then_some(CongestionInfo::default()); @@ -177,7 +177,7 @@ fn create_shard_chunk( ) -> ShardChunk { ShardChunk::V2(ShardChunkV2 { chunk_hash: chunk_hash.clone(), - header: create_chunk_header(0, 0), + header: create_chunk_header(0, new_shard_id_tmp(0)), transactions, prev_outgoing_receipts: receipts, }) @@ -198,7 +198,7 @@ fn create_encoded_shard_chunk( Default::default(), Default::default(), Default::default(), - Default::default(), + new_shard_id_tmp(0), Default::default(), Default::default(), Default::default(), @@ -231,8 +231,8 @@ fn encoded_chunk_to_partial_encoded_chunk( let receipt_proofs = proofs .into_iter() .enumerate() - .map(move |(proof_shard_id, proof)| { - let proof_shard_id = proof_shard_id as u64; + .map(move |(proof_shard_index, proof)| { + let proof_shard_id = shard_layout.get_shard_id(proof_shard_index); let receipts = receipts_by_shard.remove(&proof_shard_id).unwrap_or_else(Vec::new); let shard_proof = ShardProof { from_shard_id: shard_id, to_shard_id: proof_shard_id, proof }; diff --git a/core/store/src/flat/storage.rs b/core/store/src/flat/storage.rs index 172928d35eb..79f7001aacf 100644 --- a/core/store/src/flat/storage.rs +++ b/core/store/src/flat/storage.rs @@ -130,7 +130,7 @@ impl FlatStorageInner { if blocks.len() >= Self::HOPS_LIMIT { warn!( target: "chain", - shard_id = self.shard_uid.shard_id(), + shard_id = ?self.shard_uid.shard_id(), flat_head_height = flat_head.height, cached_deltas = self.deltas.len(), num_hops = blocks.len(), @@ -160,7 +160,7 @@ impl FlatStorageInner { if cached_changes_size_bytes >= Self::CACHED_CHANGES_SIZE_LIMIT { warn!( target: "chain", - shard_id = self.shard_uid.shard_id(), + shard_id = ?self.shard_uid.shard_id(), flat_head_height = self.flat_head.height, cached_deltas, %cached_changes_size_bytes, @@ -380,7 +380,7 @@ impl FlatStorage { let shard_uid = guard.shard_uid; let shard_id = shard_uid.shard_id(); - tracing::debug!(target: "store", flat_head = ?guard.flat_head.hash, ?new_head, shard_id, "Moving flat head"); + tracing::debug!(target: "store", flat_head = ?guard.flat_head.hash, ?new_head, ?shard_id, "Moving flat head"); let blocks = guard.get_blocks_to_head(&new_head)?; for block_hash in blocks.into_iter().rev() { diff --git a/core/store/src/genesis/initialization.rs b/core/store/src/genesis/initialization.rs index bc9b1c9a65c..210106e91dd 100644 --- a/core/store/src/genesis/initialization.rs +++ b/core/store/src/genesis/initialization.rs @@ -2,7 +2,11 @@ //! We first check if store has the genesis hash and state_roots, if not, we go ahead with initialization use rayon::prelude::*; -use std::{collections::HashSet, fs, path::Path}; +use std::{ + collections::{BTreeMap, HashSet}, + fs, + path::Path, +}; use borsh::BorshDeserialize; use near_chain_configs::{Genesis, GenesisContents}; @@ -110,17 +114,18 @@ fn genesis_state_from_genesis( let runtime_config_store = RuntimeConfigStore::for_chain_id(&genesis.config.chain_id); let runtime_config = runtime_config_store.get_config(genesis.config.protocol_version); let storage_usage_config = &runtime_config.fees.storage_usage_config; + let shard_ids: Vec<_> = shard_layout.shard_ids().collect(); let shard_uids: Vec<_> = shard_layout.shard_uids().collect(); - // note that here we are depending on the behavior that shard_layout.shard_uids() returns an iterator - // in order by shard id from 0 to num_shards() - let mut shard_account_ids: Vec> = - shard_uids.iter().map(|_| HashSet::new()).collect(); + + let mut shard_account_ids: BTreeMap> = + shard_ids.iter().map(|&shard_id| (shard_id, HashSet::new())).collect(); let mut has_protocol_account = false; info!(target: "store","distributing records to shards"); genesis.for_each_record(|record: &StateRecord| { - shard_account_ids[state_record_to_shard_id(record, &shard_layout) as usize] - .insert(state_record_to_account_id(record).clone()); + let shard_id = state_record_to_shard_id(record, &shard_layout); + let account_id = state_record_to_account_id(record).clone(); + shard_account_ids.get_mut(&shard_id).unwrap().insert(account_id); if let StateRecord::Account { account_id, .. } = record { if account_id == &genesis.config.protocol_treasury_account { has_protocol_account = true; @@ -165,7 +170,7 @@ fn genesis_state_from_genesis( &validators, storage_usage_config, genesis, - shard_account_ids[shard_id as usize].clone(), + shard_account_ids[&shard_id].clone(), ) }) .collect() diff --git a/core/store/src/trie/prefetching_trie_storage.rs b/core/store/src/trie/prefetching_trie_storage.rs index 31e831e331c..bd886ab32b6 100644 --- a/core/store/src/trie/prefetching_trie_storage.rs +++ b/core/store/src/trie/prefetching_trie_storage.rs @@ -593,12 +593,13 @@ mod tests_utils { mod tests { use super::{PrefetchStagingArea, PrefetcherResult}; use near_primitives::hash::CryptoHash; + use near_primitives::types::new_shard_id_tmp; #[test] fn test_prefetch_staging_area_blocking_get_after_update() { let key = CryptoHash::hash_bytes(&[1, 2, 3]); let value: std::sync::Arc<[u8]> = vec![4, 5, 6].into(); - let prefetch_staging_area = PrefetchStagingArea::new(0); + let prefetch_staging_area = PrefetchStagingArea::new(new_shard_id_tmp(0)); assert!(matches!( prefetch_staging_area.get_or_set_fetching(key), PrefetcherResult::SlotReserved diff --git a/core/store/src/trie/shard_tries.rs b/core/store/src/trie/shard_tries.rs index 92977af9469..d804b49dcb4 100644 --- a/core/store/src/trie/shard_tries.rs +++ b/core/store/src/trie/shard_tries.rs @@ -286,7 +286,7 @@ impl ShardTries { level = "trace", target = "store::trie::shard_tries", "ShardTries::apply_insertions", - fields(num_insertions = trie_changes.insertions().len(), shard_id = shard_uid.shard_id()), + fields(num_insertions = trie_changes.insertions().len(), shard_id = ?shard_uid.shard_id()), skip_all, )] pub fn apply_insertions( @@ -309,7 +309,7 @@ impl ShardTries { level = "trace", target = "store::trie::shard_tries", "ShardTries::apply_deletions", - fields(num_deletions = trie_changes.deletions().len(), shard_id = shard_uid.shard_id()), + fields(num_deletions = trie_changes.deletions().len(), shard_id = ?shard_uid.shard_id()), skip_all, )] pub fn apply_deletions( @@ -553,7 +553,7 @@ impl WrappedTrieChanges { level = "debug", target = "store::trie::shard_tries", "ShardTries::state_changes_into", - fields(num_state_changes = self.state_changes.len(), shard_id = self.shard_uid.shard_id()), + fields(num_state_changes = self.state_changes.len(), shard_id = ?self.shard_uid.shard_id()), skip_all, )] pub fn state_changes_into(&mut self, store_update: &mut TrieStoreUpdateAdapter) { diff --git a/core/store/src/trie/state_parts.rs b/core/store/src/trie/state_parts.rs index 47b54740534..1f20c3c5ea0 100644 --- a/core/store/src/trie/state_parts.rs +++ b/core/store/src/trie/state_parts.rs @@ -130,7 +130,7 @@ impl Trie { ) -> Result<(PartialState, Vec, Vec), StorageError> { let shard_id: ShardId = self.flat_storage_chunk_view.as_ref().map_or( ShardId::MAX, // Fake value for metrics. - |chunk_view| chunk_view.shard_uid().shard_id as ShardId, + |chunk_view| chunk_view.shard_uid().shard_id(), ); let _span = tracing::debug_span!( target: "state-parts", @@ -184,7 +184,7 @@ impl Trie { ) -> Result { let shard_id: ShardId = self.flat_storage_chunk_view.as_ref().map_or( ShardId::MAX, // Fake value for metrics. - |chunk_view| chunk_view.shard_uid().shard_id as ShardId, + |chunk_view| chunk_view.shard_uid().shard_id(), ); let _span = tracing::debug_span!( target: "state-parts", diff --git a/core/store/src/trie/trie_storage.rs b/core/store/src/trie/trie_storage.rs index 61ac07fa9b0..51a7b048a4d 100644 --- a/core/store/src/trie/trie_storage.rs +++ b/core/store/src/trie/trie_storage.rs @@ -612,7 +612,7 @@ mod trie_cache_tests { use crate::{StoreConfig, TrieCache, TrieConfig}; use near_primitives::hash::hash; use near_primitives::shard_layout::ShardUId; - use near_primitives::types::ShardId; + use near_primitives::types::{new_shard_id_tmp, shard_id_as_u32, ShardId}; fn put_value(cache: &mut TrieCacheInner, value: &[u8]) { cache.put(hash(value), value.into()); @@ -622,7 +622,8 @@ mod trie_cache_tests { fn test_size_limit() { let value_size_sum = 5; let memory_overhead = 2 * TrieCacheInner::PER_ENTRY_OVERHEAD; - let mut cache = TrieCacheInner::new(100, value_size_sum + memory_overhead, 0, false); + let mut cache = + TrieCacheInner::new(100, value_size_sum + memory_overhead, new_shard_id_tmp(0), false); // Add three values. Before each put, condition on total size should not be triggered. put_value(&mut cache, &[1, 1]); assert_eq!(cache.current_total_size(), 2 + TrieCacheInner::PER_ENTRY_OVERHEAD); @@ -640,7 +641,7 @@ mod trie_cache_tests { #[test] fn test_deletions_queue() { - let mut cache = TrieCacheInner::new(2, 1000, 0, false); + let mut cache = TrieCacheInner::new(2, 1000, new_shard_id_tmp(0), false); // Add two values to the cache. put_value(&mut cache, &[1]); put_value(&mut cache, &[1, 1]); @@ -659,7 +660,7 @@ mod trie_cache_tests { fn test_cache_capacity() { let capacity = 2; let total_size_limit = TrieCacheInner::PER_ENTRY_OVERHEAD * capacity; - let mut cache = TrieCacheInner::new(100, total_size_limit, 0, false); + let mut cache = TrieCacheInner::new(100, total_size_limit, new_shard_id_tmp(0), false); put_value(&mut cache, &[1]); put_value(&mut cache, &[2]); put_value(&mut cache, &[3]); @@ -672,7 +673,7 @@ mod trie_cache_tests { #[test] fn test_small_memory_limit() { let total_size_limit = 1; - let mut cache = TrieCacheInner::new(100, total_size_limit, 0, false); + let mut cache = TrieCacheInner::new(100, total_size_limit, new_shard_id_tmp(0), false); put_value(&mut cache, &[1, 2, 3]); put_value(&mut cache, &[2, 3, 4]); put_value(&mut cache, &[3, 4, 5]); @@ -699,10 +700,10 @@ mod trie_cache_tests { store_config.view_trie_cache.per_shard_max_bytes.insert(s0, S0_VIEW_SIZE); let trie_config = TrieConfig::from_store_config(&store_config); - check_cache_size(&trie_config, 1, false, DEFAULT_SIZE); - check_cache_size(&trie_config, 0, false, S0_SIZE); - check_cache_size(&trie_config, 1, true, DEFAULT_VIEW_SIZE); - check_cache_size(&trie_config, 0, true, S0_VIEW_SIZE); + check_cache_size(&trie_config, new_shard_id_tmp(1), false, DEFAULT_SIZE); + check_cache_size(&trie_config, new_shard_id_tmp(0), false, S0_SIZE); + check_cache_size(&trie_config, new_shard_id_tmp(1), true, DEFAULT_VIEW_SIZE); + check_cache_size(&trie_config, new_shard_id_tmp(0), true, S0_VIEW_SIZE); } #[track_caller] @@ -712,7 +713,7 @@ mod trie_cache_tests { is_view: bool, expected_size: bytesize::ByteSize, ) { - let shard_uid = ShardUId { version: 0, shard_id: shard_id as u32 }; + let shard_uid = ShardUId { version: 0, shard_id: shard_id_as_u32(shard_id) }; let trie_cache = TrieCache::new(&trie_config, shard_uid, is_view); assert_eq!(expected_size.as_u64(), trie_cache.lock().total_size_limit); assert_eq!(is_view, trie_cache.lock().is_view); diff --git a/genesis-tools/genesis-csv-to-json/src/csv_to_json_configs.rs b/genesis-tools/genesis-csv-to-json/src/csv_to_json_configs.rs index b5b1fd18f65..5f1695cef89 100644 --- a/genesis-tools/genesis-csv-to-json/src/csv_to_json_configs.rs +++ b/genesis-tools/genesis-csv-to-json/src/csv_to_json_configs.rs @@ -5,7 +5,7 @@ use near_chain_configs::{ MIN_GAS_PRICE, NEAR_BASE, NUM_BLOCKS_PER_YEAR, NUM_BLOCK_PRODUCER_SEATS, PROTOCOL_REWARD_RATE, PROTOCOL_UPGRADE_STAKE_THRESHOLD, TRANSACTION_VALIDITY_PERIOD, }; -use near_primitives::types::{Balance, NumShards, ShardId}; +use near_primitives::types::{new_shard_id_tmp, Balance, NumShards, ShardId}; use near_primitives::utils::get_num_seats_per_shard; use near_primitives::version::PROTOCOL_VERSION; use nearcore::config::{Config, CONFIG_FILENAME, NODE_KEY_FILE}; @@ -14,7 +14,16 @@ use std::fs::File; use std::path::Path; const ACCOUNTS_FILE: &str = "accounts.csv"; -const SHARDS: &'static [ShardId] = &[0, 1, 2, 3, 4, 5, 6, 7]; +const SHARDS: &'static [ShardId] = &[ + new_shard_id_tmp(0), + new_shard_id_tmp(1), + new_shard_id_tmp(2), + new_shard_id_tmp(3), + new_shard_id_tmp(4), + new_shard_id_tmp(5), + new_shard_id_tmp(6), + new_shard_id_tmp(7), +]; fn verify_total_supply(total_supply: Balance, chain_id: &str) { if chain_id == near_primitives::chains::MAINNET { diff --git a/genesis-tools/genesis-csv-to-json/src/main.rs b/genesis-tools/genesis-csv-to-json/src/main.rs index 4553970cfe8..7bd9b40f229 100644 --- a/genesis-tools/genesis-csv-to-json/src/main.rs +++ b/genesis-tools/genesis-csv-to-json/src/main.rs @@ -35,7 +35,7 @@ fn main() { if s.is_empty() { HashSet::default() } else { - s.split(',').map(|v| v.parse::().unwrap()).collect() + s.split(',').map(|v| v.parse::().unwrap().into()).collect() } } None => HashSet::default(), diff --git a/genesis-tools/genesis-populate/src/lib.rs b/genesis-tools/genesis-populate/src/lib.rs index 905dc8fa469..ee223645c5b 100644 --- a/genesis-tools/genesis-populate/src/lib.rs +++ b/genesis-tools/genesis-populate/src/lib.rs @@ -18,7 +18,10 @@ use near_primitives::hash::{hash, CryptoHash}; use near_primitives::shard_layout::{account_id_to_shard_id, ShardUId}; use near_primitives::state_record::StateRecord; use near_primitives::types::chunk_extra::ChunkExtra; -use near_primitives::types::{AccountId, Balance, EpochId, ShardId, StateChangeCause, StateRoot}; +use near_primitives::types::{ + new_shard_id_tmp, shard_id_as_u32, AccountId, Balance, EpochId, ShardId, StateChangeCause, + StateRoot, +}; use near_primitives::utils::to_timestamp; use near_primitives::version::ProtocolFeature; use near_store::adapter::StoreUpdateAdapter; @@ -134,15 +137,19 @@ impl GenesisBuilder { let roots = get_genesis_state_roots(self.runtime.store())? .expect("genesis state roots not initialized."); let genesis_shard_version = self.genesis.config.shard_layout.version(); - self.roots = roots.into_iter().enumerate().map(|(k, v)| (k as u64, v)).collect(); + self.roots = + roots.into_iter().enumerate().map(|(k, v)| (new_shard_id_tmp(k as u64), v)).collect(); self.state_updates = self .roots .iter() - .map(|(shard_idx, root)| { + .map(|(&shard_id, root)| { ( - *shard_idx, + shard_id, self.runtime.get_tries().new_trie_update( - ShardUId { version: genesis_shard_version, shard_id: *shard_idx as u32 }, + ShardUId { + version: genesis_shard_version, + shard_id: shard_id_as_u32(shard_id), + }, *root, ), ) @@ -200,7 +207,8 @@ impl GenesisBuilder { state_update.commit(StateChangeCause::InitialState); let (_, trie_changes, state_changes) = state_update.finalize()?; let genesis_shard_version = self.genesis.config.shard_layout.version(); - let shard_uid = ShardUId { version: genesis_shard_version, shard_id: shard_idx as u32 }; + let shard_uid = + ShardUId { version: genesis_shard_version, shard_id: shard_id_as_u32(shard_idx) }; let mut store_update = tries.store_update(); let root = tries.apply_all(&trie_changes, shard_uid, &mut store_update); near_store::flat::FlatStateChanges::from_state_changes(&state_changes) @@ -300,7 +308,7 @@ impl GenesisBuilder { &self, protocol_version: ProtocolVersion, genesis: &Block, - shard_id: u64, + shard_id: ShardId, state_root: CryptoHash, ) -> Result> { if !ProtocolFeature::CongestionControl.enabled(protocol_version) { diff --git a/integration-tests/src/runtime_utils.rs b/integration-tests/src/runtime_utils.rs index 214dc62046a..94bc7c4957b 100644 --- a/integration-tests/src/runtime_utils.rs +++ b/integration-tests/src/runtime_utils.rs @@ -7,8 +7,8 @@ use near_chain_configs::Genesis; use near_parameters::RuntimeConfig; use near_primitives::shard_layout::ShardUId; use near_primitives::state_record::{state_record_to_account_id, StateRecord}; -use near_primitives::types::AccountId; use near_primitives::types::StateRoot; +use near_primitives::types::{new_shard_id_tmp, AccountId}; use near_primitives_core::types::NumShards; use near_store::genesis::GenesisStateApplier; use near_store::test_utils::TestTriesBuilder; @@ -51,7 +51,7 @@ pub fn get_runtime_and_trie_from_genesis(genesis: &Genesis) -> (Runtime, ShardTr let genesis_root = GenesisStateApplier::apply( &writers, tries.clone(), - ShardUId::from_shard_id_and_layout(0, shard_layout), + ShardUId::from_shard_id_and_layout(new_shard_id_tmp(0), shard_layout), &genesis .config .validators diff --git a/integration-tests/src/test_loop/builder.rs b/integration-tests/src/test_loop/builder.rs index 9c9903b1e75..197faf0c31f 100644 --- a/integration-tests/src/test_loop/builder.rs +++ b/integration-tests/src/test_loop/builder.rs @@ -29,7 +29,7 @@ use near_parameters::RuntimeConfigStore; use near_primitives::epoch_manager::EpochConfigStore; use near_primitives::network::PeerId; use near_primitives::test_utils::create_test_signer; -use near_primitives::types::AccountId; +use near_primitives::types::{new_shard_id_tmp, AccountId}; use near_store::adapter::StoreAdapter; use near_store::config::StateSnapshotType; use near_store::genesis::initialize_genesis_state; @@ -285,7 +285,7 @@ impl TestLoopBuilder { if is_validator && !self.track_all_shards { client_config.tracked_shards = Vec::new(); } else { - client_config.tracked_shards = vec![666]; + client_config.tracked_shards = vec![new_shard_id_tmp(666)]; } if let Some(config_modifier) = &self.config_modifier { diff --git a/integration-tests/src/test_loop/tests/in_memory_tries.rs b/integration-tests/src/test_loop/tests/in_memory_tries.rs index 3d63b5a90f0..6f37d2cf06b 100644 --- a/integration-tests/src/test_loop/tests/in_memory_tries.rs +++ b/integration-tests/src/test_loop/tests/in_memory_tries.rs @@ -3,7 +3,7 @@ use near_async::time::Duration; use near_chain_configs::test_genesis::TestGenesisBuilder; use near_client::test_utils::test_loop::ClientQueries; use near_o11y::testonly::init_test_logger; -use near_primitives::types::AccountId; +use near_primitives::types::{new_shard_id_tmp, AccountId}; use near_store::ShardUId; use crate::test_loop::builder::TestLoopBuilder; @@ -77,7 +77,15 @@ fn test_load_memtrie_after_empty_chunks() { current_tracked_shards .iter() .enumerate() - .find_map(|(idx, shards)| if shards.contains(&0) { Some(idx) } else { None }) + .find_map( + |(idx, shards)| { + if shards.contains(&new_shard_id_tmp(0)) { + Some(idx) + } else { + None + } + }, + ) .expect("Not found any client tracking shard 0") }; @@ -87,11 +95,15 @@ fn test_load_memtrie_after_empty_chunks() { clients[idx] .runtime_adapter .get_tries() - .unload_mem_trie(&ShardUId::from_shard_id_and_layout(0, &shard_layout)); + .unload_mem_trie(&ShardUId::from_shard_id_and_layout(new_shard_id_tmp(0), &shard_layout)); clients[idx] .runtime_adapter .get_tries() - .load_mem_trie(&ShardUId::from_shard_id_and_layout(0, &shard_layout), None, true) + .load_mem_trie( + &ShardUId::from_shard_id_and_layout(new_shard_id_tmp(0), &shard_layout), + None, + true, + ) .expect("Couldn't load memtrie"); // Give the test a chance to finish off remaining events in the event loop, which can diff --git a/integration-tests/src/test_loop/tests/view_requests_to_archival_node.rs b/integration-tests/src/test_loop/tests/view_requests_to_archival_node.rs index fc8531552c9..f2f4a3dc7d1 100644 --- a/integration-tests/src/test_loop/tests/view_requests_to_archival_node.rs +++ b/integration-tests/src/test_loop/tests/view_requests_to_archival_node.rs @@ -15,8 +15,8 @@ use near_network::client::BlockHeadersRequest; use near_o11y::testonly::init_test_logger; use near_primitives::sharding::ChunkHash; use near_primitives::types::{ - AccountId, BlockHeight, BlockId, BlockReference, EpochId, EpochReference, Finality, - SyncCheckpoint, + new_shard_id_tmp, AccountId, BlockHeight, BlockId, BlockReference, EpochId, EpochReference, + Finality, SyncCheckpoint, }; use near_primitives::version::PROTOCOL_VERSION; use near_primitives::views::{ @@ -223,10 +223,10 @@ impl<'a> ViewClientTester<'a> { chunk }; - let chunk_by_height = GetChunk::Height(5, 0); + let chunk_by_height = GetChunk::Height(5, new_shard_id_tmp(0)); get_and_check_chunk(chunk_by_height); - let chunk_by_block_hash = GetChunk::BlockHash(block.header.hash, 0); + let chunk_by_block_hash = GetChunk::BlockHash(block.header.hash, new_shard_id_tmp(0)); get_and_check_chunk(chunk_by_block_hash); let chunk_by_chunk_hash = GetChunk::ChunkHash(ChunkHash(block.chunks[0].chunk_hash)); @@ -242,10 +242,10 @@ impl<'a> ViewClientTester<'a> { assert_eq!(shard_chunk.take_header().gas_limit(), 1_000_000_000_000_000); }; - let chunk_by_height = GetShardChunk::Height(5, 0); + let chunk_by_height = GetShardChunk::Height(5, new_shard_id_tmp(0)); get_and_check_shard_chunk(chunk_by_height); - let chunk_by_block_hash = GetShardChunk::BlockHash(block.header.hash, 0); + let chunk_by_block_hash = GetShardChunk::BlockHash(block.header.hash, new_shard_id_tmp(0)); get_and_check_shard_chunk(chunk_by_block_hash); let chunk_by_chunk_hash = GetShardChunk::ChunkHash(ChunkHash(block.chunks[0].chunk_hash)); @@ -376,9 +376,9 @@ impl<'a> ViewClientTester<'a> { let request = GetExecutionOutcomesForBlock { block_hash: block.header.hash }; let outcomes = self.send(request, ARCHIVAL_CLIENT).unwrap(); assert_eq!(outcomes.len(), NUM_SHARDS); - assert_eq!(outcomes[&0].len(), 1); + assert_eq!(outcomes[&new_shard_id_tmp(0)].len(), 1); assert!(matches!( - outcomes[&0][0], + outcomes[&new_shard_id_tmp(0)][0], ExecutionOutcomeWithIdView { outcome: ExecutionOutcomeView { status: ExecutionStatusView::SuccessReceiptId(_), @@ -387,9 +387,9 @@ impl<'a> ViewClientTester<'a> { .. } )); - assert_eq!(outcomes[&1].len(), 1); + assert_eq!(outcomes[&new_shard_id_tmp(1)].len(), 1); assert!(matches!( - outcomes[&1][0], + outcomes[&new_shard_id_tmp(1)][0], ExecutionOutcomeWithIdView { outcome: ExecutionOutcomeView { status: ExecutionStatusView::SuccessReceiptId(_), @@ -398,8 +398,8 @@ impl<'a> ViewClientTester<'a> { .. } )); - assert_eq!(outcomes[&2].len(), 0); - assert_eq!(outcomes[&3].len(), 0); + assert_eq!(outcomes[&new_shard_id_tmp(2)].len(), 0); + assert_eq!(outcomes[&new_shard_id_tmp(3)].len(), 0); } /// Generates variations of the [`GetStateChanges`] request and issues them to the view client of the archival node. diff --git a/integration-tests/src/tests/client/block_corruption.rs b/integration-tests/src/tests/client/block_corruption.rs index d45f9bc579e..e5ffaf651b7 100644 --- a/integration-tests/src/tests/client/block_corruption.rs +++ b/integration-tests/src/tests/client/block_corruption.rs @@ -61,16 +61,17 @@ fn change_shard_id_to_invalid() { let mut block = env.clients[0].produce_block(2).unwrap().unwrap(); // 1. Corrupt chunks + let bad_shard_id = 100; let mut new_chunks = vec![]; for chunk in block.chunks().iter() { let mut new_chunk = chunk.clone(); match &mut new_chunk { - ShardChunkHeader::V1(new_chunk) => new_chunk.inner.shard_id = 100, - ShardChunkHeader::V2(new_chunk) => new_chunk.inner.shard_id = 100, + ShardChunkHeader::V1(new_chunk) => new_chunk.inner.shard_id = bad_shard_id, + ShardChunkHeader::V2(new_chunk) => new_chunk.inner.shard_id = bad_shard_id, ShardChunkHeader::V3(new_chunk) => match &mut new_chunk.inner { - ShardChunkHeaderInner::V1(inner) => inner.shard_id = 100, - ShardChunkHeaderInner::V2(inner) => inner.shard_id = 100, - ShardChunkHeaderInner::V3(inner) => inner.shard_id = 100, + ShardChunkHeaderInner::V1(inner) => inner.shard_id = bad_shard_id, + ShardChunkHeaderInner::V2(inner) => inner.shard_id = bad_shard_id, + ShardChunkHeaderInner::V3(inner) => inner.shard_id = bad_shard_id, }, }; new_chunks.push(new_chunk); @@ -88,7 +89,8 @@ fn change_shard_id_to_invalid() { // Try to process corrupt block and expect code to notice invalid shard_id let res = env.clients[0].process_block_test(block.into(), Provenance::NONE); match res { - Err(Error::InvalidShardId(100)) => { + Err(Error::InvalidShardId(shard_id)) => { + assert_eq!(shard_id, bad_shard_id); tracing::debug!("process failed successfully"); } Err(e) => { diff --git a/integration-tests/src/tests/client/challenges.rs b/integration-tests/src/tests/client/challenges.rs index 7338208fd99..9b1072384ee 100644 --- a/integration-tests/src/tests/client/challenges.rs +++ b/integration-tests/src/tests/client/challenges.rs @@ -22,7 +22,7 @@ use near_primitives::stateless_validation::chunk_endorsement::ChunkEndorsementV1 use near_primitives::test_utils::create_test_signer; use near_primitives::transaction::SignedTransaction; use near_primitives::types::chunk_extra::ChunkExtra; -use near_primitives::types::AccountId; +use near_primitives::types::{AccountId, ShardId}; use near_primitives::version::{ProtocolFeature, PROTOCOL_VERSION}; use near_store::Trie; use nearcore::test_utils::TestEnvNightshadeSetupExt; @@ -200,7 +200,7 @@ fn test_verify_chunk_invalid_proofs_challenge() { let shard_id = chunk.shard_id(); let challenge_result = - challenge(env, shard_id as usize, MaybeEncodedShardChunk::Encoded(chunk).into(), &block); + challenge(env, shard_id, MaybeEncodedShardChunk::Encoded(chunk).into(), &block); assert_eq!(challenge_result.unwrap(), (*block.hash(), vec!["test0".parse().unwrap()])); } @@ -215,7 +215,7 @@ fn test_verify_chunk_invalid_proofs_challenge_decoded_chunk() { let shard_id = chunk.shard_id(); let challenge_result = - challenge(env, shard_id as usize, MaybeEncodedShardChunk::Decoded(chunk).into(), &block); + challenge(env, shard_id, MaybeEncodedShardChunk::Decoded(chunk).into(), &block); assert_eq!(challenge_result.unwrap(), (*block.hash(), vec!["test0".parse().unwrap()])); } @@ -228,7 +228,7 @@ fn test_verify_chunk_proofs_malicious_challenge_no_changes() { let shard_id = chunk.shard_id(); let challenge_result = - challenge(env, shard_id as usize, MaybeEncodedShardChunk::Encoded(chunk).into(), &block); + challenge(env, shard_id, MaybeEncodedShardChunk::Encoded(chunk).into(), &block); assert_matches!(challenge_result.unwrap_err(), Error::MaliciousChallenge); } @@ -265,7 +265,7 @@ fn test_verify_chunk_proofs_malicious_challenge_valid_order_transactions() { let shard_id = chunk.shard_id(); let challenge_result = - challenge(env, shard_id as usize, MaybeEncodedShardChunk::Encoded(chunk).into(), &block); + challenge(env, shard_id, MaybeEncodedShardChunk::Encoded(chunk).into(), &block); assert_matches!(challenge_result.unwrap_err(), Error::MaliciousChallenge); } @@ -302,13 +302,13 @@ fn test_verify_chunk_proofs_challenge_transaction_order() { let shard_id = chunk.shard_id(); let challenge_result = - challenge(env, shard_id as usize, MaybeEncodedShardChunk::Encoded(chunk).into(), &block); + challenge(env, shard_id, MaybeEncodedShardChunk::Encoded(chunk).into(), &block); assert_eq!(challenge_result.unwrap(), (*block.hash(), vec!["test0".parse().unwrap()])); } fn challenge( env: TestEnv, - shard_id: usize, + shard_id: ShardId, chunk: Box, block: &Block, ) -> Result<(CryptoHash, Vec), Error> { @@ -317,7 +317,7 @@ fn challenge( ChallengeBody::ChunkProofs(ChunkProofs { block_header: borsh::to_vec(&block.header()).unwrap(), chunk, - merkle_proof: merkle_paths[shard_id].clone(), + merkle_proof: merkle_paths[shard_id as usize].clone(), }), &*env.clients[0].validator_signer.get().unwrap(), ); diff --git a/nearcore/src/config.rs b/nearcore/src/config.rs index 57a2d8d1fac..30a2607adf1 100644 --- a/nearcore/src/config.rs +++ b/nearcore/src/config.rs @@ -40,8 +40,8 @@ use near_primitives::hash::CryptoHash; use near_primitives::shard_layout::ShardLayout; use near_primitives::test_utils::create_test_signer; use near_primitives::types::{ - AccountId, AccountInfo, Balance, BlockHeight, BlockHeightDelta, Gas, NumSeats, NumShards, - ShardId, + new_shard_id_tmp, AccountId, AccountInfo, Balance, BlockHeight, BlockHeightDelta, Gas, + NumSeats, NumShards, ShardId, }; use near_primitives::utils::{from_timestamp, get_num_seats_per_shard}; use near_primitives::validator_signer::{InMemoryValidatorSigner, ValidatorSigner}; @@ -845,7 +845,7 @@ pub fn init_configs( let mut config = Config::default(); // Make sure node tracks all shards, see // https://github.com/near/nearcore/issues/7388 - config.tracked_shards = vec![0]; + config.tracked_shards = vec![new_shard_id_tmp(0)]; // If a config gets generated, block production times may need to be updated. set_block_production_delay(&chain_id, fast, &mut config); @@ -1073,7 +1073,7 @@ pub fn create_localnet_configs_from_seeds( num_non_validators_archival: NumSeats, num_non_validators_rpc: NumSeats, num_non_validators: NumSeats, - tracked_shards: Vec, + tracked_shards: Vec, ) -> (Vec, Vec, Vec, Genesis) { assert_eq!( seeds.len() as u64, @@ -1163,7 +1163,7 @@ pub fn create_localnet_configs_from_seeds( fn create_localnet_config( num_shards: NumShards, num_validators: NumSeats, - tracked_shards: &Vec, + tracked_shards: &Vec, network_signers: &Vec, boot_node_addr: &tcp::ListenerAddr, params: LocalnetNodeParams, @@ -1204,7 +1204,7 @@ fn create_localnet_config( // Make non-validator archival and RPC nodes track all shards. // Note that validator nodes may track all or some of the shards. config.tracked_shards = if !params.is_validator && (params.is_archival || params.is_rpc) { - (0..num_shards).collect() + (0..num_shards).map(new_shard_id_tmp).collect() } else { tracked_shards.clone() }; @@ -1232,7 +1232,7 @@ pub fn create_localnet_configs( num_non_validators_rpc: NumSeats, num_non_validators: NumSeats, prefix: &str, - tracked_shards: Vec, + tracked_shards: Vec, ) -> (Vec, Vec, Vec, Genesis, Vec) { let num_all_nodes = num_validators + num_non_validators_archival + num_non_validators_rpc + num_non_validators; @@ -1272,7 +1272,7 @@ pub fn init_localnet_configs( num_non_validators_rpc: NumSeats, num_non_validators: NumSeats, prefix: &str, - tracked_shards: Vec, + tracked_shards: Vec, ) { let (configs, validator_signers, network_signers, genesis, shard_keys) = create_localnet_configs( @@ -1522,7 +1522,7 @@ mod tests { use near_chain_configs::{GCConfig, Genesis, GenesisValidationMode}; use near_crypto::InMemorySigner; use near_primitives::shard_layout::account_id_to_shard_id; - use near_primitives::types::{AccountId, NumShards}; + use near_primitives::types::{new_shard_id_tmp, AccountId, NumShards}; use tempfile::tempdir; use crate::config::{ @@ -1562,21 +1562,21 @@ mod tests { &AccountId::from_str("foobar.near").unwrap(), &genesis.config.shard_layout, ), - 0 + new_shard_id_tmp(0) ); assert_eq!( account_id_to_shard_id( &AccountId::from_str("shard1.test.near").unwrap(), &genesis.config.shard_layout, ), - 1 + new_shard_id_tmp(1) ); assert_eq!( account_id_to_shard_id( &AccountId::from_str("shard2.test.near").unwrap(), &genesis.config.shard_layout, ), - 2 + new_shard_id_tmp(2) ); } @@ -1704,7 +1704,7 @@ mod tests { let prefix = "node"; // Validators will track single shard but archival and RPC nodes will track all shards. - let empty_tracked_shards: Vec = vec![]; + let empty_tracked_shards = vec![]; let (configs, _validator_signers, _network_signers, genesis, _shard_keys) = create_localnet_configs( @@ -1746,7 +1746,10 @@ mod tests { config.split_storage.clone().unwrap().enable_split_storage_view_client, true ); - assert_eq!(config.tracked_shards, (0..num_shards).collect::>()); + assert_eq!( + config.tracked_shards, + (0..num_shards).map(new_shard_id_tmp).collect::>() + ); } // Check non-validator RPC nodes. @@ -1755,7 +1758,10 @@ mod tests { assert_eq!(config.archive, false); assert!(config.cold_store.is_none()); assert!(config.split_storage.is_none()); - assert_eq!(config.tracked_shards, (0..num_shards).collect::>()); + assert_eq!( + config.tracked_shards, + (0..num_shards).map(new_shard_id_tmp).collect::>() + ); } // Check other non-validator nodes. @@ -1781,7 +1787,7 @@ mod tests { let prefix = "node"; // Validators will track 2 shards and non-validators will track all shards. - let tracked_shards: Vec = vec![1, 3]; + let tracked_shards = vec![new_shard_id_tmp(1), new_shard_id_tmp(3)]; let (configs, _validator_signers, _network_signers, genesis, _shard_keys) = create_localnet_configs( @@ -1823,7 +1829,10 @@ mod tests { config.split_storage.clone().unwrap().enable_split_storage_view_client, true ); - assert_eq!(config.tracked_shards, (0..num_shards).collect::>()); + assert_eq!( + config.tracked_shards, + (0..num_shards).map(new_shard_id_tmp).collect::>() + ); } // Check non-validator RPC nodes. @@ -1832,7 +1841,10 @@ mod tests { assert_eq!(config.archive, false); assert!(config.cold_store.is_none()); assert!(config.split_storage.is_none()); - assert_eq!(config.tracked_shards, (0..num_shards).collect::>()); + assert_eq!( + config.tracked_shards, + (0..num_shards).map(new_shard_id_tmp).collect::>() + ); } // Check other non-validator nodes. diff --git a/nearcore/src/config_validate.rs b/nearcore/src/config_validate.rs index 3db950d9607..a554563680a 100644 --- a/nearcore/src/config_validate.rs +++ b/nearcore/src/config_validate.rs @@ -171,6 +171,8 @@ impl<'a> ConfigValidator<'a> { #[cfg(test)] mod tests { + use near_primitives::types::new_shard_id_tmp; + use super::*; #[test] @@ -179,7 +181,7 @@ mod tests { let mut config = Config::default(); config.gc.gc_blocks_limit = 0; // set tracked_shards to be non-empty - config.tracked_shards.push(20); + config.tracked_shards.push(new_shard_id_tmp(20)); validate_config(&config).unwrap(); } @@ -192,7 +194,7 @@ mod tests { config.archive = false; config.save_trie_changes = Some(false); // set tracked_shards to be non-empty - config.tracked_shards.push(20); + config.tracked_shards.push(new_shard_id_tmp(20)); validate_config(&config).unwrap(); } @@ -206,7 +208,7 @@ mod tests { config.save_trie_changes = Some(false); config.gc.gc_blocks_limit = 0; // set tracked_shards to be non-empty - config.tracked_shards.push(20); + config.tracked_shards.push(new_shard_id_tmp(20)); validate_config(&config).unwrap(); } diff --git a/nearcore/src/entity_debug.rs b/nearcore/src/entity_debug.rs index 6f297abced2..ec633b4f3b7 100644 --- a/nearcore/src/entity_debug.rs +++ b/nearcore/src/entity_debug.rs @@ -261,9 +261,11 @@ impl EntityDebugHandlerImpl { let shard_layout = self .epoch_manager .get_shard_layout_from_prev_block(&chunk.cloned_header().prev_block_hash())?; + let shard_id = chunk.shard_id(); + let shard_index = shard_layout.get_shard_index(shard_id); let shard_uid = shard_layout .shard_uids() - .nth(chunk.shard_id() as usize) + .nth(shard_index) .ok_or_else(|| anyhow!("Shard {} not found", chunk.shard_id()))?; let node = store .get_ser::( @@ -299,9 +301,11 @@ impl EntityDebugHandlerImpl { } EntityQuery::ShardUIdByShardId { shard_id, epoch_id } => { let shard_layout = self.epoch_manager.get_shard_layout(&epoch_id)?; + let shard_index = shard_layout.get_shard_index(shard_id); + let shard_uid = shard_layout .shard_uids() - .nth(shard_id as usize) + .nth(shard_index) .ok_or_else(|| anyhow!("Shard {} not found", shard_id))?; Ok(serialize_entity(&shard_uid)) } @@ -380,9 +384,11 @@ impl EntityDebugHandlerImpl { let shard_layout = self .epoch_manager .get_shard_layout_from_prev_block(&chunk.cloned_header().prev_block_hash())?; + let shard_id = chunk.shard_id(); + let shard_index = shard_layout.get_shard_index(shard_id); let shard_uid = shard_layout .shard_uids() - .nth(chunk.shard_id() as usize) + .nth(shard_index) .ok_or_else(|| anyhow!("Shard {} not found", chunk.shard_id()))?; let path = TriePath { path: vec![], shard_uid, state_root: chunk.prev_state_root() }; diff --git a/nearcore/src/metrics.rs b/nearcore/src/metrics.rs index f8cb0d85362..a8db65cc675 100644 --- a/nearcore/src/metrics.rs +++ b/nearcore/src/metrics.rs @@ -8,6 +8,7 @@ use near_o11y::metrics::{ try_create_int_gauge, try_create_int_gauge_vec, HistogramVec, IntCounterVec, IntGauge, IntGaugeVec, }; +use near_primitives::types::ShardId; use near_primitives::{shard_layout::ShardLayout, state_record::StateRecord, trie_key}; use near_store::adapter::StoreAdapter; use near_store::{ShardUId, Store, Trie, TrieDBStorage}; @@ -148,7 +149,7 @@ fn export_postponed_receipt_count(near_config: &NearConfig, store: &Store) -> an } fn get_postponed_receipt_count_for_shard( - shard_id: u64, + shard_id: ShardId, shard_layout: &ShardLayout, chain_store: &ChainStore, block: &Block, diff --git a/nearcore/src/state_sync.rs b/nearcore/src/state_sync.rs index 73c3ed5feff..75ee9e2b11e 100644 --- a/nearcore/src/state_sync.rs +++ b/nearcore/src/state_sync.rs @@ -261,17 +261,17 @@ fn get_current_state( epoch_height: new_epoch_height, sync_hash: new_sync_hash, } = latest_epoch_info.map_err(|err| { - tracing::error!(target: "state_sync_dump", shard_id, ?err, "Failed to get the latest epoch"); + tracing::error!(target: "state_sync_dump", ?shard_id, ?err, "Failed to get the latest epoch"); err })?; if Some(&new_epoch_id) == was_last_epoch_done.as_ref() { - tracing::debug!(target: "state_sync_dump", shard_id, ?was_last_epoch_done, ?new_epoch_id, new_epoch_height, ?new_sync_hash, "latest epoch is done. No new epoch to dump. Idle"); + tracing::debug!(target: "state_sync_dump", ?shard_id, ?was_last_epoch_done, ?new_epoch_id, new_epoch_height, ?new_sync_hash, "latest epoch is done. No new epoch to dump. Idle"); Ok(StateDumpAction::Wait) } else if epoch_manager.get_shard_layout(&prev_epoch_id) != epoch_manager.get_shard_layout(&new_epoch_id) { - tracing::debug!(target: "state_sync_dump", shard_id, ?was_last_epoch_done, ?new_epoch_id, new_epoch_height, ?new_sync_hash, "Shard layout change detected, will skip dumping for this epoch. Idle"); + tracing::debug!(target: "state_sync_dump", ?shard_id, ?was_last_epoch_done, ?new_epoch_id, new_epoch_height, ?new_sync_hash, "Shard layout change detected, will skip dumping for this epoch. Idle"); chain.chain_store().set_state_sync_dump_progress( *shard_id, Some(StateSyncDumpProgress::Skipped { @@ -287,7 +287,7 @@ fn get_current_state( sync_hash: new_sync_hash, }) } else { - tracing::debug!(target: "state_sync_dump", shard_id, ?new_epoch_id, new_epoch_height, ?new_sync_hash, "Doesn't care about the shard in the current epoch. Idle"); + tracing::debug!(target: "state_sync_dump", ?shard_id, ?new_epoch_id, new_epoch_height, ?new_sync_hash, "Doesn't care about the shard in the current epoch. Idle"); Ok(StateDumpAction::Wait) } } @@ -313,11 +313,11 @@ async fn upload_state_header( external_storage_location(&chain_id, &epoch_id, epoch_height, shard_id, &file_type); match external.put_file(file_type, &header, shard_id, &location).await { Err(err) => { - tracing::warn!(target: "state_sync_dump", shard_id, epoch_height, ?err, "Failed to put header into external storage. Will retry next iteration."); + tracing::warn!(target: "state_sync_dump", ?shard_id, epoch_height, ?err, "Failed to put header into external storage. Will retry next iteration."); false } Ok(_) => { - tracing::trace!(target: "state_sync_dump", shard_id, epoch_height, "Header saved to external storage."); + tracing::trace!(target: "state_sync_dump", ?shard_id, epoch_height, "Header saved to external storage."); true } } @@ -341,17 +341,17 @@ async fn state_sync_dump( validator: MutableValidatorSigner, keep_running: Arc, ) { - tracing::info!(target: "state_sync_dump", shard_id, "Running StateSyncDump loop"); + tracing::info!(target: "state_sync_dump", ?shard_id, "Running StateSyncDump loop"); if restart_dump_for_shards.contains(&shard_id) { - tracing::debug!(target: "state_sync_dump", shard_id, "Dropped existing progress"); + tracing::debug!(target: "state_sync_dump", ?shard_id, "Dropped existing progress"); chain.chain_store().set_state_sync_dump_progress(shard_id, None).unwrap(); } // Stop if the node is stopped. // Note that without this check the state dumping thread is unstoppable, i.e. non-interruptable. while keep_running.load(std::sync::atomic::Ordering::Relaxed) { - tracing::debug!(target: "state_sync_dump", shard_id, "Running StateSyncDump loop iteration"); + tracing::debug!(target: "state_sync_dump", ?shard_id, "Running StateSyncDump loop iteration"); let account_id = validator.get().map(|v| v.validator_id().clone()); let current_state = get_current_state( &chain, @@ -370,7 +370,7 @@ async fn state_sync_dump( let in_progress_data = get_in_progress_data(shard_id, sync_hash, &chain); match in_progress_data { Err(err) => { - tracing::error!(target: "state_sync_dump", ?err, ? shard_id, "Failed to get in progress data"); + tracing::error!(target: "state_sync_dump", ?err, ?shard_id, "Failed to get in progress data"); None } Ok((state_root, num_parts, sync_prev_prev_hash)) => { @@ -472,7 +472,7 @@ async fn state_sync_dump( let state_part = match state_part { Ok(state_part) => state_part, Err(err) => { - tracing::warn!(target: "state_sync_dump", shard_id, epoch_height, part_id, ?err, "Failed to obtain and store part. Will skip this part."); + tracing::warn!(target: "state_sync_dump", ?shard_id, epoch_height, part_id, ?err, "Failed to obtain and store part. Will skip this part."); failures_cnt += 1; continue; } @@ -492,7 +492,7 @@ async fn state_sync_dump( { // no need to break if there's an error, we should keep dumping other parts. // reason is we are dumping random selected parts, so it's fine if we are not able to finish all of them - tracing::warn!(target: "state_sync_dump", shard_id, epoch_height, part_id, ?err, "Failed to put a store part into external storage. Will skip this part."); + tracing::warn!(target: "state_sync_dump", ?shard_id, epoch_height, part_id, ?err, "Failed to put a store part into external storage. Will skip this part."); failures_cnt += 1; continue; } @@ -540,19 +540,19 @@ async fn state_sync_dump( // Record the next state of the state machine. let has_progress = match next_state { Some(next_state) => { - tracing::debug!(target: "state_sync_dump", shard_id, ?next_state); + tracing::debug!(target: "state_sync_dump", ?shard_id, ?next_state); match chain.chain_store().set_state_sync_dump_progress(shard_id, Some(next_state)) { Ok(_) => true, Err(err) => { // This will be retried. - tracing::debug!(target: "state_sync_dump", shard_id, ?err, "Failed to set progress"); + tracing::debug!(target: "state_sync_dump", ?shard_id, ?err, "Failed to set progress"); false } } } None => { // Nothing to do, will check again later. - tracing::debug!(target: "state_sync_dump", shard_id, "Idle"); + tracing::debug!(target: "state_sync_dump", ?shard_id, "Idle"); false } }; @@ -562,7 +562,7 @@ async fn state_sync_dump( clock.sleep(iteration_delay).await; } } - tracing::debug!(target: "state_sync_dump", shard_id, "Stopped state dump thread"); + tracing::debug!(target: "state_sync_dump", ?shard_id, "Stopped state dump thread"); } // Extracts extra data needed for obtaining state parts. @@ -658,7 +658,7 @@ fn get_latest_epoch( epoch_manager: Arc, ) -> Result { let head = chain.head()?; - tracing::debug!(target: "state_sync_dump", shard_id, "Check if a new complete epoch is available"); + tracing::debug!(target: "state_sync_dump", ?shard_id, "Check if a new complete epoch is available"); let hash = head.last_block_hash; let header = chain.get_block_header(&hash)?; let final_hash = header.last_final_block(); diff --git a/runtime/runtime/src/balance_checker.rs b/runtime/runtime/src/balance_checker.rs index 9bdbdc67f99..2baeabfadff 100644 --- a/runtime/runtime/src/balance_checker.rs +++ b/runtime/runtime/src/balance_checker.rs @@ -12,7 +12,7 @@ use near_primitives::hash::CryptoHash; use near_primitives::receipt::{Receipt, ReceiptEnum, ReceiptOrStateStoredReceipt}; use near_primitives::transaction::SignedTransaction; use near_primitives::trie_key::TrieKey; -use near_primitives::types::{AccountId, Balance}; +use near_primitives::types::{AccountId, Balance, ShardId}; use near_store::trie::receipts_column_helper::{ShardsOutgoingReceiptBuffer, TrieQueue}; use near_store::{ get, get_account, get_postponed_receipt, get_promise_yield_receipt, Trie, TrieAccess, @@ -141,7 +141,7 @@ fn buffered_receipts( let mut forwarded_receipts: Vec = vec![]; let mut new_buffered_receipts: Vec = vec![]; - let mut shards: BTreeSet = BTreeSet::new(); + let mut shards: BTreeSet = BTreeSet::new(); shards.extend(initial_buffers.shards().iter()); shards.extend(final_buffers.shards().iter()); for shard_id in shards { @@ -400,7 +400,7 @@ mod tests { }; use near_primitives::test_utils::account_new; use near_primitives::transaction::{Action, TransferAction}; - use near_primitives::types::{MerkleHash, StateChangeCause}; + use near_primitives::types::{new_shard_id_tmp, MerkleHash, StateChangeCause}; use near_store::test_utils::TestTriesBuilder; use near_store::{set, set_account, Trie}; use testlib::runtime_utils::{alice_account, bob_account}; @@ -706,14 +706,15 @@ mod tests { // create buffer with already a receipt in it, but a different balance let mut indices = BufferedReceiptIndices::default(); - indices - .shard_buffers - .insert(0, TrieQueueIndices { first_index: 0, next_available_index: 1 }); + indices.shard_buffers.insert( + new_shard_id_tmp(0), + TrieQueueIndices { first_index: 0, next_available_index: 1 }, + ); set(trie_update, TrieKey::BufferedReceiptIndices, &indices); set( trie_update, - TrieKey::BufferedReceipt { receiving_shard: 0, index: 0 }, + TrieKey::BufferedReceipt { receiving_shard: new_shard_id_tmp(0), index: 0 }, &existing_receipt, ); }, @@ -727,14 +728,15 @@ mod tests { // store receipt with the balance in the receipt buffer let mut indices = BufferedReceiptIndices::default(); - indices - .shard_buffers - .insert(0, TrieQueueIndices { first_index: 0, next_available_index: 2 }); + indices.shard_buffers.insert( + new_shard_id_tmp(0), + TrieQueueIndices { first_index: 0, next_available_index: 2 }, + ); set(trie_update, TrieKey::BufferedReceiptIndices, &indices); set( trie_update, - TrieKey::BufferedReceipt { receiving_shard: 0, index: 1 }, + TrieKey::BufferedReceipt { receiving_shard: new_shard_id_tmp(0), index: 1 }, &new_receipt, ); }, @@ -776,31 +778,36 @@ mod tests { |trie_update| { // store 2 receipts with balance in the receipt buffer let mut indices = BufferedReceiptIndices::default(); - indices - .shard_buffers - .insert(0, TrieQueueIndices { first_index: 0, next_available_index: 2 }); + indices.shard_buffers.insert( + new_shard_id_tmp(0), + TrieQueueIndices { first_index: 0, next_available_index: 2 }, + ); set(trie_update, TrieKey::BufferedReceiptIndices, &indices); set( trie_update, - TrieKey::BufferedReceipt { receiving_shard: 0, index: 0 }, + TrieKey::BufferedReceipt { receiving_shard: new_shard_id_tmp(0), index: 0 }, &receipt0, ); set( trie_update, - TrieKey::BufferedReceipt { receiving_shard: 0, index: 1 }, + TrieKey::BufferedReceipt { receiving_shard: new_shard_id_tmp(0), index: 1 }, &receipt1, ); }, |trie_update| { // remove 1 receipt at index 0 let mut indices = BufferedReceiptIndices::default(); - indices - .shard_buffers - .insert(0, TrieQueueIndices { first_index: 1, next_available_index: 2 }); + indices.shard_buffers.insert( + new_shard_id_tmp(0), + TrieQueueIndices { first_index: 1, next_available_index: 2 }, + ); set(trie_update, TrieKey::BufferedReceiptIndices, &indices); - trie_update.remove(TrieKey::BufferedReceipt { receiving_shard: 0, index: 0 }); + trie_update.remove(TrieKey::BufferedReceipt { + receiving_shard: new_shard_id_tmp(0), + index: 0, + }); }, ); @@ -834,31 +841,36 @@ mod tests { |trie_update| { // store receipt0 with balance in the receipt buffer let mut indices = BufferedReceiptIndices::default(); - indices - .shard_buffers - .insert(0, TrieQueueIndices { first_index: 0, next_available_index: 1 }); + indices.shard_buffers.insert( + new_shard_id_tmp(0), + TrieQueueIndices { first_index: 0, next_available_index: 1 }, + ); set(trie_update, TrieKey::BufferedReceiptIndices, &indices); set( trie_update, - TrieKey::BufferedReceipt { receiving_shard: 0, index: 0 }, + TrieKey::BufferedReceipt { receiving_shard: new_shard_id_tmp(0), index: 0 }, &receipt0, ); }, |trie_update| { // pop receipt0 and push receipt1 with a different balance let mut indices = BufferedReceiptIndices::default(); - indices - .shard_buffers - .insert(0, TrieQueueIndices { first_index: 1, next_available_index: 2 }); + indices.shard_buffers.insert( + new_shard_id_tmp(0), + TrieQueueIndices { first_index: 1, next_available_index: 2 }, + ); set( trie_update, - TrieKey::BufferedReceipt { receiving_shard: 0, index: 1 }, + TrieKey::BufferedReceipt { receiving_shard: new_shard_id_tmp(0), index: 1 }, &receipt1, ); set(trie_update, TrieKey::BufferedReceiptIndices, &indices); - trie_update.remove(TrieKey::BufferedReceipt { receiving_shard: 0, index: 0 }); + trie_update.remove(TrieKey::BufferedReceipt { + receiving_shard: new_shard_id_tmp(0), + index: 0, + }); }, ); diff --git a/runtime/runtime/src/congestion_control.rs b/runtime/runtime/src/congestion_control.rs index dbbd1228a20..6f0b356e72b 100644 --- a/runtime/runtime/src/congestion_control.rs +++ b/runtime/runtime/src/congestion_control.rs @@ -9,7 +9,7 @@ use near_primitives::receipt::{ Receipt, ReceiptEnum, ReceiptOrStateStoredReceipt, StateStoredReceipt, StateStoredReceiptMetadata, }; -use near_primitives::types::{EpochInfoProvider, Gas, ShardId}; +use near_primitives::types::{new_shard_id_tmp, EpochInfoProvider, Gas, ShardId}; use near_primitives::version::ProtocolFeature; use near_store::trie::receipts_column_helper::{ DelayedReceiptQueue, ReceiptIterator, ShardsOutgoingReceiptBuffer, TrieQueue, @@ -193,7 +193,7 @@ impl ReceiptSinkV2<'_> { fn forward_from_buffer_to_shard( &mut self, - shard_id: u64, + shard_id: ShardId, state_update: &mut TrieUpdate, apply_state: &ApplyState, ) -> Result<(), RuntimeError> { @@ -317,7 +317,7 @@ impl ReceiptSinkV2<'_> { size: u64, gas: u64, state_update: &mut TrieUpdate, - shard: u64, + shard: ShardId, use_state_stored_receipt: bool, ) -> Result<(), RuntimeError> { let receipt = match use_state_stored_receipt { @@ -460,7 +460,7 @@ pub fn bootstrap_congestion_info( // It is also irrelevant, since the bootstrapped value is only used at // the start of applying a chunk on this shard. Other shards will only // see and act on the first congestion info after that. - allowed_shard: shard_id as u16, + allowed_shard: new_shard_id_tmp(shard_id) as u16, })) } diff --git a/runtime/runtime/src/lib.rs b/runtime/runtime/src/lib.rs index 5b9dde14a90..1dba5601085 100644 --- a/runtime/runtime/src/lib.rs +++ b/runtime/runtime/src/lib.rs @@ -39,6 +39,7 @@ use near_primitives::transaction::{ SignedTransaction, TransferAction, }; use near_primitives::trie_key::TrieKey; +use near_primitives::types::new_shard_id_tmp; use near_primitives::types::{ validator_stake::ValidatorStake, AccountId, Balance, BlockHeight, Compute, EpochHeight, EpochId, EpochInfoProvider, Gas, RawStateChangesWithTrieKey, ShardId, StateChangeCause, @@ -1351,7 +1352,7 @@ impl Runtime { { // Note that receipts are restored only on mainnet so restored_receipts will be empty on // other chains. - migration_data.restored_receipts.get(&0u64).cloned().unwrap_or_default() + migration_data.restored_receipts.get(&new_shard_id_tmp(0)).cloned().unwrap_or_default() } else { vec![] }; @@ -2020,7 +2021,10 @@ impl Runtime { delayed_receipts.apply_congestion_changes(congestion_info)?; let all_shards = apply_state.congestion_info.all_shards(); - let congestion_seed = apply_state.block_height.wrapping_add(apply_state.shard_id); + // TODO(wacban) Using non-contiguous shard id here breaks some + // assumptions. The shard index should be used here instead. + let congestion_seed = + apply_state.block_height.wrapping_add(apply_state.shard_id.into()); congestion_info.finalize_allowed_shard( apply_state.shard_id, all_shards.as_slice(), diff --git a/runtime/runtime/src/metrics.rs b/runtime/runtime/src/metrics.rs index 182e8304bc1..7afd7a1f489 100644 --- a/runtime/runtime/src/metrics.rs +++ b/runtime/runtime/src/metrics.rs @@ -800,7 +800,7 @@ pub fn report_recorded_column_sizes(trie: &Trie, apply_state: &ApplyState) { // Tracing span to measure time spent on reporting column sizes. let _span = tracing::debug_span!( target: "runtime", "report_recorded_column_sizes", - shard_id = apply_state.shard_id, + shard_id = ?apply_state.shard_id, block_height = apply_state.block_height) .entered(); diff --git a/tools/state-viewer/src/epoch_info.rs b/tools/state-viewer/src/epoch_info.rs index ac75f749070..cd29783e481 100644 --- a/tools/state-viewer/src/epoch_info.rs +++ b/tools/state-viewer/src/epoch_info.rs @@ -90,13 +90,16 @@ fn display_block_and_chunk_producers( let block_height_range: Range = get_block_height_range(epoch_id, chain_store, epoch_manager)?; let shard_ids = epoch_manager.shard_ids(epoch_id).unwrap(); + let shard_layout = epoch_manager.get_shard_layout(epoch_id).unwrap(); for block_height in block_height_range { let bp = epoch_info.sample_block_producer(block_height); let bp = epoch_info.get_validator(bp).account_id().clone(); let cps: Vec = shard_ids .iter() .map(|&shard_id| { - let cp = epoch_info.sample_chunk_producer(block_height, shard_id).unwrap(); + let cp = epoch_info + .sample_chunk_producer(&shard_layout, shard_id, block_height) + .unwrap(); let cp = epoch_info.get_validator(cp).account_id().clone(); cp.as_str().to_string() }) @@ -274,13 +277,14 @@ fn display_validator_info( println!("Block producer for {} blocks: {bp_for_blocks:?}", bp_for_blocks.len()); let shard_ids = epoch_manager.shard_ids(epoch_id).unwrap(); + let shard_layout = epoch_manager.get_shard_layout(epoch_id).unwrap(); let cp_for_chunks: Vec<(BlockHeight, ShardId)> = block_height_range .flat_map(|block_height| { shard_ids .iter() .map(|&shard_id| (block_height, shard_id)) .filter(|&(block_height, shard_id)| { - epoch_info.sample_chunk_producer(block_height, shard_id) + epoch_info.sample_chunk_producer(&shard_layout, shard_id, block_height) == Some(*validator_id) }) .collect::>() diff --git a/tools/state-viewer/src/replay_headers.rs b/tools/state-viewer/src/replay_headers.rs index 58c713b9dfa..c830b54aa4b 100644 --- a/tools/state-viewer/src/replay_headers.rs +++ b/tools/state-viewer/src/replay_headers.rs @@ -228,6 +228,8 @@ fn get_block_info( { let block = chain_store.get_block(header.hash())?; let chunks = block.chunks(); + let epoch_id = block.header().epoch_id(); + let shard_layout = epoch_manager.get_shard_layout(epoch_id)?; let endorsement_signatures = block.chunk_endorsements().to_vec(); assert_eq!(endorsement_signatures.len(), chunks.len()); @@ -237,12 +239,12 @@ fn get_block_info( let height = header.height(); let prev_block_epoch_id = epoch_manager.get_epoch_id_from_prev_block(header.prev_hash())?; - for chunk_header in chunks.iter() { - let shard_id = chunk_header.shard_id(); + for (shard_index, chunk_header) in chunks.iter().enumerate() { + let shard_id = shard_layout.get_shard_id(shard_index); let endorsements = &endorsement_signatures[shard_id as usize]; if !chunk_header.is_new_chunk(height) { assert_eq!(endorsements.len(), 0); - bitmap.add_endorsements(shard_id, vec![]); + bitmap.add_endorsements(shard_index, vec![]); } else { let assignments = epoch_manager .get_chunk_validator_assignments( @@ -253,7 +255,7 @@ fn get_block_info( .ordered_chunk_validators(); assert_eq!(endorsements.len(), assignments.len()); bitmap.add_endorsements( - shard_id, + shard_index, endorsements.iter().map(|signature| signature.is_some()).collect_vec(), ); } From b484f1fa676267fa9e7c45e2f228d61dc9605532 Mon Sep 17 00:00:00 2001 From: Longarithm Date: Fri, 11 Oct 2024 01:56:35 +0400 Subject: [PATCH 3/3] mini fix --- chain/chain/src/stateless_validation/chunk_validation.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/chain/chain/src/stateless_validation/chunk_validation.rs b/chain/chain/src/stateless_validation/chunk_validation.rs index e87c3cf873d..a387a613baa 100644 --- a/chain/chain/src/stateless_validation/chunk_validation.rs +++ b/chain/chain/src/stateless_validation/chunk_validation.rs @@ -150,7 +150,7 @@ pub fn pre_validate_chunk_state_witness( epoch_manager.get_prev_shard_id(&block_hash, current_shard_id)?; let chunks = block.chunks(); - let Some(chunk) = chunks.get(current_shard_index as usize) else { + let Some(chunk) = chunks.get(current_shard_index) else { return Err(Error::InvalidChunkStateWitness(format!( "Shard {} does not exist in block {:?}", current_shard_id, block_hash