From 475a1413e410581a4154164b93dcb1e3ea88a3ca Mon Sep 17 00:00:00 2001 From: Tadeo Hepperle <62739623+tadeohepperle@users.noreply.github.com> Date: Wed, 19 Jul 2023 19:49:08 +0200 Subject: [PATCH] Prevent bug when reusing type ids in hashing (#1075) * practice TDD * implement a hashmap 2-phases approach * use nicer types * add test for cache filling * adjust test --------- Co-authored-by: James Wilson --- metadata/src/lib.rs | 4 +- metadata/src/utils/validation.rs | 196 +++++++++++++++++++++---------- 2 files changed, 139 insertions(+), 61 deletions(-) diff --git a/metadata/src/lib.rs b/metadata/src/lib.rs index 8403b33e00..988bba1db9 100644 --- a/metadata/src/lib.rs +++ b/metadata/src/lib.rs @@ -20,7 +20,7 @@ mod from_into; mod utils; use scale_info::{form::PortableForm, PortableRegistry, Variant}; -use std::collections::{HashMap, HashSet}; +use std::collections::HashMap; use std::sync::Arc; use utils::ordered_map::OrderedMap; use utils::variant_index::VariantIndex; @@ -152,7 +152,7 @@ impl Metadata { Some(crate::utils::validation::get_type_hash( &self.types, id, - &mut HashSet::::new(), + &mut HashMap::new(), )) } } diff --git a/metadata/src/utils/validation.rs b/metadata/src/utils/validation.rs index 0ce0caf55d..4847722b77 100644 --- a/metadata/src/utils/validation.rs +++ b/metadata/src/utils/validation.rs @@ -9,10 +9,7 @@ use crate::{ RuntimeApiMethodMetadata, StorageEntryMetadata, StorageEntryType, }; use scale_info::{form::PortableForm, Field, PortableRegistry, TypeDef, TypeDefVariant, Variant}; -use std::collections::HashSet; - -/// Predefined value to be returned when we already visited a type. -const MAGIC_RECURSIVE_TYPE_VALUE: &[u8] = &[123]; +use std::collections::HashMap; // The number of bytes our `hash` function produces. const HASH_LEN: usize = 32; @@ -75,7 +72,7 @@ concat_and_hash_n!(concat_and_hash5(a b c d e)); fn get_field_hash( registry: &PortableRegistry, field: &Field, - visited_ids: &mut HashSet, + cache: &mut HashMap, ) -> [u8; HASH_LEN] { let field_name_bytes = match &field.name { Some(name) => hash(name.as_bytes()), @@ -84,7 +81,7 @@ fn get_field_hash( concat_and_hash2( &field_name_bytes, - &get_type_hash(registry, field.ty.id, visited_ids), + &get_type_hash(registry, field.ty.id, cache), ) } @@ -92,13 +89,13 @@ fn get_field_hash( fn get_variant_hash( registry: &PortableRegistry, var: &Variant, - visited_ids: &mut HashSet, + cache: &mut HashMap, ) -> [u8; HASH_LEN] { let variant_name_bytes = hash(var.name.as_bytes()); let variant_field_bytes = var.fields.iter().fold([0u8; HASH_LEN], |bytes, field| { // EncodeAsType and DecodeAsType don't care about variant field ordering, // so XOR the fields to ensure that it doesn't matter. - xor(bytes, get_field_hash(registry, field, visited_ids)) + xor(bytes, get_field_hash(registry, field, cache)) }); concat_and_hash2(&variant_name_bytes, &variant_field_bytes) @@ -108,7 +105,7 @@ fn get_type_def_variant_hash( registry: &PortableRegistry, variant: &TypeDefVariant, only_these_variants: Option<&[&str]>, - visited_ids: &mut HashSet, + cache: &mut HashMap, ) -> [u8; HASH_LEN] { let variant_id_bytes = [TypeBeingHashed::Variant as u8; HASH_LEN]; let variant_field_bytes = variant.variants.iter().fold([0u8; HASH_LEN], |bytes, var| { @@ -120,7 +117,7 @@ fn get_type_def_variant_hash( .unwrap_or(true); if should_hash { - xor(bytes, get_variant_hash(registry, var, visited_ids)) + xor(bytes, get_variant_hash(registry, var, cache)) } else { bytes } @@ -132,7 +129,7 @@ fn get_type_def_variant_hash( fn get_type_def_hash( registry: &PortableRegistry, ty_def: &TypeDef, - visited_ids: &mut HashSet, + cache: &mut HashMap, ) -> [u8; HASH_LEN] { match ty_def { TypeDef::Composite(composite) => { @@ -144,16 +141,14 @@ fn get_type_def_hash( .fold([0u8; HASH_LEN], |bytes, field| { // With EncodeAsType and DecodeAsType we no longer care which order the fields are in, // as long as all of the names+types are there. XOR to not care about ordering. - xor(bytes, get_field_hash(registry, field, visited_ids)) + xor(bytes, get_field_hash(registry, field, cache)) }); concat_and_hash2(&composite_id_bytes, &composite_field_bytes) } - TypeDef::Variant(variant) => { - get_type_def_variant_hash(registry, variant, None, visited_ids) - } + TypeDef::Variant(variant) => get_type_def_variant_hash(registry, variant, None, cache), TypeDef::Sequence(sequence) => concat_and_hash2( &[TypeBeingHashed::Sequence as u8; HASH_LEN], - &get_type_hash(registry, sequence.type_param.id, visited_ids), + &get_type_hash(registry, sequence.type_param.id, cache), ), TypeDef::Array(array) => { // Take length into account too; different length must lead to different hash. @@ -165,13 +160,13 @@ fn get_type_def_hash( }; concat_and_hash2( &array_id_bytes, - &get_type_hash(registry, array.type_param.id, visited_ids), + &get_type_hash(registry, array.type_param.id, cache), ) } TypeDef::Tuple(tuple) => { let mut bytes = hash(&[TypeBeingHashed::Tuple as u8]); for field in &tuple.fields { - bytes = concat_and_hash2(&bytes, &get_type_hash(registry, field.id, visited_ids)); + bytes = concat_and_hash2(&bytes, &get_type_hash(registry, field.id, cache)); } bytes } @@ -181,31 +176,64 @@ fn get_type_def_hash( } TypeDef::Compact(compact) => concat_and_hash2( &[TypeBeingHashed::Compact as u8; HASH_LEN], - &get_type_hash(registry, compact.type_param.id, visited_ids), + &get_type_hash(registry, compact.type_param.id, cache), ), TypeDef::BitSequence(bitseq) => concat_and_hash3( &[TypeBeingHashed::BitSequence as u8; HASH_LEN], - &get_type_hash(registry, bitseq.bit_order_type.id, visited_ids), - &get_type_hash(registry, bitseq.bit_store_type.id, visited_ids), + &get_type_hash(registry, bitseq.bit_order_type.id, cache), + &get_type_hash(registry, bitseq.bit_store_type.id, cache), ), } } +/// indicates whether a hash has been fully computed for a type or not +#[derive(Clone, Debug)] +pub enum CachedHash { + /// hash not known yet, but computation has already started + Recursive, + /// hash of the type, computation was finished + Hash([u8; HASH_LEN]), +} + +impl CachedHash { + fn hash(&self) -> [u8; HASH_LEN] { + match &self { + CachedHash::Hash(hash) => *hash, + CachedHash::Recursive => [123; HASH_LEN], // some magical value + } + } +} + /// Obtain the hash representation of a `scale_info::Type` identified by id. pub fn get_type_hash( registry: &PortableRegistry, id: u32, - visited_ids: &mut HashSet, + cache: &mut HashMap, ) -> [u8; HASH_LEN] { - // Guard against recursive types and return a fixed arbitrary hash - if !visited_ids.insert(id) { - return hash(MAGIC_RECURSIVE_TYPE_VALUE); + // Guard against recursive types, with a 2 step caching approach: + // if the cache has an entry for the id, just return a hash derived from it. + // if the type has not been seen yet, mark it with `CachedHash::Recursive` in the cache and proceed to `get_type_def_hash()`. + // -> During the execution of get_type_def_hash() we might get into get_type_hash(id) again for the original id + // -> in this case the `CachedHash::Recursive` provokes an early return. + // -> Once we return from `get_type_def_hash()` we need to update the cache entry: + // -> We set the cache value to `CachedHash::Hash(type_hash)`, where `type_hash` was returned from `get_type_def_hash()` + // -> It makes sure, that different types end up with different cache values. + // + // Values in the cache can be thought of as a mapping like this: + // type_id -> not contained = We haven't seen the type yet. + // -> `CachedHash::Recursive` = We have seen the type but hash calculation for it hasn't finished yet. + // -> `CachedHash::Hash(hash)` = Hash calculation for the type was completed. + + if let Some(cached_hash) = cache.get(&id) { + return cached_hash.hash(); } - + cache.insert(id, CachedHash::Recursive); let ty = registry .resolve(id) .expect("Type ID provided by the metadata is registered; qed"); - get_type_def_hash(registry, &ty.type_def, visited_ids) + let type_hash = get_type_def_hash(registry, &ty.type_def, cache); + cache.insert(id, CachedHash::Hash(type_hash)); + type_hash } /// Obtain the hash representation of a `frame_metadata::v15::ExtrinsicMetadata`. @@ -213,13 +241,13 @@ fn get_extrinsic_hash( registry: &PortableRegistry, extrinsic: &ExtrinsicMetadata, ) -> [u8; HASH_LEN] { - let mut visited_ids = HashSet::::new(); + let mut cache = HashMap::::new(); // Get the hashes of the extrinsic type. - let address_hash = get_type_hash(registry, extrinsic.address_ty, &mut visited_ids); + let address_hash = get_type_hash(registry, extrinsic.address_ty, &mut cache); // The `RuntimeCall` type is intentionally omitted and hashed by the outer enums instead. - let signature_hash = get_type_hash(registry, extrinsic.signature_ty, &mut visited_ids); - let extra_hash = get_type_hash(registry, extrinsic.extra_ty, &mut visited_ids); + let signature_hash = get_type_hash(registry, extrinsic.signature_ty, &mut cache); + let extra_hash = get_type_hash(registry, extrinsic.extra_ty, &mut cache); let mut bytes = concat_and_hash4( &address_hash, @@ -232,8 +260,8 @@ fn get_extrinsic_hash( bytes = concat_and_hash4( &bytes, &hash(signed_extension.identifier.as_bytes()), - &get_type_hash(registry, signed_extension.extra_ty, &mut visited_ids), - &get_type_hash(registry, signed_extension.additional_ty, &mut visited_ids), + &get_type_hash(registry, signed_extension.extra_ty, &mut cache), + &get_type_hash(registry, signed_extension.additional_ty, &mut cache), ) } @@ -258,9 +286,9 @@ fn get_outer_enums_hash( .expect("Metadata should contain enum type in registry"); if let TypeDef::Variant(variant) = &ty.ty.type_def { - get_type_def_variant_hash(registry, variant, only_these_variants, &mut HashSet::new()) + get_type_def_variant_hash(registry, variant, only_these_variants, &mut HashMap::new()) } else { - get_type_hash(registry, id, &mut HashSet::new()) + get_type_hash(registry, id, &mut HashMap::new()) } } @@ -277,7 +305,7 @@ fn get_outer_enums_hash( fn get_storage_entry_hash( registry: &PortableRegistry, entry: &StorageEntryMetadata, - visited_ids: &mut HashSet, + cache: &mut HashMap, ) -> [u8; HASH_LEN] { let mut bytes = concat_and_hash3( &hash(entry.name.as_bytes()), @@ -288,7 +316,7 @@ fn get_storage_entry_hash( match &entry.entry_type { StorageEntryType::Plain(ty) => { - concat_and_hash2(&bytes, &get_type_hash(registry, *ty, visited_ids)) + concat_and_hash2(&bytes, &get_type_hash(registry, *ty, cache)) } StorageEntryType::Map { hashers, @@ -301,8 +329,8 @@ fn get_storage_entry_hash( } concat_and_hash3( &bytes, - &get_type_hash(registry, *key_ty, visited_ids), - &get_type_hash(registry, *value_ty, visited_ids), + &get_type_hash(registry, *key_ty, cache), + &get_type_hash(registry, *value_ty, cache), ) } } @@ -313,7 +341,7 @@ fn get_runtime_method_hash( registry: &PortableRegistry, trait_name: &str, method_metadata: &RuntimeApiMethodMetadata, - visited_ids: &mut HashSet, + cache: &mut HashMap, ) -> [u8; HASH_LEN] { // The trait name is part of the runtime API call that is being // generated for this method. Therefore the trait name is strongly @@ -328,13 +356,13 @@ fn get_runtime_method_hash( bytes = concat_and_hash3( &bytes, &hash(input.name.as_bytes()), - &get_type_hash(registry, input.ty, visited_ids), + &get_type_hash(registry, input.ty, cache), ); } bytes = concat_and_hash2( &bytes, - &get_type_hash(registry, method_metadata.output_ty, visited_ids), + &get_type_hash(registry, method_metadata.output_ty, cache), ); bytes @@ -342,7 +370,7 @@ fn get_runtime_method_hash( /// Obtain the hash of all of a runtime API trait, including all of its methods. pub fn get_runtime_trait_hash(trait_metadata: RuntimeApiMetadata) -> [u8; HASH_LEN] { - let mut visited_ids = HashSet::new(); + let mut cache = HashMap::new(); let trait_name = &*trait_metadata.inner.name; let method_bytes = trait_metadata .methods() @@ -357,7 +385,7 @@ pub fn get_runtime_trait_hash(trait_metadata: RuntimeApiMetadata) -> [u8; HASH_L trait_metadata.types, trait_name, method_metadata, - &mut visited_ids, + &mut cache, ), ) }); @@ -370,7 +398,7 @@ pub fn get_storage_hash(pallet: &PalletMetadata, entry_name: &str) -> Option<[u8 let storage = pallet.storage()?; let entry = storage.entry_by_name(entry_name)?; - let hash = get_storage_entry_hash(pallet.types, entry, &mut HashSet::new()); + let hash = get_storage_entry_hash(pallet.types, entry, &mut HashMap::new()); Some(hash) } @@ -379,7 +407,7 @@ pub fn get_constant_hash(pallet: &PalletMetadata, constant_name: &str) -> Option let constant = pallet.constant_by_name(constant_name)?; // We only need to check that the type of the constant asked for matches. - let bytes = get_type_hash(pallet.types, constant.ty, &mut HashSet::new()); + let bytes = get_type_hash(pallet.types, constant.ty, &mut HashMap::new()); Some(bytes) } @@ -388,7 +416,7 @@ pub fn get_call_hash(pallet: &PalletMetadata, call_name: &str) -> Option<[u8; HA let call_variant = pallet.call_variant_by_name(call_name)?; // hash the specific variant representing the call we are interested in. - let hash = get_variant_hash(pallet.types, call_variant, &mut HashSet::new()); + let hash = get_variant_hash(pallet.types, call_variant, &mut HashMap::new()); Some(hash) } @@ -404,25 +432,25 @@ pub fn get_runtime_api_hash( runtime_apis.types, trait_name, method_metadata, - &mut HashSet::new(), + &mut HashMap::new(), )) } /// Obtain the hash representation of a `frame_metadata::v15::PalletMetadata`. pub fn get_pallet_hash(pallet: PalletMetadata) -> [u8; HASH_LEN] { - let mut visited_ids = HashSet::::new(); + let mut cache = HashMap::::new(); let registry = pallet.types; let call_bytes = match pallet.call_ty_id() { - Some(calls) => get_type_hash(registry, calls, &mut visited_ids), + Some(calls) => get_type_hash(registry, calls, &mut cache), None => [0u8; HASH_LEN], }; let event_bytes = match pallet.event_ty_id() { - Some(event) => get_type_hash(registry, event, &mut visited_ids), + Some(event) => get_type_hash(registry, event, &mut cache), None => [0u8; HASH_LEN], }; let error_bytes = match pallet.error_ty_id() { - Some(error) => get_type_hash(registry, error, &mut visited_ids), + Some(error) => get_type_hash(registry, error, &mut cache), None => [0u8; HASH_LEN], }; let constant_bytes = pallet.constants().fold([0u8; HASH_LEN], |bytes, constant| { @@ -430,7 +458,7 @@ pub fn get_pallet_hash(pallet: PalletMetadata) -> [u8; HASH_LEN] { // of (constantName, constantType) to make the order we see them irrelevant. let constant_hash = concat_and_hash2( &hash(constant.name.as_bytes()), - &get_type_hash(registry, constant.ty(), &mut visited_ids), + &get_type_hash(registry, constant.ty(), &mut cache), ); xor(bytes, constant_hash) }); @@ -443,10 +471,7 @@ pub fn get_pallet_hash(pallet: PalletMetadata) -> [u8; HASH_LEN] { .fold([0u8; HASH_LEN], |bytes, entry| { // We don't care what order the storage entries occur in, so XOR them together // to make the order irrelevant. - xor( - bytes, - get_storage_entry_hash(registry, entry, &mut visited_ids), - ) + xor(bytes, get_storage_entry_hash(registry, entry, &mut cache)) }); concat_and_hash2(&prefix_hash, &entries_hash) } @@ -537,7 +562,7 @@ impl<'a> MetadataHasher<'a> { let extrinsic_hash = get_extrinsic_hash(&metadata.types, &metadata.extrinsic); let runtime_hash = - get_type_hash(&metadata.types, metadata.runtime_ty(), &mut HashSet::new()); + get_type_hash(&metadata.types, metadata.runtime_ty(), &mut HashMap::new()); let outer_enums_hash = get_outer_enums_hash( &metadata.types, &metadata.outer_enums(), @@ -559,7 +584,7 @@ mod tests { use super::*; use bitvec::{order::Lsb0, vec::BitVec}; use frame_metadata::v15; - use scale_info::meta_type; + use scale_info::{meta_type, Registry}; // Define recursive types. #[allow(dead_code)] @@ -743,6 +768,59 @@ mod tests { assert_eq!(hash, hash_swap); } + #[allow(dead_code)] + #[derive(scale_info::TypeInfo)] + struct Aba { + ab: (A, B), + other: A, + } + + #[allow(dead_code)] + #[derive(scale_info::TypeInfo)] + struct Abb { + ab: (A, B), + other: B, + } + + #[test] + /// Ensure ABB and ABA have a different structure: + fn do_not_reuse_visited_type_ids() { + let metadata_hash_with_type = |ty| { + let mut pallets = build_default_pallets(); + pallets[0].calls = Some(v15::PalletCallMetadata { ty }); + let metadata = pallets_to_metadata(pallets); + MetadataHasher::new(&metadata).hash() + }; + + let aba_hash = metadata_hash_with_type(meta_type::()); + let abb_hash = metadata_hash_with_type(meta_type::()); + + assert_ne!(aba_hash, abb_hash); + } + + #[test] + fn hash_cache_gets_filled_with_correct_hashes() { + let mut registry = Registry::new(); + let a_type_id = registry.register_type(&meta_type::()).id; + let b_type_id = registry.register_type(&meta_type::()).id; + let registry: PortableRegistry = registry.into(); + + let mut cache = HashMap::new(); + + let a_hash = get_type_hash(®istry, a_type_id, &mut cache); + let a_hash2 = get_type_hash(®istry, a_type_id, &mut cache); + let b_hash = get_type_hash(®istry, b_type_id, &mut cache); + + let CachedHash::Hash(a_cache_hash) = cache[&a_type_id] else { panic!() }; + let CachedHash::Hash(b_cache_hash) = cache[&b_type_id] else { panic!() }; + + assert_eq!(a_hash, a_cache_hash); + assert_eq!(b_hash, b_cache_hash); + + assert_eq!(a_hash, a_hash2); + assert_ne!(a_hash, b_hash); + } + #[test] // Redundant clone clippy warning is a lie; https://github.com/rust-lang/rust-clippy/issues/10870 #[allow(clippy::redundant_clone)]