Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement ping-pong for WebSocket server #782

Merged
merged 8 commits into from
May 27, 2022
Merged
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 61 additions & 9 deletions ws-server/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,10 @@
use std::future::Future;
use std::net::SocketAddr;
use std::pin::Pin;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
use std::task::{Context, Poll};
use std::time::Duration;

use crate::future::{FutureDriver, ServerHandle, StopMonitor};
use crate::types::error::{ErrorCode, ErrorObject, BATCHES_NOT_SUPPORTED_CODE, BATCHES_NOT_SUPPORTED_MSG};
Expand All @@ -46,9 +48,11 @@ use jsonrpsee_core::traits::IdProvider;
use jsonrpsee_core::{Error, TEN_MB_SIZE_BYTES};
use jsonrpsee_types::Params;
use soketto::connection::Error as SokettoError;
use soketto::data::ByteSlice125;
use soketto::handshake::{server::Response, Server as SokettoServer};
use soketto::Sender;
use soketto::{Receiver, Sender};
use tokio::net::{TcpListener, TcpStream, ToSocketAddrs};
use tokio::select;
use tokio_util::compat::{Compat, TokioAsyncReadCompatExt};

/// Default maximum connections allowed.
Expand Down Expand Up @@ -275,6 +279,7 @@ where
stop_monitor.clone(),
middleware,
id_provider,
cfg.ping_interval,
))
.await;

Expand All @@ -298,6 +303,7 @@ async fn background_task(
stop_server: StopMonitor,
middleware: impl Middleware,
id_provider: Arc<dyn IdProvider>,
ping_interval: Duration,
) -> Result<(), Error> {
// And we can finally transition to a websocket background_task.
let mut builder = server.into_builder();
Expand All @@ -309,19 +315,32 @@ async fn background_task(
let stop_server2 = stop_server.clone();
let sink = MethodSink::new_with_limit(tx, max_response_body_size);

let mut submit_ping = tokio::time::interval(ping_interval);

middleware.on_connect();

// Send results back to the client.
tokio::spawn(async move {
while !stop_server2.shutdown_requested() {
if let Some(response) = rx.next().await {
// If websocket message send fail then terminate the connection.
if let Err(err) = send_ws_message(&mut sender, response).await {
tracing::warn!("WS send error: {}; terminate connection", err);
break;
select! {
niklasad1 marked this conversation as resolved.
Show resolved Hide resolved
rx_value = rx.next() => {
niklasad1 marked this conversation as resolved.
Show resolved Hide resolved
if let Some(response) = rx_value {
// If websocket message send fail then terminate the connection.
if let Err(err) = send_ws_message(&mut sender, response).await {
tracing::warn!("WS send error: {}; terminate connection", err);
break;
}
} else {
break;
}
}

_ = submit_ping.tick() => {
if let Err(err) = send_ws_ping(&mut sender).await {
tracing::warn!("WS send ping error: {}; terminate connection", err);
break;
}
}
} else {
break;
}
}

Expand All @@ -343,7 +362,6 @@ async fn background_task(
{
// Need the extra scope to drop this pinned future and reclaim access to `data`
let receive = receiver.receive_data(&mut data);

tokio::pin!(receive);

if let Err(err) = method_executors.select_with(Monitored::new(receive, &stop_server)).await {
Expand Down Expand Up @@ -668,6 +686,8 @@ struct Settings {
batch_requests_supported: bool,
/// Custom tokio runtime to run the server on.
tokio_runtime: Option<tokio::runtime::Handle>,
/// The interval at which `Ping` frames are submitted.
ping_interval: Duration,
}

impl Default for Settings {
Expand All @@ -681,6 +701,7 @@ impl Default for Settings {
allowed_origins: AllowedValue::Any,
allowed_hosts: AllowedValue::Any,
tokio_runtime: None,
ping_interval: Duration::from_secs(60),
}
}
}
Expand Down Expand Up @@ -863,6 +884,27 @@ impl<M> Builder<M> {
self
}

/// Configure the interval at which pings are submitted.
///
/// This option is used to keep the connection alive, and is just submitting `Ping` frames,
/// without making any assumptions about when a `Pong` frame should be received.
///
/// Default: 60 seconds.
///
/// # Examples
///
/// ```rust
/// use std::time::Duration;
/// use jsonrpsee_ws_server::WsServerBuilder;
///
/// // Set the ping interval to 10 seconds.
/// let builder = WsServerBuilder::default().ping_interval(Duration::from_secs(10));
/// ```
pub fn ping_interval(mut self, interval: Duration) -> Self {
self.settings.ping_interval = interval;
self
}

/// Configure custom `subscription ID` provider for the server to use
/// to when getting new subscription calls.
///
Expand Down Expand Up @@ -928,3 +970,13 @@ async fn send_ws_message(
sender.send_text_owned(response).await?;
sender.flush().await.map_err(Into::into)
}

async fn send_ws_ping(sender: &mut Sender<BufReader<BufWriter<Compat<TcpStream>>>>) -> Result<(), Error> {
tracing::debug!("submit ping");
// Submit empty slice as "optional" parameter.
let slice: &[u8] = &[];
// Byte slice fails if the provided slice is larger than 125 bytes.
let byte_slice = ByteSlice125::try_from(slice).expect("Empty slice should fit into ByteSlice125");
sender.send_ping(byte_slice).await?;
sender.flush().await.map_err(Into::into)
}