Skip to content

Commit

Permalink
feat: State sync from local filesystem (#8913)
Browse files Browse the repository at this point in the history
Also adds a unit test for state dump and an integration test for state dump and state sync from that dump
  • Loading branch information
nikurt committed Apr 25, 2023
1 parent 8ef2542 commit 6f273b1
Showing 1 changed file with 62 additions and 111 deletions.
173 changes: 62 additions & 111 deletions tools/state-viewer/src/state_parts.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
use crate::epoch_info::iterate_and_filter;
use borsh::BorshDeserialize;
use near_chain::{Chain, ChainGenesis, ChainStoreAccess, DoomslugThresholdMode};
use near_client::sync::state::StateSync;
use near_primitives::challenge::PartialState;
use near_client::sync::state::{
get_num_parts_from_filename, is_part_filename, location_prefix, part_filename, StateSync,
};
use near_primitives::epoch_manager::epoch_info::EpochInfo;
use near_primitives::state_part::PartId;
use near_primitives::state_record::StateRecord;
Expand All @@ -19,26 +20,13 @@ use std::path::{Path, PathBuf};
use std::str::FromStr;
use std::time::Instant;

#[derive(clap::ArgEnum, Debug, Clone)]
pub(crate) enum ApplyAction {
Apply,
Validate,
Print,
}

impl Default for ApplyAction {
fn default() -> Self {
ApplyAction::Apply
}
}

#[derive(clap::Subcommand, Debug, Clone)]
pub(crate) enum StatePartsSubCommand {
/// Apply all or a single state part of a shard.
Apply {
/// Apply, validate or print.
#[clap(arg_enum, long)]
action: ApplyAction,
/// If true, validate the state part but don't write it to the DB.
#[clap(long)]
dry_run: bool,
/// If provided, this value will be used instead of looking it up in the headers.
/// Use if those headers or blocks are not available.
#[clap(long)]
Expand Down Expand Up @@ -93,12 +81,12 @@ impl StatePartsSubCommand {
.unwrap();
let chain_id = &near_config.genesis.config.chain_id;
match self {
StatePartsSubCommand::Apply { action, state_root, part_id, epoch_selection } => {
StatePartsSubCommand::Apply { dry_run, state_root, part_id, epoch_selection } => {
apply_state_parts(
action,
epoch_selection,
shard_id,
part_id,
dry_run,
state_root,
&mut chain,
chain_id,
Expand Down Expand Up @@ -146,7 +134,7 @@ impl EpochSelection {
chain.runtime_adapter.get_epoch_id(&chain.head().unwrap().last_block_hash).unwrap()
}
EpochSelection::EpochId { epoch_id } => {
EpochId(CryptoHash::from_str(&epoch_id).unwrap())
EpochId(CryptoHash::from_str(epoch_id).unwrap())
}
EpochSelection::EpochHeight { epoch_height } => {
// Fetch epochs at the given height.
Expand All @@ -159,7 +147,7 @@ impl EpochSelection {
epoch_ids[0].clone()
}
EpochSelection::BlockHash { block_hash } => {
let block_hash = CryptoHash::from_str(&block_hash).unwrap();
let block_hash = CryptoHash::from_str(block_hash).unwrap();
chain.runtime_adapter.get_epoch_id(&block_hash).unwrap()
}
EpochSelection::BlockHeight { block_height } => {
Expand Down Expand Up @@ -233,10 +221,10 @@ fn get_any_block_hash_of_epoch(epoch_info: &EpochInfo, chain: &Chain) -> CryptoH
}

fn apply_state_parts(
action: ApplyAction,
epoch_selection: EpochSelection,
shard_id: ShardId,
part_id: Option<u64>,
dry_run: bool,
maybe_state_root: Option<StateRoot>,
chain: &mut Chain,
chain_id: &str,
Expand All @@ -249,19 +237,19 @@ fn apply_state_parts(
{
(state_root, *epoch_height, None, None)
} else {
let epoch_id = epoch_selection.to_epoch_id(store, &chain);
let epoch_id = epoch_selection.to_epoch_id(store, chain);
let epoch = chain.runtime_adapter.get_epoch_info(&epoch_id).unwrap();

let sync_hash = get_any_block_hash_of_epoch(&epoch, &chain);
let sync_hash = StateSync::get_epoch_start_sync_hash(&chain, &sync_hash).unwrap();
let sync_hash = get_any_block_hash_of_epoch(&epoch, chain);
let sync_hash = StateSync::get_epoch_start_sync_hash(chain, &sync_hash).unwrap();

let state_header = chain.get_state_response_header(shard_id, sync_hash).unwrap();
let state_root = state_header.chunk_prev_state_root();

(state_root, epoch.epoch_height(), Some(epoch_id), Some(sync_hash))
};

let part_storage = get_state_part_reader(location, &chain_id, epoch_height, shard_id);
let part_storage = get_state_part_reader(location, chain_id, epoch_height, shard_id);

let num_parts = part_storage.num_parts();
assert_ne!(num_parts, 0, "Too few num_parts: {}", num_parts);
Expand All @@ -282,50 +270,38 @@ fn apply_state_parts(
assert!(part_id < num_parts, "part_id: {}, num_parts: {}", part_id, num_parts);
let part = part_storage.read(part_id, num_parts);

match action {
ApplyAction::Apply => {
chain
.set_state_part(
shard_id,
sync_hash.unwrap(),
PartId::new(part_id, num_parts),
&part,
)
.unwrap();
chain
.runtime_adapter
.apply_state_part(
shard_id,
&state_root,
PartId::new(part_id, num_parts),
&part,
epoch_id.as_ref().unwrap(),
)
.unwrap();
tracing::info!(target: "state-parts", part_id, part_length = part.len(), elapsed_sec = timer.elapsed().as_secs_f64(), "Applied a state part");
}
ApplyAction::Validate => {
assert!(chain.runtime_adapter.validate_state_part(
if dry_run {
assert!(chain.runtime_adapter.validate_state_part(
&state_root,
PartId::new(part_id, num_parts),
&part
));
tracing::info!(target: "state-parts", part_id, part_length = part.len(), elapsed_sec = timer.elapsed().as_secs_f64(), "Validated a state part");
} else {
chain
.set_state_part(
shard_id,
sync_hash.unwrap(),
PartId::new(part_id, num_parts),
&part,
)
.unwrap();
chain
.runtime_adapter
.apply_state_part(
shard_id,
&state_root,
PartId::new(part_id, num_parts),
&part
));
tracing::info!(target: "state-parts", part_id, part_length = part.len(), elapsed_sec = timer.elapsed().as_secs_f64(), "Validated a state part");
}
ApplyAction::Print => {
print_state_part(&state_root, PartId::new(part_id, num_parts), &part)
}
&part,
epoch_id.as_ref().unwrap(),
)
.unwrap();
tracing::info!(target: "state-parts", part_id, part_length = part.len(), elapsed_sec = timer.elapsed().as_secs_f64(), "Applied a state part");
}
}
tracing::info!(target: "state-parts", total_elapsed_sec = timer.elapsed().as_secs_f64(), "Applied all requested state parts");
}

fn print_state_part(state_root: &StateRoot, _part_id: PartId, data: &[u8]) {
let trie_nodes: PartialState = BorshDeserialize::try_from_slice(data).unwrap();
let trie = Trie::from_recorded_storage(PartialStorage { nodes: trie_nodes }, *state_root);
trie.print_recursive(&mut std::io::stdout().lock(), &state_root, u32::MAX);
}

fn dump_state_parts(
epoch_selection: EpochSelection,
shard_id: ShardId,
Expand All @@ -336,10 +312,12 @@ fn dump_state_parts(
store: Store,
location: Location,
) {
let epoch_id = epoch_selection.to_epoch_id(store, &chain);
let epoch_id = epoch_selection.to_epoch_id(store, chain);
let epoch = chain.runtime_adapter.get_epoch_info(&epoch_id).unwrap();
let sync_hash = get_any_block_hash_of_epoch(&epoch, &chain);
let sync_hash = StateSync::get_epoch_start_sync_hash(&chain, &sync_hash).unwrap();
let sync_hash = get_any_block_hash_of_epoch(&epoch, chain);
let sync_hash = StateSync::get_epoch_start_sync_hash(chain, &sync_hash).unwrap();
let sync_block = chain.get_block_header(&sync_hash).unwrap();
let sync_prev_hash = sync_block.prev_hash();

let state_header = chain.compute_state_response_header(shard_id, sync_hash).unwrap();
let state_root = state_header.chunk_prev_state_root();
Expand All @@ -366,7 +344,12 @@ fn dump_state_parts(
assert!(part_id < num_parts, "part_id: {}, num_parts: {}", part_id, num_parts);
let state_part = chain
.runtime_adapter
.obtain_state_part(shard_id, &sync_hash, &state_root, PartId::new(part_id, num_parts))
.obtain_state_part(
shard_id,
&sync_prev_hash,
&state_root,
PartId::new(part_id, num_parts),
)
.unwrap();
part_storage.write(&state_part, part_id, num_parts);
let elapsed_sec = timer.elapsed().as_secs_f64();
Expand All @@ -376,7 +359,7 @@ fn dump_state_parts(
part_id,
part_length = state_part.len(),
elapsed_sec,
first_state_record = ?first_state_record.map(|sr| format!("{}", sr)),
?first_state_record,
"Wrote a state part");
}
tracing::info!(target: "state-parts", total_elapsed_sec = timer.elapsed().as_secs_f64(), "Wrote all requested state parts");
Expand All @@ -387,11 +370,9 @@ fn get_first_state_record(state_root: &StateRoot, data: &[u8]) -> Option<StateRe
let trie_nodes = BorshDeserialize::try_from_slice(data).unwrap();
let trie = Trie::from_recorded_storage(PartialStorage { nodes: trie_nodes }, *state_root);

for item in trie.iter().unwrap() {
if let Ok((key, value)) = item {
if let Some(sr) = StateRecord::from_raw_key_value(key, value) {
return Some(sr);
}
for (key, value) in trie.iter().unwrap().flatten() {
if let Some(sr) = StateRecord::from_raw_key_value(key, value) {
return Some(sr);
}
}
None
Expand All @@ -404,11 +385,11 @@ fn read_state_header(
chain: &Chain,
store: Store,
) {
let epoch_id = epoch_selection.to_epoch_id(store, &chain);
let epoch_id = epoch_selection.to_epoch_id(store, chain);
let epoch = chain.runtime_adapter.get_epoch_info(&epoch_id).unwrap();

let sync_hash = get_any_block_hash_of_epoch(&epoch, &chain);
let sync_hash = StateSync::get_epoch_start_sync_hash(&chain, &sync_hash).unwrap();
let sync_hash = get_any_block_hash_of_epoch(&epoch, chain);
let sync_hash = StateSync::get_epoch_start_sync_hash(chain, &sync_hash).unwrap();

let state_header = chain.store().get_state_header(shard_id, sync_hash);
tracing::info!(target: "state-parts", ?epoch_id, ?sync_hash, ?state_header);
Expand All @@ -418,35 +399,6 @@ fn get_part_ids(part_from: Option<u64>, part_to: Option<u64>, num_parts: u64) ->
part_from.unwrap_or(0)..part_to.unwrap_or(num_parts)
}

// Needs to be in sync with `fn s3_location()`.
fn location_prefix(chain_id: &str, epoch_height: u64, shard_id: u64) -> String {
format!("chain_id={}/epoch_height={}/shard_id={}", chain_id, epoch_height, shard_id)
}

fn match_filename(s: &str) -> Option<regex::Captures> {
let re = regex::Regex::new(r"^state_part_(\d{6})_of_(\d{6})$").unwrap();
re.captures(s)
}

fn is_part_filename(s: &str) -> bool {
match_filename(s).is_some()
}

fn get_num_parts_from_filename(s: &str) -> Option<u64> {
if let Some(captures) = match_filename(s) {
if let Some(num_parts) = captures.get(2) {
if let Ok(num_parts) = num_parts.as_str().parse::<u64>() {
return Some(num_parts);
}
}
}
None
}

fn part_filename(part_id: u64, num_parts: u64) -> String {
format!("state_part_{:06}_of_{:06}", part_id, num_parts)
}

trait StatePartWriter {
fn write(&self, state_part: &[u8], part_id: u64, num_parts: u64);
}
Expand Down Expand Up @@ -511,7 +463,7 @@ impl FileSystemStorage {
}

fn get_location(&self, part_id: u64, num_parts: u64) -> PathBuf {
(&self.state_parts_dir).join(part_filename(part_id, num_parts))
self.state_parts_dir.join(part_filename(part_id, num_parts))
}
}

Expand All @@ -527,8 +479,7 @@ impl StatePartReader for FileSystemStorage {
fn read(&self, part_id: u64, num_parts: u64) -> Vec<u8> {
let filename = self.get_location(part_id, num_parts);
tracing::debug!(target: "state-parts", part_id, num_parts, ?filename, "Reading state part file");
let part = std::fs::read(filename).unwrap();
part
std::fs::read(filename).unwrap()
}

fn num_parts(&self) -> u64 {
Expand Down Expand Up @@ -576,7 +527,7 @@ impl S3Storage {
) -> Self {
let location = location_prefix(chain_id, epoch_height, shard_id);
let bucket = s3::Bucket::new(
&s3_bucket,
s3_bucket,
s3_region.parse::<s3::Region>().unwrap(),
s3::creds::Credentials::default().unwrap(),
)
Expand All @@ -594,7 +545,7 @@ impl S3Storage {
impl StatePartWriter for S3Storage {
fn write(&self, state_part: &[u8], part_id: u64, num_parts: u64) {
let location = self.get_location(part_id, num_parts);
self.bucket.put_object_blocking(&location, &state_part).unwrap();
self.bucket.put_object_blocking(&location, state_part).unwrap();
tracing::info!(target: "state-parts", part_id, part_length = state_part.len(), ?location, "Wrote a state part to S3");
}
}
Expand Down

0 comments on commit 6f273b1

Please sign in to comment.