Skip to content

Commit

Permalink
fix(client): remove Send bounds for request Body (hyperium#3266)
Browse files Browse the repository at this point in the history
Closes hyperium#3184

Signed-off-by: Sven Pfennig <s.pfennig@reply.de>
  • Loading branch information
Ruben2424 authored and 0xE282B0 committed Jan 16, 2024
1 parent 29926df commit 5400295
Show file tree
Hide file tree
Showing 3 changed files with 180 additions and 59 deletions.
235 changes: 178 additions & 57 deletions examples/single_threaded.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ impl HttpBody for Body {
fn main() {
pretty_env_logger::init();

let server = thread::spawn(move || {
let server_http2 = thread::spawn(move || {
// Configure a runtime for the server that runs everything on the current thread
let rt = tokio::runtime::Builder::new_current_thread()
.enable_all()
Expand All @@ -61,10 +61,10 @@ fn main() {

// Combine it with a `LocalSet, which means it can spawn !Send futures...
let local = tokio::task::LocalSet::new();
local.block_on(&rt, server()).unwrap();
local.block_on(&rt, http2_server()).unwrap();
});

let client = thread::spawn(move || {
let client_http2 = thread::spawn(move || {
// Configure a runtime for the client that runs everything on the current thread
let rt = tokio::runtime::Builder::new_current_thread()
.enable_all()
Expand All @@ -76,16 +76,137 @@ fn main() {
local
.block_on(
&rt,
client("http://localhost:3000".parse::<hyper::Uri>().unwrap()),
http2_client("http://localhost:3000".parse::<hyper::Uri>().unwrap()),
)
.unwrap();
});

server.join().unwrap();
client.join().unwrap();
let server_http1 = thread::spawn(move || {
// Configure a runtime for the server that runs everything on the current thread
let rt = tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
.expect("build runtime");

// Combine it with a `LocalSet, which means it can spawn !Send futures...
let local = tokio::task::LocalSet::new();
local.block_on(&rt, http1_server()).unwrap();
});

let client_http1 = thread::spawn(move || {
// Configure a runtime for the client that runs everything on the current thread
let rt = tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
.expect("build runtime");

// Combine it with a `LocalSet, which means it can spawn !Send futures...
let local = tokio::task::LocalSet::new();
local
.block_on(
&rt,
http1_client("http://localhost:3001".parse::<hyper::Uri>().unwrap()),
)
.unwrap();
});

server_http2.join().unwrap();
client_http2.join().unwrap();

server_http1.join().unwrap();
client_http1.join().unwrap();
}

async fn http1_server() -> Result<(), Box<dyn std::error::Error>> {
let addr = SocketAddr::from(([127, 0, 0, 1], 3001));

let listener = TcpListener::bind(addr).await?;

// For each connection, clone the counter to use in our service...
let counter = Rc::new(Cell::new(0));

loop {
let (stream, _) = listener.accept().await?;

let io = TokioIo::new(stream);

let cnt = counter.clone();

let service = service_fn(move |_| {
let prev = cnt.get();
cnt.set(prev + 1);
let value = cnt.get();
async move { Ok::<_, Error>(Response::new(Body::from(format!("Request #{}", value)))) }
});

tokio::task::spawn_local(async move {
if let Err(err) = hyper::server::conn::http1::Builder::new()
.serve_connection(io, service)
.await
{
println!("Error serving connection: {:?}", err);
}
});
}
}

async fn http1_client(url: hyper::Uri) -> Result<(), Box<dyn std::error::Error>> {
let host = url.host().expect("uri has no host");
let port = url.port_u16().unwrap_or(80);
let addr = format!("{}:{}", host, port);
let stream = TcpStream::connect(addr).await?;

let io = TokioIo::new(stream);

let (mut sender, conn) = hyper::client::conn::http1::handshake(io).await?;

tokio::task::spawn_local(async move {
if let Err(err) = conn.await {
let mut stdout = io::stdout();
stdout
.write_all(format!("Connection failed: {:?}", err).as_bytes())
.await
.unwrap();
stdout.flush().await.unwrap();
}
});

let authority = url.authority().unwrap().clone();

// Make 4 requests
for _ in 0..4 {
let req = Request::builder()
.uri(url.clone())
.header(hyper::header::HOST, authority.as_str())
.body(Body::from("test".to_string()))?;

let mut res = sender.send_request(req).await?;

let mut stdout = io::stdout();
stdout
.write_all(format!("Response: {}\n", res.status()).as_bytes())
.await
.unwrap();
stdout
.write_all(format!("Headers: {:#?}\n", res.headers()).as_bytes())
.await
.unwrap();
stdout.flush().await.unwrap();

// Print the response body
while let Some(next) = res.frame().await {
let frame = next?;
if let Some(chunk) = frame.data_ref() {
stdout.write_all(&chunk).await.unwrap();
}
}
stdout.write_all(b"\n-----------------\n").await.unwrap();
stdout.flush().await.unwrap();
}
Ok(())
}

async fn server() -> Result<(), Box<dyn std::error::Error>> {
async fn http2_server() -> Result<(), Box<dyn std::error::Error>> {
let mut stdout = io::stdout();

let addr: SocketAddr = ([127, 0, 0, 1], 3000).into();
Expand All @@ -102,7 +223,7 @@ async fn server() -> Result<(), Box<dyn std::error::Error>> {

loop {
let (stream, _) = listener.accept().await?;
let io = TokioIo::new(stream);
let io = IOTypeNotSend::new(TokioIo::new(stream));

// For each connection, clone the counter to use in our service...
let cnt = counter.clone();
Expand Down Expand Up @@ -130,55 +251,7 @@ async fn server() -> Result<(), Box<dyn std::error::Error>> {
}
}

struct IOTypeNotSend {
_marker: PhantomData<*const ()>,
stream: TokioIo<TcpStream>,
}

impl IOTypeNotSend {
fn new(stream: TokioIo<TcpStream>) -> Self {
Self {
_marker: PhantomData,
stream,
}
}
}

impl hyper::rt::Write for IOTypeNotSend {
fn poll_write(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<Result<usize, std::io::Error>> {
Pin::new(&mut self.stream).poll_write(cx, buf)
}

fn poll_flush(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Result<(), std::io::Error>> {
Pin::new(&mut self.stream).poll_flush(cx)
}

fn poll_shutdown(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Result<(), std::io::Error>> {
Pin::new(&mut self.stream).poll_shutdown(cx)
}
}

impl hyper::rt::Read for IOTypeNotSend {
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: hyper::rt::ReadBufCursor<'_>,
) -> Poll<std::io::Result<()>> {
Pin::new(&mut self.stream).poll_read(cx, buf)
}
}

async fn client(url: hyper::Uri) -> Result<(), Box<dyn std::error::Error>> {
async fn http2_client(url: hyper::Uri) -> Result<(), Box<dyn std::error::Error>> {
let host = url.host().expect("uri has no host");
let port = url.port_u16().unwrap_or(80);
let addr = format!("{}:{}", host, port);
Expand Down Expand Up @@ -250,3 +323,51 @@ where
tokio::task::spawn_local(fut);
}
}

struct IOTypeNotSend {
_marker: PhantomData<*const ()>,
stream: TokioIo<TcpStream>,
}

impl IOTypeNotSend {
fn new(stream: TokioIo<TcpStream>) -> Self {
Self {
_marker: PhantomData,
stream,
}
}
}

impl hyper::rt::Write for IOTypeNotSend {
fn poll_write(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<Result<usize, std::io::Error>> {
Pin::new(&mut self.stream).poll_write(cx, buf)
}

fn poll_flush(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Result<(), std::io::Error>> {
Pin::new(&mut self.stream).poll_flush(cx)
}

fn poll_shutdown(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Result<(), std::io::Error>> {
Pin::new(&mut self.stream).poll_shutdown(cx)
}
}

impl hyper::rt::Read for IOTypeNotSend {
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: hyper::rt::ReadBufCursor<'_>,
) -> Poll<std::io::Result<()>> {
Pin::new(&mut self.stream).poll_read(cx, buf)
}
}
2 changes: 1 addition & 1 deletion src/client/conn/http1.rs
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,7 @@ where
impl<T, B> Future for Connection<T, B>
where
T: Read + Write + Unpin + Send + 'static,
B: Body + Send + 'static,
B: Body + 'static,
B::Data: Send,
B::Error: Into<Box<dyn StdError + Send + Sync>>,
{
Expand Down
2 changes: 1 addition & 1 deletion src/client/conn/http2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ impl<B> fmt::Debug for SendRequest<B> {
impl<T, B, E> Connection<T, B, E>
where
T: Read + Write + Unpin + 'static,
B: Body + Unpin + Send + 'static,
B: Body + Unpin + 'static,
B::Data: Send,
B::Error: Into<Box<dyn Error + Send + Sync>>,
E: ExecutorClient<B, T> + Unpin,
Expand Down

0 comments on commit 5400295

Please sign in to comment.