diff --git a/client/transport/src/ws/mod.rs b/client/transport/src/ws/mod.rs index dfd3fce0a4..5bf9264356 100644 --- a/client/transport/src/ws/mod.rs +++ b/client/transport/src/ws/mod.rs @@ -31,7 +31,6 @@ use std::net::SocketAddr; use std::time::Duration; use futures_util::io::{BufReader, BufWriter}; -pub use futures_util::{AsyncRead, AsyncWrite}; use jsonrpsee_core::client::{CertificateStore, MaybeSend, ReceivedMessage, TransportReceiverT, TransportSenderT}; use jsonrpsee_core::TEN_MB_SIZE_BYTES; use jsonrpsee_core::{async_trait, Cow}; @@ -41,10 +40,12 @@ use soketto::handshake::client::{Client as WsHandshakeClient, ServerResponse}; use soketto::{connection, Data, Incoming}; use thiserror::Error; use tokio::net::TcpStream; +use tokio_util::compat::{Compat, TokioAsyncReadCompatExt}; pub use http::{uri::InvalidUri, HeaderMap, HeaderValue, Uri}; pub use soketto::handshake::client::Header; pub use stream::EitherStream; +pub use tokio::io::{AsyncRead, AsyncWrite}; pub use url::Url; const LOG_TARGET: &str = "jsonrpsee-client"; @@ -229,7 +230,7 @@ pub enum WsError { #[async_trait] impl TransportSenderT for Sender where - T: AsyncRead + AsyncWrite + Unpin + MaybeSend + 'static, + T: futures_util::io::AsyncRead + futures_util::io::AsyncWrite + Unpin + MaybeSend + 'static, { type Error = WsError; @@ -268,7 +269,7 @@ where #[async_trait] impl TransportReceiverT for Receiver where - T: AsyncRead + AsyncWrite + Unpin + MaybeSend + 'static, + T: futures_util::io::AsyncRead + futures_util::io::AsyncWrite + Unpin + MaybeSend + 'static, { type Error = WsError; @@ -295,7 +296,10 @@ impl WsTransportClientBuilder { /// Try to establish the connection. /// /// Uses the default connection over TCP. - pub async fn build(self, uri: Url) -> Result<(Sender, Receiver), WsHandshakeError> { + pub async fn build( + self, + uri: Url, + ) -> Result<(Sender>, Receiver>), WsHandshakeError> { self.try_connect_over_tcp(uri).await } @@ -304,19 +308,19 @@ impl WsTransportClientBuilder { self, uri: Url, data_stream: T, - ) -> Result<(Sender, Receiver), WsHandshakeError> + ) -> Result<(Sender>, Receiver>), WsHandshakeError> where - T: AsyncRead + AsyncWrite + Unpin, + T: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin, { let target: Target = uri.try_into()?; - self.try_connect(&target, data_stream).await + self.try_connect(&target, data_stream.compat()).await } // Try to establish the connection over TCP. async fn try_connect_over_tcp( &self, uri: Url, - ) -> Result<(Sender, Receiver), WsHandshakeError> { + ) -> Result<(Sender>, Receiver>), WsHandshakeError> { let mut target: Target = uri.try_into()?; let mut err = None; @@ -353,7 +357,7 @@ impl WsTransportClientBuilder { } }; - match self.try_connect(&target, tcp_stream).await { + match self.try_connect(&target, tcp_stream.compat()).await { Ok(result) => return Ok(result), Err(WsHandshakeError::Redirected { status_code, location }) => { @@ -422,7 +426,7 @@ impl WsTransportClientBuilder { data_stream: T, ) -> Result<(Sender, Receiver), WsHandshakeError> where - T: AsyncRead + AsyncWrite + Unpin, + T: futures_util::AsyncRead + futures_util::AsyncWrite + Unpin, { let mut client = WsHandshakeClient::new( BufReader::new(BufWriter::new(data_stream)), diff --git a/client/transport/src/ws/stream.rs b/client/transport/src/ws/stream.rs index c9d7d40b72..55332b6f18 100644 --- a/client/transport/src/ws/stream.rs +++ b/client/transport/src/ws/stream.rs @@ -31,11 +31,9 @@ use std::pin::Pin; use std::task::Context; use std::task::Poll; -use futures_util::io::{IoSlice, IoSliceMut}; -use futures_util::*; use pin_project::pin_project; +use tokio::io::{AsyncRead, AsyncWrite}; use tokio::net::TcpStream; -use tokio_util::compat::{TokioAsyncReadCompatExt, TokioAsyncWriteCompatExt}; /// Stream to represent either a unencrypted or encrypted socket stream. #[pin_project(project = EitherStreamProj)] @@ -50,39 +48,15 @@ pub enum EitherStream { } impl AsyncRead for EitherStream { - fn poll_read(self: Pin<&mut Self>, cx: &mut Context, buf: &mut [u8]) -> Poll> { - match self.project() { - EitherStreamProj::Plain(s) => { - let compat = s.compat(); - futures_util::pin_mut!(compat); - AsyncRead::poll_read(compat, cx, buf) - } - #[cfg(feature = "__tls")] - EitherStreamProj::Tls(t) => { - let compat = t.compat(); - futures_util::pin_mut!(compat); - AsyncRead::poll_read(compat, cx, buf) - } - } - } - - fn poll_read_vectored( + fn poll_read( self: Pin<&mut Self>, cx: &mut Context, - bufs: &mut [IoSliceMut], - ) -> Poll> { + buf: &mut tokio::io::ReadBuf<'_>, + ) -> Poll> { match self.project() { - EitherStreamProj::Plain(s) => { - let compat = s.compat(); - futures_util::pin_mut!(compat); - AsyncRead::poll_read_vectored(compat, cx, bufs) - } + EitherStreamProj::Plain(stream) => AsyncRead::poll_read(stream, cx, buf), #[cfg(feature = "__tls")] - EitherStreamProj::Tls(t) => { - let compat = t.compat(); - futures_util::pin_mut!(compat); - AsyncRead::poll_read_vectored(compat, cx, bufs) - } + EitherStreamProj::Tls(stream) => AsyncRead::poll_read(stream, cx, buf), } } } @@ -90,65 +64,25 @@ impl AsyncRead for EitherStream { impl AsyncWrite for EitherStream { fn poll_write(self: Pin<&mut Self>, cx: &mut Context, buf: &[u8]) -> Poll> { match self.project() { - EitherStreamProj::Plain(s) => { - let compat = s.compat_write(); - futures_util::pin_mut!(compat); - AsyncWrite::poll_write(compat, cx, buf) - } - #[cfg(feature = "__tls")] - EitherStreamProj::Tls(t) => { - let compat = t.compat_write(); - futures_util::pin_mut!(compat); - AsyncWrite::poll_write(compat, cx, buf) - } - } - } - - fn poll_write_vectored(self: Pin<&mut Self>, cx: &mut Context, bufs: &[IoSlice]) -> Poll> { - match self.project() { - EitherStreamProj::Plain(s) => { - let compat = s.compat_write(); - futures_util::pin_mut!(compat); - AsyncWrite::poll_write_vectored(compat, cx, bufs) - } + EitherStreamProj::Plain(stream) => AsyncWrite::poll_write(stream, cx, buf), #[cfg(feature = "__tls")] - EitherStreamProj::Tls(t) => { - let compat = t.compat_write(); - futures_util::pin_mut!(compat); - AsyncWrite::poll_write_vectored(compat, cx, bufs) - } + EitherStreamProj::Tls(stream) => AsyncWrite::poll_write(stream, cx, buf), } } - fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll> { + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { match self.project() { - EitherStreamProj::Plain(s) => { - let compat = s.compat_write(); - futures_util::pin_mut!(compat); - AsyncWrite::poll_flush(compat, cx) - } + EitherStreamProj::Plain(stream) => AsyncWrite::poll_flush(stream, cx), #[cfg(feature = "__tls")] - EitherStreamProj::Tls(t) => { - let compat = t.compat_write(); - futures_util::pin_mut!(compat); - AsyncWrite::poll_flush(compat, cx) - } + EitherStreamProj::Tls(stream) => AsyncWrite::poll_flush(stream, cx), } } - fn poll_close(self: Pin<&mut Self>, cx: &mut Context) -> Poll> { + fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { match self.project() { - EitherStreamProj::Plain(s) => { - let compat = s.compat_write(); - futures_util::pin_mut!(compat); - AsyncWrite::poll_close(compat, cx) - } + EitherStreamProj::Plain(stream) => AsyncWrite::poll_shutdown(stream, cx), #[cfg(feature = "__tls")] - EitherStreamProj::Tls(t) => { - let compat = t.compat_write(); - futures_util::pin_mut!(compat); - AsyncWrite::poll_close(compat, cx) - } + EitherStreamProj::Tls(stream) => AsyncWrite::poll_shutdown(stream, cx), } } } diff --git a/tests/tests/helpers.rs b/tests/tests/helpers.rs index 2edfb0e4f5..0398d3ae86 100644 --- a/tests/tests/helpers.rs +++ b/tests/tests/helpers.rs @@ -32,7 +32,7 @@ use std::time::Duration; use fast_socks5::client::Socks5Stream; use fast_socks5::server; -use futures::{AsyncRead, AsyncWrite, SinkExt, Stream, StreamExt}; +use futures::{SinkExt, Stream, StreamExt}; use jsonrpsee::core::Error; use jsonrpsee::server::middleware::http::ProxyGetRequestLayer; use jsonrpsee::server::{ @@ -40,12 +40,10 @@ use jsonrpsee::server::{ }; use jsonrpsee::types::{ErrorObject, ErrorObjectOwned}; use jsonrpsee::SubscriptionCloseResponse; -use pin_project::pin_project; use serde::Serialize; use tokio::net::TcpStream; use tokio::time::interval; use tokio_stream::wrappers::IntervalStream; -use tokio_util::compat::{TokioAsyncReadCompatExt, TokioAsyncWriteCompatExt}; use tower_http::cors::CorsLayer; pub async fn server_with_subscription_and_handle() -> (SocketAddr, ServerHandle) { @@ -290,54 +288,3 @@ pub async fn connect_over_socks_stream(server_addr: SocketAddr) -> Socks5Stream< .await .unwrap() } - -#[pin_project] -pub struct DataStream(#[pin] Socks5Stream); - -impl DataStream { - pub fn new(t: Socks5Stream) -> Self { - Self(t) - } -} - -impl AsyncRead for DataStream { - fn poll_read( - self: std::pin::Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - buf: &mut [u8], - ) -> std::task::Poll> { - let this = self.project().0.compat(); - futures_util::pin_mut!(this); - AsyncRead::poll_read(this, cx, buf) - } -} - -impl AsyncWrite for DataStream { - fn poll_write( - self: std::pin::Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - buf: &[u8], - ) -> std::task::Poll> { - let this = self.project().0.compat_write(); - futures_util::pin_mut!(this); - AsyncWrite::poll_write(this, cx, buf) - } - - fn poll_flush( - self: std::pin::Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - ) -> std::task::Poll> { - let this = self.project().0.compat_write(); - futures_util::pin_mut!(this); - AsyncWrite::poll_flush(this, cx) - } - - fn poll_close( - self: std::pin::Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - ) -> std::task::Poll> { - let this = self.project().0.compat_write(); - futures_util::pin_mut!(this); - AsyncWrite::poll_close(this, cx) - } -} diff --git a/tests/tests/integration_tests.rs b/tests/tests/integration_tests.rs index d4041cf6ae..38851a07e1 100644 --- a/tests/tests/integration_tests.rs +++ b/tests/tests/integration_tests.rs @@ -37,7 +37,7 @@ use futures::stream::FuturesUnordered; use futures::{channel::mpsc, StreamExt, TryStreamExt}; use helpers::{ connect_over_socks_stream, init_logger, pipe_from_stream_and_drop, server, server_with_cors, - server_with_health_api, server_with_subscription, server_with_subscription_and_handle, DataStream, + server_with_health_api, server_with_subscription, server_with_subscription_and_handle, }; use hyper::http::HeaderValue; use jsonrpsee::core::client::{ClientT, IdKind, Subscription, SubscriptionClientT}; @@ -85,9 +85,7 @@ async fn ws_subscription_works_over_proxy_stream() { let target_url = format!("ws://{}", server_addr); let socks_stream = connect_over_socks_stream(server_addr).await; - let data_stream = DataStream::new(socks_stream); - - let client = WsClientBuilder::default().build_with_stream(target_url, data_stream).await.unwrap(); + let client = WsClientBuilder::default().build_with_stream(target_url, socks_stream).await.unwrap(); let mut hello_sub: Subscription = client.subscribe("subscribe_hello", rpc_params![], "unsubscribe_hello").await.unwrap(); @@ -129,8 +127,8 @@ async fn ws_unsubscription_works_over_proxy_stream() { let server_addr = server_with_sleeping_subscription(tx).await; let server_url = format!("ws://{}", server_addr); - let stream = DataStream::new(connect_over_socks_stream(server_addr).await); - let client = WsClientBuilder::default().build_with_stream(&server_url, stream).await.unwrap(); + let socks_stream = connect_over_socks_stream(server_addr).await; + let client = WsClientBuilder::default().build_with_stream(&server_url, socks_stream).await.unwrap(); let sub: Subscription = client.subscribe("subscribe_sleep", rpc_params![], "unsubscribe_sleep").await.unwrap(); @@ -166,9 +164,7 @@ async fn ws_subscription_with_input_works_over_proxy_stream() { let server_url = format!("ws://{}", server_addr); let socks_stream = connect_over_socks_stream(server_addr).await; - let data_stream = DataStream::new(socks_stream); - - let client = WsClientBuilder::default().build_with_stream(&server_url, data_stream).await.unwrap(); + let client = WsClientBuilder::default().build_with_stream(&server_url, socks_stream).await.unwrap(); let mut add_one: Subscription = client.subscribe("subscribe_add_one", rpc_params![1], "unsubscribe_add_one").await.unwrap(); @@ -198,9 +194,8 @@ async fn ws_method_call_works_over_proxy_stream() { let server_url = format!("ws://{}", server_addr); let socks_stream = connect_over_socks_stream(server_addr).await; - let data_stream = DataStream::new(socks_stream); - let client = WsClientBuilder::default().build_with_stream(&server_url, data_stream).await.unwrap(); + let client = WsClientBuilder::default().build_with_stream(&server_url, socks_stream).await.unwrap(); let response: String = client.request("say_hello", rpc_params![]).await.unwrap(); assert_eq!(&response, "hello"); } @@ -224,10 +219,12 @@ async fn ws_method_call_str_id_works_over_proxy_stream() { let server_url = format!("ws://{}", server_addr); let socks_stream = connect_over_socks_stream(server_addr).await; - let data_stream = DataStream::new(socks_stream); - let client = - WsClientBuilder::default().id_format(IdKind::String).build_with_stream(&server_url, data_stream).await.unwrap(); + let client = WsClientBuilder::default() + .id_format(IdKind::String) + .build_with_stream(&server_url, socks_stream) + .await + .unwrap(); let response: String = client.request("say_hello", rpc_params![]).await.unwrap(); assert_eq!(&response, "hello"); }