diff --git a/Cargo.toml b/Cargo.toml index 030e4e3..c341d9a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -178,6 +178,11 @@ name = "tls" path = "tests/tls.rs" required-features = ["client", "server", "tls", "tls-ring"] +[[test]] +name = "upgrades" +path = "tests/upgrades.rs" +required-features = ["server", "client", "stream"] + [package.metadata.cargo-udeps.ignore] normal = ["rustls-native-certs", "rustls", "tokio-rustls"] diff --git a/src/client/conn/protocol/auto.rs b/src/client/conn/protocol/auto.rs index 14b91b2..02b7255 100644 --- a/src/client/conn/protocol/auto.rs +++ b/src/client/conn/protocol/auto.rs @@ -94,7 +94,7 @@ where .await .map_err(|error| ConnectionError::Handshake(error.into()))?; tokio::spawn(async { - if let Err(err) = conn.await { + if let Err(err) = conn.with_upgrades().await { tracing::error!(err = format!("{err:#}"), "h1 connection driver error"); } }); diff --git a/src/client/conn/protocol/mod.rs b/src/client/conn/protocol/mod.rs index b4aea4d..caacf23 100644 --- a/src/client/conn/protocol/mod.rs +++ b/src/client/conn/protocol/mod.rs @@ -190,7 +190,7 @@ where .map_err(|err| ConnectionError::Handshake(err.into()))?; tokio::spawn(async { - if let Err(err) = conn.await { + if let Err(err) = conn.with_upgrades().await { if err.is_user() { tracing::error!(err = format!("{err:#}"), "h1 connection driver error"); } else { diff --git a/tests/upgrades.rs b/tests/upgrades.rs new file mode 100644 index 0000000..fe991d4 --- /dev/null +++ b/tests/upgrades.rs @@ -0,0 +1,137 @@ +#![allow(missing_docs)] + +use std::pin::pin; + +use futures_util::StreamExt as _; +use hyper::upgrade::Upgraded; +use hyperdriver::{bridge::io::TokioIo, client::conn::transport::duplex::DuplexTransport}; +use tokio::io::{AsyncReadExt as _, AsyncWriteExt as _}; + +const TIMEOUT: std::time::Duration = std::time::Duration::from_secs(5); + +async fn server_for_client_upgrade( +) -> Result> { + let (tx, incoming) = hyperdriver::stream::duplex::pair(); + + let acceptor: hyperdriver::server::conn::Acceptor = + hyperdriver::server::conn::Acceptor::from(incoming); + + tokio::spawn(tokio::time::timeout( + TIMEOUT, + serve_one_h1_upgrade(acceptor), + )); + + Ok(DuplexTransport::new(1024, tx)) +} + +async fn serve_one_h1_upgrade( + acceptor: hyperdriver::server::conn::Acceptor, +) -> Result<(), Box> { + let mut acceptor = pin!(acceptor); + let stream = acceptor.next().await.ok_or("no connection")??; + + let service = hyper::service::service_fn(upgrade_svc); + + let conn = + hyper::server::conn::http1::Builder::new().serve_connection(TokioIo::new(stream), service); + + conn.with_upgrades().await?; + + Ok(()) +} + +async fn upgrade_svc( + mut request: http::Request, +) -> Result, Box> { + if !request.headers().contains_key(http::header::UPGRADE) { + return Ok(http::Response::builder() + .status(http::StatusCode::BAD_REQUEST) + .body(hyperdriver::Body::empty())?); + } + + tokio::spawn(tokio::time::timeout(TIMEOUT, async move { + let upgraded = hyper::upgrade::on(&mut request) + .await + .expect("[server] upgrade erorr"); + server_upgraded_io(upgraded) + .await + .expect("[server] upgraded protocol error"); + })); + + Ok(http::Response::builder() + .status(http::StatusCode::SWITCHING_PROTOCOLS) + .body(hyperdriver::Body::empty())?) +} + +async fn server_upgraded_io( + upgraded: Upgraded, +) -> Result<(), Box> { + let mut upgraded = TokioIo::new(upgraded); + let mut vec = vec![0; 5]; + upgraded.read_exact(&mut vec).await?; + let req = String::from_utf8_lossy(&vec); + println!("[server] client sent {req:?}"); + if req != "hello" { + println!("[server] got unexpected request"); + } + upgraded.write_all(b"world").await?; + println!("[server] sent response"); + Ok(()) +} + +async fn clinet_upgraded_io(mut response: http::Response) { + if response.status() != http::StatusCode::SWITCHING_PROTOCOLS { + panic!("Server didn't upgrade: {}", response.status()); + } + + let upgraded = hyper::upgrade::on(&mut response) + .await + .expect("upgrade error"); + let mut upgraded = TokioIo::new(upgraded); + upgraded.write_all(b"hello").await.unwrap(); + println!("[client] sent hello"); + let mut vec = vec![0; 5]; + upgraded.read_exact(&mut vec).await.unwrap(); + let res = String::from_utf8_lossy(&vec); + println!("[client] got {res:?}"); + assert_eq!(res, "world"); +} + +#[tokio::test] +async fn client_auto() { + let transport = server_for_client_upgrade().await.unwrap(); + + let mut client = hyperdriver::Client::builder() + .with_auto_http() + .with_default_pool() + .with_transport(transport) + .build(); + + let request = http::Request::get("http://example.org") + .header(http::header::UPGRADE, "test-hyperdriver") + .body(hyperdriver::Body::empty()) + .unwrap(); + + let response = client.request(request).await.unwrap(); + clinet_upgraded_io(response).await; +} + +#[tokio::test] +async fn client_http1() { + let transport = server_for_client_upgrade().await.unwrap(); + + let mut client = hyperdriver::Client::builder() + .with_protocol(hyper::client::conn::http1::Builder::new()) + .with_default_pool() + .with_transport(transport) + .build(); + + let request = http::Request::get("http://example.org") + .header(http::header::UPGRADE, "test-hyperdriver") + .body(hyperdriver::Body::empty()) + .unwrap(); + + let response = client.request(request).await.unwrap(); + + clinet_upgraded_io(response).await; +}