From 9beb47da340a6c153a10a0725e5c4c99137e72fe Mon Sep 17 00:00:00 2001 From: Paolo Barbolini Date: Sat, 3 Jun 2023 12:57:57 +0200 Subject: [PATCH] Add zstd support --- .github/workflows/ci.yml | 15 ++-- Cargo.toml | 8 +++ src/async_impl/client.rs | 40 +++++++++++ src/async_impl/decoder.rs | 142 ++++++++++++++++++++++++++++++++----- src/blocking/client.rs | 31 ++++++++ src/lib.rs | 1 + tests/client.rs | 6 ++ tests/zstd.rs | 145 ++++++++++++++++++++++++++++++++++++++ 8 files changed, 364 insertions(+), 24 deletions(-) create mode 100644 tests/zstd.rs diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 957186a0d1..73f4f5b78e 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -77,6 +77,7 @@ jobs: - "feat.: blocking" - "feat.: gzip" - "feat.: brotli" + - "feat.: zstd" - "feat.: deflate" - "feat.: json" - "feat.: multipart" @@ -101,25 +102,21 @@ jobs: - name: windows / stable-x86_64-msvc os: windows-latest target: x86_64-pc-windows-msvc - features: "--features blocking,gzip,brotli,deflate,json,multipart" + features: "--features blocking,gzip,brotli,zstd,deflate,json,multipart" - name: windows / stable-i686-msvc os: windows-latest target: i686-pc-windows-msvc - features: "--features blocking,gzip,brotli,deflate,json,multipart" + features: "--features blocking,gzip,brotli,zstd,deflate,json,multipart" - name: windows / stable-x86_64-gnu os: windows-latest rust: stable-x86_64-pc-windows-gnu target: x86_64-pc-windows-gnu - features: "--features blocking,gzip,brotli,deflate,json,multipart" - package_name: mingw-w64-x86_64-gcc - mingw64_path: "C:\\msys64\\mingw64\\bin" + features: "--features blocking,gzip,brotli,zstd,deflate,json,multipart" - name: windows / stable-i686-gnu os: windows-latest rust: stable-i686-pc-windows-gnu target: i686-pc-windows-gnu - features: "--features blocking,gzip,brotli,deflate,json,multipart" - package_name: mingw-w64-i686-gcc - mingw64_path: "C:\\msys64\\mingw32\\bin" + features: "--features blocking,gzip,brotli,zstd,deflate,json,multipart" - name: "feat.: default-tls disabled" features: "--no-default-features" @@ -139,6 +136,8 @@ jobs: features: "--features gzip" - name: "feat.: brotli" features: "--features brotli" + - name: "feat.: zstd" + features: "--features zstd" - name: "feat.: deflate" features: "--features deflate" - name: "feat.: json" diff --git a/Cargo.toml b/Cargo.toml index 6129126297..40673db50f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -51,6 +51,8 @@ gzip = ["async-compression", "async-compression/gzip", "tokio-util"] brotli = ["async-compression", "async-compression/brotli", "tokio-util"] +zstd = ["async-compression", "async-compression/zstd", "tokio-util"] + deflate = ["async-compression", "async-compression/zlib", "tokio-util"] json = ["serde_json"] @@ -152,6 +154,7 @@ hyper = { version = "0.14", default-features = false, features = ["tcp", "stream serde = { version = "1.0", features = ["derive"] } libflate = "1.0" brotli_crate = { package = "brotli", version = "3.3.0" } +zstd_crate = { package = "zstd", version = "0.13" } doc-comment = "0.3" tokio = { version = "1.0", default-features = false, features = ["macros", "rt-multi-thread"] } @@ -242,6 +245,11 @@ name = "brotli" path = "tests/brotli.rs" required-features = ["brotli"] +[[test]] +name = "zstd" +path = "tests/zstd.rs" +required-features = ["zstd"] + [[test]] name = "deflate" path = "tests/deflate.rs" diff --git a/src/async_impl/client.rs b/src/async_impl/client.rs index b6b515b89e..16a8a7734c 100644 --- a/src/async_impl/client.rs +++ b/src/async_impl/client.rs @@ -868,6 +868,29 @@ impl ClientBuilder { self } + /// Enable auto zstd decompression by checking the `Content-Encoding` response header. + /// + /// If auto zstd decompression is turned on: + /// + /// - When sending a request and if the request's headers do not already contain + /// an `Accept-Encoding` **and** `Range` values, the `Accept-Encoding` header is set to `zstd`. + /// The request body is **not** automatically compressed. + /// - When receiving a response, if its headers contain a `Content-Encoding` value of + /// `zstd`, both `Content-Encoding` and `Content-Length` are removed from the + /// headers' set. The response body is automatically decompressed. + /// + /// If the `zstd` feature is turned on, the default option is enabled. + /// + /// # Optional + /// + /// This requires the optional `zstd` feature to be enabled + #[cfg(feature = "zstd")] + #[cfg_attr(docsrs, doc(cfg(feature = "zstd")))] + pub fn zstd(mut self, enable: bool) -> ClientBuilder { + self.config.accepts.zstd = enable; + self + } + /// Enable auto deflate decompression by checking the `Content-Encoding` response header. /// /// If auto deflate decompression is turned on: @@ -925,6 +948,23 @@ impl ClientBuilder { } } + /// Disable auto response body zstd decompression. + /// + /// This method exists even if the optional `zstd` feature is not enabled. + /// This can be used to ensure a `Client` doesn't use zstd decompression + /// even if another dependency were to enable the optional `zstd` feature. + pub fn no_zstd(self) -> ClientBuilder { + #[cfg(feature = "zstd")] + { + self.zstd(false) + } + + #[cfg(not(feature = "zstd"))] + { + self + } + } + /// Disable auto response body deflate decompression. /// /// This method exists even if the optional `deflate` feature is not enabled. diff --git a/src/async_impl/decoder.rs b/src/async_impl/decoder.rs index c0542cfb13..44181a7d61 100644 --- a/src/async_impl/decoder.rs +++ b/src/async_impl/decoder.rs @@ -9,6 +9,9 @@ use async_compression::tokio::bufread::GzipDecoder; #[cfg(feature = "brotli")] use async_compression::tokio::bufread::BrotliDecoder; +#[cfg(feature = "zstd")] +use async_compression::tokio::bufread::ZstdDecoder; + #[cfg(feature = "deflate")] use async_compression::tokio::bufread::ZlibDecoder; @@ -18,9 +21,19 @@ use futures_util::stream::Peekable; use http::HeaderMap; use hyper::body::HttpBody; -#[cfg(any(feature = "gzip", feature = "brotli", feature = "deflate"))] +#[cfg(any( + feature = "gzip", + feature = "brotli", + feature = "zstd", + feature = "deflate" +))] use tokio_util::codec::{BytesCodec, FramedRead}; -#[cfg(any(feature = "gzip", feature = "brotli", feature = "deflate"))] +#[cfg(any( + feature = "gzip", + feature = "brotli", + feature = "zstd", + feature = "deflate" +))] use tokio_util::io::StreamReader; use super::super::Body; @@ -32,6 +45,8 @@ pub(super) struct Accepts { pub(super) gzip: bool, #[cfg(feature = "brotli")] pub(super) brotli: bool, + #[cfg(feature = "zstd")] + pub(super) zstd: bool, #[cfg(feature = "deflate")] pub(super) deflate: bool, } @@ -45,7 +60,12 @@ pub(crate) struct Decoder { type PeekableIoStream = Peekable; -#[cfg(any(feature = "gzip", feature = "brotli", feature = "deflate"))] +#[cfg(any( + feature = "gzip", + feature = "zstd", + feature = "brotli", + feature = "deflate" +))] type PeekableIoStreamReader = StreamReader; enum Inner { @@ -60,12 +80,21 @@ enum Inner { #[cfg(feature = "brotli")] Brotli(Pin, BytesCodec>>>), + /// A `Zstd` decoder will uncompress the zstd compressed response content before returning it. + #[cfg(feature = "zstd")] + Zstd(Pin, BytesCodec>>>), + /// A `Deflate` decoder will uncompress the deflated response content before returning it. #[cfg(feature = "deflate")] Deflate(Pin, BytesCodec>>>), /// A decoder that doesn't have a value yet. - #[cfg(any(feature = "brotli", feature = "gzip", feature = "deflate"))] + #[cfg(any( + feature = "brotli", + feature = "zstd", + feature = "gzip", + feature = "deflate" + ))] Pending(Pin>), } @@ -79,6 +108,8 @@ enum DecoderType { Gzip, #[cfg(feature = "brotli")] Brotli, + #[cfg(feature = "zstd")] + Zstd, #[cfg(feature = "deflate")] Deflate, } @@ -136,6 +167,21 @@ impl Decoder { } } + /// A zstd decoder. + /// + /// This decoder will buffer and decompress chunks that are zstd compressed. + #[cfg(feature = "zstd")] + fn zstd(body: Body) -> Decoder { + use futures_util::StreamExt; + + Decoder { + inner: Inner::Pending(Box::pin(Pending( + IoStream(body.into_stream()).peekable(), + DecoderType::Zstd, + ))), + } + } + /// A deflate decoder. /// /// This decoder will buffer and decompress chunks that are deflated. @@ -151,7 +197,12 @@ impl Decoder { } } - #[cfg(any(feature = "brotli", feature = "gzip", feature = "deflate"))] + #[cfg(any( + feature = "brotli", + feature = "zstd", + feature = "gzip", + feature = "deflate" + ))] fn detect_encoding(headers: &mut HeaderMap, encoding_str: &str) -> bool { use http::header::{CONTENT_ENCODING, CONTENT_LENGTH, TRANSFER_ENCODING}; use log::warn; @@ -202,6 +253,13 @@ impl Decoder { } } + #[cfg(feature = "zstd")] + { + if _accepts.zstd && Decoder::detect_encoding(_headers, "zstd") { + return Decoder::zstd(body); + } + } + #[cfg(feature = "deflate")] { if _accepts.deflate && Decoder::detect_encoding(_headers, "deflate") { @@ -219,7 +277,12 @@ impl Stream for Decoder { fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll> { // Do a read or poll for a pending decoder value. match self.inner { - #[cfg(any(feature = "brotli", feature = "gzip", feature = "deflate"))] + #[cfg(any( + feature = "brotli", + feature = "zstd", + feature = "gzip", + feature = "deflate" + ))] Inner::Pending(ref mut future) => match Pin::new(future).poll(cx) { Poll::Ready(Ok(inner)) => { self.inner = inner; @@ -245,6 +308,14 @@ impl Stream for Decoder { None => Poll::Ready(None), } } + #[cfg(feature = "zstd")] + Inner::Zstd(ref mut decoder) => { + return match futures_core::ready!(Pin::new(decoder).poll_next(cx)) { + Some(Ok(bytes)) => Poll::Ready(Some(Ok(bytes.freeze()))), + Some(Err(err)) => Poll::Ready(Some(Err(crate::error::decode_io(err)))), + None => Poll::Ready(None), + }; + } #[cfg(feature = "deflate")] Inner::Deflate(ref mut decoder) => { match futures_core::ready!(Pin::new(decoder).poll_next(cx)) { @@ -279,7 +350,12 @@ impl HttpBody for Decoder { match self.inner { Inner::PlainText(ref body) => HttpBody::size_hint(body), // the rest are "unknown", so default - #[cfg(any(feature = "brotli", feature = "gzip", feature = "deflate"))] + #[cfg(any( + feature = "brotli", + feature = "zstd", + feature = "gzip", + feature = "deflate" + ))] _ => http_body::SizeHint::default(), } } @@ -317,6 +393,11 @@ impl Future for Pending { BrotliDecoder::new(StreamReader::new(_body)), BytesCodec::new(), ))))), + #[cfg(feature = "zstd")] + DecoderType::Zstd => Poll::Ready(Ok(Inner::Zstd(Box::pin(FramedRead::new( + ZstdDecoder::new(StreamReader::new(_body)), + BytesCodec::new(), + ))))), #[cfg(feature = "gzip")] DecoderType::Gzip => Poll::Ready(Ok(Inner::Gzip(Box::pin(FramedRead::new( GzipDecoder::new(StreamReader::new(_body)), @@ -352,21 +433,36 @@ impl Accepts { gzip: false, #[cfg(feature = "brotli")] brotli: false, + #[cfg(feature = "zstd")] + zstd: false, #[cfg(feature = "deflate")] deflate: false, } } pub(super) fn as_str(&self) -> Option<&'static str> { - match (self.is_gzip(), self.is_brotli(), self.is_deflate()) { - (true, true, true) => Some("gzip, br, deflate"), - (true, true, false) => Some("gzip, br"), - (true, false, true) => Some("gzip, deflate"), - (false, true, true) => Some("br, deflate"), - (true, false, false) => Some("gzip"), - (false, true, false) => Some("br"), - (false, false, true) => Some("deflate"), - (false, false, false) => None, + match ( + self.is_gzip(), + self.is_brotli(), + self.is_zstd(), + self.is_deflate(), + ) { + (true, true, true, true) => Some("gzip, br, zstd, deflate"), + (true, true, false, true) => Some("gzip, br, deflate"), + (true, true, true, false) => Some("gzip, br, zstd"), + (true, true, false, false) => Some("gzip, br"), + (true, false, true, true) => Some("gzip, zstd, deflate"), + (true, false, false, true) => Some("gzip, zstd, deflate"), + (false, true, true, true) => Some("br, zstd, deflate"), + (false, true, false, true) => Some("br, zstd, deflate"), + (true, false, true, false) => Some("gzip, zstd"), + (true, false, false, false) => Some("gzip"), + (false, true, true, false) => Some("br, zstd"), + (false, true, false, false) => Some("br"), + (false, false, true, true) => Some("zstd, deflate"), + (false, false, true, false) => Some("zstd"), + (false, false, false, true) => Some("deflate"), + (false, false, false, false) => None, } } @@ -394,6 +490,18 @@ impl Accepts { } } + fn is_zstd(&self) -> bool { + #[cfg(feature = "zstd")] + { + self.zstd + } + + #[cfg(not(feature = "zstd"))] + { + false + } + } + fn is_deflate(&self) -> bool { #[cfg(feature = "deflate")] { @@ -414,6 +522,8 @@ impl Default for Accepts { gzip: true, #[cfg(feature = "brotli")] brotli: true, + #[cfg(feature = "zstd")] + zstd: true, #[cfg(feature = "deflate")] deflate: true, } diff --git a/src/blocking/client.rs b/src/blocking/client.rs index d57f3a031b..d8cd52ab6f 100644 --- a/src/blocking/client.rs +++ b/src/blocking/client.rs @@ -260,6 +260,28 @@ impl ClientBuilder { self.with_inner(|inner| inner.brotli(enable)) } + /// Enable auto zstd decompression by checking the `Content-Encoding` response header. + /// + /// If auto zstd decompression is turned on: + /// + /// - When sending a request and if the request's headers do not already contain + /// an `Accept-Encoding` **and** `Range` values, the `Accept-Encoding` header is set to `zstd`. + /// The request body is **not** automatically compressed. + /// - When receiving a response, if its headers contain a `Content-Encoding` value of + /// `zstd`, both `Content-Encoding` and `Content-Length` are removed from the + /// headers' set. The response body is automatically decompressed. + /// + /// If the `zstd` feature is turned on, the default option is enabled. + /// + /// # Optional + /// + /// This requires the optional `zstd` feature to be enabled + #[cfg(feature = "zstd")] + #[cfg_attr(docsrs, doc(cfg(feature = "zstd")))] + pub fn zstd(self, enable: bool) -> ClientBuilder { + self.with_inner(|inner| inner.zstd(enable)) + } + /// Enable auto deflate decompression by checking the `Content-Encoding` response header. /// /// If auto deflate decompresson is turned on: @@ -300,6 +322,15 @@ impl ClientBuilder { self.with_inner(|inner| inner.no_brotli()) } + /// Disable auto response body zstd decompression. + /// + /// This method exists even if the optional `zstd` feature is not enabled. + /// This can be used to ensure a `Client` doesn't use zstd decompression + /// even if another dependency were to enable the optional `zstd` feature. + pub fn no_zstd(self) -> ClientBuilder { + self.with_inner(|inner| inner.no_zstd()) + } + /// Disable auto response body deflate decompression. /// /// This method exists even if the optional `deflate` feature is not enabled. diff --git a/src/lib.rs b/src/lib.rs index 188ba4f029..d4f8abe4b0 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -190,6 +190,7 @@ //! - **cookies**: Provides cookie session support. //! - **gzip**: Provides response body gzip decompression. //! - **brotli**: Provides response body brotli decompression. +//! - **zstd**: Provides response body zstd decompression. //! - **deflate**: Provides response body deflate decompression. //! - **json**: Provides serialization and deserialization for JSON bodies. //! - **multipart**: Provides functionality for multipart forms. diff --git a/tests/client.rs b/tests/client.rs index e77cc6a4a9..0532d11989 100644 --- a/tests/client.rs +++ b/tests/client.rs @@ -30,6 +30,12 @@ async fn auto_headers() { .unwrap() .contains("br")); } + if cfg!(feature = "zstd") { + assert!(req.headers()["accept-encoding"] + .to_str() + .unwrap() + .contains("zstd")); + } if cfg!(feature = "deflate") { assert!(req.headers()["accept-encoding"] .to_str() diff --git a/tests/zstd.rs b/tests/zstd.rs new file mode 100644 index 0000000000..4f5e96344a --- /dev/null +++ b/tests/zstd.rs @@ -0,0 +1,145 @@ +mod support; +use support::*; + +#[tokio::test] +async fn zstd_response() { + zstd_case(10_000, 4096).await; +} + +#[tokio::test] +async fn zstd_single_byte_chunks() { + zstd_case(10, 1).await; +} + +#[tokio::test] +async fn test_zstd_empty_body() { + let server = server::http(move |req| async move { + assert_eq!(req.method(), "HEAD"); + + http::Response::builder() + .header("content-encoding", "zstd") + .header("content-length", 100) + .body(Default::default()) + .unwrap() + }); + + let client = reqwest::Client::new(); + let res = client + .head(&format!("http://{}/zstd", server.addr())) + .send() + .await + .unwrap(); + + let body = res.text().await.unwrap(); + + assert_eq!(body, ""); +} + +#[tokio::test] +async fn test_accept_header_is_not_changed_if_set() { + let server = server::http(move |req| async move { + assert_eq!(req.headers()["accept"], "application/json"); + assert!(req.headers()["accept-encoding"] + .to_str() + .unwrap() + .contains("zstd")); + http::Response::default() + }); + + let client = reqwest::Client::new(); + + let res = client + .get(&format!("http://{}/accept", server.addr())) + .header( + reqwest::header::ACCEPT, + reqwest::header::HeaderValue::from_static("application/json"), + ) + .send() + .await + .unwrap(); + + assert_eq!(res.status(), reqwest::StatusCode::OK); +} + +#[tokio::test] +async fn test_accept_encoding_header_is_not_changed_if_set() { + let server = server::http(move |req| async move { + assert_eq!(req.headers()["accept"], "*/*"); + assert_eq!(req.headers()["accept-encoding"], "identity"); + http::Response::default() + }); + + let client = reqwest::Client::new(); + + let res = client + .get(&format!("http://{}/accept-encoding", server.addr())) + .header( + reqwest::header::ACCEPT_ENCODING, + reqwest::header::HeaderValue::from_static("identity"), + ) + .send() + .await + .unwrap(); + + assert_eq!(res.status(), reqwest::StatusCode::OK); +} + +async fn zstd_case(response_size: usize, chunk_size: usize) { + use futures_util::stream::StreamExt; + + let content: String = (0..response_size) + .into_iter() + .map(|i| format!("test {}", i)) + .collect(); + + let zstded_content = zstd_crate::encode_all(content.as_bytes(), 3).unwrap(); + + let mut response = format!( + "\ + HTTP/1.1 200 OK\r\n\ + Server: test-accept\r\n\ + Content-Encoding: zstd\r\n\ + Content-Length: {}\r\n\ + \r\n", + &zstded_content.len() + ) + .into_bytes(); + response.extend(&zstded_content); + + let server = server::http(move |req| { + assert!(req.headers()["accept-encoding"] + .to_str() + .unwrap() + .contains("zstd")); + + let zstded = zstded_content.clone(); + async move { + let len = zstded.len(); + let stream = + futures_util::stream::unfold((zstded, 0), move |(zstded, pos)| async move { + let chunk = zstded.chunks(chunk_size).nth(pos)?.to_vec(); + + Some((chunk, (zstded, pos + 1))) + }); + + let body = hyper::Body::wrap_stream(stream.map(Ok::<_, std::convert::Infallible>)); + + http::Response::builder() + .header("content-encoding", "zstd") + .header("content-length", len) + .body(body) + .unwrap() + } + }); + + let client = reqwest::Client::new(); + + let res = client + .get(&format!("http://{}/zstd", server.addr())) + .send() + .await + .expect("response"); + + let body = res.text().await.expect("text"); + assert_eq!(body, content); +}