Skip to content

Commit

Permalink
Support chaining handlers with one.or(two).or(three) (#1170)
Browse files Browse the repository at this point in the history
* WIP: Handler fallbacks

* Docs

* box futures a bit less

* changelog
  • Loading branch information
davidpdrsn authored Aug 10, 2022
1 parent 7cbb7cf commit e79820d
Show file tree
Hide file tree
Showing 4 changed files with 392 additions and 0 deletions.
2 changes: 2 additions & 0 deletions axum-extra/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,11 @@ and this project adheres to [Semantic Versioning].
- **changed:** For methods that accept some `S: Service`, the bounds have been
relaxed so the response type must implement `IntoResponse` rather than being a
literal `Response`
- **added:** Support chaining handlers with `HandlerCallWithExtractors::or` ([#1170])
- **change:** axum-extra's MSRV is now 1.60 ([#1239])

[#1119]: https://github.com/tokio-rs/axum/pull/1119
[#1170]: https://github.com/tokio-rs/axum/pull/1170
[#1239]: https://github.com/tokio-rs/axum/pull/1239

# 0.3.5 (27. June, 2022)
Expand Down
192 changes: 192 additions & 0 deletions axum-extra/src/handler/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,192 @@
//! Additional handler utilities.
use axum::{
extract::{FromRequest, RequestParts},
handler::Handler,
response::{IntoResponse, Response},
};
use futures_util::future::{BoxFuture, FutureExt, Map};
use std::{future::Future, marker::PhantomData};

mod or;

pub use self::or::Or;

/// Trait for async functions that can be used to handle requests.
///
/// This trait is similar to [`Handler`] but rather than taking the request it takes the extracted
/// inputs.
///
/// The drawbacks of this trait is that you cannot apply middleware to individual handlers like you
/// can with [`Handler::layer`].
pub trait HandlerCallWithExtractors<T, B>: Sized {
/// The type of future calling this handler returns.
type Future: Future<Output = Response> + Send + 'static;

/// Call the handler with the extracted inputs.
fn call(self, extractors: T) -> <Self as HandlerCallWithExtractors<T, B>>::Future;

/// Conver this `HandlerCallWithExtractors` into [`Handler`].
fn into_handler(self) -> IntoHandler<Self, T, B> {
IntoHandler {
handler: self,
_marker: PhantomData,
}
}

/// Chain two handlers together, running the second one if the first one rejects.
///
/// Note that this only moves to the next handler if an extractor fails. The response from
/// handlers are not considered.
///
/// # Example
///
/// ```
/// use axum_extra::handler::HandlerCallWithExtractors;
/// use axum::{
/// Router,
/// async_trait,
/// routing::get,
/// extract::FromRequest,
/// };
///
/// // handlers for varying levels of access
/// async fn admin(admin: AdminPermissions) {
/// // request came from an admin
/// }
///
/// async fn user(user: User) {
/// // we have a `User`
/// }
///
/// async fn guest() {
/// // `AdminPermissions` and `User` failed, so we're just a guest
/// }
///
/// // extractors for checking permissions
/// struct AdminPermissions {}
///
/// #[async_trait]
/// impl<B: Send> FromRequest<B> for AdminPermissions {
/// // check for admin permissions...
/// # type Rejection = ();
/// # async fn from_request(req: &mut axum::extract::RequestParts<B>) -> Result<Self, Self::Rejection> {
/// # todo!()
/// # }
/// }
///
/// struct User {}
///
/// #[async_trait]
/// impl<B: Send> FromRequest<B> for User {
/// // check for a logged in user...
/// # type Rejection = ();
/// # async fn from_request(req: &mut axum::extract::RequestParts<B>) -> Result<Self, Self::Rejection> {
/// # todo!()
/// # }
/// }
///
/// let app = Router::new().route(
/// "/users/:id",
/// get(
/// // first try `admin`, if that rejects run `user`, finally falling back
/// // to `guest`
/// admin.or(user).or(guest)
/// )
/// );
/// # let _: Router = app;
/// ```
fn or<R, Rt>(self, rhs: R) -> Or<Self, R, T, Rt, B>
where
R: HandlerCallWithExtractors<Rt, B>,
{
Or {
lhs: self,
rhs,
_marker: PhantomData,
}
}
}

macro_rules! impl_handler_call_with {
( $($ty:ident),* $(,)? ) => {
#[allow(non_snake_case)]
impl<F, Fut, B, $($ty,)*> HandlerCallWithExtractors<($($ty,)*), B> for F
where
F: FnOnce($($ty,)*) -> Fut,
Fut: Future + Send + 'static,
Fut::Output: IntoResponse,
{
// this puts `futures_util` in our public API but thats fine in axum-extra
type Future = Map<Fut, fn(Fut::Output) -> Response>;

fn call(
self,
($($ty,)*): ($($ty,)*),
) -> <Self as HandlerCallWithExtractors<($($ty,)*), B>>::Future {
self($($ty,)*).map(IntoResponse::into_response)
}
}
};
}

impl_handler_call_with!();
impl_handler_call_with!(T1);
impl_handler_call_with!(T1, T2);
impl_handler_call_with!(T1, T2, T3);
impl_handler_call_with!(T1, T2, T3, T4);
impl_handler_call_with!(T1, T2, T3, T4, T5);
impl_handler_call_with!(T1, T2, T3, T4, T5, T6);
impl_handler_call_with!(T1, T2, T3, T4, T5, T6, T7);
impl_handler_call_with!(T1, T2, T3, T4, T5, T6, T7, T8);
impl_handler_call_with!(T1, T2, T3, T4, T5, T6, T7, T8, T9);
impl_handler_call_with!(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10);
impl_handler_call_with!(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11);
impl_handler_call_with!(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12);
impl_handler_call_with!(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13);
impl_handler_call_with!(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14);
impl_handler_call_with!(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15);
impl_handler_call_with!(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15, T16);

/// A [`Handler`] created from a [`HandlerCallWithExtractors`].
///
/// Created with [`HandlerCallWithExtractors::into_handler`].
#[allow(missing_debug_implementations)]
pub struct IntoHandler<H, T, B> {
handler: H,
_marker: PhantomData<fn() -> (T, B)>,
}

impl<H, T, B> Handler<T, B> for IntoHandler<H, T, B>
where
H: HandlerCallWithExtractors<T, B> + Clone + Send + 'static,
T: FromRequest<B> + Send + 'static,
T::Rejection: Send,
B: Send + 'static,
{
type Future = BoxFuture<'static, Response>;

fn call(self, req: http::Request<B>) -> Self::Future {
Box::pin(async move {
let mut req = RequestParts::new(req);
match req.extract::<T>().await {
Ok(t) => self.handler.call(t).await,
Err(rejection) => rejection.into_response(),
}
})
}
}

impl<H, T, B> Copy for IntoHandler<H, T, B> where H: Copy {}

impl<H, T, B> Clone for IntoHandler<H, T, B>
where
H: Clone,
{
fn clone(&self) -> Self {
Self {
handler: self.handler.clone(),
_marker: self._marker,
}
}
}
151 changes: 151 additions & 0 deletions axum-extra/src/handler/or.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
use super::HandlerCallWithExtractors;
use crate::Either;
use axum::{
extract::{FromRequest, RequestParts},
handler::Handler,
http::Request,
response::{IntoResponse, Response},
};
use futures_util::future::{BoxFuture, Either as EitherFuture, FutureExt, Map};
use http::StatusCode;
use std::{future::Future, marker::PhantomData};

/// [`Handler`] that runs one [`Handler`] and if that rejects it'll fallback to another
/// [`Handler`].
///
/// Created with [`HandlerCallWithExtractors::or`](super::HandlerCallWithExtractors::or).
#[allow(missing_debug_implementations)]
pub struct Or<L, R, Lt, Rt, B> {
pub(super) lhs: L,
pub(super) rhs: R,
pub(super) _marker: PhantomData<fn() -> (Lt, Rt, B)>,
}

impl<B, L, R, Lt, Rt> HandlerCallWithExtractors<Either<Lt, Rt>, B> for Or<L, R, Lt, Rt, B>
where
L: HandlerCallWithExtractors<Lt, B> + Send + 'static,
R: HandlerCallWithExtractors<Rt, B> + Send + 'static,
Rt: Send + 'static,
Lt: Send + 'static,
B: Send + 'static,
{
// this puts `futures_util` in our public API but thats fine in axum-extra
type Future = EitherFuture<
Map<L::Future, fn(<L::Future as Future>::Output) -> Response>,
Map<R::Future, fn(<R::Future as Future>::Output) -> Response>,
>;

fn call(
self,
extractors: Either<Lt, Rt>,
) -> <Self as HandlerCallWithExtractors<Either<Lt, Rt>, B>>::Future {
match extractors {
Either::Left(lt) => self
.lhs
.call(lt)
.map(IntoResponse::into_response as _)
.left_future(),
Either::Right(rt) => self
.rhs
.call(rt)
.map(IntoResponse::into_response as _)
.right_future(),
}
}
}

impl<B, L, R, Lt, Rt> Handler<(Lt, Rt), B> for Or<L, R, Lt, Rt, B>
where
L: HandlerCallWithExtractors<Lt, B> + Clone + Send + 'static,
R: HandlerCallWithExtractors<Rt, B> + Clone + Send + 'static,
Lt: FromRequest<B> + Send + 'static,
Rt: FromRequest<B> + Send + 'static,
Lt::Rejection: Send,
Rt::Rejection: Send,
B: Send + 'static,
{
// this puts `futures_util` in our public API but thats fine in axum-extra
type Future = BoxFuture<'static, Response>;

fn call(self, req: Request<B>) -> Self::Future {
Box::pin(async move {
let mut req = RequestParts::new(req);

if let Ok(lt) = req.extract::<Lt>().await {
return self.lhs.call(lt).await;
}

if let Ok(rt) = req.extract::<Rt>().await {
return self.rhs.call(rt).await;
}

StatusCode::NOT_FOUND.into_response()
})
}
}

impl<L, R, Lt, Rt, B> Copy for Or<L, R, Lt, Rt, B>
where
L: Copy,
R: Copy,
{
}

impl<L, R, Lt, Rt, B> Clone for Or<L, R, Lt, Rt, B>
where
L: Clone,
R: Clone,
{
fn clone(&self) -> Self {
Self {
lhs: self.lhs.clone(),
rhs: self.rhs.clone(),
_marker: self._marker,
}
}
}

#[cfg(test)]
mod tests {
use super::*;
use crate::test_helpers::*;
use axum::{
extract::{Path, Query},
routing::get,
Router,
};
use serde::Deserialize;

#[tokio::test]
async fn works() {
#[derive(Deserialize)]
struct Params {
a: String,
}

async fn one(Path(id): Path<u32>) -> String {
id.to_string()
}

async fn two(Query(params): Query<Params>) -> String {
params.a
}

async fn three() -> &'static str {
"fallback"
}

let app = Router::new().route("/:id", get(one.or(two).or(three)));

let client = TestClient::new(app);

let res = client.get("/123").send().await;
assert_eq!(res.text().await, "123");

let res = client.get("/foo?a=bar").send().await;
assert_eq!(res.text().await, "bar");

let res = client.get("/foo").send().await;
assert_eq!(res.text().await, "fallback");
}
}
Loading

0 comments on commit e79820d

Please sign in to comment.