Skip to content

Commit

Permalink
graceful shutdown: terminate on error (#1218)
Browse files Browse the repository at this point in the history
* graceful shutdown: terminate on error

* fix nits

* adress grumbles
  • Loading branch information
niklasad1 committed Oct 24, 2023
1 parent 5ef8a87 commit 398d8c1
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 69 deletions.
122 changes: 58 additions & 64 deletions server/src/transport/ws.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use crate::server::{BatchRequestConfig, ServiceData};
use futures_util::future::{self, Either};
use futures_util::io::{BufReader, BufWriter};
use futures_util::stream::FuturesOrdered;
use futures_util::{Future, StreamExt};
use futures_util::{Future, StreamExt, TryStreamExt};
use hyper::upgrade::Upgraded;
use jsonrpsee_core::server::helpers::{
batch_response_error, prepare_error, BatchResponseBuilder, MethodResponse, MethodSink,
Expand All @@ -23,7 +23,6 @@ use jsonrpsee_types::error::{
use jsonrpsee_types::{ErrorObject, Id, InvalidRequest, Notification, Params, Request};
use soketto::connection::Error as SokettoError;
use soketto::data::ByteSlice125;
use soketto::Incoming;

use tokio::sync::{mpsc, oneshot};
use tokio_stream::wrappers::{IntervalStream, ReceiverStream};
Expand All @@ -35,6 +34,11 @@ pub(crate) type Receiver = soketto::Receiver<BufReader<BufWriter<Compat<Upgraded

type Notif<'a> = Notification<'a, Option<&'a JsonRawValue>>;

enum Incoming {
Data(Vec<u8>),
Pong,
}

pub(crate) async fn send_message(sender: &mut Sender, response: String) -> Result<(), Error> {
sender.send_text_owned(response).await?;
sender.flush().await.map_err(Into::into)
Expand Down Expand Up @@ -225,7 +229,7 @@ pub(crate) async fn execute_call<'a, L: Logger>(req: Request<'a>, call: CallData
response
}

pub(crate) async fn background_task<L: Logger>(sender: Sender, mut receiver: Receiver, svc: ServiceData<L>) {
pub(crate) async fn background_task<L: Logger>(sender: Sender, receiver: Receiver, svc: ServiceData<L>) {
let ServiceData {
methods,
max_request_body_size,
Expand Down Expand Up @@ -257,8 +261,6 @@ pub(crate) async fn background_task<L: Logger>(sender: Sender, mut receiver: Rec
// Spawn another task that sends out the responses on the Websocket.
let send_task_handle = tokio::spawn(send_task(rx, sender, ping_interval, conn_rx));

// Buffer for incoming data.
let mut data = Vec::with_capacity(100);
let stopped = stop_handle.clone().shutdown();

let params = Arc::new(ExecuteCallParams {
Expand All @@ -275,13 +277,26 @@ pub(crate) async fn background_task<L: Logger>(sender: Sender, mut receiver: Rec

tokio::pin!(stopped);

let result = loop {
data.clear();
let ws_stream = futures_util::stream::unfold(receiver, |mut receiver| async {
let mut data = Vec::new();
match receiver.receive(&mut data).await {
Ok(soketto::Incoming::Data(_)) => Some((Ok(Incoming::Data(data)), receiver)),
Ok(soketto::Incoming::Pong(_)) => Some((Ok(Incoming::Pong), receiver)),
Ok(soketto::Incoming::Closed(_)) | Err(SokettoError::Closed) => None,
// The closing reason is already logged by `soketto` trace log level.
// Return the `Closed` error to avoid logging unnecessary warnings on clean shutdown.
Err(e) => Some((Err(e), receiver)),
}
});

tokio::pin!(ws_stream);

match try_recv(&mut receiver, &mut data, stopped).await {
let result = loop {
let data = match try_recv(&mut ws_stream, stopped).await {
Receive::Shutdown => break Ok(Shutdown::Stopped),
Receive::Ok(stop) => {
Receive::Ok(data, stop) => {
stopped = stop;
data
}
Receive::Err(err, stop) => {
stopped = stop;
Expand Down Expand Up @@ -311,14 +326,14 @@ pub(crate) async fn background_task<L: Logger>(sender: Sender, mut receiver: Rec
}
};

tokio::spawn(execute_unchecked_call(params.clone(), std::mem::take(&mut data), pending_calls.clone()));
tokio::spawn(execute_unchecked_call(params.clone(), data, pending_calls.clone()));
};

// Drive all running methods to completion.
// **NOTE** Do not return early in this function. This `await` needs to run to guarantee
// proper drop behaviour.
drop(pending_calls);
graceful_shutdown(result, pending_calls_completed, receiver, data, conn_tx, send_task_handle).await;
graceful_shutdown(result, pending_calls_completed, ws_stream, conn_tx, send_task_handle).await;

logger.on_disconnect(remote_addr, TransportProtocol::WebSocket);
drop(conn);
Expand Down Expand Up @@ -396,35 +411,30 @@ async fn send_task(
enum Receive<S> {
Shutdown,
Err(SokettoError, S),
Ok(S),
Ok(Vec<u8>, S),
}

/// Attempts to read data from WebSocket fails if the server was stopped.
async fn try_recv<S>(receiver: &mut Receiver, data: &mut Vec<u8>, stopped: S) -> Receive<S>
async fn try_recv<T, S>(ws_stream: &mut T, mut stopped: S) -> Receive<S>
where
S: Future<Output = ()> + Unpin,
T: StreamExt<Item = Result<Incoming, SokettoError>> + Unpin,
{
let receive = async {
// Identical loop to `soketto::receive_data` with debug logs for `Pong` frames.
loop {
match receiver.receive(data).await? {
soketto::Incoming::Data(d) => break Ok(d),
soketto::Incoming::Pong(_) => tracing::debug!("Received pong"),
soketto::Incoming::Closed(_) => {
// The closing reason is already logged by `soketto` trace log level.
// Return the `Closed` error to avoid logging unnecessary warnings on clean shutdown.
break Err(SokettoError::Closed);
}
loop {
match futures_util::future::select(ws_stream.next(), stopped).await {
// The connection is closed.
Either::Left((None, _)) => break Receive::Shutdown,
// The message has been received, we are done
Either::Left((Some(Ok(Incoming::Data(d))), s)) => break Receive::Ok(d, s),
// Got a pong response, update our "last seen" timestamp.
Either::Left((Some(Ok(Incoming::Pong)), s)) => {
stopped = s;
}
// Received an error, terminate the connection.
Either::Left((Some(Err(e)), s)) => break Receive::Err(e, s),
// Server has been stopped.
Either::Right(_) => break Receive::Shutdown,
}
};

tokio::pin!(receive);

match futures_util::future::select(receive, stopped).await {
Either::Left((Ok(_), s)) => Receive::Ok(s),
Either::Left((Err(e), s)) => Receive::Err(e, s),
Either::Right(_) => Receive::Shutdown,
}
}

Expand Down Expand Up @@ -517,47 +527,31 @@ pub(crate) enum Shutdown {
/// Enforce a graceful shutdown.
///
/// This will return once the connection has been terminated or all pending calls have been executed.
async fn graceful_shutdown(
async fn graceful_shutdown<S>(
result: Result<Shutdown, SokettoError>,
pending_calls: mpsc::Receiver<()>,
receiver: Receiver,
data: Vec<u8>,
ws_stream: S,
mut conn_tx: oneshot::Sender<()>,
send_task_handle: tokio::task::JoinHandle<()>,
) {
) where
S: StreamExt<Item = Result<Incoming, SokettoError>> + Unpin,
{
let pending_calls = ReceiverStream::new(pending_calls);

match result {
Ok(Shutdown::ConnectionClosed) | Err(SokettoError::Closed) => (),
Ok(Shutdown::Stopped) | Err(_) => {
// Soketto doesn't have a way to signal when the connection is closed
// thus just throw away the data and terminate the stream once the connection has
// been terminated.
//
// The receiver is not cancel-safe such that it's used in a stream to enforce that.
let disconnect_stream = futures_util::stream::unfold((receiver, data), |(mut receiver, mut data)| async {
match receiver.receive(&mut data).await {
Ok(Incoming::Closed(_)) | Err(SokettoError::Closed) => None,
Ok(Incoming::Data(_) | Incoming::Pong(_)) => Some(((), (receiver, data))),
Err(e) => {
tracing::warn!("Graceful shutdown got WebSocket error: {e} but polling until the connection is closed or all pending calls has been executed");
Some(((), (receiver, data)))
}
}
});

let graceful_shutdown = pending_calls.for_each(|_| async {});
let disconnect = disconnect_stream.for_each(|_| async {});
if let Ok(Shutdown::Stopped) = result {
let graceful_shutdown = pending_calls.for_each(|_| async {});
let disconnect = ws_stream.try_for_each(|_| async { Ok(()) });

// All pending calls has been finished or the connection closed.
// Fine to terminate
tokio::select! {
_ = graceful_shutdown => {}
_ = disconnect => {}
_ = conn_tx.closed() => {}
tokio::select! {
_ = graceful_shutdown => {}
res = disconnect => {
if let Err(err) = res {
tracing::warn!("Graceful shutdown terminated because of error: `{err}`");
}
}
_ = conn_tx.closed() => {}
}
};
}

// Send a message to close down the "send task".
_ = conn_tx.send(());
Expand Down
10 changes: 5 additions & 5 deletions tests/tests/integration_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1144,15 +1144,15 @@ async fn run_shutdown_test_inner<C: ClientT + Send + Sync + 'static>(
call_answered: Arc<AtomicBool>,
mut call_ack: tokio::sync::mpsc::UnboundedReceiver<()>,
) {
let mut calls: FuturesUnordered<_> = (0..10)
const LEN: usize = 10;

let mut calls: FuturesUnordered<_> = (0..LEN)
.map(|_| {
let c = client.clone();
async move { c.request::<String, _>("sleep_20s", rpc_params!()).await }
})
.collect();

let calls_len = calls.len();

let res = tokio::spawn(async move {
let mut c = 0;
while let Some(Ok(_)) = calls.next().await {
Expand All @@ -1162,7 +1162,7 @@ async fn run_shutdown_test_inner<C: ClientT + Send + Sync + 'static>(
});

// All calls has been received by server => then stop.
for _ in 0..calls_len {
for _ in 0..LEN {
call_ack.recv().await.unwrap();
}

Expand All @@ -1185,7 +1185,7 @@ async fn run_shutdown_test_inner<C: ClientT + Send + Sync + 'static>(
assert!(call_after_stop.await.unwrap().is_err());

// The pending calls should be answered before shutdown.
assert_eq!(res.await.unwrap(), 10);
assert_eq!(res.await.unwrap(), LEN);

// The server should be closed now.
assert!(client.request::<String, _>("sleep_20s", rpc_params!()).await.is_err());
Expand Down

0 comments on commit 398d8c1

Please sign in to comment.