-
Notifications
You must be signed in to change notification settings - Fork 1.1k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Support chaining handlers with
one.or(two).or(three)
(#1170)
* WIP: Handler fallbacks * Docs * box futures a bit less * changelog
- Loading branch information
1 parent
7cbb7cf
commit e79820d
Showing
4 changed files
with
392 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,192 @@ | ||
//! Additional handler utilities. | ||
use axum::{ | ||
extract::{FromRequest, RequestParts}, | ||
handler::Handler, | ||
response::{IntoResponse, Response}, | ||
}; | ||
use futures_util::future::{BoxFuture, FutureExt, Map}; | ||
use std::{future::Future, marker::PhantomData}; | ||
|
||
mod or; | ||
|
||
pub use self::or::Or; | ||
|
||
/// Trait for async functions that can be used to handle requests. | ||
/// | ||
/// This trait is similar to [`Handler`] but rather than taking the request it takes the extracted | ||
/// inputs. | ||
/// | ||
/// The drawbacks of this trait is that you cannot apply middleware to individual handlers like you | ||
/// can with [`Handler::layer`]. | ||
pub trait HandlerCallWithExtractors<T, B>: Sized { | ||
/// The type of future calling this handler returns. | ||
type Future: Future<Output = Response> + Send + 'static; | ||
|
||
/// Call the handler with the extracted inputs. | ||
fn call(self, extractors: T) -> <Self as HandlerCallWithExtractors<T, B>>::Future; | ||
|
||
/// Conver this `HandlerCallWithExtractors` into [`Handler`]. | ||
fn into_handler(self) -> IntoHandler<Self, T, B> { | ||
IntoHandler { | ||
handler: self, | ||
_marker: PhantomData, | ||
} | ||
} | ||
|
||
/// Chain two handlers together, running the second one if the first one rejects. | ||
/// | ||
/// Note that this only moves to the next handler if an extractor fails. The response from | ||
/// handlers are not considered. | ||
/// | ||
/// # Example | ||
/// | ||
/// ``` | ||
/// use axum_extra::handler::HandlerCallWithExtractors; | ||
/// use axum::{ | ||
/// Router, | ||
/// async_trait, | ||
/// routing::get, | ||
/// extract::FromRequest, | ||
/// }; | ||
/// | ||
/// // handlers for varying levels of access | ||
/// async fn admin(admin: AdminPermissions) { | ||
/// // request came from an admin | ||
/// } | ||
/// | ||
/// async fn user(user: User) { | ||
/// // we have a `User` | ||
/// } | ||
/// | ||
/// async fn guest() { | ||
/// // `AdminPermissions` and `User` failed, so we're just a guest | ||
/// } | ||
/// | ||
/// // extractors for checking permissions | ||
/// struct AdminPermissions {} | ||
/// | ||
/// #[async_trait] | ||
/// impl<B: Send> FromRequest<B> for AdminPermissions { | ||
/// // check for admin permissions... | ||
/// # type Rejection = (); | ||
/// # async fn from_request(req: &mut axum::extract::RequestParts<B>) -> Result<Self, Self::Rejection> { | ||
/// # todo!() | ||
/// # } | ||
/// } | ||
/// | ||
/// struct User {} | ||
/// | ||
/// #[async_trait] | ||
/// impl<B: Send> FromRequest<B> for User { | ||
/// // check for a logged in user... | ||
/// # type Rejection = (); | ||
/// # async fn from_request(req: &mut axum::extract::RequestParts<B>) -> Result<Self, Self::Rejection> { | ||
/// # todo!() | ||
/// # } | ||
/// } | ||
/// | ||
/// let app = Router::new().route( | ||
/// "/users/:id", | ||
/// get( | ||
/// // first try `admin`, if that rejects run `user`, finally falling back | ||
/// // to `guest` | ||
/// admin.or(user).or(guest) | ||
/// ) | ||
/// ); | ||
/// # let _: Router = app; | ||
/// ``` | ||
fn or<R, Rt>(self, rhs: R) -> Or<Self, R, T, Rt, B> | ||
where | ||
R: HandlerCallWithExtractors<Rt, B>, | ||
{ | ||
Or { | ||
lhs: self, | ||
rhs, | ||
_marker: PhantomData, | ||
} | ||
} | ||
} | ||
|
||
macro_rules! impl_handler_call_with { | ||
( $($ty:ident),* $(,)? ) => { | ||
#[allow(non_snake_case)] | ||
impl<F, Fut, B, $($ty,)*> HandlerCallWithExtractors<($($ty,)*), B> for F | ||
where | ||
F: FnOnce($($ty,)*) -> Fut, | ||
Fut: Future + Send + 'static, | ||
Fut::Output: IntoResponse, | ||
{ | ||
// this puts `futures_util` in our public API but thats fine in axum-extra | ||
type Future = Map<Fut, fn(Fut::Output) -> Response>; | ||
|
||
fn call( | ||
self, | ||
($($ty,)*): ($($ty,)*), | ||
) -> <Self as HandlerCallWithExtractors<($($ty,)*), B>>::Future { | ||
self($($ty,)*).map(IntoResponse::into_response) | ||
} | ||
} | ||
}; | ||
} | ||
|
||
impl_handler_call_with!(); | ||
impl_handler_call_with!(T1); | ||
impl_handler_call_with!(T1, T2); | ||
impl_handler_call_with!(T1, T2, T3); | ||
impl_handler_call_with!(T1, T2, T3, T4); | ||
impl_handler_call_with!(T1, T2, T3, T4, T5); | ||
impl_handler_call_with!(T1, T2, T3, T4, T5, T6); | ||
impl_handler_call_with!(T1, T2, T3, T4, T5, T6, T7); | ||
impl_handler_call_with!(T1, T2, T3, T4, T5, T6, T7, T8); | ||
impl_handler_call_with!(T1, T2, T3, T4, T5, T6, T7, T8, T9); | ||
impl_handler_call_with!(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10); | ||
impl_handler_call_with!(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11); | ||
impl_handler_call_with!(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12); | ||
impl_handler_call_with!(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13); | ||
impl_handler_call_with!(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14); | ||
impl_handler_call_with!(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15); | ||
impl_handler_call_with!(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15, T16); | ||
|
||
/// A [`Handler`] created from a [`HandlerCallWithExtractors`]. | ||
/// | ||
/// Created with [`HandlerCallWithExtractors::into_handler`]. | ||
#[allow(missing_debug_implementations)] | ||
pub struct IntoHandler<H, T, B> { | ||
handler: H, | ||
_marker: PhantomData<fn() -> (T, B)>, | ||
} | ||
|
||
impl<H, T, B> Handler<T, B> for IntoHandler<H, T, B> | ||
where | ||
H: HandlerCallWithExtractors<T, B> + Clone + Send + 'static, | ||
T: FromRequest<B> + Send + 'static, | ||
T::Rejection: Send, | ||
B: Send + 'static, | ||
{ | ||
type Future = BoxFuture<'static, Response>; | ||
|
||
fn call(self, req: http::Request<B>) -> Self::Future { | ||
Box::pin(async move { | ||
let mut req = RequestParts::new(req); | ||
match req.extract::<T>().await { | ||
Ok(t) => self.handler.call(t).await, | ||
Err(rejection) => rejection.into_response(), | ||
} | ||
}) | ||
} | ||
} | ||
|
||
impl<H, T, B> Copy for IntoHandler<H, T, B> where H: Copy {} | ||
|
||
impl<H, T, B> Clone for IntoHandler<H, T, B> | ||
where | ||
H: Clone, | ||
{ | ||
fn clone(&self) -> Self { | ||
Self { | ||
handler: self.handler.clone(), | ||
_marker: self._marker, | ||
} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,151 @@ | ||
use super::HandlerCallWithExtractors; | ||
use crate::Either; | ||
use axum::{ | ||
extract::{FromRequest, RequestParts}, | ||
handler::Handler, | ||
http::Request, | ||
response::{IntoResponse, Response}, | ||
}; | ||
use futures_util::future::{BoxFuture, Either as EitherFuture, FutureExt, Map}; | ||
use http::StatusCode; | ||
use std::{future::Future, marker::PhantomData}; | ||
|
||
/// [`Handler`] that runs one [`Handler`] and if that rejects it'll fallback to another | ||
/// [`Handler`]. | ||
/// | ||
/// Created with [`HandlerCallWithExtractors::or`](super::HandlerCallWithExtractors::or). | ||
#[allow(missing_debug_implementations)] | ||
pub struct Or<L, R, Lt, Rt, B> { | ||
pub(super) lhs: L, | ||
pub(super) rhs: R, | ||
pub(super) _marker: PhantomData<fn() -> (Lt, Rt, B)>, | ||
} | ||
|
||
impl<B, L, R, Lt, Rt> HandlerCallWithExtractors<Either<Lt, Rt>, B> for Or<L, R, Lt, Rt, B> | ||
where | ||
L: HandlerCallWithExtractors<Lt, B> + Send + 'static, | ||
R: HandlerCallWithExtractors<Rt, B> + Send + 'static, | ||
Rt: Send + 'static, | ||
Lt: Send + 'static, | ||
B: Send + 'static, | ||
{ | ||
// this puts `futures_util` in our public API but thats fine in axum-extra | ||
type Future = EitherFuture< | ||
Map<L::Future, fn(<L::Future as Future>::Output) -> Response>, | ||
Map<R::Future, fn(<R::Future as Future>::Output) -> Response>, | ||
>; | ||
|
||
fn call( | ||
self, | ||
extractors: Either<Lt, Rt>, | ||
) -> <Self as HandlerCallWithExtractors<Either<Lt, Rt>, B>>::Future { | ||
match extractors { | ||
Either::Left(lt) => self | ||
.lhs | ||
.call(lt) | ||
.map(IntoResponse::into_response as _) | ||
.left_future(), | ||
Either::Right(rt) => self | ||
.rhs | ||
.call(rt) | ||
.map(IntoResponse::into_response as _) | ||
.right_future(), | ||
} | ||
} | ||
} | ||
|
||
impl<B, L, R, Lt, Rt> Handler<(Lt, Rt), B> for Or<L, R, Lt, Rt, B> | ||
where | ||
L: HandlerCallWithExtractors<Lt, B> + Clone + Send + 'static, | ||
R: HandlerCallWithExtractors<Rt, B> + Clone + Send + 'static, | ||
Lt: FromRequest<B> + Send + 'static, | ||
Rt: FromRequest<B> + Send + 'static, | ||
Lt::Rejection: Send, | ||
Rt::Rejection: Send, | ||
B: Send + 'static, | ||
{ | ||
// this puts `futures_util` in our public API but thats fine in axum-extra | ||
type Future = BoxFuture<'static, Response>; | ||
|
||
fn call(self, req: Request<B>) -> Self::Future { | ||
Box::pin(async move { | ||
let mut req = RequestParts::new(req); | ||
|
||
if let Ok(lt) = req.extract::<Lt>().await { | ||
return self.lhs.call(lt).await; | ||
} | ||
|
||
if let Ok(rt) = req.extract::<Rt>().await { | ||
return self.rhs.call(rt).await; | ||
} | ||
|
||
StatusCode::NOT_FOUND.into_response() | ||
}) | ||
} | ||
} | ||
|
||
impl<L, R, Lt, Rt, B> Copy for Or<L, R, Lt, Rt, B> | ||
where | ||
L: Copy, | ||
R: Copy, | ||
{ | ||
} | ||
|
||
impl<L, R, Lt, Rt, B> Clone for Or<L, R, Lt, Rt, B> | ||
where | ||
L: Clone, | ||
R: Clone, | ||
{ | ||
fn clone(&self) -> Self { | ||
Self { | ||
lhs: self.lhs.clone(), | ||
rhs: self.rhs.clone(), | ||
_marker: self._marker, | ||
} | ||
} | ||
} | ||
|
||
#[cfg(test)] | ||
mod tests { | ||
use super::*; | ||
use crate::test_helpers::*; | ||
use axum::{ | ||
extract::{Path, Query}, | ||
routing::get, | ||
Router, | ||
}; | ||
use serde::Deserialize; | ||
|
||
#[tokio::test] | ||
async fn works() { | ||
#[derive(Deserialize)] | ||
struct Params { | ||
a: String, | ||
} | ||
|
||
async fn one(Path(id): Path<u32>) -> String { | ||
id.to_string() | ||
} | ||
|
||
async fn two(Query(params): Query<Params>) -> String { | ||
params.a | ||
} | ||
|
||
async fn three() -> &'static str { | ||
"fallback" | ||
} | ||
|
||
let app = Router::new().route("/:id", get(one.or(two).or(three))); | ||
|
||
let client = TestClient::new(app); | ||
|
||
let res = client.get("/123").send().await; | ||
assert_eq!(res.text().await, "123"); | ||
|
||
let res = client.get("/foo?a=bar").send().await; | ||
assert_eq!(res.text().await, "bar"); | ||
|
||
let res = client.get("/foo").send().await; | ||
assert_eq!(res.text().await, "fallback"); | ||
} | ||
} |
Oops, something went wrong.