From e9e0eaaafc8e1b92d38267a7833b666652f1f3df Mon Sep 17 00:00:00 2001 From: JESS IZEN Date: Thu, 12 Dec 2024 16:46:51 +0000 Subject: [PATCH] feat: allow pluggable tower layers in connector service stack --- Cargo.toml | 11 +- ...onnect_via_lower_priority_tokio_runtime.rs | 265 ++++++++++++++ src/async_impl/client.rs | 294 +++++++++------ src/blocking/client.rs | 230 ++++++++---- src/connect.rs | 196 ++++++++-- src/error.rs | 12 + tests/connector_layers.rs | 344 ++++++++++++++++++ tests/support/delay_layer.rs | 120 ++++++ tests/support/mod.rs | 1 + tests/timeouts.rs | 18 + 10 files changed, 1267 insertions(+), 224 deletions(-) create mode 100644 examples/connect_via_lower_priority_tokio_runtime.rs create mode 100644 tests/connector_layers.rs create mode 100644 tests/support/delay_layer.rs diff --git a/Cargo.toml b/Cargo.toml index 39ff48424..01fa5339c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -105,6 +105,7 @@ url = "2.4" bytes = "1.0" serde = "1.0" serde_urlencoded = "0.7.1" +tower = { version = "0.5", default-features = false, features = ["timeout", "util"] } tower-service = "0.3" futures-core = { version = "0.3.28", default-features = false } futures-util = { version = "0.3.28", default-features = false } @@ -169,7 +170,6 @@ quinn = { version = "0.11.1", default-features = false, features = ["rustls", "r slab = { version = "0.4.9", optional = true } # just to get minimal versions working with quinn futures-channel = { version = "0.3", optional = true } - [target.'cfg(not(target_arch = "wasm32"))'.dev-dependencies] env_logger = "0.10" hyper = { version = "1.1.0", default-features = false, features = ["http1", "http2", "client", "server"] } @@ -222,6 +222,11 @@ features = [ wasm-bindgen = { version = "0.2.89", features = ["serde-serialize"] } wasm-bindgen-test = "0.3" +[dev-dependencies] +tower = { version = "0.5.1", default-features = false, features = ["limit"] } +num_cpus = "1.0" +libc = "0" + [lints.rust] unexpected_cfgs = { level = "warn", check-cfg = ['cfg(reqwest_unstable)'] } @@ -253,6 +258,10 @@ path = "examples/form.rs" name = "simple" path = "examples/simple.rs" +[[example]] +name = "connect_via_lower_priority_tokio_runtime" +path = "examples/connect_via_lower_priority_tokio_runtime.rs" + [[test]] name = "blocking" path = "tests/blocking.rs" diff --git a/examples/connect_via_lower_priority_tokio_runtime.rs b/examples/connect_via_lower_priority_tokio_runtime.rs new file mode 100644 index 000000000..88ed7af8f --- /dev/null +++ b/examples/connect_via_lower_priority_tokio_runtime.rs @@ -0,0 +1,265 @@ +#![deny(warnings)] +// This example demonstrates how to delegate the connect calls, which contain TLS handshakes, +// to a secondary tokio runtime of lower OS thread priority using a custom tower layer. +// This helps to ensure that long-running futures during handshake crypto operations don't block other I/O futures. +// +// This does introduce overhead of additional threads, channels, extra vtables, etc, +// so it is best suited to services with large numbers of incoming connections or that +// are otherwise very sensitive to any blocking futures. Or, you might want fewer threads +// and/or to use the current_thread runtime. +// +// This is using the `tokio` runtime and certain other dependencies: +// +// `tokio = { version = "1", features = ["full"] }` +// `num_cpus = "1.0"` +// `libc = "0"` +// `pin-project-lite = "0.2"` +// `tower = { version = "0.5", default-features = false}` +use std::{ + future::Future, + marker::PhantomData, + pin::Pin, + sync::OnceLock, + task::{Context, Poll}, + time::Duration, +}; + +use futures_util::TryFutureExt; +use pin_project_lite::pin_project; +use tokio::{runtime::Handle, select, sync::mpsc::error::TrySendError}; +use tower::{BoxError, Layer, Service}; + +static CPU_HEAVY_THREAD_POOL: OnceLock< + tokio::sync::mpsc::Sender + Send + 'static>>>, +> = OnceLock::new(); + +#[cfg(not(target_arch = "wasm32"))] +#[tokio::main] +async fn main() -> Result<(), reqwest::Error> { + init_background_runtime(); + tokio::time::sleep(Duration::from_millis(10)).await; + + let client = reqwest::Client::builder() + .connector_layer(BackgroundProcessorLayer::::new()) + .build() + .expect("should be able to build reqwest client"); + + let url = if let Some(url) = std::env::args().nth(1) { + url + } else { + println!("No CLI URL provided, using default."); + "https://hyper.rs".into() + }; + + eprintln!("Fetching {url:?}..."); + + let res = client.get(url).send().await?; + + eprintln!("Response: {:?} {}", res.version(), res.status()); + eprintln!("Headers: {:#?}\n", res.headers()); + + let body = res.text().await?; + + println!("{body}"); + + Ok(()) +} + +fn init_background_runtime() { + std::thread::Builder::new() + .name("cpu-heavy-background-threadpool".to_string()) + .spawn(move || { + let rt = tokio::runtime::Builder::new_multi_thread() + .thread_name("cpu-heavy-background-pool-thread") + .worker_threads(num_cpus::get() as usize) + // ref: https://github.com/tokio-rs/tokio/issues/4941 + // consider uncommenting if seeing heavy task contention + // .disable_lifo_slot() + .on_thread_start(move || unsafe { + // Reduce thread pool thread niceness, so they are lower priority + // than the foreground executor and don't interfere with I/O tasks + #[cfg(target_os = "linux")] + { + *libc::__errno_location() = 0; + if libc::nice(10) == -1 && *libc::__errno_location() != 0 { + let error = std::io::Error::last_os_error(); + log::error!("failed to set threadpool niceness: {}", error); + } + } + }) + .enable_all() + .build() + .unwrap_or_else(|e| panic!("cpu heavy runtime failed_to_initialize: {}", e)); + rt.block_on(async { + log::debug!("starting background cpu-heavy work"); + process_cpu_work().await; + }); + }) + .unwrap_or_else(|e| panic!("cpu heavy thread failed_to_initialize: {}", e)); +} + +async fn process_cpu_work() { + // we only use this channel for routing work, it should move pretty quick, it can be small + let (tx, mut rx) = tokio::sync::mpsc::channel(10); + // share the handle to the background channel globally + CPU_HEAVY_THREAD_POOL.set(tx).unwrap(); + + while let Some(work) = rx.recv().await { + tokio::task::spawn(work); + } +} + +// retrieve the sender to the background channel, and send the future over to it for execution +fn send_to_background_runtime(future: impl Future + Send + 'static) { + let tx = CPU_HEAVY_THREAD_POOL + .get() + .expect("start up the secondary tokio runtime before sending to `CPU_HEAVY_THREAD_POOL`"); + + match tx.try_send(Box::pin(future)) { + Ok(_) => (), + Err(TrySendError::Closed(_)) => { + panic!("background cpu heavy runtime channel is closed") + } + Err(TrySendError::Full(msg)) => { + log::warn!("background cpu heavy runtime channel is full, task spawning loop delayed"); + let tx = tx.clone(); + Handle::current().spawn(async move { + tx.send(msg) + .await + .expect("background cpu heavy runtime channel is closed") + }); + } + } +} + +// This tower layer injects futures with a oneshot channel, and then sends them to the background runtime for processing. +// We don't use the Buffer service because that is intended to process sequentially on a single task, whereas we want to +// spawn a new task per call. +pub struct BackgroundProcessorLayer { + _p: PhantomData, +} +impl BackgroundProcessorLayer { + pub fn new() -> Self { + Self { _p: PhantomData } + } +} +impl Layer for BackgroundProcessorLayer { + type Service = BackgroundProcessor; + fn layer(&self, service: S) -> Self::Service { + BackgroundProcessor::new(service) + } +} + +impl std::fmt::Debug for BackgroundProcessorLayer { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + f.debug_struct("BackgroundProcessorLayer").finish() + } +} + +impl Clone for BackgroundProcessorLayer { + fn clone(&self) -> Self { + Self { _p: PhantomData } + } +} + +impl Copy for BackgroundProcessorLayer {} + +// This tower service injects futures with a oneshot channel, and then sends them to the background runtime for processing. +#[derive(Debug, Clone)] +pub struct BackgroundProcessor { + inner: S, +} + +impl BackgroundProcessor { + pub fn new(inner: S) -> Self { + BackgroundProcessor { inner } + } +} + +impl Service for BackgroundProcessor +where + S: Service, + S::Response: Send + 'static, + S::Error: Into + Send, + S::Future: Send + 'static, +{ + type Response = S::Response; + + type Error = BoxError; + + type Future = BackgroundResponseFuture; + + fn poll_ready( + &mut self, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + match self.inner.poll_ready(cx) { + Poll::Pending => Poll::Pending, + Poll::Ready(r) => Poll::Ready(r.map_err(Into::into)), + } + } + + fn call(&mut self, req: Request) -> Self::Future { + let response = self.inner.call(req); + + // wrap our inner service's future with a future that writes to this oneshot channel + let (mut tx, rx) = tokio::sync::oneshot::channel(); + let future = async move { + select!( + _ = tx.closed() => { + // receiver already dropped, don't need to do anything + } + result = response.map_err(|err| Into::::into(err)) => { + // if this fails, the receiver already dropped, so we don't need to do anything + let _ = tx.send(result); + } + ) + }; + // send the wrapped future to the background + send_to_background_runtime(future); + + BackgroundResponseFuture::new(rx) + } +} + +// `BackgroundProcessor` response future +pin_project! { + #[derive(Debug)] + pub struct BackgroundResponseFuture { + #[pin] + rx: tokio::sync::oneshot::Receiver>, + } +} + +impl BackgroundResponseFuture { + pub(crate) fn new(rx: tokio::sync::oneshot::Receiver>) -> Self { + BackgroundResponseFuture { rx } + } +} + +impl Future for BackgroundResponseFuture +where + S: Send + 'static, +{ + type Output = Result; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let this = self.project(); + + // now poll on the receiver end of the oneshot to get the result + match this.rx.poll(cx) { + Poll::Ready(v) => match v { + Ok(v) => Poll::Ready(v.map_err(Into::into)), + Err(err) => Poll::Ready(Err(Box::new(err) as BoxError)), + }, + Poll::Pending => Poll::Pending, + } + } +} + +// The [cfg(not(target_arch = "wasm32"))] above prevent building the tokio::main function +// for wasm32 target, because tokio isn't compatible with wasm32. +// If you aren't building for wasm32, you don't need that line. +// The two lines below avoid the "'main' function not found" error when building for wasm32 target. +#[cfg(any(target_arch = "wasm32"))] +fn main() {} diff --git a/src/async_impl/client.rs b/src/async_impl/client.rs index 579050041..51a260824 100644 --- a/src/async_impl/client.rs +++ b/src/async_impl/client.rs @@ -1,27 +1,14 @@ #[cfg(any(feature = "native-tls", feature = "__rustls",))] use std::any::Any; +use std::future::Future; use std::net::IpAddr; +use std::pin::Pin; use std::sync::Arc; +use std::task::{Context, Poll}; use std::time::Duration; use std::{collections::HashMap, convert::TryInto, net::SocketAddr}; use std::{fmt, str}; -use bytes::Bytes; -use http::header::{ - Entry, HeaderMap, HeaderValue, ACCEPT, ACCEPT_ENCODING, CONTENT_ENCODING, CONTENT_LENGTH, - CONTENT_TYPE, LOCATION, PROXY_AUTHORIZATION, RANGE, REFERER, TRANSFER_ENCODING, USER_AGENT, -}; -use http::uri::Scheme; -use http::Uri; -use hyper_util::client::legacy::connect::HttpConnector; -#[cfg(feature = "default-tls")] -use native_tls_crate::TlsConnector; -use pin_project_lite::pin_project; -use std::future::Future; -use std::pin::Pin; -use std::task::{Context, Poll}; -use tokio::time::Sleep; - use super::decoder::Accepts; use super::request::{Request, RequestBuilder}; use super::response::Response; @@ -30,13 +17,13 @@ use super::Body; use crate::async_impl::h3_client::connect::H3Connector; #[cfg(feature = "http3")] use crate::async_impl::h3_client::{H3Client, H3ResponseFuture}; -use crate::connect::Connector; +use crate::connect::{Conn, Connector, ConnectorBuilder, ConnectorLayerBuilder, ConnectorService}; #[cfg(feature = "cookies")] use crate::cookie; #[cfg(feature = "hickory-dns")] use crate::dns::hickory::HickoryDnsResolver; use crate::dns::{gai::GaiResolver, DnsResolverWithOverrides, DynResolver, Resolve}; -use crate::error; +use crate::error::{self, BoxError}; use crate::into_url::try_uri; use crate::redirect::{self, remove_sensitive_headers}; #[cfg(feature = "__rustls")] @@ -48,11 +35,24 @@ use crate::Certificate; #[cfg(any(feature = "native-tls", feature = "__rustls"))] use crate::Identity; use crate::{IntoUrl, Method, Proxy, StatusCode, Url}; +use bytes::Bytes; +use http::header::{ + Entry, HeaderMap, HeaderValue, ACCEPT, ACCEPT_ENCODING, CONTENT_ENCODING, CONTENT_LENGTH, + CONTENT_TYPE, LOCATION, PROXY_AUTHORIZATION, RANGE, REFERER, TRANSFER_ENCODING, USER_AGENT, +}; +use http::uri::Scheme; +use http::Uri; +use hyper_util::client::legacy::connect::HttpConnector; use log::debug; +#[cfg(feature = "default-tls")] +use native_tls_crate::TlsConnector; +use pin_project_lite::pin_project; #[cfg(feature = "http3")] use quinn::TransportConfig; #[cfg(feature = "http3")] use quinn::VarInt; +use tokio::time::Sleep; +use tower::{layer::util::Stack, Layer, Service, ServiceBuilder}; type HyperResponseFuture = hyper_util::client::legacy::ResponseFuture; @@ -76,8 +76,10 @@ pub struct Client { /// A `ClientBuilder` can be used to create a `Client` with custom configuration. #[must_use] -pub struct ClientBuilder { +pub struct ClientBuilder { config: Config, + // separated out to simplify casting between generic types while copying config + connector_layers: ConnectorLayerBuilder, } enum HttpVersionPref { @@ -175,17 +177,26 @@ struct Config { dns_resolver: Option>, } -impl Default for ClientBuilder { +impl Default for ClientBuilder { fn default() -> Self { Self::new() } } -impl ClientBuilder { +#[allow(private_bounds)] +impl ClientBuilder +where + CL1: Layer, + CL2: Layer<>::Service>, + >::Service>>::Service: + Service + Clone + Send + Sync + 'static, + <>::Service>>::Service as Service>::Future: + Send + 'static, +{ /// Constructs a new `ClientBuilder`. /// /// This is the same as `Client::builder()`. - pub fn new() -> ClientBuilder { + pub fn new() -> ClientBuilder { let mut headers: HeaderMap = HeaderMap::with_capacity(2); headers.insert(ACCEPT, HeaderValue::from_static("*/*")); @@ -276,6 +287,10 @@ impl ClientBuilder { quic_send_window: None, dns_resolver: None, }, + connector_layers: ConnectorLayerBuilder { + builder: ServiceBuilder::new().layer(tower::layer::util::Identity::new()), + has_custom_layers: false, + }, } } @@ -302,7 +317,7 @@ impl ClientBuilder { #[cfg(feature = "http3")] let mut h3_connector = None; - let mut connector = { + let mut connector_builder = { #[cfg(feature = "__tls")] fn user_agent(headers: &HeaderMap) -> Option { headers.get(USER_AGENT).cloned() @@ -445,7 +460,7 @@ impl ClientBuilder { tls.max_protocol_version(Some(protocol)); } - Connector::new_default_tls( + ConnectorBuilder::new_default_tls( http, tls, proxies.clone(), @@ -462,7 +477,7 @@ impl ClientBuilder { )? } #[cfg(feature = "native-tls")] - TlsBackend::BuiltNativeTls(conn) => Connector::from_built_default_tls( + TlsBackend::BuiltNativeTls(conn) => ConnectorBuilder::from_built_default_tls( http, conn, proxies.clone(), @@ -489,7 +504,7 @@ impl ClientBuilder { )?; } - Connector::new_rustls_tls( + ConnectorBuilder::new_rustls_tls( http, conn, proxies.clone(), @@ -684,7 +699,7 @@ impl ClientBuilder { )?; } - Connector::new_rustls_tls( + ConnectorBuilder::new_rustls_tls( http, tls, proxies.clone(), @@ -709,7 +724,7 @@ impl ClientBuilder { } #[cfg(not(feature = "__tls"))] - Connector::new( + ConnectorBuilder::new( http, proxies.clone(), config.local_address, @@ -719,8 +734,9 @@ impl ClientBuilder { ) }; - connector.set_timeout(config.connect_timeout); - connector.set_verbose(config.connection_verbose); + connector_builder.set_timeout(config.connect_timeout); + connector_builder.set_verbose(config.connection_verbose); + connector_builder.set_keepalive(config.tcp_keepalive); let mut builder = hyper_util::client::legacy::Client::builder(hyper_util::rt::TokioExecutor::new()); @@ -763,7 +779,6 @@ impl ClientBuilder { builder.pool_timer(hyper_util::rt::TokioTimer::new()); builder.pool_idle_timeout(config.pool_idle_timeout); builder.pool_max_idle_per_host(config.pool_max_idle_per_host); - connector.set_keepalive(config.tcp_keepalive); if config.http09_responses { builder.http09_responses(true); @@ -801,7 +816,7 @@ impl ClientBuilder { } None => None, }, - hyper: builder.build(connector), + hyper: builder.build(connector_builder.build(self.connector_layers)), headers: config.headers, redirect_policy: config.redirect_policy, referer: config.referer, @@ -836,7 +851,7 @@ impl ClientBuilder { /// # Ok(()) /// # } /// ``` - pub fn user_agent(mut self, value: V) -> ClientBuilder + pub fn user_agent(mut self, value: V) -> ClientBuilder where V: TryInto, V::Error: Into, @@ -874,7 +889,7 @@ impl ClientBuilder { /// # Ok(()) /// # } /// ``` - pub fn default_headers(mut self, headers: HeaderMap) -> ClientBuilder { + pub fn default_headers(mut self, headers: HeaderMap) -> ClientBuilder { for (key, value) in headers.iter() { self.config.headers.insert(key, value.clone()); } @@ -897,7 +912,7 @@ impl ClientBuilder { /// This requires the optional `cookies` feature to be enabled. #[cfg(feature = "cookies")] #[cfg_attr(docsrs, doc(cfg(feature = "cookies")))] - pub fn cookie_store(mut self, enable: bool) -> ClientBuilder { + pub fn cookie_store(mut self, enable: bool) -> ClientBuilder { if enable { self.cookie_provider(Arc::new(cookie::Jar::default())) } else { @@ -924,7 +939,7 @@ impl ClientBuilder { pub fn cookie_provider( mut self, cookie_store: Arc, - ) -> ClientBuilder { + ) -> ClientBuilder { self.config.cookie_store = Some(cookie_store as _); self } @@ -947,7 +962,7 @@ impl ClientBuilder { /// This requires the optional `gzip` feature to be enabled #[cfg(feature = "gzip")] #[cfg_attr(docsrs, doc(cfg(feature = "gzip")))] - pub fn gzip(mut self, enable: bool) -> ClientBuilder { + pub fn gzip(mut self, enable: bool) -> ClientBuilder { self.config.accepts.gzip = enable; self } @@ -970,7 +985,7 @@ impl ClientBuilder { /// This requires the optional `brotli` feature to be enabled #[cfg(feature = "brotli")] #[cfg_attr(docsrs, doc(cfg(feature = "brotli")))] - pub fn brotli(mut self, enable: bool) -> ClientBuilder { + pub fn brotli(mut self, enable: bool) -> ClientBuilder { self.config.accepts.brotli = enable; self } @@ -993,7 +1008,7 @@ impl ClientBuilder { /// This requires the optional `zstd` feature to be enabled #[cfg(feature = "zstd")] #[cfg_attr(docsrs, doc(cfg(feature = "zstd")))] - pub fn zstd(mut self, enable: bool) -> ClientBuilder { + pub fn zstd(mut self, enable: bool) -> ClientBuilder { self.config.accepts.zstd = enable; self } @@ -1016,7 +1031,7 @@ impl ClientBuilder { /// This requires the optional `deflate` feature to be enabled #[cfg(feature = "deflate")] #[cfg_attr(docsrs, doc(cfg(feature = "deflate")))] - pub fn deflate(mut self, enable: bool) -> ClientBuilder { + pub fn deflate(mut self, enable: bool) -> ClientBuilder { self.config.accepts.deflate = enable; self } @@ -1026,7 +1041,7 @@ impl ClientBuilder { /// This method exists even if the optional `gzip` feature is not enabled. /// This can be used to ensure a `Client` doesn't use gzip decompression /// even if another dependency were to enable the optional `gzip` feature. - pub fn no_gzip(self) -> ClientBuilder { + pub fn no_gzip(self) -> ClientBuilder { #[cfg(feature = "gzip")] { self.gzip(false) @@ -1043,7 +1058,7 @@ impl ClientBuilder { /// This method exists even if the optional `brotli` feature is not enabled. /// This can be used to ensure a `Client` doesn't use brotli decompression /// even if another dependency were to enable the optional `brotli` feature. - pub fn no_brotli(self) -> ClientBuilder { + pub fn no_brotli(self) -> ClientBuilder { #[cfg(feature = "brotli")] { self.brotli(false) @@ -1060,7 +1075,7 @@ impl ClientBuilder { /// This method exists even if the optional `zstd` feature is not enabled. /// This can be used to ensure a `Client` doesn't use zstd decompression /// even if another dependency were to enable the optional `zstd` feature. - pub fn no_zstd(self) -> ClientBuilder { + pub fn no_zstd(self) -> ClientBuilder { #[cfg(feature = "zstd")] { self.zstd(false) @@ -1077,7 +1092,7 @@ impl ClientBuilder { /// This method exists even if the optional `deflate` feature is not enabled. /// This can be used to ensure a `Client` doesn't use deflate decompression /// even if another dependency were to enable the optional `deflate` feature. - pub fn no_deflate(self) -> ClientBuilder { + pub fn no_deflate(self) -> ClientBuilder { #[cfg(feature = "deflate")] { self.deflate(false) @@ -1094,7 +1109,7 @@ impl ClientBuilder { /// Set a `RedirectPolicy` for this client. /// /// Default will follow redirects up to a maximum of 10. - pub fn redirect(mut self, policy: redirect::Policy) -> ClientBuilder { + pub fn redirect(mut self, policy: redirect::Policy) -> ClientBuilder { self.config.redirect_policy = policy; self } @@ -1102,7 +1117,7 @@ impl ClientBuilder { /// Enable or disable automatic setting of the `Referer` header. /// /// Default is `true`. - pub fn referer(mut self, enable: bool) -> ClientBuilder { + pub fn referer(mut self, enable: bool) -> ClientBuilder { self.config.referer = enable; self } @@ -1114,7 +1129,7 @@ impl ClientBuilder { /// # Note /// /// Adding a proxy will disable the automatic usage of the "system" proxy. - pub fn proxy(mut self, proxy: Proxy) -> ClientBuilder { + pub fn proxy(mut self, proxy: Proxy) -> ClientBuilder { self.config.proxies.push(proxy); self.config.auto_sys_proxy = false; self @@ -1127,7 +1142,7 @@ impl ClientBuilder { /// on all desired proxies instead. /// /// This also disables the automatic usage of the "system" proxy. - pub fn no_proxy(mut self) -> ClientBuilder { + pub fn no_proxy(mut self) -> ClientBuilder { self.config.proxies.clear(); self.config.auto_sys_proxy = false; self @@ -1141,7 +1156,7 @@ impl ClientBuilder { /// response body has finished. Also considered a total deadline. /// /// Default is no timeout. - pub fn timeout(mut self, timeout: Duration) -> ClientBuilder { + pub fn timeout(mut self, timeout: Duration) -> ClientBuilder { self.config.timeout = Some(timeout); self } @@ -1153,7 +1168,7 @@ impl ClientBuilder { /// connections when the size isn't known beforehand. /// /// Default is no timeout. - pub fn read_timeout(mut self, timeout: Duration) -> ClientBuilder { + pub fn read_timeout(mut self, timeout: Duration) -> ClientBuilder { self.config.read_timeout = Some(timeout); self } @@ -1166,7 +1181,7 @@ impl ClientBuilder { /// /// This **requires** the futures be executed in a tokio runtime with /// a tokio timer enabled. - pub fn connect_timeout(mut self, timeout: Duration) -> ClientBuilder { + pub fn connect_timeout(mut self, timeout: Duration) -> ClientBuilder { self.config.connect_timeout = Some(timeout); self } @@ -1177,7 +1192,7 @@ impl ClientBuilder { /// for read and write operations on connections. /// /// [log]: https://crates.io/crates/log - pub fn connection_verbose(mut self, verbose: bool) -> ClientBuilder { + pub fn connection_verbose(mut self, verbose: bool) -> ClientBuilder { self.config.connection_verbose = verbose; self } @@ -1189,7 +1204,7 @@ impl ClientBuilder { /// Pass `None` to disable timeout. /// /// Default is 90 seconds. - pub fn pool_idle_timeout(mut self, val: D) -> ClientBuilder + pub fn pool_idle_timeout(mut self, val: D) -> ClientBuilder where D: Into>, { @@ -1198,13 +1213,13 @@ impl ClientBuilder { } /// Sets the maximum idle connection per host allowed in the pool. - pub fn pool_max_idle_per_host(mut self, max: usize) -> ClientBuilder { + pub fn pool_max_idle_per_host(mut self, max: usize) -> ClientBuilder { self.config.pool_max_idle_per_host = max; self } /// Send headers as title case instead of lowercase. - pub fn http1_title_case_headers(mut self) -> ClientBuilder { + pub fn http1_title_case_headers(mut self) -> ClientBuilder { self.config.http1_title_case_headers = true; self } @@ -1217,14 +1232,17 @@ impl ClientBuilder { pub fn http1_allow_obsolete_multiline_headers_in_responses( mut self, value: bool, - ) -> ClientBuilder { + ) -> ClientBuilder { self.config .http1_allow_obsolete_multiline_headers_in_responses = value; self } /// Sets whether invalid header lines should be silently ignored in HTTP/1 responses. - pub fn http1_ignore_invalid_headers_in_responses(mut self, value: bool) -> ClientBuilder { + pub fn http1_ignore_invalid_headers_in_responses( + mut self, + value: bool, + ) -> ClientBuilder { self.config.http1_ignore_invalid_headers_in_responses = value; self } @@ -1237,20 +1255,20 @@ impl ClientBuilder { pub fn http1_allow_spaces_after_header_name_in_responses( mut self, value: bool, - ) -> ClientBuilder { + ) -> ClientBuilder { self.config .http1_allow_spaces_after_header_name_in_responses = value; self } /// Only use HTTP/1. - pub fn http1_only(mut self) -> ClientBuilder { + pub fn http1_only(mut self) -> ClientBuilder { self.config.http_version_pref = HttpVersionPref::Http1; self } /// Allow HTTP/0.9 responses - pub fn http09_responses(mut self) -> ClientBuilder { + pub fn http09_responses(mut self) -> ClientBuilder { self.config.http09_responses = true; self } @@ -1258,7 +1276,7 @@ impl ClientBuilder { /// Only use HTTP/2. #[cfg(feature = "http2")] #[cfg_attr(docsrs, doc(cfg(feature = "http2")))] - pub fn http2_prior_knowledge(mut self) -> ClientBuilder { + pub fn http2_prior_knowledge(mut self) -> ClientBuilder { self.config.http_version_pref = HttpVersionPref::Http2; self } @@ -1266,7 +1284,7 @@ impl ClientBuilder { /// Only use HTTP/3. #[cfg(feature = "http3")] #[cfg_attr(docsrs, doc(cfg(all(reqwest_unstable, feature = "http3",))))] - pub fn http3_prior_knowledge(mut self) -> ClientBuilder { + pub fn http3_prior_knowledge(mut self) -> ClientBuilder { self.config.http_version_pref = HttpVersionPref::Http3; self } @@ -1276,7 +1294,10 @@ impl ClientBuilder { /// Default is currently 65,535 but may change internally to optimize for common uses. #[cfg(feature = "http2")] #[cfg_attr(docsrs, doc(cfg(feature = "http2")))] - pub fn http2_initial_stream_window_size(mut self, sz: impl Into>) -> ClientBuilder { + pub fn http2_initial_stream_window_size( + mut self, + sz: impl Into>, + ) -> ClientBuilder { self.config.http2_initial_stream_window_size = sz.into(); self } @@ -1289,7 +1310,7 @@ impl ClientBuilder { pub fn http2_initial_connection_window_size( mut self, sz: impl Into>, - ) -> ClientBuilder { + ) -> ClientBuilder { self.config.http2_initial_connection_window_size = sz.into(); self } @@ -1300,7 +1321,7 @@ impl ClientBuilder { /// `http2_initial_connection_window_size`. #[cfg(feature = "http2")] #[cfg_attr(docsrs, doc(cfg(feature = "http2")))] - pub fn http2_adaptive_window(mut self, enabled: bool) -> ClientBuilder { + pub fn http2_adaptive_window(mut self, enabled: bool) -> ClientBuilder { self.config.http2_adaptive_window = enabled; self } @@ -1310,7 +1331,7 @@ impl ClientBuilder { /// Default is currently 16,384 but may change internally to optimize for common uses. #[cfg(feature = "http2")] #[cfg_attr(docsrs, doc(cfg(feature = "http2")))] - pub fn http2_max_frame_size(mut self, sz: impl Into>) -> ClientBuilder { + pub fn http2_max_frame_size(mut self, sz: impl Into>) -> ClientBuilder { self.config.http2_max_frame_size = sz.into(); self } @@ -1320,7 +1341,10 @@ impl ClientBuilder { /// Default is currently 16KB, but can change. #[cfg(feature = "http2")] #[cfg_attr(docsrs, doc(cfg(feature = "http2")))] - pub fn http2_max_header_list_size(mut self, max_header_size_bytes: u32) -> ClientBuilder { + pub fn http2_max_header_list_size( + mut self, + max_header_size_bytes: u32, + ) -> ClientBuilder { self.config.http2_max_header_list_size = Some(max_header_size_bytes); self } @@ -1334,7 +1358,7 @@ impl ClientBuilder { pub fn http2_keep_alive_interval( mut self, interval: impl Into>, - ) -> ClientBuilder { + ) -> ClientBuilder { self.config.http2_keep_alive_interval = interval.into(); self } @@ -1346,7 +1370,7 @@ impl ClientBuilder { /// Default is currently disabled. #[cfg(feature = "http2")] #[cfg_attr(docsrs, doc(cfg(feature = "http2")))] - pub fn http2_keep_alive_timeout(mut self, timeout: Duration) -> ClientBuilder { + pub fn http2_keep_alive_timeout(mut self, timeout: Duration) -> ClientBuilder { self.config.http2_keep_alive_timeout = Some(timeout); self } @@ -1359,7 +1383,7 @@ impl ClientBuilder { /// Default is `false`. #[cfg(feature = "http2")] #[cfg_attr(docsrs, doc(cfg(feature = "http2")))] - pub fn http2_keep_alive_while_idle(mut self, enabled: bool) -> ClientBuilder { + pub fn http2_keep_alive_while_idle(mut self, enabled: bool) -> ClientBuilder { self.config.http2_keep_alive_while_idle = enabled; self } @@ -1369,7 +1393,7 @@ impl ClientBuilder { /// Set whether sockets have `TCP_NODELAY` enabled. /// /// Default is `true`. - pub fn tcp_nodelay(mut self, enabled: bool) -> ClientBuilder { + pub fn tcp_nodelay(mut self, enabled: bool) -> ClientBuilder { self.config.nodelay = enabled; self } @@ -1387,7 +1411,7 @@ impl ClientBuilder { /// .local_address(local_addr) /// .build().unwrap(); /// ``` - pub fn local_address(mut self, addr: T) -> ClientBuilder + pub fn local_address(mut self, addr: T) -> ClientBuilder where T: Into>, { @@ -1408,7 +1432,7 @@ impl ClientBuilder { /// .build().unwrap(); /// ``` #[cfg(any(target_os = "android", target_os = "fuchsia", target_os = "linux"))] - pub fn interface(mut self, interface: &str) -> ClientBuilder { + pub fn interface(mut self, interface: &str) -> ClientBuilder { self.config.interface = Some(interface.to_string()); self } @@ -1416,7 +1440,7 @@ impl ClientBuilder { /// Set that all sockets have `SO_KEEPALIVE` set with the supplied duration. /// /// If `None`, the option will not be set. - pub fn tcp_keepalive(mut self, val: D) -> ClientBuilder + pub fn tcp_keepalive(mut self, val: D) -> ClientBuilder where D: Into>, { @@ -1444,7 +1468,7 @@ impl ClientBuilder { feature = "rustls-tls" ))) )] - pub fn add_root_certificate(mut self, cert: Certificate) -> ClientBuilder { + pub fn add_root_certificate(mut self, cert: Certificate) -> ClientBuilder { self.config.root_certs.push(cert); self } @@ -1457,7 +1481,7 @@ impl ClientBuilder { /// This requires the `rustls-tls(-...)` Cargo feature enabled. #[cfg(feature = "__rustls")] #[cfg_attr(docsrs, doc(cfg(feature = "rustls-tls")))] - pub fn add_crl(mut self, crl: CertificateRevocationList) -> ClientBuilder { + pub fn add_crl(mut self, crl: CertificateRevocationList) -> ClientBuilder { self.config.crls.push(crl); self } @@ -1473,7 +1497,7 @@ impl ClientBuilder { pub fn add_crls( mut self, crls: impl IntoIterator, - ) -> ClientBuilder { + ) -> ClientBuilder { self.config.crls.extend(crls); self } @@ -1504,7 +1528,10 @@ impl ClientBuilder { feature = "rustls-tls" ))) )] - pub fn tls_built_in_root_certs(mut self, tls_built_in_root_certs: bool) -> ClientBuilder { + pub fn tls_built_in_root_certs( + mut self, + tls_built_in_root_certs: bool, + ) -> ClientBuilder { self.config.tls_built_in_root_certs = tls_built_in_root_certs; #[cfg(feature = "rustls-tls-webpki-roots-no-provider")] @@ -1525,7 +1552,7 @@ impl ClientBuilder { /// If the feature is enabled, this value is `true` by default. #[cfg(feature = "rustls-tls-webpki-roots-no-provider")] #[cfg_attr(docsrs, doc(cfg(feature = "rustls-tls-webpki-roots-no-provider")))] - pub fn tls_built_in_webpki_certs(mut self, enabled: bool) -> ClientBuilder { + pub fn tls_built_in_webpki_certs(mut self, enabled: bool) -> ClientBuilder { self.config.tls_built_in_certs_webpki = enabled; self } @@ -1535,7 +1562,7 @@ impl ClientBuilder { /// If the feature is enabled, this value is `true` by default. #[cfg(feature = "rustls-tls-native-roots-no-provider")] #[cfg_attr(docsrs, doc(cfg(feature = "rustls-tls-native-roots-no-provider")))] - pub fn tls_built_in_native_certs(mut self, enabled: bool) -> ClientBuilder { + pub fn tls_built_in_native_certs(mut self, enabled: bool) -> ClientBuilder { self.config.tls_built_in_certs_native = enabled; self } @@ -1548,7 +1575,7 @@ impl ClientBuilder { /// enabled. #[cfg(any(feature = "native-tls", feature = "__rustls"))] #[cfg_attr(docsrs, doc(cfg(any(feature = "native-tls", feature = "rustls-tls"))))] - pub fn identity(mut self, identity: Identity) -> ClientBuilder { + pub fn identity(mut self, identity: Identity) -> ClientBuilder { self.config.identity = Some(identity); self } @@ -1580,7 +1607,7 @@ impl ClientBuilder { pub fn danger_accept_invalid_hostnames( mut self, accept_invalid_hostname: bool, - ) -> ClientBuilder { + ) -> ClientBuilder { self.config.hostname_verification = !accept_invalid_hostname; self } @@ -1610,7 +1637,10 @@ impl ClientBuilder { feature = "rustls-tls" ))) )] - pub fn danger_accept_invalid_certs(mut self, accept_invalid_certs: bool) -> ClientBuilder { + pub fn danger_accept_invalid_certs( + mut self, + accept_invalid_certs: bool, + ) -> ClientBuilder { self.config.certs_verification = !accept_invalid_certs; self } @@ -1632,7 +1662,7 @@ impl ClientBuilder { feature = "rustls-tls" ))) )] - pub fn tls_sni(mut self, tls_sni: bool) -> ClientBuilder { + pub fn tls_sni(mut self, tls_sni: bool) -> ClientBuilder { self.config.tls_sni = tls_sni; self } @@ -1661,7 +1691,7 @@ impl ClientBuilder { feature = "rustls-tls" ))) )] - pub fn min_tls_version(mut self, version: tls::Version) -> ClientBuilder { + pub fn min_tls_version(mut self, version: tls::Version) -> ClientBuilder { self.config.min_tls_version = Some(version); self } @@ -1693,7 +1723,7 @@ impl ClientBuilder { feature = "rustls-tls" ))) )] - pub fn max_tls_version(mut self, version: tls::Version) -> ClientBuilder { + pub fn max_tls_version(mut self, version: tls::Version) -> ClientBuilder { self.config.max_tls_version = Some(version); self } @@ -1708,7 +1738,7 @@ impl ClientBuilder { /// This requires the optional `native-tls` feature to be enabled. #[cfg(feature = "native-tls")] #[cfg_attr(docsrs, doc(cfg(feature = "native-tls")))] - pub fn use_native_tls(mut self) -> ClientBuilder { + pub fn use_native_tls(mut self) -> ClientBuilder { self.config.tls = TlsBackend::Default; self } @@ -1723,7 +1753,7 @@ impl ClientBuilder { /// This requires the optional `rustls-tls(-...)` feature to be enabled. #[cfg(feature = "__rustls")] #[cfg_attr(docsrs, doc(cfg(feature = "rustls-tls")))] - pub fn use_rustls_tls(mut self) -> ClientBuilder { + pub fn use_rustls_tls(mut self) -> ClientBuilder { self.config.tls = TlsBackend::Rustls; self } @@ -1748,7 +1778,7 @@ impl ClientBuilder { /// `rustls-tls(-...)` to be enabled. #[cfg(any(feature = "native-tls", feature = "__rustls",))] #[cfg_attr(docsrs, doc(cfg(any(feature = "native-tls", feature = "rustls-tls"))))] - pub fn use_preconfigured_tls(mut self, tls: impl Any) -> ClientBuilder { + pub fn use_preconfigured_tls(mut self, tls: impl Any) -> ClientBuilder { let mut tls = Some(tls); #[cfg(feature = "native-tls")] { @@ -1791,7 +1821,7 @@ impl ClientBuilder { feature = "rustls-tls" ))) )] - pub fn tls_info(mut self, tls_info: bool) -> ClientBuilder { + pub fn tls_info(mut self, tls_info: bool) -> ClientBuilder { self.config.tls_info = tls_info; self } @@ -1799,7 +1829,7 @@ impl ClientBuilder { /// Restrict the Client to be used with HTTPS only requests. /// /// Defaults to false. - pub fn https_only(mut self, enabled: bool) -> ClientBuilder { + pub fn https_only(mut self, enabled: bool) -> ClientBuilder { self.config.https_only = enabled; self } @@ -1808,7 +1838,7 @@ impl ClientBuilder { #[cfg(feature = "hickory-dns")] #[cfg_attr(docsrs, doc(cfg(feature = "hickory-dns")))] #[deprecated(note = "use `hickory_dns` instead")] - pub fn trust_dns(mut self, enable: bool) -> ClientBuilder { + pub fn trust_dns(mut self, enable: bool) -> ClientBuilder { self.config.hickory_dns = enable; self } @@ -1828,14 +1858,14 @@ impl ClientBuilder { /// that the default resolver does #[cfg(feature = "hickory-dns")] #[cfg_attr(docsrs, doc(cfg(feature = "hickory-dns")))] - pub fn hickory_dns(mut self, enable: bool) -> ClientBuilder { + pub fn hickory_dns(mut self, enable: bool) -> ClientBuilder { self.config.hickory_dns = enable; self } #[doc(hidden)] #[deprecated(note = "use `no_hickory_dns` instead")] - pub fn no_trust_dns(self) -> ClientBuilder { + pub fn no_trust_dns(self) -> ClientBuilder { self.no_hickory_dns() } @@ -1844,7 +1874,7 @@ impl ClientBuilder { /// This method exists even if the optional `hickory-dns` feature is not enabled. /// This can be used to ensure a `Client` doesn't use the hickory-dns async resolver /// even if another dependency were to enable the optional `hickory-dns` feature. - pub fn no_hickory_dns(self) -> ClientBuilder { + pub fn no_hickory_dns(self) -> ClientBuilder { #[cfg(feature = "hickory-dns")] { self.hickory_dns(false) @@ -1860,7 +1890,7 @@ impl ClientBuilder { /// /// Set the port to `0` to use the conventional port for the given scheme (e.g. 80 for http). /// Ports in the URL itself will always be used instead of the port in the overridden addr. - pub fn resolve(self, domain: &str, addr: SocketAddr) -> ClientBuilder { + pub fn resolve(self, domain: &str, addr: SocketAddr) -> ClientBuilder { self.resolve_to_addrs(domain, &[addr]) } @@ -1868,7 +1898,11 @@ impl ClientBuilder { /// /// Set the port to `0` to use the conventional port for the given scheme (e.g. 80 for http). /// Ports in the URL itself will always be used instead of the port in the overridden addr. - pub fn resolve_to_addrs(mut self, domain: &str, addrs: &[SocketAddr]) -> ClientBuilder { + pub fn resolve_to_addrs( + mut self, + domain: &str, + addrs: &[SocketAddr], + ) -> ClientBuilder { self.config .dns_overrides .insert(domain.to_ascii_lowercase(), addrs.to_vec()); @@ -1880,7 +1914,10 @@ impl ClientBuilder { /// Pass an `Arc` wrapping a trait object implementing `Resolve`. /// Overrides for specific names passed to `resolve` and `resolve_to_addrs` will /// still be applied on top of this resolver. - pub fn dns_resolver(mut self, resolver: Arc) -> ClientBuilder { + pub fn dns_resolver( + mut self, + resolver: Arc, + ) -> ClientBuilder { self.config.dns_resolver = Some(resolver as _); self } @@ -1891,7 +1928,7 @@ impl ClientBuilder { /// The default is false. #[cfg(feature = "http3")] #[cfg_attr(docsrs, doc(cfg(all(reqwest_unstable, feature = "http3",))))] - pub fn tls_early_data(mut self, enabled: bool) -> ClientBuilder { + pub fn tls_early_data(mut self, enabled: bool) -> ClientBuilder { self.config.tls_enable_early_data = enabled; self } @@ -1903,7 +1940,7 @@ impl ClientBuilder { /// [`TransportConfig`]: https://docs.rs/quinn/latest/quinn/struct.TransportConfig.html #[cfg(feature = "http3")] #[cfg_attr(docsrs, doc(cfg(all(reqwest_unstable, feature = "http3",))))] - pub fn http3_max_idle_timeout(mut self, value: Duration) -> ClientBuilder { + pub fn http3_max_idle_timeout(mut self, value: Duration) -> ClientBuilder { self.config.quic_max_idle_timeout = Some(value); self } @@ -1920,7 +1957,7 @@ impl ClientBuilder { /// Panics if the value is over 2^62. #[cfg(feature = "http3")] #[cfg_attr(docsrs, doc(cfg(all(reqwest_unstable, feature = "http3",))))] - pub fn http3_stream_receive_window(mut self, value: u64) -> ClientBuilder { + pub fn http3_stream_receive_window(mut self, value: u64) -> ClientBuilder { self.config.quic_stream_receive_window = Some(value.try_into().unwrap()); self } @@ -1937,7 +1974,7 @@ impl ClientBuilder { /// Panics if the value is over 2^62. #[cfg(feature = "http3")] #[cfg_attr(docsrs, doc(cfg(all(reqwest_unstable, feature = "http3",))))] - pub fn http3_conn_receive_window(mut self, value: u64) -> ClientBuilder { + pub fn http3_conn_receive_window(mut self, value: u64) -> ClientBuilder { self.config.quic_receive_window = Some(value.try_into().unwrap()); self } @@ -1949,10 +1986,51 @@ impl ClientBuilder { /// [`TransportConfig`]: https://docs.rs/quinn/latest/quinn/struct.TransportConfig.html #[cfg(feature = "http3")] #[cfg_attr(docsrs, doc(cfg(all(reqwest_unstable, feature = "http3",))))] - pub fn http3_send_window(mut self, value: u64) -> ClientBuilder { + pub fn http3_send_window(mut self, value: u64) -> ClientBuilder { self.config.quic_send_window = Some(value); self } + + /// Adds a new Tower [`Layer`](https://docs.rs/tower/latest/tower/trait.Layer.html) to the + /// base connector [`Service`](https://docs.rs/tower/latest/tower/trait.Service.html) which + /// is responsible for connection establishment. + /// + /// Each subsequent invocation of this function will wrap previous layers. + /// + /// If configured, the `connect_timeout` will be the outermost layer. + /// + /// Simple example: + /// ``` + /// use std::time::Duration; + /// + /// + /// let client = reqwest::Client::builder() + /// // resolved to outermost layer, so before the semaphore permit is attempted + /// .connect_timeout(Duration::from_millis(100)) + /// // underneath the concurrency check, so only after a semaphore permit is acquired + /// .connector_layer(tower::timeout::TimeoutLayer::new(Duration::from_millis(50))) + /// .connector_layer(tower::limit::concurrency::ConcurrencyLimitLayer::new(2)) + /// .build() + /// .unwrap(); + /// ``` + /// + /// For a more complex example involving a custom layer, see `examples/connect_via_lower_priority_tokio_runtime.rs`. + /// Additionally see `` + pub fn connector_layer(self, layer: L) -> ClientBuilder> + where + S: Send + Sync + Clone + 'static, + L: Layer, Service = S> + Send + Sync + Clone + 'static, + { + let connector_layers = ConnectorLayerBuilder { + builder: self.connector_layers.builder.layer(layer), + has_custom_layers: true, + }; + + ClientBuilder::> { + config: self.config, + connector_layers, + } + } } type HyperClient = hyper_util::client::legacy::Client; @@ -1974,14 +2052,16 @@ impl Client { /// Use `Client::builder()` if you wish to handle the failure as an `Error` /// instead of panicking. pub fn new() -> Client { - ClientBuilder::new().build().expect("Client::new()") + ClientBuilder::::new() + .build() + .expect("Client::new()") } /// Creates a `ClientBuilder` to configure a `Client`. /// /// This is the same as `ClientBuilder::new()`. - pub fn builder() -> ClientBuilder { - ClientBuilder::new() + pub fn builder() -> ClientBuilder { + ClientBuilder::::new() } /// Convenience method to make a `GET` request to a URL. @@ -2237,7 +2317,7 @@ impl tower_service::Service for &'_ Client { } } -impl fmt::Debug for ClientBuilder { +impl fmt::Debug for ClientBuilder { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { let mut builder = f.debug_struct("ClientBuilder"); self.config.fmt_fields(&mut builder); diff --git a/src/blocking/client.rs b/src/blocking/client.rs index 73f25208f..737ca7af2 100644 --- a/src/blocking/client.rs +++ b/src/blocking/client.rs @@ -10,13 +10,20 @@ use std::thread; use std::time::Duration; use http::header::HeaderValue; +use http::Uri; use log::{error, trace}; use tokio::sync::{mpsc, oneshot}; +use tower::layer::util::Stack; +use tower::Layer; +use tower::Service; use super::request::{Request, RequestBuilder}; use super::response::Response; use super::wait; +use crate::connect::Conn; +use crate::connect::ConnectorService; use crate::dns::Resolve; +use crate::error::BoxError; #[cfg(feature = "__tls")] use crate::tls; #[cfg(feature = "__rustls")] @@ -69,24 +76,36 @@ pub struct Client { /// # } /// ``` #[must_use] -pub struct ClientBuilder { - inner: async_impl::ClientBuilder, +pub struct ClientBuilder { + inner: async_impl::ClientBuilder, timeout: Timeout, } -impl Default for ClientBuilder { +impl Default for ClientBuilder { fn default() -> Self { Self::new() } } -impl ClientBuilder { +#[allow(private_bounds)] +impl ClientBuilder +where + CL1: Layer + Send + Sync + 'static, + CL2: Layer<>::Service> + Send + Sync + 'static, + >::Service>>::Service: + Service + Clone + Send + Sync + 'static, + <>::Service>>::Service as Service>::Future: + Send + 'static, +{ /// Constructs a new `ClientBuilder`. /// /// This is the same as `Client::builder()`. - pub fn new() -> ClientBuilder { + pub fn new() -> ClientBuilder { ClientBuilder { - inner: async_impl::ClientBuilder::new(), + inner: async_impl::ClientBuilder::< + tower::layer::util::Identity, + tower::layer::util::Identity, + >::new(), timeout: Timeout::default(), } } @@ -128,7 +147,7 @@ impl ClientBuilder { /// # Ok(()) /// # } /// ``` - pub fn user_agent(self, value: V) -> ClientBuilder + pub fn user_agent(self, value: V) -> ClientBuilder where V: TryInto, V::Error: Into, @@ -160,7 +179,7 @@ impl ClientBuilder { /// # Ok(()) /// # } /// ``` - pub fn default_headers(self, headers: header::HeaderMap) -> ClientBuilder { + pub fn default_headers(self, headers: header::HeaderMap) -> ClientBuilder { self.with_inner(move |inner| inner.default_headers(headers)) } @@ -176,7 +195,7 @@ impl ClientBuilder { /// This requires the optional `cookies` feature to be enabled. #[cfg(feature = "cookies")] #[cfg_attr(docsrs, doc(cfg(feature = "cookies")))] - pub fn cookie_store(self, enable: bool) -> ClientBuilder { + pub fn cookie_store(self, enable: bool) -> ClientBuilder { self.with_inner(|inner| inner.cookie_store(enable)) } @@ -195,7 +214,7 @@ impl ClientBuilder { pub fn cookie_provider( self, cookie_store: Arc, - ) -> ClientBuilder { + ) -> ClientBuilder { self.with_inner(|inner| inner.cookie_provider(cookie_store)) } @@ -217,7 +236,7 @@ impl ClientBuilder { /// This requires the optional `gzip` feature to be enabled #[cfg(feature = "gzip")] #[cfg_attr(docsrs, doc(cfg(feature = "gzip")))] - pub fn gzip(self, enable: bool) -> ClientBuilder { + pub fn gzip(self, enable: bool) -> ClientBuilder { self.with_inner(|inner| inner.gzip(enable)) } @@ -239,7 +258,7 @@ impl ClientBuilder { /// This requires the optional `brotli` feature to be enabled #[cfg(feature = "brotli")] #[cfg_attr(docsrs, doc(cfg(feature = "brotli")))] - pub fn brotli(self, enable: bool) -> ClientBuilder { + pub fn brotli(self, enable: bool) -> ClientBuilder { self.with_inner(|inner| inner.brotli(enable)) } @@ -261,7 +280,7 @@ impl ClientBuilder { /// This requires the optional `zstd` feature to be enabled #[cfg(feature = "zstd")] #[cfg_attr(docsrs, doc(cfg(feature = "zstd")))] - pub fn zstd(self, enable: bool) -> ClientBuilder { + pub fn zstd(self, enable: bool) -> ClientBuilder { self.with_inner(|inner| inner.zstd(enable)) } @@ -283,7 +302,7 @@ impl ClientBuilder { /// This requires the optional `deflate` feature to be enabled #[cfg(feature = "deflate")] #[cfg_attr(docsrs, doc(cfg(feature = "deflate")))] - pub fn deflate(self, enable: bool) -> ClientBuilder { + pub fn deflate(self, enable: bool) -> ClientBuilder { self.with_inner(|inner| inner.deflate(enable)) } @@ -292,7 +311,7 @@ impl ClientBuilder { /// This method exists even if the optional `gzip` feature is not enabled. /// This can be used to ensure a `Client` doesn't use gzip decompression /// even if another dependency were to enable the optional `gzip` feature. - pub fn no_gzip(self) -> ClientBuilder { + pub fn no_gzip(self) -> ClientBuilder { self.with_inner(|inner| inner.no_gzip()) } @@ -301,7 +320,7 @@ impl ClientBuilder { /// This method exists even if the optional `brotli` feature is not enabled. /// This can be used to ensure a `Client` doesn't use brotli decompression /// even if another dependency were to enable the optional `brotli` feature. - pub fn no_brotli(self) -> ClientBuilder { + pub fn no_brotli(self) -> ClientBuilder { self.with_inner(|inner| inner.no_brotli()) } @@ -310,7 +329,7 @@ impl ClientBuilder { /// This method exists even if the optional `zstd` feature is not enabled. /// This can be used to ensure a `Client` doesn't use zstd decompression /// even if another dependency were to enable the optional `zstd` feature. - pub fn no_zstd(self) -> ClientBuilder { + pub fn no_zstd(self) -> ClientBuilder { self.with_inner(|inner| inner.no_zstd()) } @@ -319,7 +338,7 @@ impl ClientBuilder { /// This method exists even if the optional `deflate` feature is not enabled. /// This can be used to ensure a `Client` doesn't use deflate decompression /// even if another dependency were to enable the optional `deflate` feature. - pub fn no_deflate(self) -> ClientBuilder { + pub fn no_deflate(self) -> ClientBuilder { self.with_inner(|inner| inner.no_deflate()) } @@ -328,14 +347,14 @@ impl ClientBuilder { /// Set a `redirect::Policy` for this client. /// /// Default will follow redirects up to a maximum of 10. - pub fn redirect(self, policy: redirect::Policy) -> ClientBuilder { + pub fn redirect(self, policy: redirect::Policy) -> ClientBuilder { self.with_inner(move |inner| inner.redirect(policy)) } /// Enable or disable automatic setting of the `Referer` header. /// /// Default is `true`. - pub fn referer(self, enable: bool) -> ClientBuilder { + pub fn referer(self, enable: bool) -> ClientBuilder { self.with_inner(|inner| inner.referer(enable)) } @@ -346,7 +365,7 @@ impl ClientBuilder { /// # Note /// /// Adding a proxy will disable the automatic usage of the "system" proxy. - pub fn proxy(self, proxy: Proxy) -> ClientBuilder { + pub fn proxy(self, proxy: Proxy) -> ClientBuilder { self.with_inner(move |inner| inner.proxy(proxy)) } @@ -357,7 +376,7 @@ impl ClientBuilder { /// on all desired proxies instead. /// /// This also disables the automatic usage of the "system" proxy. - pub fn no_proxy(self) -> ClientBuilder { + pub fn no_proxy(self) -> ClientBuilder { self.with_inner(move |inner| inner.no_proxy()) } @@ -368,7 +387,7 @@ impl ClientBuilder { /// Default is 30 seconds. /// /// Pass `None` to disable timeout. - pub fn timeout(mut self, timeout: T) -> ClientBuilder + pub fn timeout(mut self, timeout: T) -> ClientBuilder where T: Into>, { @@ -379,7 +398,7 @@ impl ClientBuilder { /// Set a timeout for only the connect phase of a `Client`. /// /// Default is `None`. - pub fn connect_timeout(self, timeout: T) -> ClientBuilder + pub fn connect_timeout(self, timeout: T) -> ClientBuilder where T: Into>, { @@ -397,7 +416,7 @@ impl ClientBuilder { /// for read and write operations on connections. /// /// [log]: https://crates.io/crates/log - pub fn connection_verbose(self, verbose: bool) -> ClientBuilder { + pub fn connection_verbose(self, verbose: bool) -> ClientBuilder { self.with_inner(move |inner| inner.connection_verbose(verbose)) } @@ -408,7 +427,7 @@ impl ClientBuilder { /// Pass `None` to disable timeout. /// /// Default is 90 seconds. - pub fn pool_idle_timeout(self, val: D) -> ClientBuilder + pub fn pool_idle_timeout(self, val: D) -> ClientBuilder where D: Into>, { @@ -416,12 +435,12 @@ impl ClientBuilder { } /// Sets the maximum idle connection per host allowed in the pool. - pub fn pool_max_idle_per_host(self, max: usize) -> ClientBuilder { + pub fn pool_max_idle_per_host(self, max: usize) -> ClientBuilder { self.with_inner(move |inner| inner.pool_max_idle_per_host(max)) } /// Send headers as title case instead of lowercase. - pub fn http1_title_case_headers(self) -> ClientBuilder { + pub fn http1_title_case_headers(self) -> ClientBuilder { self.with_inner(|inner| inner.http1_title_case_headers()) } @@ -430,12 +449,15 @@ impl ClientBuilder { /// /// Newline codepoints (`\r` and `\n`) will be transformed to spaces when /// parsing. - pub fn http1_allow_obsolete_multiline_headers_in_responses(self, value: bool) -> ClientBuilder { + pub fn http1_allow_obsolete_multiline_headers_in_responses( + self, + value: bool, + ) -> ClientBuilder { self.with_inner(|inner| inner.http1_allow_obsolete_multiline_headers_in_responses(value)) } /// Sets whether invalid header lines should be silently ignored in HTTP/1 responses. - pub fn http1_ignore_invalid_headers_in_responses(self, value: bool) -> ClientBuilder { + pub fn http1_ignore_invalid_headers_in_responses(self, value: bool) -> ClientBuilder { self.with_inner(|inner| inner.http1_ignore_invalid_headers_in_responses(value)) } @@ -444,24 +466,27 @@ impl ClientBuilder { /// /// Newline codepoints (\r and \n) will be transformed to spaces when /// parsing. - pub fn http1_allow_spaces_after_header_name_in_responses(self, value: bool) -> ClientBuilder { + pub fn http1_allow_spaces_after_header_name_in_responses( + self, + value: bool, + ) -> ClientBuilder { self.with_inner(|inner| inner.http1_allow_spaces_after_header_name_in_responses(value)) } /// Only use HTTP/1. - pub fn http1_only(self) -> ClientBuilder { + pub fn http1_only(self) -> ClientBuilder { self.with_inner(|inner| inner.http1_only()) } /// Allow HTTP/0.9 responses - pub fn http09_responses(self) -> ClientBuilder { + pub fn http09_responses(self) -> ClientBuilder { self.with_inner(|inner| inner.http09_responses()) } /// Only use HTTP/2. #[cfg(feature = "http2")] #[cfg_attr(docsrs, doc(cfg(feature = "http2")))] - pub fn http2_prior_knowledge(self) -> ClientBuilder { + pub fn http2_prior_knowledge(self) -> ClientBuilder { self.with_inner(|inner| inner.http2_prior_knowledge()) } @@ -470,7 +495,10 @@ impl ClientBuilder { /// Default is currently 65,535 but may change internally to optimize for common uses. #[cfg(feature = "http2")] #[cfg_attr(docsrs, doc(cfg(feature = "http2")))] - pub fn http2_initial_stream_window_size(self, sz: impl Into>) -> ClientBuilder { + pub fn http2_initial_stream_window_size( + self, + sz: impl Into>, + ) -> ClientBuilder { self.with_inner(|inner| inner.http2_initial_stream_window_size(sz)) } @@ -479,7 +507,10 @@ impl ClientBuilder { /// Default is currently 65,535 but may change internally to optimize for common uses. #[cfg(feature = "http2")] #[cfg_attr(docsrs, doc(cfg(feature = "http2")))] - pub fn http2_initial_connection_window_size(self, sz: impl Into>) -> ClientBuilder { + pub fn http2_initial_connection_window_size( + self, + sz: impl Into>, + ) -> ClientBuilder { self.with_inner(|inner| inner.http2_initial_connection_window_size(sz)) } @@ -489,7 +520,7 @@ impl ClientBuilder { /// `http2_initial_connection_window_size`. #[cfg(feature = "http2")] #[cfg_attr(docsrs, doc(cfg(feature = "http2")))] - pub fn http2_adaptive_window(self, enabled: bool) -> ClientBuilder { + pub fn http2_adaptive_window(self, enabled: bool) -> ClientBuilder { self.with_inner(|inner| inner.http2_adaptive_window(enabled)) } @@ -498,7 +529,7 @@ impl ClientBuilder { /// Default is currently 16,384 but may change internally to optimize for common uses. #[cfg(feature = "http2")] #[cfg_attr(docsrs, doc(cfg(feature = "http2")))] - pub fn http2_max_frame_size(self, sz: impl Into>) -> ClientBuilder { + pub fn http2_max_frame_size(self, sz: impl Into>) -> ClientBuilder { self.with_inner(|inner| inner.http2_max_frame_size(sz)) } @@ -507,7 +538,7 @@ impl ClientBuilder { /// Default is currently 16KB, but can change. #[cfg(feature = "http2")] #[cfg_attr(docsrs, doc(cfg(feature = "http2")))] - pub fn http2_max_header_list_size(self, max_header_size_bytes: u32) -> ClientBuilder { + pub fn http2_max_header_list_size(self, max_header_size_bytes: u32) -> ClientBuilder { self.with_inner(|inner| inner.http2_max_header_list_size(max_header_size_bytes)) } @@ -515,7 +546,7 @@ impl ClientBuilder { /// enabled. #[cfg(feature = "http3")] #[cfg_attr(docsrs, doc(cfg(feature = "http3")))] - pub fn http3_prior_knowledge(self) -> ClientBuilder { + pub fn http3_prior_knowledge(self) -> ClientBuilder { self.with_inner(|inner| inner.http3_prior_knowledge()) } @@ -524,7 +555,7 @@ impl ClientBuilder { /// Set whether sockets have `TCP_NODELAY` enabled. /// /// Default is `true`. - pub fn tcp_nodelay(self, enabled: bool) -> ClientBuilder { + pub fn tcp_nodelay(self, enabled: bool) -> ClientBuilder { self.with_inner(move |inner| inner.tcp_nodelay(enabled)) } @@ -539,7 +570,7 @@ impl ClientBuilder { /// .local_address(local_addr) /// .build().unwrap(); /// ``` - pub fn local_address(self, addr: T) -> ClientBuilder + pub fn local_address(self, addr: T) -> ClientBuilder where T: Into>, { @@ -557,14 +588,14 @@ impl ClientBuilder { /// .build().unwrap(); /// ``` #[cfg(any(target_os = "android", target_os = "fuchsia", target_os = "linux"))] - pub fn interface(self, interface: &str) -> ClientBuilder { + pub fn interface(self, interface: &str) -> ClientBuilder { self.with_inner(move |inner| inner.interface(interface)) } /// Set that all sockets have `SO_KEEPALIVE` set with the supplied duration. /// /// If `None`, the option will not be set. - pub fn tcp_keepalive(self, val: D) -> ClientBuilder + pub fn tcp_keepalive(self, val: D) -> ClientBuilder where D: Into>, { @@ -613,7 +644,7 @@ impl ClientBuilder { feature = "rustls-tls" ))) )] - pub fn add_root_certificate(self, cert: Certificate) -> ClientBuilder { + pub fn add_root_certificate(self, cert: Certificate) -> ClientBuilder { self.with_inner(move |inner| inner.add_root_certificate(cert)) } @@ -625,7 +656,7 @@ impl ClientBuilder { /// This requires the `rustls-tls(-...)` Cargo feature enabled. #[cfg(feature = "__rustls")] #[cfg_attr(docsrs, doc(cfg(feature = "rustls-tls")))] - pub fn add_crl(self, crl: CertificateRevocationList) -> ClientBuilder { + pub fn add_crl(self, crl: CertificateRevocationList) -> ClientBuilder { self.with_inner(move |inner| inner.add_crl(crl)) } @@ -640,7 +671,7 @@ impl ClientBuilder { pub fn add_crls( self, crls: impl IntoIterator, - ) -> ClientBuilder { + ) -> ClientBuilder { self.with_inner(move |inner| inner.add_crls(crls)) } @@ -661,7 +692,7 @@ impl ClientBuilder { feature = "rustls-tls" ))) )] - pub fn tls_built_in_root_certs(self, tls_built_in_root_certs: bool) -> ClientBuilder { + pub fn tls_built_in_root_certs(self, tls_built_in_root_certs: bool) -> ClientBuilder { self.with_inner(move |inner| inner.tls_built_in_root_certs(tls_built_in_root_certs)) } @@ -670,7 +701,7 @@ impl ClientBuilder { /// If the feature is enabled, this value is `true` by default. #[cfg(feature = "rustls-tls-webpki-roots-no-provider")] #[cfg_attr(docsrs, doc(cfg(feature = "rustls-tls-webpki-roots-no-provider")))] - pub fn tls_built_in_webpki_certs(self, enabled: bool) -> ClientBuilder { + pub fn tls_built_in_webpki_certs(self, enabled: bool) -> ClientBuilder { self.with_inner(move |inner| inner.tls_built_in_webpki_certs(enabled)) } @@ -679,7 +710,7 @@ impl ClientBuilder { /// If the feature is enabled, this value is `true` by default. #[cfg(feature = "rustls-tls-native-roots-no-provider")] #[cfg_attr(docsrs, doc(cfg(feature = "rustls-tls-native-roots-no-provider")))] - pub fn tls_built_in_native_certs(self, enabled: bool) -> ClientBuilder { + pub fn tls_built_in_native_certs(self, enabled: bool) -> ClientBuilder { self.with_inner(move |inner| inner.tls_built_in_native_certs(enabled)) } @@ -691,7 +722,7 @@ impl ClientBuilder { /// enabled. #[cfg(any(feature = "native-tls", feature = "__rustls"))] #[cfg_attr(docsrs, doc(cfg(any(feature = "native-tls", feature = "rustls-tls"))))] - pub fn identity(self, identity: Identity) -> ClientBuilder { + pub fn identity(self, identity: Identity) -> ClientBuilder { self.with_inner(move |inner| inner.identity(identity)) } @@ -719,7 +750,10 @@ impl ClientBuilder { feature = "rustls-tls" ))) )] - pub fn danger_accept_invalid_hostnames(self, accept_invalid_hostname: bool) -> ClientBuilder { + pub fn danger_accept_invalid_hostnames( + self, + accept_invalid_hostname: bool, + ) -> ClientBuilder { self.with_inner(|inner| inner.danger_accept_invalid_hostnames(accept_invalid_hostname)) } @@ -743,7 +777,10 @@ impl ClientBuilder { feature = "rustls-tls" ))) )] - pub fn danger_accept_invalid_certs(self, accept_invalid_certs: bool) -> ClientBuilder { + pub fn danger_accept_invalid_certs( + self, + accept_invalid_certs: bool, + ) -> ClientBuilder { self.with_inner(|inner| inner.danger_accept_invalid_certs(accept_invalid_certs)) } @@ -759,7 +796,7 @@ impl ClientBuilder { feature = "rustls-tls" ))) )] - pub fn tls_sni(self, tls_sni: bool) -> ClientBuilder { + pub fn tls_sni(self, tls_sni: bool) -> ClientBuilder { self.with_inner(|inner| inner.tls_sni(tls_sni)) } @@ -787,7 +824,7 @@ impl ClientBuilder { feature = "rustls-tls" ))) )] - pub fn min_tls_version(self, version: tls::Version) -> ClientBuilder { + pub fn min_tls_version(self, version: tls::Version) -> ClientBuilder { self.with_inner(|inner| inner.min_tls_version(version)) } @@ -815,7 +852,7 @@ impl ClientBuilder { feature = "rustls-tls" ))) )] - pub fn max_tls_version(self, version: tls::Version) -> ClientBuilder { + pub fn max_tls_version(self, version: tls::Version) -> ClientBuilder { self.with_inner(|inner| inner.max_tls_version(version)) } @@ -829,7 +866,7 @@ impl ClientBuilder { /// This requires the optional `native-tls` feature to be enabled. #[cfg(feature = "native-tls")] #[cfg_attr(docsrs, doc(cfg(feature = "native-tls")))] - pub fn use_native_tls(self) -> ClientBuilder { + pub fn use_native_tls(self) -> ClientBuilder { self.with_inner(move |inner| inner.use_native_tls()) } @@ -843,7 +880,7 @@ impl ClientBuilder { /// This requires the optional `rustls-tls(-...)` feature to be enabled. #[cfg(feature = "__rustls")] #[cfg_attr(docsrs, doc(cfg(feature = "rustls-tls")))] - pub fn use_rustls_tls(self) -> ClientBuilder { + pub fn use_rustls_tls(self) -> ClientBuilder { self.with_inner(move |inner| inner.use_rustls_tls()) } @@ -862,7 +899,7 @@ impl ClientBuilder { feature = "rustls-tls" ))) )] - pub fn tls_info(self, tls_info: bool) -> ClientBuilder { + pub fn tls_info(self, tls_info: bool) -> ClientBuilder { self.with_inner(|inner| inner.tls_info(tls_info)) } @@ -886,7 +923,7 @@ impl ClientBuilder { /// `rustls-tls(-...)` to be enabled. #[cfg(any(feature = "native-tls", feature = "__rustls",))] #[cfg_attr(docsrs, doc(cfg(any(feature = "native-tls", feature = "rustls-tls"))))] - pub fn use_preconfigured_tls(self, tls: impl Any) -> ClientBuilder { + pub fn use_preconfigured_tls(self, tls: impl Any) -> ClientBuilder { self.with_inner(move |inner| inner.use_preconfigured_tls(tls)) } @@ -900,7 +937,7 @@ impl ClientBuilder { #[cfg(feature = "hickory-dns")] #[cfg_attr(docsrs, doc(cfg(feature = "hickory-dns")))] #[deprecated(note = "use `hickory_dns` instead", since = "0.12.0")] - pub fn trust_dns(self, enable: bool) -> ClientBuilder { + pub fn trust_dns(self, enable: bool) -> ClientBuilder { self.with_inner(|inner| inner.hickory_dns(enable)) } @@ -913,7 +950,7 @@ impl ClientBuilder { /// This requires the optional `hickory-dns` feature to be enabled #[cfg(feature = "hickory-dns")] #[cfg_attr(docsrs, doc(cfg(feature = "hickory-dns")))] - pub fn hickory_dns(self, enable: bool) -> ClientBuilder { + pub fn hickory_dns(self, enable: bool) -> ClientBuilder { self.with_inner(|inner| inner.hickory_dns(enable)) } @@ -923,7 +960,7 @@ impl ClientBuilder { /// This can be used to ensure a `Client` doesn't use the hickory-dns async resolver /// even if another dependency were to enable the optional `hickory-dns` feature. #[deprecated(note = "use `no_hickory_dns` instead", since = "0.12.0")] - pub fn no_trust_dns(self) -> ClientBuilder { + pub fn no_trust_dns(self) -> ClientBuilder { self.with_inner(|inner| inner.no_hickory_dns()) } @@ -932,14 +969,14 @@ impl ClientBuilder { /// This method exists even if the optional `hickory-dns` feature is not enabled. /// This can be used to ensure a `Client` doesn't use the hickory-dns async resolver /// even if another dependency were to enable the optional `hickory-dns` feature. - pub fn no_hickory_dns(self) -> ClientBuilder { + pub fn no_hickory_dns(self) -> ClientBuilder { self.with_inner(|inner| inner.no_hickory_dns()) } /// Restrict the Client to be used with HTTPS only requests. /// /// Defaults to false. - pub fn https_only(self, enabled: bool) -> ClientBuilder { + pub fn https_only(self, enabled: bool) -> ClientBuilder { self.with_inner(|inner| inner.https_only(enabled)) } @@ -947,7 +984,7 @@ impl ClientBuilder { /// /// Set the port to `0` to use the conventional port for the given scheme (e.g. 80 for http). /// Ports in the URL itself will always be used instead of the port in the overridden addr. - pub fn resolve(self, domain: &str, addr: SocketAddr) -> ClientBuilder { + pub fn resolve(self, domain: &str, addr: SocketAddr) -> ClientBuilder { self.resolve_to_addrs(domain, &[addr]) } @@ -955,7 +992,7 @@ impl ClientBuilder { /// /// Set the port to `0` to use the conventional port for the given scheme (e.g. 80 for http). /// Ports in the URL itself will always be used instead of the port in the overridden addr. - pub fn resolve_to_addrs(self, domain: &str, addrs: &[SocketAddr]) -> ClientBuilder { + pub fn resolve_to_addrs(self, domain: &str, addrs: &[SocketAddr]) -> ClientBuilder { self.with_inner(|inner| inner.resolve_to_addrs(domain, addrs)) } @@ -964,23 +1001,56 @@ impl ClientBuilder { /// Pass an `Arc` wrapping a trait object implementing `Resolve`. /// Overrides for specific names passed to `resolve` and `resolve_to_addrs` will /// still be applied on top of this resolver. - pub fn dns_resolver(self, resolver: Arc) -> ClientBuilder { + pub fn dns_resolver(self, resolver: Arc) -> ClientBuilder { self.with_inner(|inner| inner.dns_resolver(resolver)) } + /// Adds a new Tower [`Layer`](https://docs.rs/tower/latest/tower/trait.Layer.html) to the + /// base connector [`Service`](https://docs.rs/tower/latest/tower/trait.Service.html) which + /// is responsible for connection establishment. + /// + /// Each subsequent invocation of this function will wrap previous layers. + /// + /// Simple example: + /// ``` + /// use std::time::Duration; + /// + /// let client = reqwest::blocking::Client::builder() + /// // resolved to outermost layer, so before the semaphore permit is attempted + /// .connect_timeout(Duration::from_millis(100)) + /// // underneath the concurrency check, so only after a semaphore permit is acquired + /// .connector_layer(tower::timeout::TimeoutLayer::new(Duration::from_millis(50))) + /// .connector_layer(tower::limit::concurrency::ConcurrencyLimitLayer::new(2)) + /// .build() + /// .unwrap(); + /// ``` + pub fn connector_layer(self, layer: L) -> ClientBuilder> + where + S: Send + Sync + Clone + 'static, + L: Layer, Service = S> + Send + Sync + Clone + 'static, + { + // skipping using `with_inner` here because we need to cast the generic type + let inner = self.inner.connector_layer(layer); + + ClientBuilder::> { + inner, + timeout: self.timeout, + } + } + // private - fn with_inner(mut self, func: F) -> ClientBuilder + fn with_inner(mut self, func: F) -> ClientBuilder where - F: FnOnce(async_impl::ClientBuilder) -> async_impl::ClientBuilder, + F: FnOnce(async_impl::ClientBuilder) -> async_impl::ClientBuilder, { self.inner = func(self.inner); self } } -impl From for ClientBuilder { - fn from(builder: async_impl::ClientBuilder) -> Self { +impl From> for ClientBuilder { + fn from(builder: async_impl::ClientBuilder) -> Self { Self { inner: builder, timeout: Timeout::default(), @@ -1008,14 +1078,16 @@ impl Client { /// This method also panics if called from within an async runtime. See docs /// on [`reqwest::blocking`][crate::blocking] for details. pub fn new() -> Client { - ClientBuilder::new().build().expect("Client::new()") + ClientBuilder::::new() + .build() + .expect("Client::new()") } /// Creates a `ClientBuilder` to configure a `Client`. /// /// This is the same as `ClientBuilder::new()`. - pub fn builder() -> ClientBuilder { - ClientBuilder::new() + pub fn builder() -> ClientBuilder { + ClientBuilder::::new() } /// Convenience method to make a `GET` request to a URL. @@ -1112,7 +1184,7 @@ impl fmt::Debug for Client { } } -impl fmt::Debug for ClientBuilder { +impl fmt::Debug for ClientBuilder { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { self.inner.fmt(f) } @@ -1149,7 +1221,11 @@ impl Drop for InnerClientHandle { } impl ClientHandle { - fn new(builder: ClientBuilder) -> crate::Result { + fn new (builder: ClientBuilder ) -> crate::Result + where CL1: Layer + Send + Sync + 'static, + CL2: Layer<>::Service> + Send + Sync + 'static, + >::Service>>::Service: Service + Clone + Send + Sync + 'static, + <>::Service>>::Service as Service>::Future: Send + 'static{ let timeout = builder.timeout; let builder = builder.inner; let (tx, rx) = mpsc::unbounded_channel::<(async_impl::Request, OneshotResponse)>(); diff --git a/src/connect.rs b/src/connect.rs index ff86ba3c9..282ac7e91 100644 --- a/src/connect.rs +++ b/src/connect.rs @@ -8,6 +8,9 @@ use hyper_util::client::legacy::connect::{Connected, Connection}; use hyper_util::rt::TokioIo; #[cfg(feature = "default-tls")] use native_tls_crate::{TlsConnector, TlsConnectorBuilder}; +use tower::{ + layer::util::Stack, timeout::TimeoutLayer, util::BoxCloneSyncService, Layer, ServiceBuilder, +}; use tower_service::Service; use pin_project_lite::pin_project; @@ -24,13 +27,50 @@ use self::native_tls_conn::NativeTlsConn; #[cfg(feature = "__rustls")] use self::rustls_tls_conn::RustlsTlsConn; use crate::dns::DynResolver; -use crate::error::BoxError; +use crate::error::{cast_to_internal_error, BoxError}; use crate::proxy::{Proxy, ProxyScheme}; pub(crate) type HttpConnector = hyper_util::client::legacy::connect::HttpConnector; #[derive(Clone)] -pub(crate) struct Connector { +pub(crate) enum Connector { + // base service, with or without an embedded timeout + Simple(ConnectorService), + // at least one custom layer along with maybe an outer timeout layer + // from `builder.connect_timeout()` + WithLayers(BoxCloneSyncService), +} + +impl Service for Connector { + type Response = Conn; + type Error = BoxError; + type Future = Connecting; + + fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { + match self { + Connector::Simple(service) => service.poll_ready(cx), + Connector::WithLayers(service) => service.poll_ready(cx), + } + } + + fn call(&mut self, dst: Uri) -> Self::Future { + match self { + Connector::Simple(service) => service.call(dst), + Connector::WithLayers(service) => service.call(dst), + } + } +} + +pub(crate) struct ConnectorLayerBuilder { + pub(crate) builder: ServiceBuilder>, + // It's not trivial to identify whether the builder stack is `Stack` or not + // so we simply add a boolean flag to track. + // + // Knowing allows us reduce indirection in certain cases. + pub(crate) has_custom_layers: bool, +} + +pub(crate) struct ConnectorBuilder { inner: Inner, proxies: Arc>, verbose: verbose::Wrapper, @@ -43,21 +83,65 @@ pub(crate) struct Connector { user_agent: Option, } -#[derive(Clone)] -enum Inner { - #[cfg(not(feature = "__tls"))] - Http(HttpConnector), - #[cfg(feature = "default-tls")] - DefaultTls(HttpConnector, TlsConnector), - #[cfg(feature = "__rustls")] - RustlsTls { - http: HttpConnector, - tls: Arc, - tls_proxy: Arc, - }, -} +impl ConnectorBuilder { + pub(crate) fn build( + self, + layer: ConnectorLayerBuilder, + ) -> Connector + where + CL1: Layer, + CL2: Layer<>::Service>, + >::Service>>::Service: Service + Clone + Send + Sync + 'static, + <>::Service>>::Service as Service>::Future: Send + 'static + { + // construct the inner tower service + let mut base_service = ConnectorService { + inner: self.inner, + proxies: self.proxies, + verbose: self.verbose, + #[cfg(feature = "__tls")] + nodelay: self.nodelay, + #[cfg(feature = "__tls")] + tls_info: self.tls_info, + #[cfg(feature = "__tls")] + user_agent: self.user_agent, + simple_timeout: None, + }; + + // no user-provider layers so we can throw away our generic input layer stack + // and compose with named layers only + if !layer.has_custom_layers { + // if we know we have no other layers, we can embed the timeout directly inside + // our base service call which saves us a Box::pin + base_service.simple_timeout = self.timeout; + return Connector::Simple(base_service); + } + + // we have user-provided generic layer stack + let service = layer.builder.service(base_service); + + if let Some(timeout) = self.timeout { + // add in named timeout layer on the outside of the stack + let service = ServiceBuilder::new() + .layer(TimeoutLayer::new(timeout)) + .service(service); + let service = ServiceBuilder::new() + .map_err(|error: BoxError| cast_to_internal_error(error)) + .service(service); + let service = BoxCloneSyncService::new(service); + return Connector::WithLayers(service); + } + + // no named timeout layer but we still map errors since + // we might have user-provided timeout layer + let service = ServiceBuilder::new().service(service); + let service = ServiceBuilder::new() + .map_err(|error: BoxError| cast_to_internal_error(error)) + .service(service); + let service = BoxCloneSyncService::new(service); + return Connector::WithLayers(service); + } -impl Connector { #[cfg(not(feature = "__tls"))] pub(crate) fn new( mut http: HttpConnector, @@ -66,7 +150,7 @@ impl Connector { #[cfg(any(target_os = "android", target_os = "fuchsia", target_os = "linux"))] interface: Option<&str>, nodelay: bool, - ) -> Connector + ) -> ConnectorBuilder where T: Into>, { @@ -77,10 +161,10 @@ impl Connector { } http.set_nodelay(nodelay); - Connector { + ConnectorBuilder { inner: Inner::Http(http), - verbose: verbose::OFF, proxies, + verbose: verbose::OFF, timeout: None, } } @@ -96,7 +180,7 @@ impl Connector { interface: Option<&str>, nodelay: bool, tls_info: bool, - ) -> crate::Result + ) -> crate::Result where T: Into>, { @@ -125,7 +209,7 @@ impl Connector { interface: Option<&str>, nodelay: bool, tls_info: bool, - ) -> Connector + ) -> ConnectorBuilder where T: Into>, { @@ -137,14 +221,14 @@ impl Connector { http.set_nodelay(nodelay); http.enforce_http(false); - Connector { + ConnectorBuilder { inner: Inner::DefaultTls(http, tls), proxies, verbose: verbose::OFF, - timeout: None, nodelay, tls_info, user_agent, + timeout: None, } } @@ -159,7 +243,7 @@ impl Connector { interface: Option<&str>, nodelay: bool, tls_info: bool, - ) -> Connector + ) -> ConnectorBuilder where T: Into>, { @@ -180,7 +264,7 @@ impl Connector { (Arc::new(tls), Arc::new(tls_proxy)) }; - Connector { + ConnectorBuilder { inner: Inner::RustlsTls { http, tls, @@ -188,10 +272,10 @@ impl Connector { }, proxies, verbose: verbose::OFF, - timeout: None, nodelay, tls_info, user_agent, + timeout: None, } } @@ -203,6 +287,51 @@ impl Connector { self.verbose.0 = enabled; } + pub(crate) fn set_keepalive(&mut self, dur: Option) { + match &mut self.inner { + #[cfg(feature = "default-tls")] + Inner::DefaultTls(http, _tls) => http.set_keepalive(dur), + #[cfg(feature = "__rustls")] + Inner::RustlsTls { http, .. } => http.set_keepalive(dur), + #[cfg(not(feature = "__tls"))] + Inner::Http(http) => http.set_keepalive(dur), + } + } +} + +#[derive(Clone)] +pub(crate) struct ConnectorService { + inner: Inner, + proxies: Arc>, + verbose: verbose::Wrapper, + /// When there is a single timeout layer and no other layers, + /// we embed it directly inside our base Service::call(). + /// This lets us avoid an extra `Box::pin` indirection layer + /// since `tokio::time::Timeout` is `Unpin` + simple_timeout: Option, + #[cfg(feature = "__tls")] + nodelay: bool, + #[cfg(feature = "__tls")] + tls_info: bool, + #[cfg(feature = "__tls")] + user_agent: Option, +} + +#[derive(Clone)] +enum Inner { + #[cfg(not(feature = "__tls"))] + Http(HttpConnector), + #[cfg(feature = "default-tls")] + DefaultTls(HttpConnector, TlsConnector), + #[cfg(feature = "__rustls")] + RustlsTls { + http: HttpConnector, + tls: Arc, + tls_proxy: Arc, + }, +} + +impl ConnectorService { #[cfg(feature = "socks")] async fn connect_socks(&self, dst: Uri, proxy: ProxyScheme) -> Result { let dns = match proxy { @@ -449,17 +578,6 @@ impl Connector { self.connect_with_maybe_proxy(proxy_dst, true).await } - - pub fn set_keepalive(&mut self, dur: Option) { - match &mut self.inner { - #[cfg(feature = "default-tls")] - Inner::DefaultTls(http, _tls) => http.set_keepalive(dur), - #[cfg(feature = "__rustls")] - Inner::RustlsTls { http, .. } => http.set_keepalive(dur), - #[cfg(not(feature = "__tls"))] - Inner::Http(http) => http.set_keepalive(dur), - } - } } fn into_uri(scheme: Scheme, host: Authority) -> Uri { @@ -487,7 +605,7 @@ where } } -impl Service for Connector { +impl Service for ConnectorService { type Response = Conn; type Error = BoxError; type Future = Connecting; @@ -498,7 +616,7 @@ impl Service for Connector { fn call(&mut self, dst: Uri) -> Self::Future { log::debug!("starting new connection: {dst:?}"); - let timeout = self.timeout; + let timeout = self.simple_timeout; for prox in self.proxies.iter() { if let Some(proxy_scheme) = prox.intercept(&dst) { return Box::pin(with_timeout( diff --git a/src/error.rs b/src/error.rs index ca7413fd6..6a9f07e51 100644 --- a/src/error.rs +++ b/src/error.rs @@ -165,6 +165,18 @@ impl Error { } } +/// Converts from external types to reqwest's +/// internal equivalents. +/// +/// Currently only is used for `tower::timeout::error::Elapsed`. +pub(crate) fn cast_to_internal_error(error: BoxError) -> BoxError { + if error.is::() { + Box::new(crate::error::TimedOut) as BoxError + } else { + error + } +} + impl fmt::Debug for Error { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { let mut builder = f.debug_struct("reqwest::Error"); diff --git a/tests/connector_layers.rs b/tests/connector_layers.rs new file mode 100644 index 000000000..e8e1a4503 --- /dev/null +++ b/tests/connector_layers.rs @@ -0,0 +1,344 @@ +#![cfg(not(target_arch = "wasm32"))] +#![cfg(not(feature = "rustls-tls-manual-roots-no-provider"))] +mod support; + +use std::time::Duration; + +use futures_util::future::join_all; +use tower::layer::util::Identity; +use tower::limit::ConcurrencyLimitLayer; +use tower::timeout::TimeoutLayer; + +use support::{delay_layer::DelayLayer, server}; + +#[cfg(not(target_arch = "wasm32"))] +#[tokio::test] +async fn non_op_layer() { + let _ = env_logger::try_init(); + + let server = server::http(move |_req| async { http::Response::default() }); + + let url = format!("http://{}", server.addr()); + + let client = reqwest::Client::builder() + .connector_layer(Identity::new()) + .no_proxy() + .build() + .unwrap(); + + let res = client.get(url).send().await; + + assert!(res.is_ok()); +} + +#[cfg(not(target_arch = "wasm32"))] +#[tokio::test] +async fn non_op_layer_with_timeout() { + let _ = env_logger::try_init(); + + let client = reqwest::Client::builder() + .connector_layer(Identity::new()) + .connect_timeout(Duration::from_millis(200)) + .no_proxy() + .build() + .unwrap(); + + // never returns + let url = "http://192.0.2.1:81/slow"; + + let res = client.get(url).send().await; + + let err = res.unwrap_err(); + + assert!(err.is_connect() && err.is_timeout()); +} + +#[cfg(not(target_arch = "wasm32"))] +#[tokio::test] +async fn with_connect_timeout_layer_never_returning() { + let _ = env_logger::try_init(); + + let client = reqwest::Client::builder() + .connector_layer(TimeoutLayer::new(Duration::from_millis(100))) + .no_proxy() + .build() + .unwrap(); + + // never returns + let url = "http://192.0.2.1:81/slow"; + + let res = client.get(url).send().await; + + let err = res.unwrap_err(); + + assert!(err.is_connect() && err.is_timeout()); +} + +#[cfg(not(target_arch = "wasm32"))] +#[tokio::test] +async fn with_connect_timeout_layer_slow() { + let _ = env_logger::try_init(); + + let server = server::http(move |_req| async { http::Response::default() }); + + let url = format!("http://{}", server.addr()); + + let client = reqwest::Client::builder() + .connector_layer(DelayLayer::new(Duration::from_millis(200))) + .connector_layer(TimeoutLayer::new(Duration::from_millis(100))) + .no_proxy() + .build() + .unwrap(); + + let res = client.get(url).send().await; + + let err = res.unwrap_err(); + + assert!(err.is_connect() && err.is_timeout()); +} + +#[cfg(not(target_arch = "wasm32"))] +#[tokio::test] +async fn multiple_timeout_layers_under_threshold() { + let _ = env_logger::try_init(); + + let server = server::http(move |_req| async { http::Response::default() }); + + let url = format!("http://{}", server.addr()); + + let client = reqwest::Client::builder() + .connector_layer(DelayLayer::new(Duration::from_millis(100))) + .connector_layer(TimeoutLayer::new(Duration::from_millis(200))) + .connector_layer(TimeoutLayer::new(Duration::from_millis(300))) + .connector_layer(TimeoutLayer::new(Duration::from_millis(500))) + .connect_timeout(Duration::from_millis(200)) + .no_proxy() + .build() + .unwrap(); + + let res = client.get(url).send().await; + + assert!(res.is_ok()); +} + +#[cfg(not(target_arch = "wasm32"))] +#[tokio::test] +async fn multiple_timeout_layers_over_threshold() { + let _ = env_logger::try_init(); + + let server = server::http(move |_req| async { http::Response::default() }); + + let url = format!("http://{}", server.addr()); + + let client = reqwest::Client::builder() + .connector_layer(DelayLayer::new(Duration::from_millis(100))) + .connector_layer(TimeoutLayer::new(Duration::from_millis(50))) + .connector_layer(TimeoutLayer::new(Duration::from_millis(50))) + .connector_layer(TimeoutLayer::new(Duration::from_millis(50))) + .connect_timeout(Duration::from_millis(50)) + .no_proxy() + .build() + .unwrap(); + + let res = client.get(url).send().await; + + let err = res.unwrap_err(); + + assert!(err.is_connect() && err.is_timeout()); +} + +#[cfg(not(target_arch = "wasm32"))] +#[tokio::test] +async fn with_concurrency_limit_layer_timeout() { + let _ = env_logger::try_init(); + + let server = server::http(move |_req| async { http::Response::default() }); + + let url = format!("http://{}", server.addr()); + + let client = reqwest::Client::builder() + .connector_layer(DelayLayer::new(Duration::from_millis(100))) + .connector_layer(ConcurrencyLimitLayer::new(1)) + .timeout(Duration::from_millis(200)) + .pool_max_idle_per_host(0) // disable connection reuse to force resource contention on the concurrency limit semaphore + .no_proxy() + .build() + .unwrap(); + + // first call succeeds since no resource contention + let res = client.get(url.clone()).send().await; + assert!(res.is_ok()); + + // 3 calls where the second two wait on the first and time out + let mut futures = Vec::new(); + for _ in 0..3 { + futures.push(client.clone().get(url.clone()).send()); + } + + let all_res = join_all(futures).await; + + let timed_out = all_res + .into_iter() + .any(|res| res.is_err_and(|err| err.is_timeout())); + + assert!(timed_out, "at least one request should have timed out"); +} + +#[cfg(not(target_arch = "wasm32"))] +#[tokio::test] +async fn with_concurrency_limit_layer_success() { + let _ = env_logger::try_init(); + + let server = server::http(move |_req| async { http::Response::default() }); + + let url = format!("http://{}", server.addr()); + + let client = reqwest::Client::builder() + .connector_layer(DelayLayer::new(Duration::from_millis(100))) + .connector_layer(TimeoutLayer::new(Duration::from_millis(110))) + .connector_layer(ConcurrencyLimitLayer::new(1)) + .timeout(Duration::from_millis(500)) + .pool_max_idle_per_host(0) // disable connection reuse to force resource contention on the concurrency limit semaphore + .no_proxy() + .build() + .unwrap(); + + // first call succeeds since no resource contention + let res = client.get(url.clone()).send().await; + assert!(res.is_ok()); + + // 3 calls of which all are individually below the inner timeout + // and the sum is below outer timeout which affects the final call which waited the whole time + let mut futures = Vec::new(); + for _ in 0..3 { + futures.push(client.clone().get(url.clone()).send()); + } + + let all_res = join_all(futures).await; + + for res in all_res.into_iter() { + assert!( + res.is_ok(), + "neither outer long timeout or inner short timeout should be exceeded" + ); + } +} + +#[cfg(feature = "blocking")] +#[test] +fn non_op_layer_blocking_client() { + let _ = env_logger::try_init(); + + let server = server::http(move |_req| async { http::Response::default() }); + + let url = format!("http://{}", server.addr()); + + let client = reqwest::blocking::Client::builder() + .connector_layer(Identity::new()) + .build() + .unwrap(); + + let res = client.get(url).send(); + + assert!(res.is_ok()); +} + +#[cfg(feature = "blocking")] +#[test] +fn timeout_layer_blocking_client() { + let _ = env_logger::try_init(); + + let server = server::http(move |_req| async { http::Response::default() }); + + let url = format!("http://{}", server.addr()); + + let client = reqwest::blocking::Client::builder() + .connector_layer(DelayLayer::new(Duration::from_millis(100))) + .connector_layer(TimeoutLayer::new(Duration::from_millis(50))) + .no_proxy() + .build() + .unwrap(); + + let res = client.get(url).send(); + let err = res.unwrap_err(); + + assert!(err.is_connect() && err.is_timeout()); +} + +#[cfg(feature = "blocking")] +#[test] +fn concurrency_layer_blocking_client_timeout() { + let _ = env_logger::try_init(); + + let server = server::http(move |_req| async { http::Response::default() }); + + let url = format!("http://{}", server.addr()); + + let client = reqwest::blocking::Client::builder() + .connector_layer(DelayLayer::new(Duration::from_millis(100))) + .connector_layer(ConcurrencyLimitLayer::new(1)) + .timeout(Duration::from_millis(200)) + .pool_max_idle_per_host(0) // disable connection reuse to force resource contention on the concurrency limit semaphore + .build() + .unwrap(); + + let res = client.get(url.clone()).send(); + + assert!(res.is_ok()); + + // 3 calls where the second two wait on the first and time out + let mut join_handles = Vec::new(); + for _ in 0..3 { + let client = client.clone(); + let url = url.clone(); + let join_handle = std::thread::spawn(move || client.get(url.clone()).send()); + join_handles.push(join_handle); + } + + let timed_out = join_handles + .into_iter() + .any(|handle| handle.join().unwrap().is_err_and(|err| err.is_timeout())); + + assert!(timed_out, "at least one request should have timed out"); +} + +#[cfg(feature = "blocking")] +#[test] +fn concurrency_layer_blocking_client_success() { + let _ = env_logger::try_init(); + + let server = server::http(move |_req| async { http::Response::default() }); + + let url = format!("http://{}", server.addr()); + + let client = reqwest::blocking::Client::builder() + .connector_layer(DelayLayer::new(Duration::from_millis(100))) + .connector_layer(TimeoutLayer::new(Duration::from_millis(110))) + .connector_layer(ConcurrencyLimitLayer::new(1)) + .timeout(Duration::from_millis(500)) + .pool_max_idle_per_host(0) // disable connection reuse to force resource contention on the concurrency limit semaphore + .build() + .unwrap(); + + let res = client.get(url.clone()).send(); + + assert!(res.is_ok()); + + // 3 calls of which all are individually below the inner timeout + // and the sum is below outer timeout which affects the final call which waited the whole time + let mut join_handles = Vec::new(); + for _ in 0..3 { + let client = client.clone(); + let url = url.clone(); + let join_handle = std::thread::spawn(move || client.get(url.clone()).send()); + join_handles.push(join_handle); + } + + for handle in join_handles { + let res = handle.join().unwrap(); + assert!( + res.is_ok(), + "neither outer long timeout or inner short timeout should be exceeded" + ); + } +} diff --git a/tests/support/delay_layer.rs b/tests/support/delay_layer.rs new file mode 100644 index 000000000..9dbe2d663 --- /dev/null +++ b/tests/support/delay_layer.rs @@ -0,0 +1,120 @@ +use std::{ + future::Future, + pin::Pin, + task::{Context, Poll}, + time::Duration, +}; + +use pin_project_lite::pin_project; +use tokio::time::Sleep; +use tower::{BoxError, Layer, Service}; + +/// This tower layer injects an arbitrary delay before calling downstream layers. +#[derive(Clone)] +pub struct DelayLayer { + delay: Duration, +} + +impl DelayLayer { + pub const fn new(delay: Duration) -> Self { + DelayLayer { delay } + } +} + +impl Layer for DelayLayer { + type Service = Delay; + fn layer(&self, service: S) -> Self::Service { + Delay::new(service, self.delay) + } +} + +impl std::fmt::Debug for DelayLayer { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + f.debug_struct("DelayLayer") + .field("delay", &self.delay) + .finish() + } +} + +/// This tower service injects an arbitrary delay before calling downstream layers. +#[derive(Debug, Clone)] +pub struct Delay { + inner: S, + delay: Duration, +} +impl Delay { + pub fn new(inner: S, delay: Duration) -> Self { + Delay { inner, delay } + } +} + +impl Service for Delay +where + S: Service, + S::Error: Into, +{ + type Response = S::Response; + + type Error = BoxError; + + type Future = ResponseFuture; + + fn poll_ready( + &mut self, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + match self.inner.poll_ready(cx) { + Poll::Pending => Poll::Pending, + Poll::Ready(r) => Poll::Ready(r.map_err(Into::into)), + } + } + + fn call(&mut self, req: Request) -> Self::Future { + let response = self.inner.call(req); + let sleep = tokio::time::sleep(self.delay); + + ResponseFuture::new(response, sleep) + } +} + +// `Delay` response future +pin_project! { + #[derive(Debug)] + pub struct ResponseFuture { + #[pin] + response: S, + #[pin] + sleep: Sleep, + } +} + +impl ResponseFuture { + pub(crate) fn new(response: S, sleep: Sleep) -> Self { + ResponseFuture { response, sleep } + } +} + +impl Future for ResponseFuture +where + F: Future>, + E: Into, +{ + type Output = Result; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + println!("Poll"); + let this = self.project(); + + // First poll the sleep until complete + match this.sleep.poll(cx) { + Poll::Pending => return Poll::Pending, + Poll::Ready(_) => {} + } + + // Then poll the inner future + match this.response.poll(cx) { + Poll::Ready(v) => Poll::Ready(v.map_err(Into::into)), + Poll::Pending => Poll::Pending, + } + } +} diff --git a/tests/support/mod.rs b/tests/support/mod.rs index c796956d8..9d4ce7b9b 100644 --- a/tests/support/mod.rs +++ b/tests/support/mod.rs @@ -1,3 +1,4 @@ +pub mod delay_layer; pub mod delay_server; pub mod server; diff --git a/tests/timeouts.rs b/tests/timeouts.rs index 79a6fbb4d..71dc0ce66 100644 --- a/tests/timeouts.rs +++ b/tests/timeouts.rs @@ -337,6 +337,24 @@ fn timeout_blocking_request() { assert_eq!(err.url().map(|u| u.as_str()), Some(url.as_str())); } +#[cfg(feature = "blocking")] +#[test] +fn connect_timeout_blocking_request() { + let _ = env_logger::try_init(); + + let client = reqwest::blocking::Client::builder() + .connect_timeout(Duration::from_millis(100)) + .build() + .unwrap(); + + // never returns + let url = "http://192.0.2.1:81/slow"; + + let err = client.get(url).send().unwrap_err(); + + assert!(err.is_timeout()); +} + #[cfg(feature = "blocking")] #[cfg(feature = "stream")] #[test]