From 59ccd8e170737443b14cf39c97a6d41671224dcd Mon Sep 17 00:00:00 2001 From: Niklas Adolfsson Date: Thu, 30 Nov 2023 17:37:38 +0100 Subject: [PATCH 1/2] refactor(ws client): tokio:{AsyncRead, AsyncWrite] Simplify the code by implementing `tokio::io::{AsyncWrite, AsyncRead}` for the EitherStream. However, we still need the compat because soketto requires futures::io::{AsyncRead, AsyncWrite} --- client/transport/src/ws/mod.rs | 21 +++---- client/transport/src/ws/stream.rs | 94 +++++++------------------------ tests/tests/helpers.rs | 57 +------------------ tests/tests/integration_tests.rs | 20 +++---- 4 files changed, 40 insertions(+), 152 deletions(-) diff --git a/client/transport/src/ws/mod.rs b/client/transport/src/ws/mod.rs index dfd3fce0a4..e5410f9e84 100644 --- a/client/transport/src/ws/mod.rs +++ b/client/transport/src/ws/mod.rs @@ -31,7 +31,6 @@ use std::net::SocketAddr; use std::time::Duration; use futures_util::io::{BufReader, BufWriter}; -pub use futures_util::{AsyncRead, AsyncWrite}; use jsonrpsee_core::client::{CertificateStore, MaybeSend, ReceivedMessage, TransportReceiverT, TransportSenderT}; use jsonrpsee_core::TEN_MB_SIZE_BYTES; use jsonrpsee_core::{async_trait, Cow}; @@ -41,7 +40,9 @@ use soketto::handshake::client::{Client as WsHandshakeClient, ServerResponse}; use soketto::{connection, Data, Incoming}; use thiserror::Error; use tokio::net::TcpStream; +use tokio_util::compat::{TokioAsyncReadCompatExt, Compat}; +pub use tokio::io::{AsyncRead, AsyncWrite}; pub use http::{uri::InvalidUri, HeaderMap, HeaderValue, Uri}; pub use soketto::handshake::client::Header; pub use stream::EitherStream; @@ -229,7 +230,7 @@ pub enum WsError { #[async_trait] impl TransportSenderT for Sender where - T: AsyncRead + AsyncWrite + Unpin + MaybeSend + 'static, + T: futures_util::io::AsyncRead + futures_util::io::AsyncWrite + Unpin + MaybeSend + 'static, { type Error = WsError; @@ -268,7 +269,7 @@ where #[async_trait] impl TransportReceiverT for Receiver where - T: AsyncRead + AsyncWrite + Unpin + MaybeSend + 'static, + T: futures_util::io::AsyncRead + futures_util::io::AsyncWrite + Unpin + MaybeSend + 'static, { type Error = WsError; @@ -295,7 +296,7 @@ impl WsTransportClientBuilder { /// Try to establish the connection. /// /// Uses the default connection over TCP. - pub async fn build(self, uri: Url) -> Result<(Sender, Receiver), WsHandshakeError> { + pub async fn build(self, uri: Url) -> Result<(Sender>, Receiver>), WsHandshakeError> { self.try_connect_over_tcp(uri).await } @@ -304,19 +305,19 @@ impl WsTransportClientBuilder { self, uri: Url, data_stream: T, - ) -> Result<(Sender, Receiver), WsHandshakeError> + ) -> Result<(Sender>, Receiver>), WsHandshakeError> where - T: AsyncRead + AsyncWrite + Unpin, + T: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin, { let target: Target = uri.try_into()?; - self.try_connect(&target, data_stream).await + self.try_connect(&target, data_stream.compat()).await } // Try to establish the connection over TCP. async fn try_connect_over_tcp( &self, uri: Url, - ) -> Result<(Sender, Receiver), WsHandshakeError> { + ) -> Result<(Sender>, Receiver>), WsHandshakeError> { let mut target: Target = uri.try_into()?; let mut err = None; @@ -353,7 +354,7 @@ impl WsTransportClientBuilder { } }; - match self.try_connect(&target, tcp_stream).await { + match self.try_connect(&target, tcp_stream.compat()).await { Ok(result) => return Ok(result), Err(WsHandshakeError::Redirected { status_code, location }) => { @@ -422,7 +423,7 @@ impl WsTransportClientBuilder { data_stream: T, ) -> Result<(Sender, Receiver), WsHandshakeError> where - T: AsyncRead + AsyncWrite + Unpin, + T: futures_util::AsyncRead + futures_util::AsyncWrite + Unpin, { let mut client = WsHandshakeClient::new( BufReader::new(BufWriter::new(data_stream)), diff --git a/client/transport/src/ws/stream.rs b/client/transport/src/ws/stream.rs index c9d7d40b72..5defe430c4 100644 --- a/client/transport/src/ws/stream.rs +++ b/client/transport/src/ws/stream.rs @@ -31,11 +31,9 @@ use std::pin::Pin; use std::task::Context; use std::task::Poll; -use futures_util::io::{IoSlice, IoSliceMut}; -use futures_util::*; use pin_project::pin_project; +use tokio::io::{AsyncRead, AsyncWrite}; use tokio::net::TcpStream; -use tokio_util::compat::{TokioAsyncReadCompatExt, TokioAsyncWriteCompatExt}; /// Stream to represent either a unencrypted or encrypted socket stream. #[pin_project(project = EitherStreamProj)] @@ -50,38 +48,14 @@ pub enum EitherStream { } impl AsyncRead for EitherStream { - fn poll_read(self: Pin<&mut Self>, cx: &mut Context, buf: &mut [u8]) -> Poll> { + fn poll_read(self: Pin<&mut Self>, cx: &mut Context, buf: &mut tokio::io::ReadBuf<'_>) -> Poll> { match self.project() { - EitherStreamProj::Plain(s) => { - let compat = s.compat(); - futures_util::pin_mut!(compat); - AsyncRead::poll_read(compat, cx, buf) + EitherStreamProj::Plain(stream) => { + AsyncRead::poll_read(stream, cx, buf) } #[cfg(feature = "__tls")] - EitherStreamProj::Tls(t) => { - let compat = t.compat(); - futures_util::pin_mut!(compat); - AsyncRead::poll_read(compat, cx, buf) - } - } - } - - fn poll_read_vectored( - self: Pin<&mut Self>, - cx: &mut Context, - bufs: &mut [IoSliceMut], - ) -> Poll> { - match self.project() { - EitherStreamProj::Plain(s) => { - let compat = s.compat(); - futures_util::pin_mut!(compat); - AsyncRead::poll_read_vectored(compat, cx, bufs) - } - #[cfg(feature = "__tls")] - EitherStreamProj::Tls(t) => { - let compat = t.compat(); - futures_util::pin_mut!(compat); - AsyncRead::poll_read_vectored(compat, cx, bufs) + EitherStreamProj::Tls(stream) => { + AsyncRead::poll_read(stream, cx, buf) } } } @@ -90,64 +64,36 @@ impl AsyncRead for EitherStream { impl AsyncWrite for EitherStream { fn poll_write(self: Pin<&mut Self>, cx: &mut Context, buf: &[u8]) -> Poll> { match self.project() { - EitherStreamProj::Plain(s) => { - let compat = s.compat_write(); - futures_util::pin_mut!(compat); - AsyncWrite::poll_write(compat, cx, buf) - } - #[cfg(feature = "__tls")] - EitherStreamProj::Tls(t) => { - let compat = t.compat_write(); - futures_util::pin_mut!(compat); - AsyncWrite::poll_write(compat, cx, buf) - } - } - } - - fn poll_write_vectored(self: Pin<&mut Self>, cx: &mut Context, bufs: &[IoSlice]) -> Poll> { - match self.project() { - EitherStreamProj::Plain(s) => { - let compat = s.compat_write(); - futures_util::pin_mut!(compat); - AsyncWrite::poll_write_vectored(compat, cx, bufs) + EitherStreamProj::Plain(stream) => { + AsyncWrite::poll_write(stream, cx, buf) } #[cfg(feature = "__tls")] - EitherStreamProj::Tls(t) => { - let compat = t.compat_write(); - futures_util::pin_mut!(compat); - AsyncWrite::poll_write_vectored(compat, cx, bufs) + EitherStreamProj::Tls(stream) => { + AsyncWrite::poll_write(stream, cx, buf) } } } - fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll> { + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { match self.project() { - EitherStreamProj::Plain(s) => { - let compat = s.compat_write(); - futures_util::pin_mut!(compat); - AsyncWrite::poll_flush(compat, cx) + EitherStreamProj::Plain(stream) => { + AsyncWrite::poll_flush(stream, cx) } #[cfg(feature = "__tls")] - EitherStreamProj::Tls(t) => { - let compat = t.compat_write(); - futures_util::pin_mut!(compat); - AsyncWrite::poll_flush(compat, cx) + EitherStreamProj::Tls(stream) => { + AsyncWrite::poll_flush(stream, cx) } } } - fn poll_close(self: Pin<&mut Self>, cx: &mut Context) -> Poll> { + fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { match self.project() { - EitherStreamProj::Plain(s) => { - let compat = s.compat_write(); - futures_util::pin_mut!(compat); - AsyncWrite::poll_close(compat, cx) + EitherStreamProj::Plain(stream) => { + AsyncWrite::poll_shutdown(stream, cx) } #[cfg(feature = "__tls")] - EitherStreamProj::Tls(t) => { - let compat = t.compat_write(); - futures_util::pin_mut!(compat); - AsyncWrite::poll_close(compat, cx) + EitherStreamProj::Tls(stream) => { + AsyncWrite::poll_shutdown(stream, cx) } } } diff --git a/tests/tests/helpers.rs b/tests/tests/helpers.rs index 2edfb0e4f5..9211c5bd0c 100644 --- a/tests/tests/helpers.rs +++ b/tests/tests/helpers.rs @@ -32,7 +32,7 @@ use std::time::Duration; use fast_socks5::client::Socks5Stream; use fast_socks5::server; -use futures::{AsyncRead, AsyncWrite, SinkExt, Stream, StreamExt}; +use futures::{SinkExt, Stream, StreamExt}; use jsonrpsee::core::Error; use jsonrpsee::server::middleware::http::ProxyGetRequestLayer; use jsonrpsee::server::{ @@ -40,12 +40,10 @@ use jsonrpsee::server::{ }; use jsonrpsee::types::{ErrorObject, ErrorObjectOwned}; use jsonrpsee::SubscriptionCloseResponse; -use pin_project::pin_project; use serde::Serialize; use tokio::net::TcpStream; use tokio::time::interval; use tokio_stream::wrappers::IntervalStream; -use tokio_util::compat::{TokioAsyncReadCompatExt, TokioAsyncWriteCompatExt}; use tower_http::cors::CorsLayer; pub async fn server_with_subscription_and_handle() -> (SocketAddr, ServerHandle) { @@ -289,55 +287,4 @@ pub async fn connect_over_socks_stream(server_addr: SocketAddr) -> Socks5Stream< ) .await .unwrap() -} - -#[pin_project] -pub struct DataStream(#[pin] Socks5Stream); - -impl DataStream { - pub fn new(t: Socks5Stream) -> Self { - Self(t) - } -} - -impl AsyncRead for DataStream { - fn poll_read( - self: std::pin::Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - buf: &mut [u8], - ) -> std::task::Poll> { - let this = self.project().0.compat(); - futures_util::pin_mut!(this); - AsyncRead::poll_read(this, cx, buf) - } -} - -impl AsyncWrite for DataStream { - fn poll_write( - self: std::pin::Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - buf: &[u8], - ) -> std::task::Poll> { - let this = self.project().0.compat_write(); - futures_util::pin_mut!(this); - AsyncWrite::poll_write(this, cx, buf) - } - - fn poll_flush( - self: std::pin::Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - ) -> std::task::Poll> { - let this = self.project().0.compat_write(); - futures_util::pin_mut!(this); - AsyncWrite::poll_flush(this, cx) - } - - fn poll_close( - self: std::pin::Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - ) -> std::task::Poll> { - let this = self.project().0.compat_write(); - futures_util::pin_mut!(this); - AsyncWrite::poll_close(this, cx) - } -} +} \ No newline at end of file diff --git a/tests/tests/integration_tests.rs b/tests/tests/integration_tests.rs index d4041cf6ae..2bd6c34fef 100644 --- a/tests/tests/integration_tests.rs +++ b/tests/tests/integration_tests.rs @@ -37,7 +37,7 @@ use futures::stream::FuturesUnordered; use futures::{channel::mpsc, StreamExt, TryStreamExt}; use helpers::{ connect_over_socks_stream, init_logger, pipe_from_stream_and_drop, server, server_with_cors, - server_with_health_api, server_with_subscription, server_with_subscription_and_handle, DataStream, + server_with_health_api, server_with_subscription, server_with_subscription_and_handle }; use hyper::http::HeaderValue; use jsonrpsee::core::client::{ClientT, IdKind, Subscription, SubscriptionClientT}; @@ -85,9 +85,7 @@ async fn ws_subscription_works_over_proxy_stream() { let target_url = format!("ws://{}", server_addr); let socks_stream = connect_over_socks_stream(server_addr).await; - let data_stream = DataStream::new(socks_stream); - - let client = WsClientBuilder::default().build_with_stream(target_url, data_stream).await.unwrap(); + let client = WsClientBuilder::default().build_with_stream(target_url, socks_stream).await.unwrap(); let mut hello_sub: Subscription = client.subscribe("subscribe_hello", rpc_params![], "unsubscribe_hello").await.unwrap(); @@ -129,8 +127,8 @@ async fn ws_unsubscription_works_over_proxy_stream() { let server_addr = server_with_sleeping_subscription(tx).await; let server_url = format!("ws://{}", server_addr); - let stream = DataStream::new(connect_over_socks_stream(server_addr).await); - let client = WsClientBuilder::default().build_with_stream(&server_url, stream).await.unwrap(); + let socks_stream = connect_over_socks_stream(server_addr).await; + let client = WsClientBuilder::default().build_with_stream(&server_url, socks_stream).await.unwrap(); let sub: Subscription = client.subscribe("subscribe_sleep", rpc_params![], "unsubscribe_sleep").await.unwrap(); @@ -166,9 +164,7 @@ async fn ws_subscription_with_input_works_over_proxy_stream() { let server_url = format!("ws://{}", server_addr); let socks_stream = connect_over_socks_stream(server_addr).await; - let data_stream = DataStream::new(socks_stream); - - let client = WsClientBuilder::default().build_with_stream(&server_url, data_stream).await.unwrap(); + let client = WsClientBuilder::default().build_with_stream(&server_url, socks_stream).await.unwrap(); let mut add_one: Subscription = client.subscribe("subscribe_add_one", rpc_params![1], "unsubscribe_add_one").await.unwrap(); @@ -198,9 +194,8 @@ async fn ws_method_call_works_over_proxy_stream() { let server_url = format!("ws://{}", server_addr); let socks_stream = connect_over_socks_stream(server_addr).await; - let data_stream = DataStream::new(socks_stream); - let client = WsClientBuilder::default().build_with_stream(&server_url, data_stream).await.unwrap(); + let client = WsClientBuilder::default().build_with_stream(&server_url, socks_stream).await.unwrap(); let response: String = client.request("say_hello", rpc_params![]).await.unwrap(); assert_eq!(&response, "hello"); } @@ -224,10 +219,9 @@ async fn ws_method_call_str_id_works_over_proxy_stream() { let server_url = format!("ws://{}", server_addr); let socks_stream = connect_over_socks_stream(server_addr).await; - let data_stream = DataStream::new(socks_stream); let client = - WsClientBuilder::default().id_format(IdKind::String).build_with_stream(&server_url, data_stream).await.unwrap(); + WsClientBuilder::default().id_format(IdKind::String).build_with_stream(&server_url, socks_stream).await.unwrap(); let response: String = client.request("say_hello", rpc_params![]).await.unwrap(); assert_eq!(&response, "hello"); } From ce158b0748dc1fa49080263ce06ccbfd7779146c Mon Sep 17 00:00:00 2001 From: Niklas Adolfsson Date: Thu, 30 Nov 2023 17:49:53 +0100 Subject: [PATCH 2/2] cargo fmt --- client/transport/src/ws/mod.rs | 9 +++++--- client/transport/src/ws/stream.rs | 38 +++++++++++-------------------- tests/tests/helpers.rs | 2 +- tests/tests/integration_tests.rs | 9 +++++--- 4 files changed, 26 insertions(+), 32 deletions(-) diff --git a/client/transport/src/ws/mod.rs b/client/transport/src/ws/mod.rs index e5410f9e84..5bf9264356 100644 --- a/client/transport/src/ws/mod.rs +++ b/client/transport/src/ws/mod.rs @@ -40,12 +40,12 @@ use soketto::handshake::client::{Client as WsHandshakeClient, ServerResponse}; use soketto::{connection, Data, Incoming}; use thiserror::Error; use tokio::net::TcpStream; -use tokio_util::compat::{TokioAsyncReadCompatExt, Compat}; +use tokio_util::compat::{Compat, TokioAsyncReadCompatExt}; -pub use tokio::io::{AsyncRead, AsyncWrite}; pub use http::{uri::InvalidUri, HeaderMap, HeaderValue, Uri}; pub use soketto::handshake::client::Header; pub use stream::EitherStream; +pub use tokio::io::{AsyncRead, AsyncWrite}; pub use url::Url; const LOG_TARGET: &str = "jsonrpsee-client"; @@ -296,7 +296,10 @@ impl WsTransportClientBuilder { /// Try to establish the connection. /// /// Uses the default connection over TCP. - pub async fn build(self, uri: Url) -> Result<(Sender>, Receiver>), WsHandshakeError> { + pub async fn build( + self, + uri: Url, + ) -> Result<(Sender>, Receiver>), WsHandshakeError> { self.try_connect_over_tcp(uri).await } diff --git a/client/transport/src/ws/stream.rs b/client/transport/src/ws/stream.rs index 5defe430c4..55332b6f18 100644 --- a/client/transport/src/ws/stream.rs +++ b/client/transport/src/ws/stream.rs @@ -48,15 +48,15 @@ pub enum EitherStream { } impl AsyncRead for EitherStream { - fn poll_read(self: Pin<&mut Self>, cx: &mut Context, buf: &mut tokio::io::ReadBuf<'_>) -> Poll> { + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context, + buf: &mut tokio::io::ReadBuf<'_>, + ) -> Poll> { match self.project() { - EitherStreamProj::Plain(stream) => { - AsyncRead::poll_read(stream, cx, buf) - } + EitherStreamProj::Plain(stream) => AsyncRead::poll_read(stream, cx, buf), #[cfg(feature = "__tls")] - EitherStreamProj::Tls(stream) => { - AsyncRead::poll_read(stream, cx, buf) - } + EitherStreamProj::Tls(stream) => AsyncRead::poll_read(stream, cx, buf), } } } @@ -64,37 +64,25 @@ impl AsyncRead for EitherStream { impl AsyncWrite for EitherStream { fn poll_write(self: Pin<&mut Self>, cx: &mut Context, buf: &[u8]) -> Poll> { match self.project() { - EitherStreamProj::Plain(stream) => { - AsyncWrite::poll_write(stream, cx, buf) - } + EitherStreamProj::Plain(stream) => AsyncWrite::poll_write(stream, cx, buf), #[cfg(feature = "__tls")] - EitherStreamProj::Tls(stream) => { - AsyncWrite::poll_write(stream, cx, buf) - } + EitherStreamProj::Tls(stream) => AsyncWrite::poll_write(stream, cx, buf), } } fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { match self.project() { - EitherStreamProj::Plain(stream) => { - AsyncWrite::poll_flush(stream, cx) - } + EitherStreamProj::Plain(stream) => AsyncWrite::poll_flush(stream, cx), #[cfg(feature = "__tls")] - EitherStreamProj::Tls(stream) => { - AsyncWrite::poll_flush(stream, cx) - } + EitherStreamProj::Tls(stream) => AsyncWrite::poll_flush(stream, cx), } } fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { match self.project() { - EitherStreamProj::Plain(stream) => { - AsyncWrite::poll_shutdown(stream, cx) - } + EitherStreamProj::Plain(stream) => AsyncWrite::poll_shutdown(stream, cx), #[cfg(feature = "__tls")] - EitherStreamProj::Tls(stream) => { - AsyncWrite::poll_shutdown(stream, cx) - } + EitherStreamProj::Tls(stream) => AsyncWrite::poll_shutdown(stream, cx), } } } diff --git a/tests/tests/helpers.rs b/tests/tests/helpers.rs index 9211c5bd0c..0398d3ae86 100644 --- a/tests/tests/helpers.rs +++ b/tests/tests/helpers.rs @@ -287,4 +287,4 @@ pub async fn connect_over_socks_stream(server_addr: SocketAddr) -> Socks5Stream< ) .await .unwrap() -} \ No newline at end of file +} diff --git a/tests/tests/integration_tests.rs b/tests/tests/integration_tests.rs index 2bd6c34fef..38851a07e1 100644 --- a/tests/tests/integration_tests.rs +++ b/tests/tests/integration_tests.rs @@ -37,7 +37,7 @@ use futures::stream::FuturesUnordered; use futures::{channel::mpsc, StreamExt, TryStreamExt}; use helpers::{ connect_over_socks_stream, init_logger, pipe_from_stream_and_drop, server, server_with_cors, - server_with_health_api, server_with_subscription, server_with_subscription_and_handle + server_with_health_api, server_with_subscription, server_with_subscription_and_handle, }; use hyper::http::HeaderValue; use jsonrpsee::core::client::{ClientT, IdKind, Subscription, SubscriptionClientT}; @@ -220,8 +220,11 @@ async fn ws_method_call_str_id_works_over_proxy_stream() { let socks_stream = connect_over_socks_stream(server_addr).await; - let client = - WsClientBuilder::default().id_format(IdKind::String).build_with_stream(&server_url, socks_stream).await.unwrap(); + let client = WsClientBuilder::default() + .id_format(IdKind::String) + .build_with_stream(&server_url, socks_stream) + .await + .unwrap(); let response: String = client.request("say_hello", rpc_params![]).await.unwrap(); assert_eq!(&response, "hello"); }