diff --git a/Cargo.toml b/Cargo.toml index e23068b30d..46e9ee3c68 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -22,7 +22,7 @@ lru_time_cache = "~0.11.0" qp2p = "~0.9.16" rand = "~0.7.3" rand_chacha = "~0.2.2" -sn_messaging = { path= "../sn_messaging" } +sn_messaging = { git= "https://github.com/joshuef/sn_messaging/", branch = "Feat-SectionKeyOnMessage" } thiserror = "1.0.23" xor_name = "1.1.0" resource_proof = "0.8.0" diff --git a/src/routing/approved.rs b/src/routing/approved.rs index bb4ba6c1ed..d607f6ad88 100644 --- a/src/routing/approved.rs +++ b/src/routing/approved.rs @@ -40,8 +40,10 @@ use ed25519_dalek::Verifier; use itertools::Itertools; use resource_proof::ResourceProof; use sn_messaging::{ - client::Error as ClientError, - infrastructure::{GetSectionResponse, Query}, + infrastructure::{ + Error as InfrastructureError, GetSectionResponse, InfrastructureInformation, + Message as InfrastructureMessage, + }, node::NodeMessage, MessageType, }; @@ -198,25 +200,29 @@ impl Approved { Ok(commands) } - pub async fn handle_infrastructure_query( + pub async fn handle_infrastructure_message( &mut self, sender: SocketAddr, - message: Query, + message: InfrastructureMessage, ) -> Vec { match message { - Query::GetSectionRequest(name) => { + InfrastructureMessage::GetSectionRequest(name) => { debug!("Received GetSectionRequest({}) from {}", name, sender); let response = if self.section.prefix().matches(&name) { - GetSectionResponse::Success { - prefix: self.section.elders_info().prefix, - key: *self.section.chain().last_key(), - elders: self - .section - .elders_info() - .peers() - .map(|peer| (*peer.name(), *peer.addr())) - .collect(), + if let Ok(pk_set) = self.public_key_set() { + GetSectionResponse::Success(InfrastructureInformation { + prefix: self.section.elders_info().prefix, + pk_set, + elders: self + .section + .elders_info() + .peers() + .map(|peer| (*peer.name(), *peer.addr())) + .collect(), + }) + } else { + GetSectionResponse::SectionInfrastructureError(InfrastructureError::NoSectionPkSet) } } else { // If we are elder, we should know a section that is closer to `name` that us. @@ -228,28 +234,28 @@ impl Approved { let addrs = section.peers().map(Peer::addr).copied().collect(); GetSectionResponse::Redirect(addrs) }; - let response = Query::GetSectionResponse(response); + let response = InfrastructureMessage::GetSectionResponse(response); debug!("Sending {:?} to {}", response, sender); vec![Command::SendMessage { recipients: vec![sender], delivery_group_size: 1, - message: MessageType::InfrastructureQuery(response), + message: MessageType::InfrastructureMessage(response), }] } - Query::GetSectionResponse(_) => { + InfrastructureMessage::GetSectionResponse(_) => { if let Some(RelocateState::InProgress(tx)) = &mut self.relocate_state { trace!("Forwarding {:?} to the bootstrap task", message); let _ = tx - .send((MessageType::InfrastructureQuery(message), sender)) + .send((MessageType::InfrastructureMessage(message), sender)) .await; } vec![] } - Query::SectionKeyResponse(_) => { - error!("Shall not receive an error response to client"); - vec![] + InfrastructureMessage::InfrastructureError(_) => { + // TODO handle this... + unimplemented!() } } } @@ -1839,18 +1845,35 @@ impl Approved { Ok(Some(command)) } - pub fn check_key_status(&self, bls_pk: &bls::PublicKey) -> Result<(), ClientError> { + pub fn check_key_status(&self, bls_pk: &bls::PublicKey) -> Result<(), InfrastructureError> { if self.dkg_voter.has_ongoing_dkg() { - return Err(ClientError::DkgInProgress); + return Err( + InfrastructureError::DkgInProgress, + ); } if !self.section.chain().has_key(bls_pk) { - return Err(ClientError::UnrecognizedSectionKey); + return Err( + InfrastructureError::UnrecognizedSectionKey, + ); } if bls_pk != self.section.chain().last_key() { if let Ok(public_key_set) = self.public_key_set() { - return Err(ClientError::TargetSectionKeyIsNotCurrent(public_key_set)); + return Err( + InfrastructureError::TargetSectionInfoOutdated(InfrastructureInformation { + prefix: *self.section.prefix(), + pk_set: public_key_set, + elders: self + .section + .elders_info() + .peers() + .map(|peer| (*peer.name(), *peer.addr())) + .collect(), + }), + ); } else { - return Err(ClientError::DkgInProgress); + return Err( + InfrastructureError::DkgInProgress, + ); } } Ok(()) diff --git a/src/routing/bootstrap.rs b/src/routing/bootstrap.rs index 658a841ae1..2e55b09a41 100644 --- a/src/routing/bootstrap.rs +++ b/src/routing/bootstrap.rs @@ -23,7 +23,9 @@ use bytes::Bytes; use futures::future; use resource_proof::ResourceProof; use sn_messaging::{ - infrastructure::{GetSectionResponse, Query}, + infrastructure::{ + GetSectionResponse, InfrastructureInformation, Message as InfrastructureMessage, + }, node::NodeMessage, MessageType, WireMsg, }; @@ -149,11 +151,12 @@ impl<'a> State<'a> { let (response, sender) = self.receive_get_section_response().await?; match response { - GetSectionResponse::Success { + GetSectionResponse::Success(InfrastructureInformation { prefix, - key, + pk_set, elders, - } => { + }) => { + let key = pk_set.public_key(); info!( "Joining a section ({:b}), key: {:?}, elders: {:?} (given by {:?})", prefix, key, elders, sender @@ -167,6 +170,10 @@ impl<'a> State<'a> { ); bootstrap_addrs = new_bootstrap_addrs.to_vec(); } + GetSectionResponse::SectionInfrastructureError(error) => { + error!("Handle infrastructure error: {:?}", error); + // TODO: handle + } } } } @@ -186,11 +193,11 @@ impl<'a> State<'a> { None => self.node.name(), }; - let message = Query::GetSectionRequest(destination); + let message = InfrastructureMessage::GetSectionRequest(destination); let _ = self .send_tx - .send((MessageType::InfrastructureQuery(message), recipients)) + .send((MessageType::InfrastructureMessage(message), recipients)) .await; Ok(()) @@ -199,28 +206,30 @@ impl<'a> State<'a> { async fn receive_get_section_response(&mut self) -> Result<(GetSectionResponse, SocketAddr)> { while let Some((message, sender)) = self.recv_rx.next().await { match message { - MessageType::InfrastructureQuery(Query::GetSectionResponse(response)) => { - match response { - GetSectionResponse::Redirect(addrs) if addrs.is_empty() => { - error!("Invalid GetSectionResponse::Redirect: missing peers"); - continue; - } - GetSectionResponse::Success { prefix, .. } - if !prefix.matches(&self.node.name()) => - { - error!("Invalid GetSectionResponse::Success: bad prefix"); - continue; - } - GetSectionResponse::Redirect(_) | GetSectionResponse::Success { .. } => { - return Ok((response, sender)) - } + MessageType::InfrastructureMessage(InfrastructureMessage::GetSectionResponse( + response, + )) => match response { + GetSectionResponse::Redirect(addrs) if addrs.is_empty() => { + error!("Invalid GetSectionResponse::Redirect: missing peers"); + continue; } - } + GetSectionResponse::Success(InfrastructureInformation { prefix, .. }) + if !prefix.matches(&self.node.name()) => + { + error!("Invalid GetSectionResponse::Success: bad prefix"); + continue; + } + GetSectionResponse::Redirect(_) + | GetSectionResponse::Success { .. } + | GetSectionResponse::SectionInfrastructureError(_) => { + return Ok((response, sender)) + } + }, MessageType::NodeMessage(NodeMessage(msg_bytes)) => { let message = Message::from_bytes(Bytes::from(msg_bytes))?; self.backlog_message(message, sender) } - MessageType::InfrastructureQuery(_) + MessageType::InfrastructureMessage(_) | MessageType::ClientMessage(_) | MessageType::Ping => {} } @@ -377,7 +386,7 @@ impl<'a> State<'a> { } MessageType::Ping | MessageType::ClientMessage(_) - | MessageType::InfrastructureQuery(_) => continue, + | MessageType::InfrastructureMessage(_) => continue, }; match message.variant() { @@ -605,12 +614,12 @@ mod tests { let (message, recipients) = send_rx.try_recv()?; assert_eq!(recipients, [bootstrap_addr]); - assert_matches!(message, MessageType::InfrastructureQuery(Query::GetSectionRequest(name)) => { + assert_matches!(message, MessageType::InfrastructureMessage(InfrastructureMessage::GetSectionRequest(name)) => { assert_eq!(name, *peer.name()); }); // Send GetSectionResponse::Success - let message = Query::GetSectionResponse(GetSectionResponse::Success { + let message = InfrastructureMessage::GetSectionResponse(GetSectionResponse::Success { prefix: elders_info.prefix, key: pk, elders: elders_info @@ -618,7 +627,7 @@ mod tests { .map(|peer| (*peer.name(), *peer.addr())) .collect(), }); - recv_tx.try_send((MessageType::InfrastructureQuery(message), bootstrap_addr))?; + recv_tx.try_send((MessageType::InfrastructureMessage(message), bootstrap_addr))?; task::yield_now().await; // Receive JoinRequest @@ -685,17 +694,17 @@ mod tests { assert_eq!(recipients, vec![bootstrap_node.addr]); assert_matches!( message, - MessageType::InfrastructureQuery(Query::GetSectionRequest(_)) + MessageType::InfrastructureMessage(InfrastructureMessage::GetSectionRequest(_)) ); // Send GetSectionResponse::Redirect let new_bootstrap_addrs: Vec<_> = (0..ELDER_SIZE).map(|_| gen_addr()).collect(); - let message = Query::GetSectionResponse(GetSectionResponse::Redirect( + let message = InfrastructureMessage::GetSectionResponse(GetSectionResponse::Redirect( new_bootstrap_addrs.clone(), )); recv_tx.try_send(( - MessageType::InfrastructureQuery(message), + MessageType::InfrastructureMessage(message), bootstrap_node.addr, ))?; task::yield_now().await; @@ -706,7 +715,7 @@ mod tests { assert_eq!(recipients, new_bootstrap_addrs); assert_matches!( message, - MessageType::InfrastructureQuery(Query::GetSectionRequest(_)) + MessageType::InfrastructureMessage(InfrastructureMessage::GetSectionRequest(_)) ); Ok(()) @@ -739,23 +748,25 @@ mod tests { let (message, _) = send_rx.try_recv()?; assert_matches!( message, - MessageType::InfrastructureQuery(Query::GetSectionRequest(_)) + MessageType::InfrastructureMessage(InfrastructureMessage::GetSectionRequest(_)) ); - let message = Query::GetSectionResponse(GetSectionResponse::Redirect(vec![])); + let message = + InfrastructureMessage::GetSectionResponse(GetSectionResponse::Redirect(vec![])); recv_tx.try_send(( - MessageType::InfrastructureQuery(message), + MessageType::InfrastructureMessage(message), bootstrap_node.addr, ))?; task::yield_now().await; assert_matches!(send_rx.try_recv(), Err(TryRecvError::Empty)); let addrs = (0..ELDER_SIZE).map(|_| gen_addr()).collect(); - let message = Query::GetSectionResponse(GetSectionResponse::Redirect(addrs)); + let message = + InfrastructureMessage::GetSectionResponse(GetSectionResponse::Redirect(addrs)); recv_tx.try_send(( - MessageType::InfrastructureQuery(message), + MessageType::InfrastructureMessage(message), bootstrap_node.addr, ))?; task::yield_now().await; @@ -763,7 +774,7 @@ mod tests { let (message, _) = send_rx.try_recv()?; assert_matches!( message, - MessageType::InfrastructureQuery(Query::GetSectionRequest(_)) + MessageType::InfrastructureMessage(InfrastructureMessage::GetSectionRequest(_)) ); Ok(()) @@ -810,10 +821,10 @@ mod tests { let (message, _) = send_rx.try_recv()?; assert_matches!( message, - MessageType::InfrastructureQuery(Query::GetSectionRequest(_)) + MessageType::InfrastructureMessage(InfrastructureMessage::GetSectionRequest(_)) ); - let message = Query::GetSectionResponse(GetSectionResponse::Success { + let message = InfrastructureMessage::GetSectionResponse(GetSectionResponse::Success { prefix: bad_prefix, key: bls::SecretKey::random().public_key(), elders: (0..ELDER_SIZE) @@ -822,13 +833,13 @@ mod tests { }); recv_tx.try_send(( - MessageType::InfrastructureQuery(message), + MessageType::InfrastructureMessage(message), bootstrap_node.addr, ))?; task::yield_now().await; assert_matches!(send_rx.try_recv(), Err(TryRecvError::Empty)); - let message = Query::GetSectionResponse(GetSectionResponse::Success { + let message = InfrastructureMessage::GetSectionResponse(GetSectionResponse::Success { prefix: good_prefix, key: bls::SecretKey::random().public_key(), elders: (0..ELDER_SIZE) @@ -837,7 +848,7 @@ mod tests { }); recv_tx.try_send(( - MessageType::InfrastructureQuery(message), + MessageType::InfrastructureMessage(message), bootstrap_node.addr, ))?; diff --git a/src/routing/command.rs b/src/routing/command.rs index 1d0382b48b..d65c116be8 100644 --- a/src/routing/command.rs +++ b/src/routing/command.rs @@ -16,7 +16,9 @@ use crate::{ use bls_signature_aggregator::Proof; use bytes::Bytes; use hex_fmt::HexFmt; -use sn_messaging::{infrastructure::Query, node::NodeMessage, MessageType}; +use sn_messaging::{ + infrastructure::Message as InfrastructureMessage, node::NodeMessage, MessageType, +}; use std::{ fmt::{self, Debug, Formatter}, net::SocketAddr, @@ -37,7 +39,10 @@ pub(crate) enum Command { message: Message, }, /// Handle infrastructure query message. - HandleInfrastructureQuery { sender: SocketAddr, message: Query }, + HandleInfrastructureMessage { + sender: SocketAddr, + message: InfrastructureMessage, + }, /// Handle a timeout previously scheduled with `ScheduleTimeout`. HandleTimeout(u64), /// Handle lost connection to a peer. @@ -116,8 +121,8 @@ impl Debug for Command { .field("sender", sender) .field("message", message) .finish(), - Self::HandleInfrastructureQuery { sender, message } => f - .debug_struct("HandleInfrastructureQuery") + Self::HandleInfrastructureMessage { sender, message } => f + .debug_struct("HandleInfrastructureMessage") .field("sender", sender) .field("message", message) .finish(), diff --git a/src/routing/mod.rs b/src/routing/mod.rs index c300508782..2b08c4af33 100644 --- a/src/routing/mod.rs +++ b/src/routing/mod.rs @@ -40,7 +40,8 @@ use bytes::Bytes; use ed25519_dalek::{Keypair, PublicKey, Signature, Signer}; use itertools::Itertools; use sn_messaging::{ - client::MsgEnvelope, infrastructure::Query, node::NodeMessage, MessageType, WireMsg, + client::MsgEnvelope, infrastructure::{ErrorResponse, Message as InfrastructureMessage}, + node::NodeMessage, MessageType, WireMsg, }; use std::{net::SocketAddr, sync::Arc}; use tokio::{sync::mpsc, task}; @@ -388,8 +389,8 @@ async fn handle_message(stage: Arc, bytes: Bytes, sender: SocketAddr) { MessageType::Ping => { // Pings are not handled } - MessageType::InfrastructureQuery(message) => { - let command = Command::HandleInfrastructureQuery { sender, message }; + MessageType::InfrastructureMessage(message) => { + let command = Command::HandleInfrastructureMessage { sender, message }; let _ = task::spawn(stage.handle_commands(command)); } MessageType::NodeMessage(NodeMessage(msg_bytes)) => { @@ -413,12 +414,18 @@ async fn handle_message(stage: Arc, bytes: Bytes, sender: SocketAddr) { if let Some(client_pk) = msg_envelope.message.target_section_pk() { if let Some(bls_pk) = client_pk.bls() { if let Err(error) = stage.check_key_status(&bls_pk).await { + let incoming_msg = msg_envelope.message; + let correlation_id = incoming_msg.id(); + let command = Command::SendMessage { recipients: vec![sender], delivery_group_size: 1, - message: MessageType::InfrastructureQuery(Query::SectionKeyResponse( - error, - )), + message: MessageType::InfrastructureMessage( + InfrastructureMessage::InfrastructureError(ErrorResponse { + correlation_id, + error, + }), + ), }; let _ = task::spawn(stage.handle_commands(command)); return; diff --git a/src/routing/stage.rs b/src/routing/stage.rs index e1a3033212..f4e849ec24 100644 --- a/src/routing/stage.rs +++ b/src/routing/stage.rs @@ -8,7 +8,7 @@ use super::{bootstrap, Approved, Comm, Command}; use crate::{error::Result, event::Event, relocation::SignedRelocateDetails}; -use sn_messaging::{client::Error as ClientError, MessageType}; +use sn_messaging::{MessageType, infrastructure::Error as InfrastructureError}; use std::{net::SocketAddr, sync::Arc, time::Duration}; use tokio::{ sync::{mpsc, watch, Mutex}, @@ -101,11 +101,11 @@ impl Stage { .handle_message(sender, message) .await } - Command::HandleInfrastructureQuery { sender, message } => Ok(self + Command::HandleInfrastructureMessage { sender, message } => Ok(self .state .lock() .await - .handle_infrastructure_query(sender, message) + .handle_infrastructure_message(sender, message) .await), Command::HandleTimeout(token) => self.state.lock().await.handle_timeout(token), Command::HandleVote { vote, proof_share } => { @@ -175,7 +175,7 @@ impl Stage { let _ = tokio::spawn(self.handle_commands(command)); } - pub async fn check_key_status(&self, bls_pk: &bls::PublicKey) -> Result<(), ClientError> { + pub async fn check_key_status(&self, bls_pk: &bls::PublicKey) -> Result<(), InfrastructureError> { self.state.lock().await.check_key_status(bls_pk) } @@ -209,7 +209,7 @@ impl Stage { } vec![] } - MessageType::InfrastructureQuery(_) => { + MessageType::InfrastructureMessage(_) => { for recipient in recipients { let _ = self .comm diff --git a/src/routing/tests/mod.rs b/src/routing/tests/mod.rs index bb1ab5ed0a..839354c661 100644 --- a/src/routing/tests/mod.rs +++ b/src/routing/tests/mod.rs @@ -57,7 +57,7 @@ async fn receive_matching_get_section_request_as_elder() -> Result<()> { let message = Query::GetSectionRequest(new_node.name()); let mut commands = stage - .handle_command(Command::HandleInfrastructureQuery { + .handle_command(Command::HandleInfrastructureMessage { sender: new_node.addr, message, }) @@ -68,7 +68,7 @@ async fn receive_matching_get_section_request_as_elder() -> Result<()> { commands.next(), Some(Command::SendMessage { recipients, - message: MessageType::InfrastructureQuery(message), .. + message: MessageType::InfrastructureMessage(message), .. }) => (recipients, message) ); @@ -102,7 +102,7 @@ async fn receive_mismatching_get_section_request_as_adult() -> Result<()> { let message = Query::GetSectionRequest(new_node_name); let mut commands = stage - .handle_command(Command::HandleInfrastructureQuery { + .handle_command(Command::HandleInfrastructureMessage { sender: new_node_addr, message, }) @@ -113,7 +113,7 @@ async fn receive_mismatching_get_section_request_as_adult() -> Result<()> { commands.next(), Some(Command::SendMessage { recipients, - message: MessageType::InfrastructureQuery(message), .. + message: MessageType::InfrastructureMessage(message), .. }) => (recipients, message) );