From 99d0655bb9cd0cbf54aa57da5f8b943bb621ba31 Mon Sep 17 00:00:00 2001 From: piaoliu <441594700@qq.com> Date: Sat, 13 Apr 2019 13:05:50 +0800 Subject: [PATCH 1/2] feat: support gracefully shutting down the service... --- secio/src/handshake/handshake_struct.rs | 14 ++- src/context.rs | 13 +++ src/service.rs | 44 +++++--- src/service/config.rs | 78 +++++++++++++++ src/service/control.rs | 11 ++ src/service/event.rs | 3 + tests/test_close.rs | 127 ++++++++++++++++++++++++ 7 files changed, 273 insertions(+), 17 deletions(-) create mode 100644 tests/test_close.rs diff --git a/secio/src/handshake/handshake_struct.rs b/secio/src/handshake/handshake_struct.rs index 0007b15e..29e6b7bc 100644 --- a/secio/src/handshake/handshake_struct.rs +++ b/secio/src/handshake/handshake_struct.rs @@ -7,6 +7,8 @@ use crate::peer_id::PeerId; use flatbuffers::FlatBufferBuilder; use flatbuffers_verifier::get_root; +use std::fmt; + #[derive(Clone, Default, PartialEq, Ord, PartialOrd, Eq, Debug)] pub struct Propose { pub(crate) rand: Vec, @@ -107,7 +109,7 @@ impl Exchange { } /// Public Key -#[derive(Clone, Debug, PartialEq, Ord, PartialOrd, Eq, Hash)] +#[derive(Clone, PartialEq, Ord, PartialOrd, Eq, Hash)] pub enum PublicKey { /// Secp256k1 Secp256k1(Vec), @@ -153,6 +155,16 @@ impl PublicKey { } } +impl fmt::Debug for PublicKey { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "0x")?; + for byte in self.inner_ref() { + write!(f, "{:02x}", byte)?; + } + Ok(()) + } +} + #[cfg(test)] mod tests { use super::{Exchange, Propose, PublicKey}; diff --git a/src/context.rs b/src/context.rs index c700753e..939ee2d9 100644 --- a/src/context.rs +++ b/src/context.rs @@ -223,6 +223,19 @@ impl ServiceContext { } } + /// Shutdown service. + /// + /// Order: + /// 1. close all listens + /// 2. try close all session's protocol stream + /// 3. try close all session + /// 4. close service + pub fn shutdown(&mut self) { + if self.inner.shutdown().is_err() { + warn!("Service is abnormally closed") + } + } + pub(crate) fn clone_self(&self) -> Self { ServiceContext { inner: self.inner.clone(), diff --git a/src/service.rs b/src/service.rs index ea17d118..2ea79bbe 100644 --- a/src/service.rs +++ b/src/service.rs @@ -21,7 +21,7 @@ use crate::{ protocol_select::ProtocolInfo, secio::{handshake::Config, PublicKey, SecioKeyPair}, service::{ - config::ServiceConfig, + config::{ServiceConfig, State}, event::ServiceTask, future_task::{BoxedFutureTask, FutureTaskManager}, }, @@ -66,9 +66,8 @@ pub struct Service { dial_protocols: HashMap, config: ServiceConfig, - /// Calculate the number of connection requests that need to be sent externally, - /// if run forever, it will default to 1, else it default to 0 - task_count: usize, + /// service state + state: State, next_session: SessionId, @@ -153,7 +152,7 @@ where listens: Vec::new(), dial_protocols: HashMap::default(), config, - task_count: if forever { 1 } else { 0 }, + state: State::new(forever), next_session: 0, write_buf: VecDeque::default(), read_service_buf: VecDeque::default(), @@ -228,7 +227,7 @@ where self.pending_tasks.push_back(ServiceTask::FutureTask { task: Box::new(task), }); - self.task_count += 1; + self.state.add(); Ok(listen_addr) } @@ -284,7 +283,7 @@ where self.pending_tasks.push_back(ServiceTask::FutureTask { task: Box::new(task), }); - self.task_count += 1; + self.state.add(); Ok(()) } @@ -693,7 +692,7 @@ where H: AsyncRead + AsyncWrite + Send + 'static, { if ty.is_outbound() { - self.task_count -= 1; + self.state.minus(); } let target = self .dial_protocols @@ -1118,7 +1117,7 @@ where } SessionEvent::HandshakeFail { ty, error, address } => { if ty.is_outbound() { - self.task_count -= 1; + self.state.minus(); self.dial_protocols.remove(&address); self.handle.handle_error( &mut self.service_context, @@ -1161,7 +1160,7 @@ where }, ), SessionEvent::DialError { address, error } => { - self.task_count -= 1; + self.state.minus(); self.dial_protocols.remove(&address); self.handle.handle_error( &mut self.service_context, @@ -1169,7 +1168,7 @@ where ) } SessionEvent::ListenError { address, error } => { - self.task_count -= 1; + self.state.minus(); self.handle.handle_error( &mut self.service_context, ServiceError::ListenError { address, error }, @@ -1207,7 +1206,7 @@ where }, ); self.listens.push((listen_address, incoming)); - self.task_count -= 1; + self.state.minus(); self.update_listens(); self.listen_poll(); } @@ -1430,6 +1429,19 @@ where session_id, proto_id, } => self.protocol_close(session_id, proto_id, Source::External), + ServiceTask::Shutdown => { + let ids = self.sessions.keys().cloned().collect::>(); + ids.into_iter() + .for_each(|i| self.session_close(i, Source::External)); + self.state.pre_shutdown(); + while let Some((address, _)) = self.listens.pop() { + self.handle.handle_event( + &mut self.service_context, + ServiceEvent::ListenClose { address }, + ) + } + self.pending_tasks.clear(); + } } } @@ -1492,7 +1504,7 @@ where fn poll(&mut self) -> Poll, Self::Error> { if self.listens.is_empty() - && self.task_count == 0 + && self.state.is_shutdown() && self.sessions.is_empty() && self.pending_tasks.is_empty() { @@ -1542,16 +1554,16 @@ where // Double check service state if self.listens.is_empty() - && self.task_count == 0 + && self.state.is_shutdown() && self.sessions.is_empty() && self.pending_tasks.is_empty() { return Ok(Async::Ready(None)); } debug!( - "listens count: {}, task_count: {}, sessions count: {}, pending task: {}", + "listens count: {}, state: {:?}, sessions count: {}, pending task: {}", self.listens.len(), - self.task_count, + self.state, self.sessions.len(), self.pending_tasks.len(), ); diff --git a/src/service/config.rs b/src/service/config.rs index 1f32f36b..17dcf307 100644 --- a/src/service/config.rs +++ b/src/service/config.rs @@ -180,3 +180,81 @@ impl ProtocolHandle { self.is_event() || self.is_both() } } + +/// Service state +#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)] +pub enum State { + /// Calculate the number of connection requests that need to be sent externally, + /// if run forever, it will default to 1, else it default to 0 + Running(usize), + PreShutdown, +} + +impl State { + /// new + pub fn new(forever: bool) -> Self { + if forever { + State::Running(1) + } else { + State::Running(0) + } + } + + /// Can it be shutdown? + #[inline] + pub fn is_shutdown(&self) -> bool { + match self { + State::Running(num) if num == &0 => true, + State::PreShutdown => true, + State::Running(_) => false, + } + } + + /// Convert to pre shutdown state + #[inline] + pub fn pre_shutdown(&mut self) { + *self = State::PreShutdown + } + + /// Add one task count + #[inline] + pub fn add(&mut self) { + match self { + State::Running(num) => *num += 1, + State::PreShutdown => (), + } + } + + /// Reduce one task count + #[inline] + pub fn minus(&mut self) { + match self { + State::Running(num) => *num -= 1, + State::PreShutdown => (), + } + } +} + +#[cfg(test)] +mod test { + use super::State; + + #[test] + fn test_state() { + let mut state = State::new(true); + state.add(); + state.add(); + assert_eq!(state, State::Running(3)); + state.minus(); + state.minus(); + assert_eq!(state, State::Running(1)); + state.minus(); + assert_eq!(state, State::Running(0)); + state.add(); + state.add(); + state.add(); + state.add(); + state.pre_shutdown(); + assert_eq!(state, State::PreShutdown); + } +} diff --git a/src/service/control.rs b/src/service/control.rs index f2e56bd2..f26f092e 100644 --- a/src/service/control.rs +++ b/src/service/control.rs @@ -177,4 +177,15 @@ impl ServiceControl { token, }) } + + /// Shutdown service + /// + /// Order: + /// 1. close all listens + /// 2. try close all session's protocol stream + /// 3. try close all session + /// 4. close service + pub fn shutdown(&mut self) -> Result<(), Error> { + self.send(ServiceTask::Shutdown) + } } diff --git a/src/service/event.rs b/src/service/event.rs index 5cf7ed48..a248ab69 100644 --- a/src/service/event.rs +++ b/src/service/event.rs @@ -240,6 +240,8 @@ pub(crate) enum ServiceTask { /// Listen address address: Multiaddr, }, + /// Shutdown service + Shutdown, } impl fmt::Debug for ServiceTask { @@ -305,6 +307,7 @@ impl fmt::Debug for ServiceTask { session_id, proto_id, } => write!(f, "Close session [{}] proto [{}]", session_id, proto_id), + Shutdown => write!(f, "Try close service"), } } } diff --git a/tests/test_close.rs b/tests/test_close.rs new file mode 100644 index 00000000..bab5da19 --- /dev/null +++ b/tests/test_close.rs @@ -0,0 +1,127 @@ +use futures::prelude::Stream; +use tentacle::{ + builder::{MetaBuilder, ServiceBuilder}, + context::{ProtocolContext, ProtocolContextMutRef}, + secio::SecioKeyPair, + service::{DialProtocol, ProtocolHandle, ProtocolMeta, Service}, + traits::{ServiceHandle, ServiceProtocol}, + ProtocolId, +}; + +use std::{thread, time::Duration}; + +pub fn create(secio: bool, metas: impl Iterator, shandle: F) -> Service +where + F: ServiceHandle, +{ + let mut builder = ServiceBuilder::default().forever(true); + + for meta in metas { + builder = builder.insert_protocol(meta); + } + + if secio { + builder + .key_pair(SecioKeyPair::secp256k1_generated()) + .build(shandle) + } else { + builder.build(shandle) + } +} + +struct PHandle { + count: u8, +} + +impl ServiceProtocol for PHandle { + fn init(&mut self, _context: &mut ProtocolContext) {} + + fn connected(&mut self, mut context: ProtocolContextMutRef, _version: &str) { + if context.session.ty.is_inbound() && context.proto_id == 1 { + self.count += 1; + if self.count >= 4 { + let proto_id = context.proto_id; + context.set_service_notify(proto_id, Duration::from_secs(2), 0); + } + } + } + + fn notify(&mut self, context: &mut ProtocolContext, _token: u64) { + self.count += 1; + if self.count > 6 { + context.shutdown(); + } + } +} + +fn create_meta(id: ProtocolId) -> ProtocolMeta { + MetaBuilder::new() + .id(id) + .service_handle(move || { + let handle = Box::new(PHandle { count: 0 }); + ProtocolHandle::Callback(handle) + }) + .build() +} + +fn test_close(secio: bool) { + let mut service_1 = create( + secio, + vec![create_meta(0), create_meta(1), create_meta(2)].into_iter(), + (), + ); + let mut service_2 = create( + secio, + vec![create_meta(0), create_meta(1), create_meta(2)].into_iter(), + (), + ); + let mut service_3 = create( + secio, + vec![create_meta(0), create_meta(1), create_meta(2)].into_iter(), + (), + ); + let mut service_4 = create( + secio, + vec![create_meta(0), create_meta(1), create_meta(2)].into_iter(), + (), + ); + let mut service_5 = create( + secio, + vec![create_meta(0), create_meta(1), create_meta(2)].into_iter(), + (), + ); + + let listen_addr = service_1 + .listen("/ip4/127.0.0.1/tcp/0".parse().unwrap()) + .unwrap(); + + let handle = thread::spawn(|| tokio::run(service_1.for_each(|_| Ok(())))); + + service_2 + .dial(listen_addr.clone(), DialProtocol::All) + .unwrap(); + service_3 + .dial(listen_addr.clone(), DialProtocol::All) + .unwrap(); + service_4 + .dial(listen_addr.clone(), DialProtocol::All) + .unwrap(); + service_5.dial(listen_addr, DialProtocol::All).unwrap(); + + thread::spawn(|| tokio::run(service_2.for_each(|_| Ok(())))); + thread::spawn(|| tokio::run(service_3.for_each(|_| Ok(())))); + thread::spawn(|| tokio::run(service_4.for_each(|_| Ok(())))); + thread::spawn(|| tokio::run(service_5.for_each(|_| Ok(())))); + + handle.join().expect("test fail"); +} + +#[test] +fn test_close_with_secio() { + test_close(true) +} + +#[test] +fn test_close_with_no_secio() { + test_close(false) +} From b71fbaaea80ae0eaf531157905f853ecfbe0509e Mon Sep 17 00:00:00 2001 From: piaoliu <441594700@qq.com> Date: Sat, 13 Apr 2019 13:41:39 +0800 Subject: [PATCH 2/2] chore: change state --- src/service.rs | 24 ++++++++++-------- src/service/config.rs | 57 +++++++++++++++++++++++++++---------------- 2 files changed, 50 insertions(+), 31 deletions(-) diff --git a/src/service.rs b/src/service.rs index 2ea79bbe..1c896490 100644 --- a/src/service.rs +++ b/src/service.rs @@ -227,7 +227,7 @@ where self.pending_tasks.push_back(ServiceTask::FutureTask { task: Box::new(task), }); - self.state.add(); + self.state.increase(); Ok(listen_addr) } @@ -283,7 +283,7 @@ where self.pending_tasks.push_back(ServiceTask::FutureTask { task: Box::new(task), }); - self.state.add(); + self.state.increase(); Ok(()) } @@ -692,7 +692,7 @@ where H: AsyncRead + AsyncWrite + Send + 'static, { if ty.is_outbound() { - self.state.minus(); + self.state.decrease(); } let target = self .dial_protocols @@ -1117,7 +1117,7 @@ where } SessionEvent::HandshakeFail { ty, error, address } => { if ty.is_outbound() { - self.state.minus(); + self.state.decrease(); self.dial_protocols.remove(&address); self.handle.handle_error( &mut self.service_context, @@ -1160,7 +1160,7 @@ where }, ), SessionEvent::DialError { address, error } => { - self.state.minus(); + self.state.decrease(); self.dial_protocols.remove(&address); self.handle.handle_error( &mut self.service_context, @@ -1168,7 +1168,7 @@ where ) } SessionEvent::ListenError { address, error } => { - self.state.minus(); + self.state.decrease(); self.handle.handle_error( &mut self.service_context, ServiceError::ListenError { address, error }, @@ -1206,7 +1206,7 @@ where }, ); self.listens.push((listen_address, incoming)); - self.state.minus(); + self.state.decrease(); self.update_listens(); self.listen_poll(); } @@ -1430,11 +1430,15 @@ where proto_id, } => self.protocol_close(session_id, proto_id, Source::External), ServiceTask::Shutdown => { - let ids = self.sessions.keys().cloned().collect::>(); - ids.into_iter() + self.sessions + .keys() + .cloned() + .collect::>() + .into_iter() .for_each(|i| self.session_close(i, Source::External)); self.state.pre_shutdown(); - while let Some((address, _)) = self.listens.pop() { + while let Some((address, incoming)) = self.listens.pop() { + drop(incoming); self.handle.handle_event( &mut self.service_context, ServiceEvent::ListenClose { address }, diff --git a/src/service/config.rs b/src/service/config.rs index 17dcf307..32ec8805 100644 --- a/src/service/config.rs +++ b/src/service/config.rs @@ -184,9 +184,9 @@ impl ProtocolHandle { /// Service state #[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)] pub enum State { - /// Calculate the number of connection requests that need to be sent externally, - /// if run forever, it will default to 1, else it default to 0 + /// Calculate the number of connection requests that need to be sent externally Running(usize), + Forever, PreShutdown, } @@ -194,7 +194,7 @@ impl State { /// new pub fn new(forever: bool) -> Self { if forever { - State::Running(1) + State::Forever } else { State::Running(0) } @@ -206,7 +206,7 @@ impl State { match self { State::Running(num) if num == &0 => true, State::PreShutdown => true, - State::Running(_) => false, + State::Running(_) | State::Forever => false, } } @@ -218,19 +218,19 @@ impl State { /// Add one task count #[inline] - pub fn add(&mut self) { + pub fn increase(&mut self) { match self { State::Running(num) => *num += 1, - State::PreShutdown => (), + State::PreShutdown | State::Forever => (), } } /// Reduce one task count #[inline] - pub fn minus(&mut self) { + pub fn decrease(&mut self) { match self { State::Running(num) => *num -= 1, - State::PreShutdown => (), + State::PreShutdown | State::Forever => (), } } } @@ -240,20 +240,35 @@ mod test { use super::State; #[test] - fn test_state() { - let mut state = State::new(true); - state.add(); - state.add(); - assert_eq!(state, State::Running(3)); - state.minus(); - state.minus(); - assert_eq!(state, State::Running(1)); - state.minus(); + fn test_state_no_forever() { + let mut state = State::new(false); + state.increase(); + state.increase(); + assert_eq!(state, State::Running(2)); + state.decrease(); + state.decrease(); assert_eq!(state, State::Running(0)); - state.add(); - state.add(); - state.add(); - state.add(); + state.increase(); + state.increase(); + state.increase(); + state.increase(); + state.pre_shutdown(); + assert_eq!(state, State::PreShutdown); + } + + #[test] + fn test_state_forever() { + let mut state = State::new(true); + state.increase(); + state.increase(); + assert_eq!(state, State::Forever); + state.decrease(); + state.decrease(); + assert_eq!(state, State::Forever); + state.increase(); + state.increase(); + state.increase(); + state.increase(); state.pre_shutdown(); assert_eq!(state, State::PreShutdown); }