Skip to content

Commit

Permalink
feat: Client connection upgrade support
Browse files Browse the repository at this point in the history
Fix broken upgrade support for client connections.

Hyper 1.0 requires calling .with_upgrades() on connections
  • Loading branch information
alexrudy committed Dec 8, 2024
1 parent f014766 commit cfacfbb
Show file tree
Hide file tree
Showing 3 changed files with 141 additions and 2 deletions.
2 changes: 1 addition & 1 deletion src/client/conn/protocol/auto.rs
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ where
.await
.map_err(|error| ConnectionError::Handshake(error.into()))?;
tokio::spawn(async {
if let Err(err) = conn.await {
if let Err(err) = conn.with_upgrades().await {
tracing::error!(err = format!("{err:#}"), "h1 connection driver error");
}
});
Expand Down
2 changes: 1 addition & 1 deletion src/client/conn/protocol/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ where
.map_err(|err| ConnectionError::Handshake(err.into()))?;

tokio::spawn(async {
if let Err(err) = conn.await {
if let Err(err) = conn.with_upgrades().await {
if err.is_user() {
tracing::error!(err = format!("{err:#}"), "h1 connection driver error");
} else {
Expand Down
139 changes: 139 additions & 0 deletions tests/upgrades.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
#![allow(missing_docs)]

use std::pin::pin;

use futures_util::StreamExt as _;
use hyper::upgrade::Upgraded;
use hyperdriver::{bridge::io::TokioIo, client::conn::transport::duplex::DuplexTransport};
use tokio::io::{AsyncReadExt as _, AsyncWriteExt as _};

const TIMEOUT: std::time::Duration = std::time::Duration::from_secs(5);

async fn server_for_client_upgrade(
) -> Result<DuplexTransport, Box<dyn std::error::Error + Send + Sync>> {
let (tx, incoming) = hyperdriver::stream::duplex::pair();

let acceptor: hyperdriver::server::conn::Acceptor =
hyperdriver::server::conn::Acceptor::from(incoming);

tokio::spawn(tokio::time::timeout(
TIMEOUT,
serve_one_h1_upgrade(acceptor),
));

Ok(DuplexTransport::new(1024, tx))
}

async fn serve_one_h1_upgrade(
acceptor: hyperdriver::server::conn::Acceptor,
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
let mut acceptor = pin!(acceptor);
let stream = acceptor.next().await.ok_or("no connection")??;

let service = hyper::service::service_fn(upgrade_svc);

let conn =
hyper::server::conn::http1::Builder::new().serve_connection(TokioIo::new(stream), service);

conn.with_upgrades().await?;

Ok(())
}

async fn upgrade_svc(
mut request: http::Request<hyper::body::Incoming>,
) -> Result<http::Response<hyperdriver::Body>, Box<dyn std::error::Error + Send + Sync>> {
if !request.headers().contains_key(http::header::UPGRADE) {
return Ok(http::Response::builder()
.status(http::StatusCode::BAD_REQUEST)
.body(hyperdriver::Body::empty())?);
}

tokio::spawn(tokio::time::timeout(TIMEOUT, async move {
let upgraded = hyper::upgrade::on(&mut request)
.await
.expect("[server] upgrade erorr");
server_upgraded_io(upgraded)
.await
.expect("[server] upgraded protocol error");
}));

Ok(http::Response::builder()
.status(http::StatusCode::SWITCHING_PROTOCOLS)
.body(hyperdriver::Body::empty())?)
}

async fn server_upgraded_io(
upgraded: Upgraded,
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
let mut upgraded = TokioIo::new(upgraded);
let mut vec = vec![0; 5];
upgraded.read_exact(&mut vec).await?;
let req = String::from_utf8_lossy(&vec);
println!("[server] client sent {req:?}");
if req != "hello" {
println!("[server] got unexpected request");
}
upgraded.write_all(b"world").await?;
println!("[server] sent response");
Ok(())
}

async fn clinet_upgraded_io(mut response: http::Response<hyperdriver::Body>) {
if response.status() != http::StatusCode::SWITCHING_PROTOCOLS {
panic!("Server didn't upgrade: {}", response.status());
}

let upgraded = hyper::upgrade::on(&mut response)
.await
.expect("upgrade error");
let mut upgraded = TokioIo::new(upgraded);
upgraded.write_all(b"hello").await.unwrap();
println!("[client] sent hello");
let mut vec = vec![0; 5];
upgraded.read_exact(&mut vec).await.unwrap();
let res = String::from_utf8_lossy(&vec);
println!("[client] got {res:?}");
assert_eq!(res, "world");
}

#[tokio::test]
async fn client_auto() {
let transport = server_for_client_upgrade().await.unwrap();

let mut client = hyperdriver::Client::builder()
.with_auto_http()
.with_default_pool()
.with_default_tls()
.with_transport(transport)
.build();

let request = http::Request::get("http://example.org")
.header(http::header::UPGRADE, "test-hyperdriver")
.body(hyperdriver::Body::empty())
.unwrap();

let response = client.request(request).await.unwrap();
clinet_upgraded_io(response).await;
}

#[tokio::test]
async fn client_http1() {
let transport = server_for_client_upgrade().await.unwrap();

let mut client = hyperdriver::Client::builder()
.with_protocol(hyper::client::conn::http1::Builder::new())
.with_default_pool()
.with_default_tls()
.with_transport(transport)
.build();

let request = http::Request::get("http://example.org")
.header(http::header::UPGRADE, "test-hyperdriver")
.body(hyperdriver::Body::empty())
.unwrap();

let response = client.request(request).await.unwrap();

clinet_upgraded_io(response).await;
}

0 comments on commit cfacfbb

Please sign in to comment.