From b243fb458f13c6291544373ba4ca22bd03b5e6a5 Mon Sep 17 00:00:00 2001 From: Moritz Date: Thu, 12 Oct 2023 18:25:17 +0200 Subject: [PATCH] feat: implement chunk validator assignment --- Cargo.lock | 1 + chain/chain/src/test_utils/kv_runtime.rs | 1 + chain/epoch-manager/src/proposals.rs | 5 + chain/epoch-manager/src/test_utils.rs | 12 +- .../epoch-manager/src/validator_selection.rs | 59 +++++ core/primitives/Cargo.toml | 1 + core/primitives/src/epoch_manager.rs | 124 ++++++++- core/primitives/src/lib.rs | 1 + core/primitives/src/types.rs | 29 +++ core/primitives/src/validator_mandates.rs | 239 ++++++++++++++++++ 10 files changed, 466 insertions(+), 6 deletions(-) create mode 100644 core/primitives/src/validator_mandates.rs diff --git a/Cargo.lock b/Cargo.lock index 14e347e3f9b..18201742286 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4106,6 +4106,7 @@ dependencies = [ "once_cell", "primitive-types", "rand 0.8.5", + "rand_chacha 0.3.1", "reed-solomon-erasure", "serde", "serde_json", diff --git a/chain/chain/src/test_utils/kv_runtime.rs b/chain/chain/src/test_utils/kv_runtime.rs index 844e7796eed..3a39e960fb0 100644 --- a/chain/chain/src/test_utils/kv_runtime.rs +++ b/chain/chain/src/test_utils/kv_runtime.rs @@ -523,6 +523,7 @@ impl EpochManagerAdapter for MockEpochManager { 1, 1, RngSeed::default(), + Default::default(), ))) } diff --git a/chain/epoch-manager/src/proposals.rs b/chain/epoch-manager/src/proposals.rs index 17e6c792dc9..1120d8cc7cc 100644 --- a/chain/epoch-manager/src/proposals.rs +++ b/chain/epoch-manager/src/proposals.rs @@ -85,6 +85,7 @@ mod old_validator_selection { use near_primitives::types::{ AccountId, Balance, NumSeats, ValidatorId, ValidatorKickoutReason, }; + use near_primitives::validator_mandates::ValidatorMandates; use near_primitives::version::ProtocolVersion; use rand::{RngCore, SeedableRng}; use rand_hc::Hc128Rng; @@ -248,6 +249,9 @@ mod old_validator_selection { .map(|(index, s)| (s.account_id().clone(), index as ValidatorId)) .collect::>(); + // Old validator selection is not aware of chunk validator mandates. + let validator_mandates: ValidatorMandates = Default::default(); + Ok(EpochInfo::new( prev_epoch_info.epoch_height() + 1, final_proposals, @@ -264,6 +268,7 @@ mod old_validator_selection { threshold, next_version, rng_seed, + validator_mandates, )) } diff --git a/chain/epoch-manager/src/test_utils.rs b/chain/epoch-manager/src/test_utils.rs index 1fde5aaabc3..eb219831209 100644 --- a/chain/epoch-manager/src/test_utils.rs +++ b/chain/epoch-manager/src/test_utils.rs @@ -20,6 +20,7 @@ use near_primitives::types::{ ValidatorId, ValidatorKickoutReason, }; use near_primitives::utils::get_num_seats_per_shard; +use near_primitives::validator_mandates::{ValidatorMandates, ValidatorMandatesConfig}; use near_primitives::version::PROTOCOL_VERSION; use near_store::test_utils::create_test_store; @@ -104,9 +105,17 @@ pub fn epoch_info_with_num_seats( }) .collect() }; + let all_validators = account_to_validators(accounts); + // TODO determine required stake per mandate instead of reusing seat price. + // TODO determine `min_mandates_per_shard` + let num_shards = chunk_producers_settlement.len(); + let min_mandates_per_shard = 0; + let validator_mandates_config = + ValidatorMandatesConfig::new(seat_price, min_mandates_per_shard, num_shards); + let validator_mandates = ValidatorMandates::new(validator_mandates_config, &all_validators); EpochInfo::new( epoch_height, - account_to_validators(accounts), + all_validators, validator_to_index, block_producers_settlement, chunk_producers_settlement, @@ -120,6 +129,7 @@ pub fn epoch_info_with_num_seats( seat_price, PROTOCOL_VERSION, TEST_SEED, + validator_mandates, ) } diff --git a/chain/epoch-manager/src/validator_selection.rs b/chain/epoch-manager/src/validator_selection.rs index 41499687e22..71813fde82e 100644 --- a/chain/epoch-manager/src/validator_selection.rs +++ b/chain/epoch-manager/src/validator_selection.rs @@ -7,6 +7,7 @@ use near_primitives::types::validator_stake::ValidatorStake; use near_primitives::types::{ AccountId, Balance, ProtocolVersion, ValidatorId, ValidatorKickoutReason, }; +use near_primitives::validator_mandates::{ValidatorMandates, ValidatorMandatesConfig}; use num_rational::Ratio; use std::cmp::{self, Ordering}; use std::collections::hash_map; @@ -96,6 +97,7 @@ pub fn proposals_to_epoch_info( } let num_chunk_producers = chunk_producers.len(); + // Constructing `all_validators` such that a validators position corresponds to its `ValidatorId`. let mut all_validators: Vec = Vec::with_capacity(num_chunk_producers); let mut validator_to_index = HashMap::new(); let mut block_producers_settlement = Vec::with_capacity(block_producers.len()); @@ -170,6 +172,16 @@ pub fn proposals_to_epoch_info( .collect() }; + // We can use `all_validators` to construct mandates Since a validator's position in + // `all_validators` corresponds to its `ValidatorId` + // TODO determine required stake per mandate instead of reusing seat price. + // TODO determine `min_mandates_per_shard` + // TODO pre chunk-validation, just pass empty vec instead of `all_validators` to avoid costs? + let min_mandates_per_shard = 0; + let validator_mandates_config = + ValidatorMandatesConfig::new(threshold, min_mandates_per_shard, num_shards as usize); + let validator_mandates = ValidatorMandates::new(validator_mandates_config, &all_validators); + let fishermen_to_index = fishermen .iter() .enumerate() @@ -192,6 +204,7 @@ pub fn proposals_to_epoch_info( threshold, next_version, rng_seed, + validator_mandates, )) } @@ -619,6 +632,52 @@ mod tests { } } + /// This test only verifies that chunk validator mandates are correctly wired up with + /// `EpochInfo`. The internals of mandate assignment are tested in the module containing + /// [`ValidatorMandates`]. + #[test] + fn test_chunk_validators_sampling() { + // When there is 1 CP per shard, they are chosen 100% of the time. + let num_shards = 4; + let epoch_config = create_epoch_config( + num_shards, + 2 * num_shards, + 0, + ValidatorSelectionConfig { + num_chunk_only_producer_seats: 0, + minimum_validators_per_shard: 1, + minimum_stake_ratio: Ratio::new(160, 1_000_000), + }, + ); + let prev_epoch_height = 7; + let prev_epoch_info = create_prev_epoch_info(prev_epoch_height, &["test1", "test2"], &[]); + let proposals = + create_proposals(&[("test1", 15), ("test2", 9), ("test3", 5), ("test4", 3)]); + + let epoch_info = proposals_to_epoch_info( + &epoch_config, + [0; 32], + &prev_epoch_info, + proposals, + Default::default(), + Default::default(), + 0, + PROTOCOL_VERSION, + PROTOCOL_VERSION, + ) + .unwrap(); + + // Given `epoch_info` and `proposals` above, the sample at a given height is deterministic. + let height = 42; + let expected_assignments: Vec> = vec![ + HashMap::from([(0, 1), (1, 5), (2, 1), (3, 1)]), + HashMap::from([(0, 6), (1, 1), (2, 1)]), + HashMap::from([(0, 5), (1, 2), (2, 1)]), + HashMap::from([(0, 3), (1, 1), (2, 2), (3, 2)]), + ]; + assert_eq!(epoch_info.sample_chunk_validators(height), expected_assignments); + } + #[test] fn test_validator_assignment_ratio_condition() { // There are more seats than proposals, however the diff --git a/core/primitives/Cargo.toml b/core/primitives/Cargo.toml index 4334d4f3447..e9a9fcb6476 100644 --- a/core/primitives/Cargo.toml +++ b/core/primitives/Cargo.toml @@ -24,6 +24,7 @@ num-rational.workspace = true once_cell.workspace = true primitive-types.workspace = true rand.workspace = true +rand_chacha.workspace = true reed-solomon-erasure.workspace = true serde.workspace = true serde_json.workspace = true diff --git a/core/primitives/src/epoch_manager.rs b/core/primitives/src/epoch_manager.rs index 9744e39a1d8..2b5fd9a4d91 100644 --- a/core/primitives/src/epoch_manager.rs +++ b/core/primitives/src/epoch_manager.rs @@ -474,12 +474,15 @@ pub mod epoch_info { use crate::epoch_manager::ValidatorWeight; use crate::types::validator_stake::{ValidatorStake, ValidatorStakeIter}; use crate::types::{BlockChunkValidatorStats, ValidatorKickoutReason}; + use crate::validator_mandates::ValidatorMandates; use crate::version::PROTOCOL_VERSION; use borsh::{BorshDeserialize, BorshSerialize}; use near_primitives_core::hash::CryptoHash; use near_primitives_core::types::{ AccountId, Balance, EpochHeight, ProtocolVersion, ValidatorId, }; + use rand::SeedableRng; + use rand_chacha::ChaCha20Rng; use smart_default::SmartDefault; use std::collections::{BTreeMap, HashMap}; @@ -499,6 +502,7 @@ pub mod epoch_info { V1(EpochInfoV1), V2(EpochInfoV2), V3(EpochInfoV3), + V4(EpochInfoV4), } impl Default for EpochInfo { @@ -585,6 +589,41 @@ pub mod epoch_info { chunk_producers_sampler: Vec, } + // V3 -> V4: Add structures and methods for stateless validator assignment. + #[derive( + SmartDefault, + BorshSerialize, + BorshDeserialize, + Clone, + Debug, + PartialEq, + Eq, + serde::Serialize, + )] + pub struct EpochInfoV4 { + pub epoch_height: EpochHeight, + pub validators: Vec, + pub validator_to_index: HashMap, + pub block_producers_settlement: Vec, + pub chunk_producers_settlement: Vec>, + pub hidden_validators_settlement: Vec, + pub fishermen: Vec, + pub fishermen_to_index: HashMap, + pub stake_change: BTreeMap, + pub validator_reward: HashMap, + pub validator_kickout: HashMap, + pub minted_amount: Balance, + pub seat_price: Balance, + #[default(PROTOCOL_VERSION)] + pub protocol_version: ProtocolVersion, + // stuff for selecting validators at each height + rng_seed: RngSeed, + block_producers_sampler: WeightedIndex, + chunk_producers_sampler: Vec, + /// Contains the epoch's validator mandates. Used to sample chunk validators. + validator_mandates: ValidatorMandates, + } + impl EpochInfo { pub fn new( epoch_height: EpochHeight, @@ -602,6 +641,7 @@ pub mod epoch_info { seat_price: Balance, protocol_version: ProtocolVersion, rng_seed: RngSeed, + validator_mandates: ValidatorMandates, ) -> Self { if checked_feature!("stable", AliasValidatorSelectionAlgorithm, protocol_version) { let stake_weights = |ids: &[ValidatorId]| -> WeightedIndex { @@ -615,7 +655,7 @@ pub mod epoch_info { let block_producers_sampler = stake_weights(&block_producers_settlement); let chunk_producers_sampler = chunk_producers_settlement.iter().map(|vs| stake_weights(vs)).collect(); - Self::V3(EpochInfoV3 { + Self::V4(EpochInfoV4 { epoch_height, validators, fishermen, @@ -633,6 +673,7 @@ pub mod epoch_info { rng_seed, block_producers_sampler, chunk_producers_sampler, + validator_mandates, }) } else { Self::V2(EpochInfoV2 { @@ -694,6 +735,7 @@ pub mod epoch_info { Self::V1(v1) => &mut v1.epoch_height, Self::V2(v2) => &mut v2.epoch_height, Self::V3(v3) => &mut v3.epoch_height, + Self::V4(v4) => &mut v4.epoch_height, } } @@ -703,6 +745,7 @@ pub mod epoch_info { Self::V1(v1) => v1.epoch_height, Self::V2(v2) => v2.epoch_height, Self::V3(v3) => v3.epoch_height, + Self::V4(v4) => v4.epoch_height, } } @@ -712,6 +755,7 @@ pub mod epoch_info { Self::V1(v1) => v1.seat_price, Self::V2(v2) => v2.seat_price, Self::V3(v3) => v3.seat_price, + Self::V4(v4) => v4.seat_price, } } @@ -721,6 +765,7 @@ pub mod epoch_info { Self::V1(v1) => v1.minted_amount, Self::V2(v2) => v2.minted_amount, Self::V3(v3) => v3.minted_amount, + Self::V4(v4) => v4.minted_amount, } } @@ -730,6 +775,7 @@ pub mod epoch_info { Self::V1(v1) => &v1.block_producers_settlement, Self::V2(v2) => &v2.block_producers_settlement, Self::V3(v3) => &v3.block_producers_settlement, + Self::V4(v4) => &v4.block_producers_settlement, } } @@ -739,6 +785,7 @@ pub mod epoch_info { Self::V1(v1) => &v1.chunk_producers_settlement, Self::V2(v2) => &v2.chunk_producers_settlement, Self::V3(v3) => &v3.chunk_producers_settlement, + Self::V4(v4) => &v4.chunk_producers_settlement, } } @@ -748,6 +795,7 @@ pub mod epoch_info { Self::V1(v1) => &v1.validator_kickout, Self::V2(v2) => &v2.validator_kickout, Self::V3(v3) => &v3.validator_kickout, + Self::V4(v4) => &v4.validator_kickout, } } @@ -757,6 +805,7 @@ pub mod epoch_info { Self::V1(v1) => v1.protocol_version, Self::V2(v2) => v2.protocol_version, Self::V3(v3) => v3.protocol_version, + Self::V4(v4) => v4.protocol_version, } } @@ -766,6 +815,7 @@ pub mod epoch_info { Self::V1(v1) => &v1.stake_change, Self::V2(v2) => &v2.stake_change, Self::V3(v3) => &v3.stake_change, + Self::V4(v4) => &v4.stake_change, } } @@ -775,6 +825,7 @@ pub mod epoch_info { Self::V1(v1) => &v1.validator_reward, Self::V2(v2) => &v2.validator_reward, Self::V3(v3) => &v3.validator_reward, + Self::V4(v4) => &v4.validator_reward, } } @@ -784,6 +835,7 @@ pub mod epoch_info { Self::V1(v1) => ValidatorStakeIter::v1(&v1.validators), Self::V2(v2) => ValidatorStakeIter::new(&v2.validators), Self::V3(v3) => ValidatorStakeIter::new(&v3.validators), + Self::V4(v4) => ValidatorStakeIter::new(&v4.validators), } } @@ -793,6 +845,7 @@ pub mod epoch_info { Self::V1(v1) => ValidatorStakeIter::v1(&v1.fishermen), Self::V2(v2) => ValidatorStakeIter::new(&v2.fishermen), Self::V3(v3) => ValidatorStakeIter::new(&v3.fishermen), + Self::V4(v4) => ValidatorStakeIter::new(&v4.fishermen), } } @@ -802,6 +855,7 @@ pub mod epoch_info { Self::V1(v1) => v1.validators[validator_id as usize].stake, Self::V2(v2) => v2.validators[validator_id as usize].stake(), Self::V3(v3) => v3.validators[validator_id as usize].stake(), + Self::V4(v4) => v4.validators[validator_id as usize].stake(), } } @@ -811,6 +865,7 @@ pub mod epoch_info { Self::V1(v1) => &v1.validators[validator_id as usize].account_id, Self::V2(v2) => v2.validators[validator_id as usize].account_id(), Self::V3(v3) => v3.validators[validator_id as usize].account_id(), + Self::V4(v4) => v4.validators[validator_id as usize].account_id(), } } @@ -820,6 +875,7 @@ pub mod epoch_info { Self::V1(v1) => v1.validator_to_index.contains_key(account_id), Self::V2(v2) => v2.validator_to_index.contains_key(account_id), Self::V3(v3) => v3.validator_to_index.contains_key(account_id), + Self::V4(v4) => v4.validator_to_index.contains_key(account_id), } } @@ -828,6 +884,7 @@ pub mod epoch_info { Self::V1(v1) => v1.validator_to_index.get(account_id), Self::V2(v2) => v2.validator_to_index.get(account_id), Self::V3(v3) => v3.validator_to_index.get(account_id), + Self::V4(v4) => v4.validator_to_index.get(account_id), } } @@ -844,6 +901,10 @@ pub mod epoch_info { .validator_to_index .get(account_id) .map(|validator_id| v3.validators[*validator_id as usize].clone()), + Self::V4(v4) => v4 + .validator_to_index + .get(account_id) + .map(|validator_id| v4.validators[*validator_id as usize].clone()), } } @@ -853,6 +914,7 @@ pub mod epoch_info { Self::V1(v1) => ValidatorStake::V1(v1.validators[validator_id as usize].clone()), Self::V2(v2) => v2.validators[validator_id as usize].clone(), Self::V3(v3) => v3.validators[validator_id as usize].clone(), + Self::V4(v4) => v4.validators[validator_id as usize].clone(), } } @@ -862,6 +924,7 @@ pub mod epoch_info { Self::V1(v1) => v1.fishermen_to_index.contains_key(account_id), Self::V2(v2) => v2.fishermen_to_index.contains_key(account_id), Self::V3(v3) => v3.fishermen_to_index.contains_key(account_id), + Self::V4(v4) => v4.fishermen_to_index.contains_key(account_id), } } @@ -878,6 +941,10 @@ pub mod epoch_info { .fishermen_to_index .get(account_id) .map(|validator_id| v3.fishermen[*validator_id as usize].clone()), + Self::V4(v4) => v4 + .fishermen_to_index + .get(account_id) + .map(|validator_id| v4.fishermen[*validator_id as usize].clone()), } } @@ -887,6 +954,7 @@ pub mod epoch_info { Self::V1(v1) => ValidatorStake::V1(v1.fishermen[fisherman_id as usize].clone()), Self::V2(v2) => v2.fishermen[fisherman_id as usize].clone(), Self::V3(v3) => v3.fishermen[fisherman_id as usize].clone(), + Self::V4(v4) => v4.fishermen[fisherman_id as usize].clone(), } } @@ -896,6 +964,7 @@ pub mod epoch_info { Self::V1(v1) => v1.validators.len(), Self::V2(v2) => v2.validators.len(), Self::V3(v3) => v3.validators.len(), + Self::V4(v4) => v4.validators.len(), } } @@ -913,6 +982,10 @@ pub mod epoch_info { let seed = Self::block_produce_seed(height, &v3.rng_seed); v3.block_producers_settlement[v3.block_producers_sampler.sample(seed)] } + Self::V4(v4) => { + let seed = Self::block_produce_seed(height, &v4.rng_seed); + v4.block_producers_settlement[v4.block_producers_sampler.sample(seed)] + } } } @@ -930,11 +1003,36 @@ pub mod epoch_info { } Self::V3(v3) => { let protocol_version = self.protocol_version(); - let seed = Self::chunk_produce_seed(protocol_version, v3, height, shard_id); + let seed = + Self::chunk_produce_seed(protocol_version, &v3.rng_seed, height, shard_id); let shard_id = shard_id as usize; let sample = v3.chunk_producers_sampler[shard_id].sample(seed); v3.chunk_producers_settlement[shard_id][sample] } + Self::V4(v4) => { + let protocol_version = self.protocol_version(); + let seed = + Self::chunk_produce_seed(protocol_version, &v4.rng_seed, height, shard_id); + let shard_id = shard_id as usize; + let sample = v4.chunk_producers_sampler[shard_id].sample(seed); + v4.chunk_producers_settlement[shard_id][sample] + } + } + } + + pub fn sample_chunk_validators( + &self, + height: BlockHeight, + ) -> Vec> { + // Chunk validator assignment was introduced with `V4`. + match &self { + Self::V1(_v1) => Default::default(), + Self::V2(_v2) => Default::default(), + Self::V3(_v4) => Default::default(), + Self::V4(v4) => { + let mut rng = Self::chunk_validate_rng(&v4.rng_seed, height); + v4.validator_mandates.sample(&mut rng) + } } } @@ -948,7 +1046,7 @@ pub mod epoch_info { fn chunk_produce_seed( protocol_version: ProtocolVersion, - epoch_info_v3: &EpochInfoV3, + seed: &RngSeed, height: BlockHeight, shard_id: ShardId, ) -> [u8; 32] { @@ -959,16 +1057,32 @@ pub mod epoch_info { // producer. This seed does not contain the shard id // so all shards will be produced by the same // validator. - Self::block_produce_seed(height, &epoch_info_v3.rng_seed) + Self::block_produce_seed(height, seed) } else { // 32 bytes from epoch_seed, 8 bytes from height, 8 bytes from shard_id let mut buffer = [0u8; 48]; - buffer[0..32].copy_from_slice(&epoch_info_v3.rng_seed); + buffer[0..32].copy_from_slice(seed); buffer[32..40].copy_from_slice(&height.to_le_bytes()); buffer[40..48].copy_from_slice(&shard_id.to_le_bytes()); hash(&buffer).0 } } + + /// Returns a new RNG obtained from combining the provided `seed` and `height`. + /// + /// The returned RNG can be used to shuffle slices via [`rand::seq::SliceRandom`]. + fn chunk_validate_rng(seed: &RngSeed, height: BlockHeight) -> ChaCha20Rng { + let mut buffer = [0u8; 40]; + buffer[0..32].copy_from_slice(seed); + buffer[32..40].copy_from_slice(&height.to_le_bytes()); + + // The recommended seed for cryptographic RNG's is `[u8; 32]` and some required traits + // are not implemented for larger seeds, see + // https://docs.rs/rand_core/0.6.2/rand_core/trait.SeedableRng.html#associated-types + // Therefore `buffer` is hashed to obtain a `[u8; 32]`. + let seed = hash(&buffer); + SeedableRng::from_seed(seed.0) + } } #[derive(BorshSerialize, BorshDeserialize)] diff --git a/core/primitives/src/lib.rs b/core/primitives/src/lib.rs index c5645c14e6c..315173ad97e 100644 --- a/core/primitives/src/lib.rs +++ b/core/primitives/src/lib.rs @@ -34,6 +34,7 @@ pub mod trie_key; pub mod types; mod upgrade_schedule; pub mod utils; +pub mod validator_mandates; pub mod validator_signer; pub mod version; pub mod views; diff --git a/core/primitives/src/types.rs b/core/primitives/src/types.rs index 1343c5aa51e..fd3898486c8 100644 --- a/core/primitives/src/types.rs +++ b/core/primitives/src/types.rs @@ -651,6 +651,35 @@ pub mod validator_stake { stake_next_epoch: if is_next_epoch { self.stake() } else { 0 }, } } + + // TODO add unit tests + // TODO if `ValidatorStake` is recalculated every epoch, this should be a field of (new) `ValidatorStakeV2`? + /// Returns the validator's number of mandates (rounded down) at `stake_per_seat`. + /// + /// It returns `u16` since it allows infallible conversion to `usize` and with [`u16::MAX`] + /// equalling 65_535 it should be sufficient to hold the number of mandates per validator. + /// + /// # Why `u16` should be sufficient + /// + /// As of October 2023, a [recommended lower bound] for the stake required per mandate is + /// 25k $NEAR. At this price, the validator with highest stake would have 1_888 mandates, + /// which is well below `u16::MAX`. + /// + /// From another point of view, with more than `u16::MAX` mandates for validators, sampling + /// mandates might become computationally too expensive. This might trigger an increase in + /// the required stake per mandate, bringing down the number of mandates per validator. + /// + /// [recommended lower bound]: https://near.zulipchat.com/#narrow/stream/407237-pagoda.2Fcore.2Fstateless-validation/topic/validator.20seat.20assignment/near/393792901 + /// + /// # Panics + /// + /// Panics if the number of mandates overflows `u16`. + pub fn num_mandates(&self, stake_per_mandate: Balance) -> u16 { + // Integer division in Rust returns the floor as described here + // https://doc.rust-lang.org/std/primitive.u64.html#method.div_euclid + u16::try_from(self.stake() / stake_per_mandate) + .expect("number of mandats should fit u16") + } } } diff --git a/core/primitives/src/validator_mandates.rs b/core/primitives/src/validator_mandates.rs new file mode 100644 index 00000000000..67cbbe20a31 --- /dev/null +++ b/core/primitives/src/validator_mandates.rs @@ -0,0 +1,239 @@ +use std::collections::HashMap; + +use crate::types::{validator_stake::ValidatorStake, ValidatorId}; +use borsh::{BorshDeserialize, BorshSerialize}; +use near_primitives_core::types::Balance; +use rand::{seq::SliceRandom, Rng}; + +/// Represents the configuration of [`ValidatorMandates`]. Its parameters are expected to remain +/// valid for one epoch. +#[derive( + BorshSerialize, BorshDeserialize, Default, Copy, Clone, Debug, PartialEq, Eq, serde::Serialize, +)] +pub struct ValidatorMandatesConfig { + /// The amount of stake that corresponds to one mandate. + stake_per_mandate: Balance, + /// The minimum number of mandates required per shard. + min_mandates_per_shard: usize, + /// The number of shards for the referenced epoch. + num_shards: usize, +} + +impl ValidatorMandatesConfig { + /// Constructs a new configuration. + /// + /// # Panics + /// + /// Panics in the following cases: + /// + /// - If `stake_per_mandate` is 0 as this would lead to division by 0. + /// - If `num_shards` is zero. + pub fn new( + stake_per_mandate: Balance, + min_mandates_per_shard: usize, + num_shards: usize, + ) -> Self { + assert!(stake_per_mandate > 0, "stake_per_mandate of 0 would lead to division by 0"); + assert!(num_shards > 0, "there should be at least one shard"); + Self { stake_per_mandate, min_mandates_per_shard, num_shards } + } +} + +/// The mandates for a set of validators given a [`ValidatorMandatesConfig`]. +#[derive( + BorshSerialize, BorshDeserialize, Default, Clone, Debug, PartialEq, Eq, serde::Serialize, +)] +pub struct ValidatorMandates { + /// The configuration applied to the mandates. + config: ValidatorMandatesConfig, + /// The id of a validator who holds `n >= 0` mandates occurs `n` times in the vector. + mandates: Vec, +} + +impl ValidatorMandates { + /// Initiates mandates corresponding to the provided `validators`. The validators must be sorted + /// by id in ascending order, so the validator with `ValidatorId` equal to `i` is given by + /// `validators[i]`. + /// + /// Only full mandates are assigned, partial mandates are dropped. For example, when the stake + /// required for a mandate is 5 and a validator has staked 12, then it will obtain 2 mandates. + pub fn new(config: ValidatorMandatesConfig, validators: &[ValidatorStake]) -> Self { + let num_mandates_per_validator: Vec = + validators.iter().map(|v| v.num_mandates(config.stake_per_mandate)).collect(); + let num_total_mandates = + num_mandates_per_validator.iter().map(|&num| usize::from(num)).sum(); + let mut mandates: Vec = Vec::with_capacity(num_total_mandates); + + for i in 0..validators.len() { + for _ in 0..num_mandates_per_validator[i] { + // Each validator's position corresponds to its id. + mandates.push(i as ValidatorId); + } + } + + let required_mandates = config.min_mandates_per_shard * config.num_shards; + if mandates.len() < required_mandates { + // TODO(chunk-validator-assignment) lower `stake_per_mandate` to reach enough mandates + panic!( + "not enough validator mandates: got {}, need {}", + mandates.len(), + required_mandates + ); + } + + Self { config, mandates } + } + + /// Returns a validator assignment obtained by shuffling mandates. + /// + /// It clones mandates since [`ValidatorMandates`] is supposed to be valid for an epoch, while a + /// new assignment is calculated at every height. + /// + /// TODO rewrite below docs + /// Returns the validator mandates assigned to `shard_id`, assuming the shard ids for all shards + /// are `[0, .., num_shards)`. + /// + /// The returned value maps `ValidatorId`s to the number of mandates they have been assigned for + /// `shard_id`. A validator whose id is not in the map has not been assigned to the shard. + pub fn sample(&self, seed: &mut R) -> Vec> + where + R: Rng + ?Sized, + { + // TODO put shuffling into a different fn and test it. + let mut shuffled_mandates = self.mandates.clone(); + shuffled_mandates.shuffle(seed); + + // Assign shuffled seat at position `i` to the shard with id `i % num_shards`. + let mut assignments_per_shard = Vec::with_capacity(self.config.num_shards); + for shard_id in 0..self.config.num_shards { + let mut assignments = HashMap::new(); + let mut idx = shard_id; + + while idx < shuffled_mandates.len() { + let id = shuffled_mandates[idx]; + assignments.entry(id).and_modify(|counter| *counter += 1).or_insert(1); + idx += self.config.num_shards; + } + + assignments_per_shard.push(assignments) + } + + assignments_per_shard + } +} + +#[cfg(test)] +mod tests { + use std::collections::HashMap; + + use near_crypto::PublicKey; + use near_primitives_core::types::Balance; + use rand::SeedableRng; + use rand_chacha::ChaCha8Rng; + + use crate::{ + types::validator_stake::ValidatorStake, types::ValidatorId, + validator_mandates::ValidatorMandatesConfig, + }; + + use super::ValidatorMandates; + + /// Returns a new, fixed RNG to be used only in tests. Using a fixed RNG facilitates testing as + /// it makes outcomes based on that RNG deterministic. + fn new_fixed_rng() -> ChaCha8Rng { + ChaCha8Rng::seed_from_u64(42) + } + + #[test] + fn test_validator_mandates_config_new() { + let stake_per_mandate = 10; + let min_mandates_per_shard = 400; + let num_shards = 4; + assert_eq!( + ValidatorMandatesConfig::new(stake_per_mandate, min_mandates_per_shard, num_shards), + ValidatorMandatesConfig { stake_per_mandate, min_mandates_per_shard, num_shards }, + ) + } + + // TODO test config::new panics + + /// Constructs some `ValidatorStakes` for usage in tests. + fn new_validator_stakes() -> Vec { + let new_vs = |account_id: &str, balance: Balance| -> ValidatorStake { + ValidatorStake::new( + account_id.parse().unwrap(), + PublicKey::empty(near_crypto::KeyType::ED25519), + balance, + ) + }; + + vec![ + new_vs("account_0", 30), + new_vs("account_1", 27), + new_vs("account_2", 9), + new_vs("account_3", 12), + new_vs("account_4", 35), + ] + } + + #[test] + fn test_validator_mandates_new() { + let validators = new_validator_stakes(); + let config = ValidatorMandatesConfig::new(10, 1, 4); + let mandates = ValidatorMandates::new(config, &validators); + + // At 10 stake per mandate, the first validator holds three mandates, and so on. + // Note that "account_2" holds no mandate as its stake is below the threshold. + let expected_mandates: Vec = vec![0, 0, 0, 1, 1, 3, 4, 4, 4]; + assert_eq!(mandates.mandates, expected_mandates); + } + + /* + // TODO remove when there is equivalent test for the private `shuffle` method. + #[test] + fn test_validator_mandates_sample() { + let validators = new_validator_stakes(); + let config = ValidatorMandatesConfig::new(10, 1, 4); + let mandates = ValidatorMandates::new(config, &validators); + let mut rng = new_fixed_rng(); + let assignment = mandates.sample(&mut rng); + let expected_assignment: Vec = vec![0, 1, 1, 4, 4, 4, 0, 3, 0]; + assert_eq!(assignment.shuffled_mandates, expected_assignment); + } + */ + + /// Test mandates per shard are collected correctly if `num_mandates % num_shards == 0`. + #[test] + fn test_assigned_validator_mandates_get_mandates_for_shard_even() { + let config = ValidatorMandatesConfig::new(10, 1, 3); + let expected_mandates_per_shards: Vec> = vec![ + HashMap::from([(0, 2), (4, 1)]), + HashMap::from([(1, 1), (3, 1), (4, 1)]), + HashMap::from([(0, 1), (1, 1), (4, 1)]), + ]; + assert_validator_mandates_sample(config, expected_mandates_per_shards); + } + + /// Test mandates per shard are collected correctly if `num_mandates % num_shards != 0`. + #[test] + fn test_assigned_validator_mandates_get_mandates_for_shard_uneven() { + let config = ValidatorMandatesConfig::new(10, 1, 2); + let expected_mandates_per_shards: Vec> = + vec![HashMap::from([(0, 3), (1, 1), (4, 1)]), HashMap::from([(1, 1), (4, 2), (3, 1)])]; + assert_validator_mandates_sample(config, expected_mandates_per_shards); + } + + /// Asserts mandates are per shard are collected correctly. + fn assert_validator_mandates_sample( + config: ValidatorMandatesConfig, + expected_mandates_per_shards: Vec>, + ) { + let validators = new_validator_stakes(); + let mandates = ValidatorMandates::new(config, &validators); + + let mut rng = new_fixed_rng(); + let mandates_per_shards = mandates.sample(&mut rng); + + assert_eq!(mandates_per_shards, expected_mandates_per_shards); + } +}