Skip to content

Commit

Permalink
feat: check partial witness metadata (#11441)
Browse files Browse the repository at this point in the history
This PR implements check for partial state witness metadata fields
(`epoch_id`, `shard_id`, `height_created`) after decoding complete state
witness. This is needed to protect against malicious chunk producer
providing incorrect metadata in the partial witness.

Large part of this PR is about moving state witness decoding
(decompression + borsh-deserialization) from client to partial witness
actor. This is required to implement the check, but also a nice change
on its own since wintess decompression can take several dozens of
milliseconds, so it is better to avoid blocking client actor.

Closes #11303.
  • Loading branch information
pugachAG authored Jun 3, 2024
1 parent f902679 commit 96af8f7
Show file tree
Hide file tree
Showing 7 changed files with 76 additions and 66 deletions.
7 changes: 5 additions & 2 deletions chain/chain/src/chain.rs
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ use near_primitives::state_sync::{
get_num_state_parts, BitArray, CachedParts, ReceiptProofResponse, RootProof,
ShardStateSyncResponseHeader, ShardStateSyncResponseHeaderV2, StateHeaderKey, StatePartKey,
};
use near_primitives::stateless_validation::EncodedChunkStateWitness;
use near_primitives::stateless_validation::{ChunkStateWitness, ChunkStateWitnessSize};
use near_primitives::transaction::{ExecutionOutcomeWithIdAndProof, SignedTransaction};
use near_primitives::types::chunk_extra::ChunkExtra;
use near_primitives::types::{
Expand Down Expand Up @@ -4696,7 +4696,10 @@ pub struct BlockCatchUpResponse {

#[derive(actix::Message, Debug, Clone, PartialEq, Eq)]
#[rtype(result = "()")]
pub struct ChunkStateWitnessMessage(pub EncodedChunkStateWitness);
pub struct ChunkStateWitnessMessage {
pub witness: ChunkStateWitness,
pub raw_witness_size: ChunkStateWitnessSize,
}

/// Helper to track blocks catch up
/// Lifetime of a block_hash is as follows:
Expand Down
3 changes: 2 additions & 1 deletion chain/client/src/client_actor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2117,7 +2117,8 @@ impl Handler<SyncMessage> for ClientActorInner {
impl Handler<ChunkStateWitnessMessage> for ClientActorInner {
#[perf]
fn handle(&mut self, msg: ChunkStateWitnessMessage) {
if let Err(err) = self.client.process_chunk_state_witness(msg.0, None) {
let ChunkStateWitnessMessage { witness, raw_witness_size } = msg;
if let Err(err) = self.client.process_chunk_state_witness(witness, raw_witness_size, None) {
tracing::error!(target: "client", ?err, "Error processing chunk state witness");
}
}
Expand Down
23 changes: 2 additions & 21 deletions chain/client/src/stateless_validation/chunk_validator/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@ use near_primitives::receipt::Receipt;
use near_primitives::sharding::{ChunkHash, ReceiptProof, ShardChunkHeader};
use near_primitives::stateless_validation::{
ChunkEndorsement, ChunkStateWitness, ChunkStateWitnessAck, ChunkStateWitnessSize,
EncodedChunkStateWitness,
};
use near_primitives::transaction::SignedTransaction;
use near_primitives::types::chunk_extra::ChunkExtra;
Expand Down Expand Up @@ -753,11 +752,10 @@ impl Client {
/// you can use the `processing_done_tracker` argument (but it's optional, it's safe to pass None there).
pub fn process_chunk_state_witness(
&mut self,
encoded_witness: EncodedChunkStateWitness,
witness: ChunkStateWitness,
raw_witness_size: ChunkStateWitnessSize,
processing_done_tracker: Option<ProcessingDoneTracker>,
) -> Result<(), Error> {
let (witness, raw_witness_size) = self.decode_state_witness(&encoded_witness)?;

tracing::debug!(
target: "client",
chunk_hash=?witness.chunk_header.chunk_hash(),
Expand Down Expand Up @@ -814,21 +812,4 @@ impl Client {

self.chunk_validator.start_validating_chunk(witness, &self.chain, processing_done_tracker)
}

fn decode_state_witness(
&self,
encoded_witness: &EncodedChunkStateWitness,
) -> Result<(ChunkStateWitness, ChunkStateWitnessSize), Error> {
let decode_start = std::time::Instant::now();
let (witness, raw_witness_size) = encoded_witness.decode()?;
let decode_elapsed_seconds = decode_start.elapsed().as_secs_f64();
let witness_shard = witness.chunk_header.shard_id();

// Record metrics after validating the witness
metrics::CHUNK_STATE_WITNESS_DECODE_TIME
.with_label_values(&[&witness_shard.to_string()])
.observe(decode_elapsed_seconds);

Ok((witness, raw_witness_size))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@ use near_chain::Error;
use near_epoch_manager::EpochManagerAdapter;
use near_primitives::reed_solomon::reed_solomon_decode;
use near_primitives::stateless_validation::{
ChunkProductionKey, EncodedChunkStateWitness, PartialEncodedStateWitness,
ChunkProductionKey, ChunkStateWitness, ChunkStateWitnessSize, EncodedChunkStateWitness,
PartialEncodedStateWitness,
};
use near_primitives::types::ShardId;
use reed_solomon_erasure::galois_8::ReedSolomon;
Expand Down Expand Up @@ -200,7 +201,16 @@ impl PartialEncodedStateWitnessTracker {
.with_label_values(&[entry.shard_id.to_string().as_str()])
.observe(entry.duration_to_last_part.as_seconds_f64());

self.client_sender.send(ChunkStateWitnessMessage(encoded_witness));
let (witness, raw_witness_size) = self.decode_state_witness(&encoded_witness)?;
if witness.chunk_production_key() != key {
return Err(Error::InvalidPartialChunkStateWitness(format!(
"Decoded witness key {:?} doesn't match partial witness {:?}",
witness.chunk_production_key(),
key,
)));
}

self.client_sender.send(ChunkStateWitnessMessage { witness, raw_witness_size });
}
self.record_total_parts_cache_size_metric();
Ok(())
Expand Down Expand Up @@ -265,4 +275,16 @@ impl PartialEncodedStateWitnessTracker {
self.parts_cache.iter().map(|(_, entry)| entry.total_parts_size).sum();
metrics::PARTIAL_WITNESS_CACHE_SIZE.set(total_size as f64);
}

fn decode_state_witness(
&self,
encoded_witness: &EncodedChunkStateWitness,
) -> Result<(ChunkStateWitness, ChunkStateWitnessSize), Error> {
let decode_start = std::time::Instant::now();
let (witness, raw_witness_size) = encoded_witness.decode()?;
metrics::CHUNK_STATE_WITNESS_DECODE_TIME
.with_label_values(&[&witness.chunk_header.shard_id().to_string()])
.observe(decode_start.elapsed().as_secs_f64());
Ok((witness, raw_witness_size))
}
}
16 changes: 7 additions & 9 deletions chain/client/src/test_utils/test_env.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ use near_primitives::epoch_manager::RngSeed;
use near_primitives::errors::InvalidTxError;
use near_primitives::hash::CryptoHash;
use near_primitives::sharding::{ChunkHash, PartialEncodedChunk};
use near_primitives::stateless_validation::{ChunkEndorsement, EncodedChunkStateWitness};
use near_primitives::stateless_validation::{ChunkEndorsement, ChunkStateWitness};
use near_primitives::test_utils::create_test_signer;
use near_primitives::transaction::{Action, FunctionCallAction, SignedTransaction};
use near_primitives::types::{AccountId, Balance, BlockHeight, EpochId, NumSeats, ShardId};
Expand Down Expand Up @@ -328,9 +328,8 @@ impl TestEnv {
}

fn found_differing_post_state_root_due_to_state_transitions(
encoded_witness: &EncodedChunkStateWitness,
witness: &ChunkStateWitness,
) -> bool {
let witness = encoded_witness.decode().unwrap().0;
let mut post_state_roots = HashSet::from([witness.main_state_transition.post_state_root]);
post_state_roots.extend(witness.implicit_transitions.iter().map(|t| t.post_state_root));
post_state_roots.len() >= 2
Expand Down Expand Up @@ -359,8 +358,8 @@ impl TestEnv {
while let Some(request) = partial_witness_adapter.pop_distribution_request() {
let DistributeStateWitnessRequest { epoch_id, chunk_header, state_witness } =
request;
let (encoded_witness, _) =
EncodedChunkStateWitness::encode(&state_witness).unwrap();

let raw_witness_size = borsh::to_vec(&state_witness).unwrap().len();
let chunk_validators = self.clients[client_idx]
.epoch_manager
.get_chunk_validator_assignments(
Expand All @@ -376,7 +375,8 @@ impl TestEnv {
witness_processing_done_waiters.push(processing_done_tracker.make_waiter());

let processing_result = self.client(&account_id).process_chunk_state_witness(
encoded_witness.clone(),
state_witness.clone(),
raw_witness_size,
Some(processing_done_tracker),
);
if !allow_errors {
Expand All @@ -386,9 +386,7 @@ impl TestEnv {

// Update output.
output.found_differing_post_state_root_due_to_state_transitions |=
Self::found_differing_post_state_root_due_to_state_transitions(
&encoded_witness,
);
Self::found_differing_post_state_root_due_to_state_transitions(&state_witness);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@ use near_primitives::sharding::ShardChunkHeaderV3;
use near_primitives::sharding::{
ChunkHash, ReceiptProof, ShardChunkHeader, ShardChunkHeaderInner, ShardProof,
};
use near_primitives::stateless_validation::EncodedChunkStateWitness;
use near_primitives::stateless_validation::ChunkStateWitness;
use near_primitives::stateless_validation::ChunkStateWitnessSize;
use near_primitives::types::AccountId;
use near_primitives_core::checked_feature;
use near_primitives_core::version::PROTOCOL_VERSION;
Expand All @@ -22,7 +23,7 @@ struct OrphanWitnessTestEnv {
env: TestEnv,
block1: Block,
block2: Block,
encoded_witness: EncodedChunkStateWitness,
witness: ChunkStateWitness,
excluded_validator: AccountId,
excluded_validator_idx: usize,
}
Expand Down Expand Up @@ -127,12 +128,12 @@ fn setup_orphan_witness_test() -> OrphanWitnessTestEnv {
// and process it on all validators except for `excluded_validator`.
// The witness isn't processed on `excluded_validator` to give users of
// `setup_orphan_witness_test()` full control over the events.
let mut encoded_witness_opt = None;
let mut witness_opt = None;
let partial_witness_adapter =
env.partial_witness_adapters[env.get_client_index(&block2_chunk_producer)].clone();
while let Some(request) = partial_witness_adapter.pop_distribution_request() {
let DistributeStateWitnessRequest { epoch_id, chunk_header, state_witness } = request;
let (encoded_witness, _) = EncodedChunkStateWitness::encode(&state_witness).unwrap();
let raw_witness_size = borsh_size(&state_witness);
let chunk_validators = env
.client(&block2_chunk_producer)
.epoch_manager
Expand All @@ -149,13 +150,17 @@ fn setup_orphan_witness_test() -> OrphanWitnessTestEnv {
let processing_done_tracker = ProcessingDoneTracker::new();
witness_processing_done_waiters.push(processing_done_tracker.make_waiter());
env.client(&account_id)
.process_chunk_state_witness(encoded_witness.clone(), Some(processing_done_tracker))
.process_chunk_state_witness(
state_witness.clone(),
raw_witness_size,
Some(processing_done_tracker),
)
.unwrap();
}
for waiter in witness_processing_done_waiters {
waiter.wait();
}
encoded_witness_opt = Some(encoded_witness);
witness_opt = Some(state_witness);
}

env.propagate_chunk_endorsements(false);
Expand All @@ -167,11 +172,8 @@ fn setup_orphan_witness_test() -> OrphanWitnessTestEnv {
block2.header().height(),
"There should be no missing chunks."
);
let encoded_witness = encoded_witness_opt.unwrap();
assert_eq!(
encoded_witness.decode().unwrap().0.chunk_header.chunk_hash(),
block2.chunks()[0].chunk_hash()
);
let witness = witness_opt.unwrap();
assert_eq!(witness.chunk_header.chunk_hash(), block2.chunks()[0].chunk_hash());

for client_idx in clients_without_excluded {
let blocks_processed = env.clients[client_idx]
Expand All @@ -189,7 +191,7 @@ fn setup_orphan_witness_test() -> OrphanWitnessTestEnv {
env,
block1,
block2,
encoded_witness,
witness,
excluded_validator,
excluded_validator_idx,
}
Expand All @@ -209,15 +211,18 @@ fn test_orphan_witness_valid() {
mut env,
block1,
block2,
encoded_witness,
witness,
excluded_validator,
excluded_validator_idx,
..
} = setup_orphan_witness_test();

// `excluded_validator` receives witness for chunk belonging to `block2`, but it doesn't have `block1`.
// The witness should become an orphaned witness and it should be saved to the orphan pool.
env.client(&excluded_validator).process_chunk_state_witness(encoded_witness, None).unwrap();
let witness_size = borsh_size(&witness);
env.client(&excluded_validator)
.process_chunk_state_witness(witness, witness_size, None)
.unwrap();

let block_processed = env
.client(&excluded_validator)
Expand All @@ -240,10 +245,9 @@ fn test_orphan_witness_too_large() {
return;
}

let OrphanWitnessTestEnv { mut env, encoded_witness, excluded_validator, .. } =
let OrphanWitnessTestEnv { mut env, witness, excluded_validator, .. } =
setup_orphan_witness_test();

let witness = encoded_witness.decode().unwrap().0;
// The witness should not be saved too the pool, as it's too big
let outcome = env
.client(&excluded_validator)
Expand All @@ -265,17 +269,16 @@ fn test_orphan_witness_far_from_head() {
return;
}

let OrphanWitnessTestEnv { mut env, mut encoded_witness, block1, excluded_validator, .. } =
let OrphanWitnessTestEnv { mut env, mut witness, block1, excluded_validator, .. } =
setup_orphan_witness_test();

let bad_height = 10000;
modify_witness_header_inner(&mut encoded_witness, |header| match &mut header.inner {
modify_witness_header_inner(&mut witness, |header| match &mut header.inner {
ShardChunkHeaderInner::V1(inner) => inner.height_created = bad_height,
ShardChunkHeaderInner::V2(inner) => inner.height_created = bad_height,
ShardChunkHeaderInner::V3(inner) => inner.height_created = bad_height,
});

let witness = encoded_witness.decode().unwrap().0;
let outcome =
env.client(&excluded_validator).handle_orphan_state_witness(witness, 2000).unwrap();
assert_eq!(
Expand All @@ -299,10 +302,9 @@ fn test_orphan_witness_not_fully_validated() {
return;
}

let OrphanWitnessTestEnv { mut env, mut encoded_witness, excluded_validator, .. } =
let OrphanWitnessTestEnv { mut env, mut witness, excluded_validator, .. } =
setup_orphan_witness_test();

let mut witness = encoded_witness.decode().unwrap().0;
// Make the witness invalid in a way that won't be detected during orphan witness validation
witness.source_receipt_proofs.insert(
ChunkHash::default(),
Expand All @@ -311,25 +313,28 @@ fn test_orphan_witness_not_fully_validated() {
ShardProof { from_shard_id: 100230230, to_shard_id: 383939, proof: vec![] },
),
);
encoded_witness = EncodedChunkStateWitness::encode(&witness).unwrap().0;

// The witness should be accepted and saved into the pool, even though it's invalid.
// There is no way to fully validate an orphan witness, so this is the correct behavior.
// The witness will later be fully validated when the required block arrives.
env.client(&excluded_validator).process_chunk_state_witness(encoded_witness, None).unwrap();
let witness_size = borsh_size(&witness);
env.client(&excluded_validator)
.process_chunk_state_witness(witness, witness_size, None)
.unwrap();
}

fn modify_witness_header_inner(
encoded_witness: &mut EncodedChunkStateWitness,
witness: &mut ChunkStateWitness,
f: impl FnOnce(&mut ShardChunkHeaderV3),
) {
let mut witness = encoded_witness.decode().unwrap().0;

match &mut witness.chunk_header {
ShardChunkHeader::V3(header) => {
f(header);
}
_ => panic!(),
_ => unreachable!(),
};
*encoded_witness = EncodedChunkStateWitness::encode(&witness).unwrap().0;
}

fn borsh_size(witness: &ChunkStateWitness) -> ChunkStateWitnessSize {
borsh::to_vec(&witness).unwrap().len()
}
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use near_client::{ProcessTxResponse, ProduceChunkResult};
use near_epoch_manager::{EpochManager, EpochManagerAdapter};
use near_primitives::account::id::AccountIdRef;
use near_primitives::stateless_validation::{ChunkStateWitness, EncodedChunkStateWitness};
use near_primitives::stateless_validation::ChunkStateWitness;
use near_store::test_utils::create_test_store;
use rand::rngs::StdRng;
use rand::{Rng, SeedableRng};
Expand Down Expand Up @@ -339,11 +339,11 @@ fn test_chunk_state_witness_bad_shard_id() {
let previous_block = env.clients[0].chain.head().unwrap().prev_block_hash;
let invalid_shard_id = 1000000000;
let witness = ChunkStateWitness::new_dummy(upper_height, invalid_shard_id, previous_block);
let encoded_witness = EncodedChunkStateWitness::encode(&witness).unwrap().0;
let witness_size = borsh::to_vec(&witness).unwrap().len();

// Client should reject this ChunkStateWitness and the error message should mention "shard"
tracing::info!(target: "test", "Processing invalid ChunkStateWitness");
let res = env.clients[0].process_chunk_state_witness(encoded_witness, None);
let res = env.clients[0].process_chunk_state_witness(witness, witness_size, None);
let error = res.unwrap_err();
let error_message = format!("{}", error).to_lowercase();
tracing::info!(target: "test", "error message: {}", error_message);
Expand Down

0 comments on commit 96af8f7

Please sign in to comment.