diff --git a/Cargo.lock b/Cargo.lock index e08934dc..90e52799 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -199,7 +199,7 @@ version = "0.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "62b74f44609f0f91493e3082d3734d98497e094777144380ea4db9f9905dd5b6" dependencies = [ - "brotli 3.3.4", + "brotli", "flate2", "futures-core", "memchr", @@ -245,9 +245,9 @@ dependencies = [ "bytes", "futures-util", "http 1.1.0", - "http-body 1.0.0", + "http-body", "http-body-util", - "hyper 1.4.1", + "hyper", "hyper-util", "itoa", "matchit", @@ -264,7 +264,7 @@ dependencies = [ "sha1", "sync_wrapper 1.0.1", "tokio", - "tokio-tungstenite 0.24.0", + "tokio-tungstenite", "tower 0.5.1", "tower-layer", "tower-service", @@ -281,7 +281,7 @@ dependencies = [ "bytes", "futures-util", "http 1.1.0", - "http-body 1.0.0", + "http-body", "http-body-util", "mime", "pin-project-lite", @@ -304,7 +304,7 @@ dependencies = [ "cookie", "futures-util", "http 1.1.0", - "http-body 1.0.0", + "http-body", "http-body-util", "mime", "pin-project-lite", @@ -325,9 +325,9 @@ dependencies = [ "bytes", "futures-util", "http 1.1.0", - "http-body 1.0.0", + "http-body", "http-body-util", - "hyper 1.4.1", + "hyper", "hyper-util", "pin-project-lite", "rustls 0.23.14", @@ -448,18 +448,7 @@ checksum = "a1a0b1dbcc8ae29329621f8d4f0d835787c1c38bb1401979b49d13b0b305ff68" dependencies = [ "alloc-no-stdlib", "alloc-stdlib", - "brotli-decompressor 2.3.4", -] - -[[package]] -name = "brotli" -version = "7.0.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cc97b8f16f944bba54f0433f07e30be199b6dc2bd25937444bbad560bcea29bd" -dependencies = [ - "alloc-no-stdlib", - "alloc-stdlib", - "brotli-decompressor 4.0.0", + "brotli-decompressor", ] [[package]] @@ -472,16 +461,6 @@ dependencies = [ "alloc-stdlib", ] -[[package]] -name = "brotli-decompressor" -version = "4.0.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e6221fe77a248b9117d431ad93761222e1cf8ff282d9d1d5d9f53d6299a1cf76" -dependencies = [ - "alloc-no-stdlib", - "alloc-stdlib", -] - [[package]] name = "bstr" version = "1.6.0" @@ -515,9 +494,9 @@ checksum = "14c189c53d098945499cdfa7ecc63567cf3886b3332b312a5b4585d8d3a6a610" [[package]] name = "bytes" -version = "1.4.0" +version = "1.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "89b2fd2a0dcf38d7971e2194b6b6eebab45ae01067456a7fd93d5547a61b70be" +checksum = "9ac0150caa2ae65ca5bd83f25c7de183dea78d4d366469f148435e2acfbad0da" [[package]] name = "cargo-lock" @@ -1436,25 +1415,6 @@ dependencies = [ "spinning_top", ] -[[package]] -name = "h2" -version = "0.3.26" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "81fe527a889e1532da5c525686d96d4c2e74cdd345badf8dfef9f6b39dd5f5e8" -dependencies = [ - "bytes", - "fnv", - "futures-core", - "futures-sink", - "futures-util", - "http 0.2.9", - "indexmap 2.0.0", - "slab", - "tokio", - "tokio-util", - "tracing", -] - [[package]] name = "h2" version = "0.4.3" @@ -1628,17 +1588,6 @@ dependencies = [ "itoa", ] -[[package]] -name = "http-body" -version = "0.4.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d5f38f16d184e36f2408a55281cd658ecbd3ca05cce6d6510a176eca393e26d1" -dependencies = [ - "bytes", - "http 0.2.9", - "pin-project-lite", -] - [[package]] name = "http-body" version = "1.0.0" @@ -1658,7 +1607,7 @@ dependencies = [ "bytes", "futures-util", "http 1.1.0", - "http-body 1.0.0", + "http-body", "pin-project-lite", ] @@ -1690,30 +1639,6 @@ dependencies = [ "serde", ] -[[package]] -name = "hyper" -version = "0.14.27" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ffb1cfd654a8219eaef89881fdb3bb3b1cdc5fa75ded05d6933b2b382e395468" -dependencies = [ - "bytes", - "futures-channel", - "futures-core", - "futures-util", - "h2 0.3.26", - "http 0.2.9", - "http-body 0.4.5", - "httparse", - "httpdate", - "itoa", - "pin-project-lite", - "socket2 0.4.9", - "tokio", - "tower-service", - "tracing", - "want", -] - [[package]] name = "hyper" version = "1.4.1" @@ -1723,9 +1648,9 @@ dependencies = [ "bytes", "futures-channel", "futures-util", - "h2 0.4.3", + "h2", "http 1.1.0", - "http-body 1.0.0", + "http-body", "httparse", "httpdate", "itoa", @@ -1743,7 +1668,7 @@ checksum = "a0bea761b46ae2b24eb4aef630d8d1c398157b6fc29e6350ecf090a0b70c952c" dependencies = [ "futures-util", "http 1.1.0", - "hyper 1.4.1", + "hyper", "hyper-util", "rustls 0.22.2", "rustls-pki-types", @@ -1760,7 +1685,7 @@ checksum = "5ee4be2c948921a1a5320b629c4193916ed787a7f7f293fd3f7f5a6c9de74155" dependencies = [ "futures-util", "http 1.1.0", - "hyper 1.4.1", + "hyper", "hyper-util", "rustls 0.23.14", "rustls-native-certs 0.7.0", @@ -1772,20 +1697,19 @@ dependencies = [ [[package]] name = "hyper-util" -version = "0.1.7" +version = "0.1.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cde7055719c54e36e95e8719f95883f22072a48ede39db7fc17a4e1d5281e9b9" +checksum = "df2dcfbe0677734ab2f3ffa7fa7bfd4706bfdc1ef393f2ee30184aed67e631b4" dependencies = [ "bytes", "futures-channel", "futures-util", "http 1.1.0", - "http-body 1.0.0", - "hyper 1.4.1", + "http-body", + "hyper", "pin-project-lite", "socket2 0.5.6", "tokio", - "tower 0.4.13", "tower-service", "tracing", ] @@ -1891,9 +1815,9 @@ dependencies = [ "base64 0.21.2", "bytes", "http 1.1.0", - "http-body 1.0.0", + "http-body", "http-body-util", - "hyper 1.4.1", + "hyper", "hyper-rustls 0.27.2", "hyper-util", "ring", @@ -2102,15 +2026,19 @@ dependencies = [ [[package]] name = "mockito" -version = "1.1.0" +version = "1.6.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "09c762b6267c4593555bb38f1df19e9318985bc4de60b5e8462890856a9a5b4c" +checksum = "652cd6d169a36eaf9d1e6bce1a221130439a966d7f27858af66a33a66e9c4ee2" dependencies = [ "assert-json-diff", + "bytes", "colored", - "futures", - "hyper 0.14.27", - "lazy_static", + "futures-util", + "http 1.1.0", + "http-body", + "http-body-util", + "hyper", + "hyper-util", "log", "rand", "regex", @@ -2737,12 +2665,12 @@ dependencies = [ "bytes", "futures-core", "futures-util", - "h2 0.4.3", + "h2", "hickory-resolver", "http 1.1.0", - "http-body 1.0.0", + "http-body", "http-body-util", - "hyper 1.4.1", + "hyper", "hyper-rustls 0.26.0", "hyper-util", "ipnet", @@ -3623,8 +3551,8 @@ dependencies = [ "axum-server", "backoff", "base64 0.22.1", - "brotli 7.0.0", "built", + "bytes", "clap", "dashmap 6.0.1", "directories", @@ -3634,8 +3562,10 @@ dependencies = [ "globwalk", "hex", "hickory-resolver", + "http-body-util", "humantime-serde", - "hyper 0.14.27", + "hyper", + "hyper-util", "include_dir", "indexmap 2.0.0", "instant-acme", @@ -3669,12 +3599,13 @@ dependencies = [ "thiserror 2.0.0", "time", "tokio", - "tokio-rustls 0.25.0", + "tokio-rustls 0.26.0", "tokio-stream", - "tokio-tungstenite 0.21.0", + "tokio-tungstenite", "toml 0.8.8", "toml_edit 0.22.9", "totp-rs", + "tower-service", "tower_governor", "tracing", "tracing-appender", @@ -3904,22 +3835,6 @@ dependencies = [ "tokio-util", ] -[[package]] -name = "tokio-tungstenite" -version = "0.21.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c83b561d025642014097b66e6c1bb422783339e0909e4429cde4749d1990bc38" -dependencies = [ - "futures-util", - "log", - "rustls 0.22.2", - "rustls-native-certs 0.7.0", - "rustls-pki-types", - "tokio", - "tokio-rustls 0.25.0", - "tungstenite 0.21.0", -] - [[package]] name = "tokio-tungstenite" version = "0.24.0" @@ -3928,8 +3843,12 @@ checksum = "edc5f74e248dc973e0dbb7b74c7e0d6fcc301c694ff50049504004ef4d0cdcd9" dependencies = [ "futures-util", "log", + "rustls 0.23.14", + "rustls-native-certs 0.8.0", + "rustls-pki-types", "tokio", - "tungstenite 0.24.0", + "tokio-rustls 0.26.0", + "tungstenite", ] [[package]] @@ -4043,7 +3962,6 @@ dependencies = [ "futures-util", "pin-project", "pin-project-lite", - "tokio", "tower-layer", "tower-service", "tracing", @@ -4182,27 +4100,6 @@ version = "0.2.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3528ecfd12c466c6f163363caf2d02a71161dd5e1cc6ae7b34207ea2d42d81ed" -[[package]] -name = "tungstenite" -version = "0.21.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9ef1a641ea34f399a848dea702823bbecfb4c486f911735368f1f137cb8257e1" -dependencies = [ - "byteorder", - "bytes", - "data-encoding", - "http 1.1.0", - "httparse", - "log", - "rand", - "rustls 0.22.2", - "rustls-pki-types", - "sha1", - "thiserror 1.0.68", - "url", - "utf-8", -] - [[package]] name = "tungstenite" version = "0.24.0" @@ -4216,6 +4113,8 @@ dependencies = [ "httparse", "log", "rand", + "rustls 0.23.14", + "rustls-pki-types", "sha1", "thiserror 1.0.68", "utf-8", diff --git a/README.md b/README.md index 65f5f34d..e1f2b7e4 100644 --- a/README.md +++ b/README.md @@ -24,7 +24,6 @@ Taxy is currently in early development. Please be aware that breaking changes ma - Allows live configuration updates via a REST API without restarting the service - Imports TLS certificates from the GUI or can generate a self-signed certificate - Provides Let's Encrypt support (ACME v2, HTTP challenge only) for seamless certificate provisioning -- Supports automatic HTTP Brotli compression ## Screenshot diff --git a/docs/content/_index.md b/docs/content/_index.md index e08710c1..117aa92e 100644 --- a/docs/content/_index.md +++ b/docs/content/_index.md @@ -16,7 +16,6 @@ sort_by = "weight" - Allows live configuration updates via a REST API without restarting the service - Imports TLS certificates from the GUI or can generate a self-signed certificate - Provides Let's Encrypt support (ACME v2, HTTP challenge only) for seamless certificate provisioning -- Supports automatic HTTP Brotli compression # Installation diff --git a/taxy/Cargo.toml b/taxy/Cargo.toml index b59ef803..cfa75782 100644 --- a/taxy/Cargo.toml +++ b/taxy/Cargo.toml @@ -28,7 +28,7 @@ axum-extra = { version = "0.9.4", features = ["cookie"] } axum-server = { version = "0.7.1", features = ["tls-rustls-no-provider"] } backoff = { version = "0.4.0", features = ["tokio"] } base64 = "0.22.1" -brotli = "7.0.0" +bytes = "1.8.0" clap = { version = "4.3.11", features = ["derive", "env"] } dashmap = "6.0.1" directories = "5.0.1" @@ -41,8 +41,15 @@ hickory-resolver = { version = "0.24.1", features = [ "tokio-runtime", "system-config", ] } +http-body-util = "0.1.2" humantime-serde = "1.1.1" -hyper = { version = "0.14.27", features = ["full"] } +hyper = { version = "1.4.1", features = ["full"] } +hyper-util = { version = "0.1.10", features = [ + "full", + "http1", + "http2", + "server", +] } include_dir = "0.7.3" indexmap = { version = "2.0.0", features = ["serde"] } instant-acme = "0.7.1" @@ -83,7 +90,7 @@ tokio = { version = "1.29.1", features = [ "signal", "io-util", ] } -tokio-rustls = { version = "0.25.0", default-features = false, features = [ +tokio-rustls = { version = "0.26.0", default-features = false, features = [ "tls12", "ring", ] } @@ -91,6 +98,7 @@ tokio-stream = { version = "0.1.14", features = ["sync", "net"] } toml = "0.8.8" toml_edit = { version = "0.22.9", features = ["serde"] } totp-rs = { version = "5.1.0", features = ["gen_secret", "zeroize"] } +tower-service = "0.3.3" tower_governor = "0.4.3" tracing = { version = "0.1.37", features = ["release_max_level_info"] } tracing-appender = "0.2.2" @@ -104,7 +112,7 @@ x509-parser = "0.16.0" built = "0.6.1" [dev-dependencies] -mockito = "1.1.0" +mockito = "1.6.1" net2 = "0.2.39" reqwest = { version = "0.12.1", default-features = false, features = [ "rustls-tls", @@ -115,6 +123,6 @@ reqwest = { version = "0.12.1", default-features = false, features = [ "http2", "hickory-dns", ] } -tokio-tungstenite = { version = "0.21.0", features = [ +tokio-tungstenite = { version = "0.24.0", features = [ "rustls-tls-native-roots", ] } diff --git a/taxy/src/proxy/http/compression.rs b/taxy/src/proxy/http/compression.rs deleted file mode 100644 index acb8498b..00000000 --- a/taxy/src/proxy/http/compression.rs +++ /dev/null @@ -1,112 +0,0 @@ -use brotli::CompressorWriter; -use futures::{Stream, StreamExt}; -use hyper::{ - body::{Bytes, HttpBody}, - Body, -}; -use phf::phf_map; -use std::{ - io::Write, - pin::Pin, - task::{self, Poll}, -}; - -pub struct CompressionStream { - body: Body, - writer: Option>>, -} - -impl CompressionStream { - pub fn new(body: Body, buffer_size: usize) -> Self { - let writer = CompressorWriter::new(Vec::new(), buffer_size, 8, 22); - Self { - body, - writer: Some(writer), - } - } -} - -impl Stream for CompressionStream { - type Item = Result; - - fn poll_next(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll> { - let poll = self.body.poll_next_unpin(cx); - if let Some(writer) = &mut self.writer { - if let Poll::Ready(Some(Ok(chunk))) = &poll { - let _ = writer.write_all(chunk); - let _ = writer.flush(); - } - } - match poll { - Poll::Ready(Some(Ok(_))) | Poll::Ready(None) => { - if let Some(mut writer) = self.writer.take() { - if self.body.is_end_stream() { - Poll::Ready(Some(Ok(Bytes::from(writer.into_inner())))) - } else { - let buffer = std::mem::take(writer.get_mut()); - self.writer = Some(writer); - Poll::Ready(Some(Ok(Bytes::from(buffer)))) - } - } else { - Poll::Ready(None) - } - } - _ => poll, - } - } -} - -pub fn is_compressed(content_type: &[u8]) -> bool { - let content_type = if let Ok(content_type) = std::str::from_utf8(content_type) { - content_type.to_ascii_lowercase() - } else { - return false; - }; - if let Some(&known) = KNOWN_TYPES.get(&content_type) { - return known; - } - !(content_type.starts_with("text/") || content_type.starts_with("application/")) -} - -static KNOWN_TYPES: phf::Map<&'static str, bool> = phf_map! { - "image/svg+xml" => false, - "image/bmp" => false, - "image/x-ms-bmp" => false, - "audio/wav" => false, - "audio/x-wav" => false, - "audio/midi" => false, - "audio/x-midi" => false, - "application/x-bzip" => true, - "application/x-bzip2" => true, - "application/gzip" => true, - "application/vnd.rar" => true, - "application/x-tar" => true, - "application/zip" => true, - "application/x-7z-compressed" => true, - "application/epub+zip" => true, - "font/otf" => false, - "font/ttf" => false, -}; - -#[cfg(test)] -mod test { - use super::*; - #[test] - fn test_is_compressed() { - assert!(!is_compressed(b"text/html")); - assert!(!is_compressed(b"application/json")); - assert!(!is_compressed(b"image/svg+xml")); - assert!(!is_compressed(b"image/bmp")); - assert!(is_compressed(b"image/png")); - assert!(!is_compressed(b"image/x-ms-bmp")); - assert!(is_compressed(b"audio/mp3")); - assert!(!is_compressed(b"audio/wav")); - assert!(!is_compressed(b"audio/x-wav")); - assert!(!is_compressed(b"audio/midi")); - assert!(!is_compressed(b"audio/x-midi")); - assert!(is_compressed(b"video/webm")); - assert!(is_compressed(b"application/x-bzip")); - assert!(is_compressed(b"application/x-bzip2")); - assert!(is_compressed(b"application/gzip")); - } -} diff --git a/taxy/src/proxy/http/error.rs b/taxy/src/proxy/http/error.rs index 97d1f43f..2268b4e5 100644 --- a/taxy/src/proxy/http/error.rs +++ b/taxy/src/proxy/http/error.rs @@ -1,4 +1,6 @@ -use hyper::{Body, Response, StatusCode}; +use bytes::Bytes; +use http_body_util::{combinators::BoxBody, BodyExt, Full}; +use hyper::{body::Body, Response, StatusCode}; use sailfish::TemplateOnce; use thiserror::Error; use tokio_rustls::rustls; @@ -21,17 +23,22 @@ impl ProxyError { } } -pub fn map_response( - res: Result, anyhow::Error>, -) -> Result, anyhow::Error> { +pub fn map_response( + res: Result, anyhow::Error>, +) -> Result>, anyhow::Error> +where + B: Body + Send + Sync + 'static, +{ match res { - Ok(res) => Ok(res), + Ok(res) => Ok(res.map(|body| BoxBody::new(body))), Err(err) => { let code = map_error(err); let ctx = ErrorTemplate { code: code.as_u16(), }; - let mut res = Response::new(Body::from(ctx.render_once().unwrap())); + let mut res = Response::new(BoxBody::new( + Full::new(Bytes::from(ctx.render_once().unwrap())).map_err(Into::into), + )); *res.status_mut() = code; Ok(res) } @@ -42,30 +49,17 @@ fn map_error(err: anyhow::Error) -> StatusCode { if let Some(err) = err.downcast_ref::() { return err.code(); } - if let Some(err) = err.downcast_ref::() { - if err.kind() == std::io::ErrorKind::TimedOut { - return StatusCode::GATEWAY_TIMEOUT; + if let Some(err) = err.downcast_ref::() { + if matches!(err, rustls::Error::InvalidCertificate(_)) { + return StatusCode::from_u16(526).unwrap(); + } else { + return StatusCode::from_u16(525).unwrap(); } } if let Ok(err) = err.downcast::() { - let is_connect = err.is_connect(); - if let Some(inner) = err.into_cause() { - if let Ok(err) = inner.downcast::() { - if err.kind() == std::io::ErrorKind::TimedOut { - return StatusCode::GATEWAY_TIMEOUT; - } - if let Some(inner) = err.into_inner() { - if let Ok(err) = inner.downcast::() { - if matches!(*err, rustls::Error::InvalidCertificate(_)) { - return StatusCode::from_u16(526).unwrap(); - } else { - return StatusCode::from_u16(525).unwrap(); - } - } - } - } - } - if is_connect { + if err.is_timeout() { + return StatusCode::GATEWAY_TIMEOUT; + } else { return StatusCode::from_u16(523).unwrap(); } } diff --git a/taxy/src/proxy/http/hyper_tls/client.rs b/taxy/src/proxy/http/hyper_tls/client.rs index ade8d153..59628a41 100644 --- a/taxy/src/proxy/http/hyper_tls/client.rs +++ b/taxy/src/proxy/http/hyper_tls/client.rs @@ -4,11 +4,14 @@ use std::pin::Pin; use std::sync::Arc; use std::task::{Context, Poll}; -use hyper::{client::connect::HttpConnector, service::Service, Uri}; -use tokio::io::{AsyncRead, AsyncWrite}; +use hyper::rt::{Read, Write}; +use hyper::Uri; +use hyper_util::client::legacy::connect::HttpConnector; +use hyper_util::rt::TokioIo; use tokio_rustls::rustls::pki_types::ServerName; use tokio_rustls::rustls::ClientConfig; use tokio_rustls::TlsConnector; +use tower_service::Service; use super::stream::MaybeHttpsStream; @@ -75,7 +78,7 @@ impl fmt::Debug for HttpsConnector { impl Service for HttpsConnector where T: Service, - T::Response: AsyncRead + AsyncWrite + Send + Unpin, + T::Response: Read + Write + Send + Unpin, T::Future: Send + 'static, T::Error: Into, { @@ -104,13 +107,23 @@ where .trim_matches(|c| c == '[' || c == ']') .to_owned(); let connecting = self.http.call(dst); - let tls = self.tls.clone(); + + let tls_connector = self.tls.clone(); + let fut = async move { let tcp = connecting.await.map_err(Into::into)?; + let maybe = if is_https { - let tls = tls - .connect(ServerName::try_from(host.as_str()).unwrap().to_owned(), tcp) - .await?; + let stream = TokioIo::new(tcp); + + let tls = TokioIo::new( + tls_connector + .connect( + ServerName::try_from(host.as_str()).unwrap().to_owned(), + stream, + ) + .await?, + ); MaybeHttpsStream::Https(tls) } else { MaybeHttpsStream::Http(tcp) @@ -130,7 +143,7 @@ type BoxedFut = Pin, BoxEr /// A Future representing work to connect to a URL, and a TLS handshake. pub struct HttpsConnecting(BoxedFut); -impl Future for HttpsConnecting { +impl Future for HttpsConnecting { type Output = Result, BoxError>; fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { diff --git a/taxy/src/proxy/http/hyper_tls/stream.rs b/taxy/src/proxy/http/hyper_tls/stream.rs index 457ca5f0..784768db 100644 --- a/taxy/src/proxy/http/hyper_tls/stream.rs +++ b/taxy/src/proxy/http/hyper_tls/stream.rs @@ -4,8 +4,11 @@ use std::io::IoSlice; use std::pin::Pin; use std::task::{Context, Poll}; -use hyper::client::connect::{Connected, Connection}; -use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; +use hyper::rt::{Read, ReadBufCursor, Write}; +use hyper_util::{ + client::legacy::connect::{Connected, Connection}, + rt::TokioIo, +}; pub use tokio_rustls::client::TlsStream; /// A stream that might be protected with TLS. @@ -14,7 +17,7 @@ pub enum MaybeHttpsStream { /// A stream over plain text. Http(T), /// A stream protected with TLS. - Https(TlsStream), + Https(TokioIo>>), } // ===== impl MaybeHttpsStream ===== @@ -34,18 +37,24 @@ impl From for MaybeHttpsStream { } } -impl From> for MaybeHttpsStream { - fn from(inner: TlsStream) -> Self { +impl From>> for MaybeHttpsStream { + fn from(inner: TlsStream>) -> Self { + MaybeHttpsStream::Https(TokioIo::new(inner)) + } +} + +impl From>>> for MaybeHttpsStream { + fn from(inner: TokioIo>>) -> Self { MaybeHttpsStream::Https(inner) } } -impl AsyncRead for MaybeHttpsStream { +impl Read for MaybeHttpsStream { #[inline] fn poll_read( self: Pin<&mut Self>, cx: &mut Context, - buf: &mut ReadBuf, + buf: ReadBufCursor<'_>, ) -> Poll> { match Pin::get_mut(self) { MaybeHttpsStream::Http(s) => Pin::new(s).poll_read(cx, buf), @@ -54,7 +63,7 @@ impl AsyncRead for MaybeHttpsStream { } } -impl AsyncWrite for MaybeHttpsStream { +impl Write for MaybeHttpsStream { #[inline] fn poll_write( self: Pin<&mut Self>, @@ -102,11 +111,11 @@ impl AsyncWrite for MaybeHttpsStream { } } -impl Connection for MaybeHttpsStream { +impl Connection for MaybeHttpsStream { fn connected(&self) -> Connected { match self { MaybeHttpsStream::Http(s) => s.connected(), - MaybeHttpsStream::Https(s) => s.get_ref().0.connected(), + MaybeHttpsStream::Https(s) => s.inner().get_ref().0.connected(), } } } diff --git a/taxy/src/proxy/http/mod.rs b/taxy/src/proxy/http/mod.rs index b2ad77ac..909e6471 100644 --- a/taxy/src/proxy/http/mod.rs +++ b/taxy/src/proxy/http/mod.rs @@ -7,15 +7,20 @@ use super::{tls::TlsTermination, PortContextEvent}; use crate::server::cert_list::CertList; use arc_swap::{ArcSwap, Cache}; use header::HeaderRewriter; +use http_body_util::{combinators::BoxBody, BodyExt}; use hyper::{ + body::Incoming, header::{HOST, LOCATION}, http::{ uri::{Parts, Scheme}, HeaderValue, }, - server::conn::Http, service::service_fn, - Body, Request, Response, StatusCode, Uri, + Request, Response, StatusCode, Uri, +}; +use hyper_util::{ + rt::{TokioExecutor, TokioIo}, + server::conn::auto, }; use std::{net::SocketAddr, sync::Arc, time::SystemTime}; use taxy_api::error::Error; @@ -35,7 +40,6 @@ use tokio_rustls::{ }; use tracing::{debug, error, info, span, Instrument, Level, Span}; -mod compression; mod error; mod filter; mod header; @@ -251,7 +255,8 @@ async fn start( if tls_acceptor.is_some() && local.port() != 80 && first_byte != 0x16 { tokio::task::spawn( async move { - if let Err(err) = Http::new() + let server_stream = TokioIo::new(server_stream); + if let Err(err) = auto::Builder::new(TokioExecutor::new()) .serve_connection(server_stream, service_fn(redirect)) .await { @@ -283,9 +288,9 @@ async fn start( } let pool = Arc::new(ConnectionPool::new(tls_client_config)); - let mut shared_cache = shared_cache.clone(); let span_cloned = span.clone(); - let service = hyper::service::service_fn(move |mut req| { + let service = hyper::service::service_fn(move |mut req: Request| { + let mut shared_cache = shared_cache.clone(); let span = span_cloned.clone(); let enter = span.clone(); let _enter = enter.enter(); @@ -327,7 +332,7 @@ async fn start( redirect = Response::builder() .status(301) .header(LOCATION, uri.to_string()) - .body(Body::empty()) + .body(String::new()) .ok(); } } @@ -383,8 +388,14 @@ async fn start( async move { map_response(match req { - ProxiedRequest::Ok(req, span) => pool.request(req).instrument(span).await, - ProxiedRequest::Redirect(resp) => Ok(resp), + ProxiedRequest::Ok(req, span) => { + pool.request(req.map(|b| BoxBody::new(b.map_err(Into::into)))) + .instrument(span) + .await + } + ProxiedRequest::Redirect(resp) => { + Ok(resp.map(|b| BoxBody::new(b.map_err(Into::into)))) + } ProxiedRequest::Err(err) => Err(err.into()), }) } @@ -393,10 +404,14 @@ async fn start( tokio::task::spawn( async move { - let http = Http::new() - .http2_only(server_http2) - .serve_connection(stream, service) - .with_upgrades(); + let stream = TokioIo::new(stream); + let builder = auto::Builder::new(TokioExecutor::new()); + let builder = if server_http2 { + builder.http2_only() + } else { + builder + }; + let http = builder.serve_connection_with_upgrades(stream, service); if let Err(err) = http.await { error!("Failed to serve the connection: {:?}", err); } @@ -408,8 +423,8 @@ async fn start( } enum ProxiedRequest { - Ok(Request, Span), - Redirect(Response), + Ok(Request, Span), + Redirect(Response), Err(ProxyError), } @@ -430,22 +445,20 @@ pub struct Connection { pub tls: bool, } -async fn redirect( - req: hyper::Request, -) -> Result, hyper::http::Error> { +async fn redirect(req: hyper::Request) -> Result, hyper::http::Error> { if let Ok(uri) = get_secure_uri(&req) { Response::builder() .header("Location", uri.to_string()) .status(StatusCode::PERMANENT_REDIRECT) - .body(hyper::Body::empty()) + .body(String::new()) } else { Response::builder() .status(StatusCode::BAD_REQUEST) - .body(hyper::Body::from("TLS required\r\n")) + .body(String::from("TLS required\r\n")) } } -fn get_secure_uri(req: &hyper::Request) -> anyhow::Result { +fn get_secure_uri(req: &hyper::Request) -> anyhow::Result { let mut parts = req.uri().clone().into_parts(); if let Some(host) = req.headers().get(HOST) { parts.authority = Some(host.to_str()?.parse()?); diff --git a/taxy/src/proxy/http/pool.rs b/taxy/src/proxy/http/pool.rs index 966b1776..5a9a1f07 100644 --- a/taxy/src/proxy/http/pool.rs +++ b/taxy/src/proxy/http/pool.rs @@ -1,31 +1,40 @@ -use super::{ - compression::{is_compressed, CompressionStream}, - error::map_response, -}; +use super::error::map_response; use crate::proxy::http::{hyper_tls::client::HttpsConnector, HTTP2_MAX_FRAME_SIZE}; -use hyper::{ - client::HttpConnector, header::UPGRADE, http::HeaderValue, Body, Client, Request, Response, +use bytes::Bytes; +use http_body_util::{combinators::BoxBody, BodyExt, Full}; +use hyper::{header::UPGRADE, Request, Response}; +use hyper_util::{ + client::legacy::{connect::HttpConnector, Client}, + rt::{TokioExecutor, TokioIo}, }; use std::sync::Arc; use tokio_rustls::rustls::ClientConfig; use tracing::error; pub struct ConnectionPool { - client: Client>, + client: Client, BoxBody>, } impl ConnectionPool { pub fn new(tls_client_config: Arc) -> Self { let https = HttpsConnector::new(tls_client_config.clone()); - let client = Client::builder() + let client = Client::builder(TokioExecutor::new()) .http2_max_frame_size(Some(HTTP2_MAX_FRAME_SIZE as u32)) - .build::<_, hyper::Body>(https); + .build(https); Self { client } } - pub async fn request(&self, mut req: Request) -> Result, anyhow::Error> { + pub async fn request( + &self, + mut req: Request>, + ) -> Result>, anyhow::Error> { let upgrading_req = if req.headers().contains_key(UPGRADE) { - let mut cloned_req = Request::builder().uri(req.uri()).body(Body::empty())?; + let mut cloned_req = Request::builder().uri(req.uri()).body(BoxBody::< + Bytes, + anyhow::Error, + >::new( + Full::new(Bytes::new()).map_err(Into::into), + ))?; cloned_req.headers_mut().clone_from(req.headers()); let mut cloned_req = Some(cloned_req); req = cloned_req.replace(req).unwrap(); @@ -34,16 +43,14 @@ impl ConnectionPool { None }; - let accept_brotli = req - .headers() - .get(hyper::header::ACCEPT_ENCODING) - .map(|value| value.to_str().unwrap_or_default().contains("br")) - .unwrap_or_default(); - *req.version_mut() = hyper::Version::HTTP_11; - let mut result: Result<_, anyhow::Error> = - self.client.request(req).await.map_err(|err| err.into()); + let mut result: Result<_, anyhow::Error> = self + .client + .request(req) + .await + .map_err(|err| err.into()) + .map(|res| res.map(|body| BoxBody::new(body.map_err(|err| err.into())))); match (&result, upgrading_req) { (Ok(res), Some(upgrading_req)) @@ -51,8 +58,13 @@ impl ConnectionPool { { let mut cloned_res = Response::builder().status(res.status()); cloned_res.headers_mut().unwrap().clone_from(res.headers()); - let upgrading_res = - std::mem::replace(&mut result, Ok(cloned_res.body(Body::empty())?)).unwrap(); + + let upgrading_res = std::mem::replace( + &mut result, + Ok(cloned_res + .body(BoxBody::new(Full::new(Bytes::new()).map_err(Into::into)))?), + ) + .unwrap(); tokio::spawn(async move { upgrade_connection(upgrading_req, upgrading_res).await; }); @@ -60,37 +72,6 @@ impl ConnectionPool { _ => (), } - let http2 = result - .as_ref() - .map(|res| res.version() == hyper::Version::HTTP_2) - .unwrap_or_default(); - - let accept_brotli = accept_brotli & http2; - - let result = result.map(|res| { - let (mut parts, body) = res.into_parts(); - - let is_compressed = parts - .headers - .get(hyper::header::CONTENT_TYPE) - .map(|value| is_compressed(value.as_bytes())) - .unwrap_or_default(); - - if !is_compressed { - let encoding = parts.headers.entry(hyper::header::CONTENT_ENCODING); - if let hyper::header::Entry::Vacant(entry) = encoding { - if accept_brotli { - entry.insert(HeaderValue::from_static("br")); - parts.headers.remove(hyper::header::CONTENT_LENGTH); - let stream = CompressionStream::new(body, HTTP2_MAX_FRAME_SIZE); - return Response::from_parts(parts, hyper::Body::wrap_stream(stream)); - } - } - } - - Response::from_parts(parts, body) - }); - if let Err(err) = &result { error!(%err); } @@ -99,9 +80,14 @@ impl ConnectionPool { } } -async fn upgrade_connection(req: Request, res: Response) { +async fn upgrade_connection( + req: Request>, + res: Response>, +) { match tokio::try_join!(hyper::upgrade::on(req), hyper::upgrade::on(res)) { - Ok((mut req, mut res)) => { + Ok((req, res)) => { + let mut req = TokioIo::new(req); + let mut res = TokioIo::new(res); if let Err(err) = tokio::io::copy_bidirectional(&mut req, &mut res).await { error!("upgraded io error: {}", err); } diff --git a/taxy/src/server/state.rs b/taxy/src/server/state.rs index 303f852a..842796fa 100644 --- a/taxy/src/server/state.rs +++ b/taxy/src/server/state.rs @@ -10,9 +10,10 @@ use crate::{ command::ServerCommand, proxy::{PortContext, PortContextKind}, }; -use hyper::server::conn::Http; +use hyper::service::service_fn; use hyper::Response; -use hyper::{service::service_fn, Body}; +use hyper_util::rt::{TokioExecutor, TokioIo}; +use hyper_util::server::conn::auto; use rand::seq::SliceRandom; use std::collections::HashSet; use std::convert::Infallible; @@ -154,12 +155,13 @@ impl ServerState { if !self.http_challenges.is_empty() { if let Some(body) = self.handle_http_challenge(&mut stream).await { tokio::task::spawn(async move { - if let Err(err) = Http::new() + let stream = TokioIo::new(BufStream::new(stream)); + if let Err(err) = auto::Builder::new(TokioExecutor::new()) .serve_connection( stream, service_fn(|_| { let body = body.clone(); - async move { Ok::<_, Infallible>(Response::new(Body::from(body))) } + async move { Ok::<_, Infallible>(Response::new(body)) } }), ) .await diff --git a/taxy/tests/http_test.rs b/taxy/tests/http_test.rs index 1e887ead..b17886ff 100644 --- a/taxy/tests/http_test.rs +++ b/taxy/tests/http_test.rs @@ -287,7 +287,7 @@ async fn http_proxy_dns_error() -> anyhow::Result<()> { with_server(config, |_| async move { let client = reqwest::Client::builder().build()?; let resp = client.get(proxy_port.http_url("/hello")).send().await?; - assert_eq!(resp.status(), 523); + assert_eq!(resp.status(), 502); Ok(()) }) .await?; diff --git a/taxy/tests/https_test.rs b/taxy/tests/https_test.rs index 91eff009..db74d32d 100644 --- a/taxy/tests/https_test.rs +++ b/taxy/tests/https_test.rs @@ -178,21 +178,21 @@ async fn https_proxy_invalid_cert() -> anyhow::Result<()> { .add_root_certificate(ca.clone()) .build()?; let resp = client.get(proxy_port.https_url("/hello")).send().await?; - assert_eq!(resp.status(), 526); + assert_eq!(resp.status(), 502); let client = reqwest::Client::builder() .http1_only() .add_root_certificate(ca.clone()) .build()?; let resp = client.get(proxy_port.https_url("/hello")).send().await?; - assert_eq!(resp.status(), 526); + assert_eq!(resp.status(), 502); let client = reqwest::Client::builder() .http2_prior_knowledge() .add_root_certificate(ca) .build()?; let resp = client.get(proxy_port.https_url("/hello")).send().await?; - assert_eq!(resp.status(), 526); + assert_eq!(resp.status(), 502); Ok(()) }) diff --git a/taxy/tests/ws_test.rs b/taxy/tests/ws_test.rs index 71055ecf..9473ab61 100644 --- a/taxy/tests/ws_test.rs +++ b/taxy/tests/ws_test.rs @@ -5,12 +5,12 @@ use axum::{ Router, }; use futures::{SinkExt, StreamExt}; +use hyper::Uri; use taxy_api::{ port::{Port, PortEntry}, proxy::{HttpProxy, Proxy, ProxyEntry, ProxyKind, Route}, }; use tokio_tungstenite::{connect_async, tungstenite::Message}; -use url::Url; mod common; use common::{alloc_tcp_port, with_server, TestStorage}; @@ -54,7 +54,7 @@ async fn ws_proxy() -> anyhow::Result<()> { .build(); with_server(config, |_| async move { - let url = Url::parse(&format!( + let url = Uri::try_from(&format!( "ws://localhost:{}/ws", proxy_port.socket_addr().port() ))?; @@ -74,7 +74,7 @@ async fn ws_proxy() -> anyhow::Result<()> { } async fn ws_handler(ws: WebSocketUpgrade) -> impl IntoResponse { - ws.on_upgrade(move |socket| handle_socket(socket)) + ws.on_upgrade(handle_socket) } async fn handle_socket(mut socket: WebSocket) { diff --git a/taxy/tests/wss_test.rs b/taxy/tests/wss_test.rs index 4be10000..5cd54fc2 100644 --- a/taxy/tests/wss_test.rs +++ b/taxy/tests/wss_test.rs @@ -7,6 +7,7 @@ use axum::{ use axum_server::tls_rustls::RustlsConfig; use core::panic; use futures::{SinkExt, StreamExt}; +use hyper::Uri; use std::sync::Arc; use taxy::certs::Cert; use taxy_api::{ @@ -16,7 +17,6 @@ use taxy_api::{ }; use tokio_rustls::rustls::{client::ClientConfig, RootCertStore}; use tokio_tungstenite::{connect_async_tls_with_config, tungstenite::Message, Connector}; -use url::Url; mod common; use common::{alloc_tcp_port, with_server, TestStorage}; @@ -87,7 +87,7 @@ async fn wss_proxy() -> anyhow::Result<()> { .with_no_client_auth(); with_server(config, |_| async move { - let url = Url::parse(&format!( + let url = Uri::try_from(&format!( "wss://localhost:{}/ws", proxy_port.socket_addr().port() ))?; @@ -113,7 +113,7 @@ async fn wss_proxy() -> anyhow::Result<()> { } async fn ws_handler(ws: WebSocketUpgrade) -> impl IntoResponse { - ws.on_upgrade(move |socket| handle_socket(socket)) + ws.on_upgrade(handle_socket) } async fn handle_socket(mut socket: WebSocket) {