Skip to content

Commit

Permalink
refactor(ws client): impl tokio:{AsyncRead, AsyncWrite} for EitherStr…
Browse files Browse the repository at this point in the history
…eam (#1249)

* refactor(ws client): tokio:{AsyncRead, AsyncWrite]

Simplify the code by implementing `tokio::io::{AsyncWrite, AsyncRead}`
for the EitherStream.

However, we still need the compat because soketto requires
futures::io::{AsyncRead, AsyncWrite}

* cargo fmt
  • Loading branch information
niklasad1 authored Dec 5, 2023
1 parent ce61f7e commit bb5780c
Show file tree
Hide file tree
Showing 4 changed files with 40 additions and 158 deletions.
24 changes: 14 additions & 10 deletions client/transport/src/ws/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ use std::net::SocketAddr;
use std::time::Duration;

use futures_util::io::{BufReader, BufWriter};
pub use futures_util::{AsyncRead, AsyncWrite};
use jsonrpsee_core::client::{CertificateStore, MaybeSend, ReceivedMessage, TransportReceiverT, TransportSenderT};
use jsonrpsee_core::TEN_MB_SIZE_BYTES;
use jsonrpsee_core::{async_trait, Cow};
Expand All @@ -41,10 +40,12 @@ use soketto::handshake::client::{Client as WsHandshakeClient, ServerResponse};
use soketto::{connection, Data, Incoming};
use thiserror::Error;
use tokio::net::TcpStream;
use tokio_util::compat::{Compat, TokioAsyncReadCompatExt};

pub use http::{uri::InvalidUri, HeaderMap, HeaderValue, Uri};
pub use soketto::handshake::client::Header;
pub use stream::EitherStream;
pub use tokio::io::{AsyncRead, AsyncWrite};
pub use url::Url;

const LOG_TARGET: &str = "jsonrpsee-client";
Expand Down Expand Up @@ -229,7 +230,7 @@ pub enum WsError {
#[async_trait]
impl<T> TransportSenderT for Sender<T>
where
T: AsyncRead + AsyncWrite + Unpin + MaybeSend + 'static,
T: futures_util::io::AsyncRead + futures_util::io::AsyncWrite + Unpin + MaybeSend + 'static,
{
type Error = WsError;

Expand Down Expand Up @@ -268,7 +269,7 @@ where
#[async_trait]
impl<T> TransportReceiverT for Receiver<T>
where
T: AsyncRead + AsyncWrite + Unpin + MaybeSend + 'static,
T: futures_util::io::AsyncRead + futures_util::io::AsyncWrite + Unpin + MaybeSend + 'static,
{
type Error = WsError;

Expand All @@ -295,7 +296,10 @@ impl WsTransportClientBuilder {
/// Try to establish the connection.
///
/// Uses the default connection over TCP.
pub async fn build(self, uri: Url) -> Result<(Sender<EitherStream>, Receiver<EitherStream>), WsHandshakeError> {
pub async fn build(
self,
uri: Url,
) -> Result<(Sender<Compat<EitherStream>>, Receiver<Compat<EitherStream>>), WsHandshakeError> {
self.try_connect_over_tcp(uri).await
}

Expand All @@ -304,19 +308,19 @@ impl WsTransportClientBuilder {
self,
uri: Url,
data_stream: T,
) -> Result<(Sender<T>, Receiver<T>), WsHandshakeError>
) -> Result<(Sender<Compat<T>>, Receiver<Compat<T>>), WsHandshakeError>
where
T: AsyncRead + AsyncWrite + Unpin,
T: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin,
{
let target: Target = uri.try_into()?;
self.try_connect(&target, data_stream).await
self.try_connect(&target, data_stream.compat()).await
}

// Try to establish the connection over TCP.
async fn try_connect_over_tcp(
&self,
uri: Url,
) -> Result<(Sender<EitherStream>, Receiver<EitherStream>), WsHandshakeError> {
) -> Result<(Sender<Compat<EitherStream>>, Receiver<Compat<EitherStream>>), WsHandshakeError> {
let mut target: Target = uri.try_into()?;
let mut err = None;

Expand Down Expand Up @@ -353,7 +357,7 @@ impl WsTransportClientBuilder {
}
};

match self.try_connect(&target, tcp_stream).await {
match self.try_connect(&target, tcp_stream.compat()).await {
Ok(result) => return Ok(result),

Err(WsHandshakeError::Redirected { status_code, location }) => {
Expand Down Expand Up @@ -422,7 +426,7 @@ impl WsTransportClientBuilder {
data_stream: T,
) -> Result<(Sender<T>, Receiver<T>), WsHandshakeError>
where
T: AsyncRead + AsyncWrite + Unpin,
T: futures_util::AsyncRead + futures_util::AsyncWrite + Unpin,
{
let mut client = WsHandshakeClient::new(
BufReader::new(BufWriter::new(data_stream)),
Expand Down
94 changes: 14 additions & 80 deletions client/transport/src/ws/stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,9 @@ use std::pin::Pin;
use std::task::Context;
use std::task::Poll;

use futures_util::io::{IoSlice, IoSliceMut};
use futures_util::*;
use pin_project::pin_project;
use tokio::io::{AsyncRead, AsyncWrite};
use tokio::net::TcpStream;
use tokio_util::compat::{TokioAsyncReadCompatExt, TokioAsyncWriteCompatExt};

/// Stream to represent either a unencrypted or encrypted socket stream.
#[pin_project(project = EitherStreamProj)]
Expand All @@ -50,105 +48,41 @@ pub enum EitherStream {
}

impl AsyncRead for EitherStream {
fn poll_read(self: Pin<&mut Self>, cx: &mut Context, buf: &mut [u8]) -> Poll<Result<usize, IoError>> {
match self.project() {
EitherStreamProj::Plain(s) => {
let compat = s.compat();
futures_util::pin_mut!(compat);
AsyncRead::poll_read(compat, cx, buf)
}
#[cfg(feature = "__tls")]
EitherStreamProj::Tls(t) => {
let compat = t.compat();
futures_util::pin_mut!(compat);
AsyncRead::poll_read(compat, cx, buf)
}
}
}

fn poll_read_vectored(
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context,
bufs: &mut [IoSliceMut],
) -> Poll<Result<usize, IoError>> {
buf: &mut tokio::io::ReadBuf<'_>,
) -> Poll<Result<(), IoError>> {
match self.project() {
EitherStreamProj::Plain(s) => {
let compat = s.compat();
futures_util::pin_mut!(compat);
AsyncRead::poll_read_vectored(compat, cx, bufs)
}
EitherStreamProj::Plain(stream) => AsyncRead::poll_read(stream, cx, buf),
#[cfg(feature = "__tls")]
EitherStreamProj::Tls(t) => {
let compat = t.compat();
futures_util::pin_mut!(compat);
AsyncRead::poll_read_vectored(compat, cx, bufs)
}
EitherStreamProj::Tls(stream) => AsyncRead::poll_read(stream, cx, buf),
}
}
}

impl AsyncWrite for EitherStream {
fn poll_write(self: Pin<&mut Self>, cx: &mut Context, buf: &[u8]) -> Poll<Result<usize, IoError>> {
match self.project() {
EitherStreamProj::Plain(s) => {
let compat = s.compat_write();
futures_util::pin_mut!(compat);
AsyncWrite::poll_write(compat, cx, buf)
}
#[cfg(feature = "__tls")]
EitherStreamProj::Tls(t) => {
let compat = t.compat_write();
futures_util::pin_mut!(compat);
AsyncWrite::poll_write(compat, cx, buf)
}
}
}

fn poll_write_vectored(self: Pin<&mut Self>, cx: &mut Context, bufs: &[IoSlice]) -> Poll<Result<usize, IoError>> {
match self.project() {
EitherStreamProj::Plain(s) => {
let compat = s.compat_write();
futures_util::pin_mut!(compat);
AsyncWrite::poll_write_vectored(compat, cx, bufs)
}
EitherStreamProj::Plain(stream) => AsyncWrite::poll_write(stream, cx, buf),
#[cfg(feature = "__tls")]
EitherStreamProj::Tls(t) => {
let compat = t.compat_write();
futures_util::pin_mut!(compat);
AsyncWrite::poll_write_vectored(compat, cx, bufs)
}
EitherStreamProj::Tls(stream) => AsyncWrite::poll_write(stream, cx, buf),
}
}

fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), IoError>> {
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), IoError>> {
match self.project() {
EitherStreamProj::Plain(s) => {
let compat = s.compat_write();
futures_util::pin_mut!(compat);
AsyncWrite::poll_flush(compat, cx)
}
EitherStreamProj::Plain(stream) => AsyncWrite::poll_flush(stream, cx),
#[cfg(feature = "__tls")]
EitherStreamProj::Tls(t) => {
let compat = t.compat_write();
futures_util::pin_mut!(compat);
AsyncWrite::poll_flush(compat, cx)
}
EitherStreamProj::Tls(stream) => AsyncWrite::poll_flush(stream, cx),
}
}

fn poll_close(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), IoError>> {
fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), IoError>> {
match self.project() {
EitherStreamProj::Plain(s) => {
let compat = s.compat_write();
futures_util::pin_mut!(compat);
AsyncWrite::poll_close(compat, cx)
}
EitherStreamProj::Plain(stream) => AsyncWrite::poll_shutdown(stream, cx),
#[cfg(feature = "__tls")]
EitherStreamProj::Tls(t) => {
let compat = t.compat_write();
futures_util::pin_mut!(compat);
AsyncWrite::poll_close(compat, cx)
}
EitherStreamProj::Tls(stream) => AsyncWrite::poll_shutdown(stream, cx),
}
}
}
55 changes: 1 addition & 54 deletions tests/tests/helpers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,20 +32,18 @@ use std::time::Duration;

use fast_socks5::client::Socks5Stream;
use fast_socks5::server;
use futures::{AsyncRead, AsyncWrite, SinkExt, Stream, StreamExt};
use futures::{SinkExt, Stream, StreamExt};
use jsonrpsee::core::Error;
use jsonrpsee::server::middleware::http::ProxyGetRequestLayer;
use jsonrpsee::server::{
PendingSubscriptionSink, RpcModule, Server, ServerBuilder, ServerHandle, SubscriptionMessage, TrySendError,
};
use jsonrpsee::types::{ErrorObject, ErrorObjectOwned};
use jsonrpsee::SubscriptionCloseResponse;
use pin_project::pin_project;
use serde::Serialize;
use tokio::net::TcpStream;
use tokio::time::interval;
use tokio_stream::wrappers::IntervalStream;
use tokio_util::compat::{TokioAsyncReadCompatExt, TokioAsyncWriteCompatExt};
use tower_http::cors::CorsLayer;

pub async fn server_with_subscription_and_handle() -> (SocketAddr, ServerHandle) {
Expand Down Expand Up @@ -290,54 +288,3 @@ pub async fn connect_over_socks_stream(server_addr: SocketAddr) -> Socks5Stream<
.await
.unwrap()
}

#[pin_project]
pub struct DataStream<T: tokio::io::AsyncRead + tokio::io::AsyncWrite + std::marker::Unpin>(#[pin] Socks5Stream<T>);

impl<T: tokio::io::AsyncRead + tokio::io::AsyncWrite + std::marker::Unpin> DataStream<T> {
pub fn new(t: Socks5Stream<T>) -> Self {
Self(t)
}
}

impl<T: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin> AsyncRead for DataStream<T> {
fn poll_read(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &mut [u8],
) -> std::task::Poll<std::io::Result<usize>> {
let this = self.project().0.compat();
futures_util::pin_mut!(this);
AsyncRead::poll_read(this, cx, buf)
}
}

impl<T: tokio::io::AsyncRead + tokio::io::AsyncWrite + std::marker::Unpin> AsyncWrite for DataStream<T> {
fn poll_write(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &[u8],
) -> std::task::Poll<std::io::Result<usize>> {
let this = self.project().0.compat_write();
futures_util::pin_mut!(this);
AsyncWrite::poll_write(this, cx, buf)
}

fn poll_flush(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<std::io::Result<()>> {
let this = self.project().0.compat_write();
futures_util::pin_mut!(this);
AsyncWrite::poll_flush(this, cx)
}

fn poll_close(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<std::io::Result<()>> {
let this = self.project().0.compat_write();
futures_util::pin_mut!(this);
AsyncWrite::poll_close(this, cx)
}
}
25 changes: 11 additions & 14 deletions tests/tests/integration_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ use futures::stream::FuturesUnordered;
use futures::{channel::mpsc, StreamExt, TryStreamExt};
use helpers::{
connect_over_socks_stream, init_logger, pipe_from_stream_and_drop, server, server_with_cors,
server_with_health_api, server_with_subscription, server_with_subscription_and_handle, DataStream,
server_with_health_api, server_with_subscription, server_with_subscription_and_handle,
};
use hyper::http::HeaderValue;
use jsonrpsee::core::client::{ClientT, IdKind, Subscription, SubscriptionClientT};
Expand Down Expand Up @@ -85,9 +85,7 @@ async fn ws_subscription_works_over_proxy_stream() {
let target_url = format!("ws://{}", server_addr);

let socks_stream = connect_over_socks_stream(server_addr).await;
let data_stream = DataStream::new(socks_stream);

let client = WsClientBuilder::default().build_with_stream(target_url, data_stream).await.unwrap();
let client = WsClientBuilder::default().build_with_stream(target_url, socks_stream).await.unwrap();

let mut hello_sub: Subscription<String> =
client.subscribe("subscribe_hello", rpc_params![], "unsubscribe_hello").await.unwrap();
Expand Down Expand Up @@ -129,8 +127,8 @@ async fn ws_unsubscription_works_over_proxy_stream() {
let server_addr = server_with_sleeping_subscription(tx).await;
let server_url = format!("ws://{}", server_addr);

let stream = DataStream::new(connect_over_socks_stream(server_addr).await);
let client = WsClientBuilder::default().build_with_stream(&server_url, stream).await.unwrap();
let socks_stream = connect_over_socks_stream(server_addr).await;
let client = WsClientBuilder::default().build_with_stream(&server_url, socks_stream).await.unwrap();

let sub: Subscription<usize> =
client.subscribe("subscribe_sleep", rpc_params![], "unsubscribe_sleep").await.unwrap();
Expand Down Expand Up @@ -166,9 +164,7 @@ async fn ws_subscription_with_input_works_over_proxy_stream() {
let server_url = format!("ws://{}", server_addr);

let socks_stream = connect_over_socks_stream(server_addr).await;
let data_stream = DataStream::new(socks_stream);

let client = WsClientBuilder::default().build_with_stream(&server_url, data_stream).await.unwrap();
let client = WsClientBuilder::default().build_with_stream(&server_url, socks_stream).await.unwrap();

let mut add_one: Subscription<u64> =
client.subscribe("subscribe_add_one", rpc_params![1], "unsubscribe_add_one").await.unwrap();
Expand Down Expand Up @@ -198,9 +194,8 @@ async fn ws_method_call_works_over_proxy_stream() {
let server_url = format!("ws://{}", server_addr);

let socks_stream = connect_over_socks_stream(server_addr).await;
let data_stream = DataStream::new(socks_stream);

let client = WsClientBuilder::default().build_with_stream(&server_url, data_stream).await.unwrap();
let client = WsClientBuilder::default().build_with_stream(&server_url, socks_stream).await.unwrap();
let response: String = client.request("say_hello", rpc_params![]).await.unwrap();
assert_eq!(&response, "hello");
}
Expand All @@ -224,10 +219,12 @@ async fn ws_method_call_str_id_works_over_proxy_stream() {
let server_url = format!("ws://{}", server_addr);

let socks_stream = connect_over_socks_stream(server_addr).await;
let data_stream = DataStream::new(socks_stream);

let client =
WsClientBuilder::default().id_format(IdKind::String).build_with_stream(&server_url, data_stream).await.unwrap();
let client = WsClientBuilder::default()
.id_format(IdKind::String)
.build_with_stream(&server_url, socks_stream)
.await
.unwrap();
let response: String = client.request("say_hello", rpc_params![]).await.unwrap();
assert_eq!(&response, "hello");
}
Expand Down

0 comments on commit bb5780c

Please sign in to comment.