diff --git a/proc-macros/src/render_client.rs b/proc-macros/src/render_client.rs index 11a12680bf..c655a261ab 100644 --- a/proc-macros/src/render_client.rs +++ b/proc-macros/src/render_client.rs @@ -142,7 +142,7 @@ impl RpcDescription { fn encode_params( &self, - params: &Vec<(syn::PatIdent, syn::Type)>, + params: &[(syn::PatIdent, syn::Type)], param_kind: &ParamKind, signature: &syn::TraitItemMethod, ) -> TokenStream2 { diff --git a/tests/tests/helpers.rs b/tests/tests/helpers.rs index f33086e445..baeb76afbf 100644 --- a/tests/tests/helpers.rs +++ b/tests/tests/helpers.rs @@ -81,7 +81,7 @@ pub async fn websocket_server_with_subscription() -> (SocketAddr, WsServerHandle .register_subscription("subscribe_noop", "unsubscribe_noop", |_, mut sink, _| { std::thread::spawn(move || { std::thread::sleep(Duration::from_secs(1)); - sink.close("Server closed the stream because it was lazy".into()) + sink.close("Server closed the stream because it was lazy") }); Ok(()) }) diff --git a/tests/tests/resource_limiting.rs b/tests/tests/resource_limiting.rs index 736de5123d..033cec0abe 100644 --- a/tests/tests/resource_limiting.rs +++ b/tests/tests/resource_limiting.rs @@ -160,8 +160,6 @@ async fn run_tests_on_ws_server(server_addr: SocketAddr, server_handle: WsServer assert!(pass_mem.is_ok()); assert_server_busy(fail_mem); - // Client being active prevents the server from shutting down?! - drop(client); server_handle.stop().unwrap().await; } diff --git a/ws-server/Cargo.toml b/ws-server/Cargo.toml index 377eaf1d01..a0f21f5010 100644 --- a/ws-server/Cargo.toml +++ b/ws-server/Cargo.toml @@ -17,7 +17,7 @@ jsonrpsee-utils = { path = "../utils", version = "0.4.1", features = ["server"] tracing = "0.1" serde_json = { version = "1", features = ["raw_value"] } soketto = "0.7.1" -tokio = { version = "1", features = ["net", "rt-multi-thread", "macros"] } +tokio = { version = "1", features = ["net", "rt-multi-thread", "macros", "time"] } tokio-util = { version = "0.6", features = ["compat"] } [dev-dependencies] diff --git a/ws-server/src/future.rs b/ws-server/src/future.rs index 69b35b6977..fd2964d194 100644 --- a/ws-server/src/future.rs +++ b/ws-server/src/future.rs @@ -36,6 +36,10 @@ use std::sync::{ Arc, Weak, }; use std::task::{Context, Poll}; +use tokio::time::{self, Duration, Interval}; + +/// Polling for server stop monitor interval in milliseconds. +const STOP_MONITOR_POLLING_INTERVAL: u64 = 1000; /// This is a flexible collection of futures that need to be driven to completion /// alongside some other future, such as connection handlers that need to be @@ -45,11 +49,16 @@ use std::task::{Context, Poll}; /// `select_with` providing some other future, the result of which you need. pub(crate) struct FutureDriver { futures: Vec, + stop_monitor_heartbeat: Interval, } impl Default for FutureDriver { fn default() -> Self { - FutureDriver { futures: Vec::new() } + let mut heartbeat = time::interval(Duration::from_millis(STOP_MONITOR_POLLING_INTERVAL)); + + heartbeat.set_missed_tick_behavior(time::MissedTickBehavior::Skip); + + FutureDriver { futures: Vec::new(), stop_monitor_heartbeat: heartbeat } } } @@ -92,6 +101,12 @@ where } } } + + fn poll_stop_monitor_heartbeat(&mut self, cx: &mut Context) { + // We don't care about the ticks of the heartbeat, it's here only + // to periodically wake the `Waker` on `cx`. + let _ = self.stop_monitor_heartbeat.poll_tick(cx); + } } impl Future for FutureDriver @@ -132,6 +147,7 @@ where let this = Pin::into_inner(self); this.driver.drive(cx); + this.driver.poll_stop_monitor_heartbeat(cx); this.selector.poll_unpin(cx) } diff --git a/ws-server/src/server.rs b/ws-server/src/server.rs index 1356ccad06..389d2b13b4 100644 --- a/ws-server/src/server.rs +++ b/ws-server/src/server.rs @@ -90,7 +90,7 @@ impl Server { let mut id = 0; let mut connections = FutureDriver::default(); - let mut incoming = Incoming::new(self.listener, &stop_monitor); + let mut incoming = Monitored::new(Incoming(self.listener), &stop_monitor); loop { match connections.select_with(&mut incoming).await { @@ -122,10 +122,10 @@ impl Server { id = id.wrapping_add(1); } - Err(IncomingError::Io(err)) => { + Err(MonitoredError::Selector(err)) => { tracing::error!("Error while awaiting a new connection: {:?}", err); } - Err(IncomingError::Shutdown) => break, + Err(MonitoredError::Shutdown) => break, } } @@ -133,35 +133,53 @@ impl Server { } } -/// This is a glorified select listening to new connections, while also checking -/// for `stop_receiver` signal. -struct Incoming<'a> { - listener: TcpListener, +/// This is a glorified select listening for new messages, while also checking the `stop_receiver` signal. +struct Monitored<'a, F> { + future: F, stop_monitor: &'a StopMonitor, } -impl<'a> Incoming<'a> { - fn new(listener: TcpListener, stop_monitor: &'a StopMonitor) -> Self { - Incoming { listener, stop_monitor } +impl<'a, F> Monitored<'a, F> { + fn new(future: F, stop_monitor: &'a StopMonitor) -> Self { + Monitored { future, stop_monitor } } } -enum IncomingError { +enum MonitoredError { Shutdown, - Io(std::io::Error), + Selector(E), +} + +struct Incoming(TcpListener); + +impl<'a> Future for Monitored<'a, Incoming> { + type Output = Result<(TcpStream, SocketAddr), MonitoredError>; + + fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll { + let this = Pin::into_inner(self); + + if this.stop_monitor.shutdown_requested() { + return Poll::Ready(Err(MonitoredError::Shutdown)); + } + + this.future.0.poll_accept(cx).map_err(MonitoredError::Selector) + } } -impl<'a> Future for Incoming<'a> { - type Output = Result<(TcpStream, SocketAddr), IncomingError>; +impl<'a, 'f, F, T, E> Future for Monitored<'a, Pin<&'f mut F>> +where + F: Future>, +{ + type Output = Result>; fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll { let this = Pin::into_inner(self); if this.stop_monitor.shutdown_requested() { - return Poll::Ready(Err(IncomingError::Shutdown)); + return Poll::Ready(Err(MonitoredError::Shutdown)); } - this.listener.poll_accept(cx).map_err(IncomingError::Io) + this.future.poll_unpin(cx).map_err(MonitoredError::Selector) } } @@ -275,31 +293,39 @@ async fn background_task( let mut data = Vec::with_capacity(100); let mut method_executors = FutureDriver::default(); - while !stop_server.shutdown_requested() { + loop { data.clear(); - if let Err(err) = method_executors.select_with(receiver.receive_data(&mut data)).await { - match err { - SokettoError::Closed => { - tracing::debug!("Remote peer terminated the connection: {}", conn_id); - tx.close_channel(); - return Ok(()); - } - SokettoError::MessageTooLarge { current, maximum } => { - tracing::warn!( - "WS transport error: message is too big error ({} bytes, max is {})", - current, - maximum - ); - send_error(Id::Null, &tx, ErrorCode::OversizedRequest.into()); - continue; - } - // These errors can not be gracefully handled, so just log them and terminate the connection. - err => { - tracing::error!("WS transport error: {:?} => terminating connection {}", err, conn_id); - tx.close_channel(); - return Err(err.into()); - } + { + // Need the extra scope to drop this pinned future and reclaim access to `data` + let receive = receiver.receive_data(&mut data); + + tokio::pin!(receive); + + if let Err(err) = method_executors.select_with(Monitored::new(receive, &stop_server)).await { + match err { + MonitoredError::Selector(SokettoError::Closed) => { + tracing::debug!("WS transport error: remote peer terminated the connection: {}", conn_id); + tx.close_channel(); + return Ok(()); + } + MonitoredError::Selector(SokettoError::MessageTooLarge { current, maximum }) => { + tracing::warn!( + "WS transport error: outgoing message is too big error ({} bytes, max is {})", + current, + maximum + ); + send_error(Id::Null, &tx, ErrorCode::OversizedRequest.into()); + continue; + } + // These errors can not be gracefully handled, so just log them and terminate the connection. + MonitoredError::Selector(err) => { + tracing::error!("WS transport error: {:?} => terminating connection {}", err, conn_id); + tx.close_channel(); + return Err(err.into()); + } + MonitoredError::Shutdown => break, + }; }; };