diff --git a/network/src/protocol/mod.rs b/network/src/protocol/mod.rs index 8ee81129a74e..e9253b6b5a19 100644 --- a/network/src/protocol/mod.rs +++ b/network/src/protocol/mod.rs @@ -335,9 +335,15 @@ enum CollatorState { impl CollatorState { fn send_key(&mut self, key: ValidatorId, mut f: F) { f(Message::ValidatorId(key)); - if let CollatorState::RolePending(role) = *self { - f(Message::CollatorRole(role)); - *self = CollatorState::Primed(Some(role)); + match self { + CollatorState::RolePending(role) => { + f(Message::CollatorRole(*role)); + *self = CollatorState::Primed(Some(*role)); + }, + CollatorState::Fresh => { + *self = CollatorState::Primed(None); + }, + CollatorState::Primed(_) => {}, } } diff --git a/network/src/protocol/tests.rs b/network/src/protocol/tests.rs index cefc8f126f03..8d7ed3bbb695 100644 --- a/network/src/protocol/tests.rs +++ b/network/src/protocol/tests.rs @@ -579,7 +579,7 @@ fn fetches_pov_block_from_gossip() { } #[test] -fn validator_sends_key_to_collator_on_status() { +fn validator_sends_key_and_role_to_collator_on_status() { let (service, _gossip, mut pool, worker_task) = test_setup(Config { collating_for: None }); let peer = PeerId::random(); @@ -602,7 +602,35 @@ fn validator_sends_key_to_collator_on_status() { }); let expected_msg = Message::ValidatorId(validator_id.clone()); - assert!(service.network_service.recorded.lock().notifications.iter().any(|(p, notification)| { + let validator_id_pos = service.network_service.recorded.lock().notifications.iter().position(|(p, notification)| { peer == *p && *notification == expected_msg - })); + }); + + let expected_msg = Message::CollatorRole(CollatorRole::Primary); + let collator_role_pos = service.network_service.recorded.lock().notifications.iter().position(|(p, notification)| { + peer == *p && *notification == expected_msg + }); + + assert!(validator_id_pos < collator_role_pos); +} + +#[test] +fn collator_state_send_key_updates_state_correctly() { + let mut state = CollatorState::Fresh; + state.send_key(Sr25519Keyring::Alice.public().into(), |_| {}); + assert!(matches!(state, CollatorState::Primed(None))); + + let mut state = CollatorState::RolePending(CollatorRole::Primary); + + let mut counter = 0; + state.send_key(Sr25519Keyring::Alice.public().into(), |msg| { + match (counter, msg) { + (0, Message::ValidatorId(_)) => { + counter += 1; + }, + (1, Message::CollatorRole(CollatorRole::Primary)) => {}, + err @ _ => panic!("Unexpected message: {:?}", err), + } + }); + assert!(matches!(state, CollatorState::Primed(Some(CollatorRole::Primary)))); }