diff --git a/axum-extra/Cargo.toml b/axum-extra/Cargo.toml index 938ea05698..c9435318c5 100644 --- a/axum-extra/Cargo.toml +++ b/axum-extra/Cargo.toml @@ -67,6 +67,7 @@ serde = { version = "1.0", features = ["derive"] } serde_json = "1.0.71" tokio = { version = "1.14", features = ["full"] } tower = { version = "0.4", features = ["util"] } +tower-http = { version = "0.3", features = ["map-response-body", "timeout"] } [package.metadata.docs.rs] all-features = true diff --git a/axum-extra/src/either.rs b/axum-extra/src/either.rs index 2e3af8b8ea..73c3412b2b 100755 --- a/axum-extra/src/either.rs +++ b/axum-extra/src/either.rs @@ -93,12 +93,16 @@ //! [`BytesRejection`]: axum::extract::rejection::BytesRejection //! [`IntoResponse::into_response`]: https://docs.rs/axum/0.5/axum/response/index.html#returning-different-response-types +use std::task::{Context, Poll}; + use axum::{ async_trait, extract::FromRequestParts, response::{IntoResponse, Response}, }; use http::request::Parts; +use tower_layer::Layer; +use tower_service::Service; /// Combines two extractors or responses into a single type. /// @@ -267,3 +271,42 @@ impl_traits_for_either!(Either5 => [E1, E2, E3, E4], E5); impl_traits_for_either!(Either6 => [E1, E2, E3, E4, E5], E6); impl_traits_for_either!(Either7 => [E1, E2, E3, E4, E5, E6], E7); impl_traits_for_either!(Either8 => [E1, E2, E3, E4, E5, E6, E7], E8); + +impl Layer for Either +where + E1: Layer, + E2: Layer, +{ + type Service = Either; + + fn layer(&self, inner: S) -> Self::Service { + match self { + Either::E1(layer) => Either::E1(layer.layer(inner)), + Either::E2(layer) => Either::E2(layer.layer(inner)), + } + } +} + +impl Service for Either +where + E1: Service, + E2: Service, +{ + type Response = E1::Response; + type Error = E1::Error; + type Future = futures_util::future::Either; + + fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { + match self { + Either::E1(inner) => inner.poll_ready(cx), + Either::E2(inner) => inner.poll_ready(cx), + } + } + + fn call(&mut self, req: R) -> Self::Future { + match self { + Either::E1(inner) => futures_util::future::Either::Left(inner.call(req)), + Either::E2(inner) => futures_util::future::Either::Right(inner.call(req)), + } + } +} diff --git a/axum-extra/src/lib.rs b/axum-extra/src/lib.rs index 2bbb07209e..8e547b8660 100644 --- a/axum-extra/src/lib.rs +++ b/axum-extra/src/lib.rs @@ -69,6 +69,7 @@ pub mod body; pub mod either; pub mod extract; pub mod handler; +pub mod middleware; pub mod response; pub mod routing; diff --git a/axum-extra/src/middleware.rs b/axum-extra/src/middleware.rs new file mode 100644 index 0000000000..51bb42338c --- /dev/null +++ b/axum-extra/src/middleware.rs @@ -0,0 +1,44 @@ +//! Additional middleware utilities. + +use crate::either::Either; +use tower_layer::Identity; + +/// Convert an `Option` into a [`Layer`]. +/// +/// If the layer is a `Some` it'll be applied, otherwise not. +/// +/// # Example +/// +/// ``` +/// use axum_extra::either::option_layer; +/// use axum::{Router, routing::get}; +/// use std::time::Duration; +/// use tower_http::timeout::TimeoutLayer; +/// +/// # let option_timeout = Some(Duration::new(10, 0)); +/// let timeout_layer = option_timeout.map(TimeoutLayer::new); +/// +/// let app = Router::new() +/// .route("/", get(|| async {})) +/// .layer(option_layer(timeout_layer)); +/// # let _: Router = app; +/// ``` +/// +/// # Difference between this and [`tower::util::option_layer`] +/// +/// [`tower::util::option_layer`] always changes the error type to [`BoxError`] which requires +/// using [`HandleErrorLayer`] when used with axum, even if the layer you're applying uses +/// [`Infallible`]. +/// +/// `axum_extra::middleware::option_layer` on the other hand doesn't change the error type so can +/// be applied directly. +/// +/// [`Layer`]: tower_layer::Layer +/// [`BoxError`]: tower::BoxError +/// [`HandleErrorLayer`]: axum::error_handling::HandleErrorLayer +/// [`Infallible`]: std::convert::Infallible +pub fn option_layer(layer: Option) -> Either { + layer + .map(Either::E1) + .unwrap_or_else(|| Either::E2(Identity::new())) +}