diff --git a/secio/src/handshake/handshake_struct.rs b/secio/src/handshake/handshake_struct.rs index 0007b15e..29e6b7bc 100644 --- a/secio/src/handshake/handshake_struct.rs +++ b/secio/src/handshake/handshake_struct.rs @@ -7,6 +7,8 @@ use crate::peer_id::PeerId; use flatbuffers::FlatBufferBuilder; use flatbuffers_verifier::get_root; +use std::fmt; + #[derive(Clone, Default, PartialEq, Ord, PartialOrd, Eq, Debug)] pub struct Propose { pub(crate) rand: Vec, @@ -107,7 +109,7 @@ impl Exchange { } /// Public Key -#[derive(Clone, Debug, PartialEq, Ord, PartialOrd, Eq, Hash)] +#[derive(Clone, PartialEq, Ord, PartialOrd, Eq, Hash)] pub enum PublicKey { /// Secp256k1 Secp256k1(Vec), @@ -153,6 +155,16 @@ impl PublicKey { } } +impl fmt::Debug for PublicKey { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "0x")?; + for byte in self.inner_ref() { + write!(f, "{:02x}", byte)?; + } + Ok(()) + } +} + #[cfg(test)] mod tests { use super::{Exchange, Propose, PublicKey}; diff --git a/src/context.rs b/src/context.rs index c700753e..939ee2d9 100644 --- a/src/context.rs +++ b/src/context.rs @@ -223,6 +223,19 @@ impl ServiceContext { } } + /// Shutdown service. + /// + /// Order: + /// 1. close all listens + /// 2. try close all session's protocol stream + /// 3. try close all session + /// 4. close service + pub fn shutdown(&mut self) { + if self.inner.shutdown().is_err() { + warn!("Service is abnormally closed") + } + } + pub(crate) fn clone_self(&self) -> Self { ServiceContext { inner: self.inner.clone(), diff --git a/src/service.rs b/src/service.rs index ea17d118..1c896490 100644 --- a/src/service.rs +++ b/src/service.rs @@ -21,7 +21,7 @@ use crate::{ protocol_select::ProtocolInfo, secio::{handshake::Config, PublicKey, SecioKeyPair}, service::{ - config::ServiceConfig, + config::{ServiceConfig, State}, event::ServiceTask, future_task::{BoxedFutureTask, FutureTaskManager}, }, @@ -66,9 +66,8 @@ pub struct Service { dial_protocols: HashMap, config: ServiceConfig, - /// Calculate the number of connection requests that need to be sent externally, - /// if run forever, it will default to 1, else it default to 0 - task_count: usize, + /// service state + state: State, next_session: SessionId, @@ -153,7 +152,7 @@ where listens: Vec::new(), dial_protocols: HashMap::default(), config, - task_count: if forever { 1 } else { 0 }, + state: State::new(forever), next_session: 0, write_buf: VecDeque::default(), read_service_buf: VecDeque::default(), @@ -228,7 +227,7 @@ where self.pending_tasks.push_back(ServiceTask::FutureTask { task: Box::new(task), }); - self.task_count += 1; + self.state.increase(); Ok(listen_addr) } @@ -284,7 +283,7 @@ where self.pending_tasks.push_back(ServiceTask::FutureTask { task: Box::new(task), }); - self.task_count += 1; + self.state.increase(); Ok(()) } @@ -693,7 +692,7 @@ where H: AsyncRead + AsyncWrite + Send + 'static, { if ty.is_outbound() { - self.task_count -= 1; + self.state.decrease(); } let target = self .dial_protocols @@ -1118,7 +1117,7 @@ where } SessionEvent::HandshakeFail { ty, error, address } => { if ty.is_outbound() { - self.task_count -= 1; + self.state.decrease(); self.dial_protocols.remove(&address); self.handle.handle_error( &mut self.service_context, @@ -1161,7 +1160,7 @@ where }, ), SessionEvent::DialError { address, error } => { - self.task_count -= 1; + self.state.decrease(); self.dial_protocols.remove(&address); self.handle.handle_error( &mut self.service_context, @@ -1169,7 +1168,7 @@ where ) } SessionEvent::ListenError { address, error } => { - self.task_count -= 1; + self.state.decrease(); self.handle.handle_error( &mut self.service_context, ServiceError::ListenError { address, error }, @@ -1207,7 +1206,7 @@ where }, ); self.listens.push((listen_address, incoming)); - self.task_count -= 1; + self.state.decrease(); self.update_listens(); self.listen_poll(); } @@ -1430,6 +1429,23 @@ where session_id, proto_id, } => self.protocol_close(session_id, proto_id, Source::External), + ServiceTask::Shutdown => { + self.sessions + .keys() + .cloned() + .collect::>() + .into_iter() + .for_each(|i| self.session_close(i, Source::External)); + self.state.pre_shutdown(); + while let Some((address, incoming)) = self.listens.pop() { + drop(incoming); + self.handle.handle_event( + &mut self.service_context, + ServiceEvent::ListenClose { address }, + ) + } + self.pending_tasks.clear(); + } } } @@ -1492,7 +1508,7 @@ where fn poll(&mut self) -> Poll, Self::Error> { if self.listens.is_empty() - && self.task_count == 0 + && self.state.is_shutdown() && self.sessions.is_empty() && self.pending_tasks.is_empty() { @@ -1542,16 +1558,16 @@ where // Double check service state if self.listens.is_empty() - && self.task_count == 0 + && self.state.is_shutdown() && self.sessions.is_empty() && self.pending_tasks.is_empty() { return Ok(Async::Ready(None)); } debug!( - "listens count: {}, task_count: {}, sessions count: {}, pending task: {}", + "listens count: {}, state: {:?}, sessions count: {}, pending task: {}", self.listens.len(), - self.task_count, + self.state, self.sessions.len(), self.pending_tasks.len(), ); diff --git a/src/service/config.rs b/src/service/config.rs index 1f32f36b..32ec8805 100644 --- a/src/service/config.rs +++ b/src/service/config.rs @@ -180,3 +180,96 @@ impl ProtocolHandle { self.is_event() || self.is_both() } } + +/// Service state +#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)] +pub enum State { + /// Calculate the number of connection requests that need to be sent externally + Running(usize), + Forever, + PreShutdown, +} + +impl State { + /// new + pub fn new(forever: bool) -> Self { + if forever { + State::Forever + } else { + State::Running(0) + } + } + + /// Can it be shutdown? + #[inline] + pub fn is_shutdown(&self) -> bool { + match self { + State::Running(num) if num == &0 => true, + State::PreShutdown => true, + State::Running(_) | State::Forever => false, + } + } + + /// Convert to pre shutdown state + #[inline] + pub fn pre_shutdown(&mut self) { + *self = State::PreShutdown + } + + /// Add one task count + #[inline] + pub fn increase(&mut self) { + match self { + State::Running(num) => *num += 1, + State::PreShutdown | State::Forever => (), + } + } + + /// Reduce one task count + #[inline] + pub fn decrease(&mut self) { + match self { + State::Running(num) => *num -= 1, + State::PreShutdown | State::Forever => (), + } + } +} + +#[cfg(test)] +mod test { + use super::State; + + #[test] + fn test_state_no_forever() { + let mut state = State::new(false); + state.increase(); + state.increase(); + assert_eq!(state, State::Running(2)); + state.decrease(); + state.decrease(); + assert_eq!(state, State::Running(0)); + state.increase(); + state.increase(); + state.increase(); + state.increase(); + state.pre_shutdown(); + assert_eq!(state, State::PreShutdown); + } + + #[test] + fn test_state_forever() { + let mut state = State::new(true); + state.increase(); + state.increase(); + assert_eq!(state, State::Forever); + state.decrease(); + state.decrease(); + assert_eq!(state, State::Forever); + state.increase(); + state.increase(); + state.increase(); + state.increase(); + state.pre_shutdown(); + assert_eq!(state, State::PreShutdown); + } +} diff --git a/src/service/control.rs b/src/service/control.rs index f2e56bd2..f26f092e 100644 --- a/src/service/control.rs +++ b/src/service/control.rs @@ -177,4 +177,15 @@ impl ServiceControl { token, }) } + + /// Shutdown service + /// + /// Order: + /// 1. close all listens + /// 2. try close all session's protocol stream + /// 3. try close all session + /// 4. close service + pub fn shutdown(&mut self) -> Result<(), Error> { + self.send(ServiceTask::Shutdown) + } } diff --git a/src/service/event.rs b/src/service/event.rs index 5cf7ed48..a248ab69 100644 --- a/src/service/event.rs +++ b/src/service/event.rs @@ -240,6 +240,8 @@ pub(crate) enum ServiceTask { /// Listen address address: Multiaddr, }, + /// Shutdown service + Shutdown, } impl fmt::Debug for ServiceTask { @@ -305,6 +307,7 @@ impl fmt::Debug for ServiceTask { session_id, proto_id, } => write!(f, "Close session [{}] proto [{}]", session_id, proto_id), + Shutdown => write!(f, "Try close service"), } } } diff --git a/tests/test_close.rs b/tests/test_close.rs new file mode 100644 index 00000000..bab5da19 --- /dev/null +++ b/tests/test_close.rs @@ -0,0 +1,127 @@ +use futures::prelude::Stream; +use tentacle::{ + builder::{MetaBuilder, ServiceBuilder}, + context::{ProtocolContext, ProtocolContextMutRef}, + secio::SecioKeyPair, + service::{DialProtocol, ProtocolHandle, ProtocolMeta, Service}, + traits::{ServiceHandle, ServiceProtocol}, + ProtocolId, +}; + +use std::{thread, time::Duration}; + +pub fn create(secio: bool, metas: impl Iterator, shandle: F) -> Service +where + F: ServiceHandle, +{ + let mut builder = ServiceBuilder::default().forever(true); + + for meta in metas { + builder = builder.insert_protocol(meta); + } + + if secio { + builder + .key_pair(SecioKeyPair::secp256k1_generated()) + .build(shandle) + } else { + builder.build(shandle) + } +} + +struct PHandle { + count: u8, +} + +impl ServiceProtocol for PHandle { + fn init(&mut self, _context: &mut ProtocolContext) {} + + fn connected(&mut self, mut context: ProtocolContextMutRef, _version: &str) { + if context.session.ty.is_inbound() && context.proto_id == 1 { + self.count += 1; + if self.count >= 4 { + let proto_id = context.proto_id; + context.set_service_notify(proto_id, Duration::from_secs(2), 0); + } + } + } + + fn notify(&mut self, context: &mut ProtocolContext, _token: u64) { + self.count += 1; + if self.count > 6 { + context.shutdown(); + } + } +} + +fn create_meta(id: ProtocolId) -> ProtocolMeta { + MetaBuilder::new() + .id(id) + .service_handle(move || { + let handle = Box::new(PHandle { count: 0 }); + ProtocolHandle::Callback(handle) + }) + .build() +} + +fn test_close(secio: bool) { + let mut service_1 = create( + secio, + vec![create_meta(0), create_meta(1), create_meta(2)].into_iter(), + (), + ); + let mut service_2 = create( + secio, + vec![create_meta(0), create_meta(1), create_meta(2)].into_iter(), + (), + ); + let mut service_3 = create( + secio, + vec![create_meta(0), create_meta(1), create_meta(2)].into_iter(), + (), + ); + let mut service_4 = create( + secio, + vec![create_meta(0), create_meta(1), create_meta(2)].into_iter(), + (), + ); + let mut service_5 = create( + secio, + vec![create_meta(0), create_meta(1), create_meta(2)].into_iter(), + (), + ); + + let listen_addr = service_1 + .listen("/ip4/127.0.0.1/tcp/0".parse().unwrap()) + .unwrap(); + + let handle = thread::spawn(|| tokio::run(service_1.for_each(|_| Ok(())))); + + service_2 + .dial(listen_addr.clone(), DialProtocol::All) + .unwrap(); + service_3 + .dial(listen_addr.clone(), DialProtocol::All) + .unwrap(); + service_4 + .dial(listen_addr.clone(), DialProtocol::All) + .unwrap(); + service_5.dial(listen_addr, DialProtocol::All).unwrap(); + + thread::spawn(|| tokio::run(service_2.for_each(|_| Ok(())))); + thread::spawn(|| tokio::run(service_3.for_each(|_| Ok(())))); + thread::spawn(|| tokio::run(service_4.for_each(|_| Ok(())))); + thread::spawn(|| tokio::run(service_5.for_each(|_| Ok(())))); + + handle.join().expect("test fail"); +} + +#[test] +fn test_close_with_secio() { + test_close(true) +} + +#[test] +fn test_close_with_no_secio() { + test_close(false) +}