Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support chaining handlers with one.or(two).or(three) #1170

Merged
merged 6 commits into from
Aug 10, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions axum-extra/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,11 @@ and this project adheres to [Semantic Versioning].
- **changed:** For methods that accept some `S: Service`, the bounds have been
relaxed so the response type must implement `IntoResponse` rather than being a
literal `Response`
- **added:** Support chaining handlers with `HandlerCallWithExtractors::or` ([#1170])
- **change:** axum-extra's MSRV is now 1.60 ([#1239])

[#1119]: https://github.com/tokio-rs/axum/pull/1119
[#1170]: https://github.com/tokio-rs/axum/pull/1170
[#1239]: https://github.com/tokio-rs/axum/pull/1239

# 0.3.5 (27. June, 2022)
Expand Down
192 changes: 192 additions & 0 deletions axum-extra/src/handler/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,192 @@
//! Additional handler utilities.

use axum::{
extract::{FromRequest, RequestParts},
handler::Handler,
response::{IntoResponse, Response},
};
use futures_util::future::{BoxFuture, FutureExt, Map};
use std::{future::Future, marker::PhantomData};

mod or;

pub use self::or::Or;

/// Trait for async functions that can be used to handle requests.
///
/// This trait is similar to [`Handler`] but rather than taking the request it takes the extracted
/// inputs.
///
/// The drawbacks of this trait is that you cannot apply middleware to individual handlers like you
/// can with [`Handler::layer`].
pub trait HandlerCallWithExtractors<T, B>: Sized {
/// The type of future calling this handler returns.
type Future: Future<Output = Response> + Send + 'static;

/// Call the handler with the extracted inputs.
fn call(self, extractors: T) -> <Self as HandlerCallWithExtractors<T, B>>::Future;

/// Conver this `HandlerCallWithExtractors` into [`Handler`].
fn into_handler(self) -> IntoHandler<Self, T, B> {
IntoHandler {
handler: self,
_marker: PhantomData,
}
}

/// Chain two handlers together, running the second one if the first one rejects.
///
/// Note that this only moves to the next handler if an extractor fails. The response from
/// handlers are not considered.
///
/// # Example
///
/// ```
/// use axum_extra::handler::HandlerCallWithExtractors;
/// use axum::{
/// Router,
/// async_trait,
/// routing::get,
/// extract::FromRequest,
/// };
///
/// // handlers for varying levels of access
/// async fn admin(admin: AdminPermissions) {
/// // request came from an admin
/// }
///
/// async fn user(user: User) {
/// // we have a `User`
/// }
///
/// async fn guest() {
/// // `AdminPermissions` and `User` failed, so we're just a guest
/// }
///
/// // extractors for checking permissions
/// struct AdminPermissions {}
///
/// #[async_trait]
/// impl<B: Send> FromRequest<B> for AdminPermissions {
/// // check for admin permissions...
/// # type Rejection = ();
/// # async fn from_request(req: &mut axum::extract::RequestParts<B>) -> Result<Self, Self::Rejection> {
/// # todo!()
/// # }
/// }
///
/// struct User {}
///
/// #[async_trait]
/// impl<B: Send> FromRequest<B> for User {
/// // check for a logged in user...
/// # type Rejection = ();
/// # async fn from_request(req: &mut axum::extract::RequestParts<B>) -> Result<Self, Self::Rejection> {
/// # todo!()
/// # }
/// }
///
/// let app = Router::new().route(
/// "/users/:id",
/// get(
/// // first try `admin`, if that rejects run `user`, finally falling back
/// // to `guest`
/// admin.or(user).or(guest)
/// )
/// );
/// # let _: Router = app;
/// ```
fn or<R, Rt>(self, rhs: R) -> Or<Self, R, T, Rt, B>
where
R: HandlerCallWithExtractors<Rt, B>,
{
Or {
lhs: self,
rhs,
_marker: PhantomData,
}
}
}

macro_rules! impl_handler_call_with {
( $($ty:ident),* $(,)? ) => {
#[allow(non_snake_case)]
impl<F, Fut, B, $($ty,)*> HandlerCallWithExtractors<($($ty,)*), B> for F
where
F: FnOnce($($ty,)*) -> Fut,
Fut: Future + Send + 'static,
Fut::Output: IntoResponse,
{
// this puts `futures_util` in our public API but thats fine in axum-extra
type Future = Map<Fut, fn(Fut::Output) -> Response>;

fn call(
self,
($($ty,)*): ($($ty,)*),
) -> <Self as HandlerCallWithExtractors<($($ty,)*), B>>::Future {
self($($ty,)*).map(IntoResponse::into_response)
}
}
};
}

impl_handler_call_with!();
impl_handler_call_with!(T1);
impl_handler_call_with!(T1, T2);
impl_handler_call_with!(T1, T2, T3);
impl_handler_call_with!(T1, T2, T3, T4);
impl_handler_call_with!(T1, T2, T3, T4, T5);
impl_handler_call_with!(T1, T2, T3, T4, T5, T6);
impl_handler_call_with!(T1, T2, T3, T4, T5, T6, T7);
impl_handler_call_with!(T1, T2, T3, T4, T5, T6, T7, T8);
impl_handler_call_with!(T1, T2, T3, T4, T5, T6, T7, T8, T9);
impl_handler_call_with!(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10);
impl_handler_call_with!(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11);
impl_handler_call_with!(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12);
impl_handler_call_with!(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13);
impl_handler_call_with!(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14);
impl_handler_call_with!(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15);
impl_handler_call_with!(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15, T16);

/// A [`Handler`] created from a [`HandlerCallWithExtractors`].
///
/// Created with [`HandlerCallWithExtractors::into_handler`].
#[allow(missing_debug_implementations)]
pub struct IntoHandler<H, T, B> {
handler: H,
_marker: PhantomData<fn() -> (T, B)>,
}

impl<H, T, B> Handler<T, B> for IntoHandler<H, T, B>
where
H: HandlerCallWithExtractors<T, B> + Clone + Send + 'static,
T: FromRequest<B> + Send + 'static,
T::Rejection: Send,
B: Send + 'static,
{
type Future = BoxFuture<'static, Response>;

fn call(self, req: http::Request<B>) -> Self::Future {
Box::pin(async move {
let mut req = RequestParts::new(req);
match req.extract::<T>().await {
Ok(t) => self.handler.call(t).await,
Err(rejection) => rejection.into_response(),
}
})
}
}

impl<H, T, B> Copy for IntoHandler<H, T, B> where H: Copy {}

impl<H, T, B> Clone for IntoHandler<H, T, B>
where
H: Clone,
{
fn clone(&self) -> Self {
Self {
handler: self.handler.clone(),
_marker: self._marker,
}
}
}
151 changes: 151 additions & 0 deletions axum-extra/src/handler/or.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
use super::HandlerCallWithExtractors;
use crate::Either;
use axum::{
extract::{FromRequest, RequestParts},
handler::Handler,
http::Request,
response::{IntoResponse, Response},
};
use futures_util::future::{BoxFuture, Either as EitherFuture, FutureExt, Map};
use http::StatusCode;
use std::{future::Future, marker::PhantomData};

/// [`Handler`] that runs one [`Handler`] and if that rejects it'll fallback to another
/// [`Handler`].
///
/// Created with [`HandlerCallWithExtractors::or`](super::HandlerCallWithExtractors::or).
#[allow(missing_debug_implementations)]
pub struct Or<L, R, Lt, Rt, B> {
pub(super) lhs: L,
pub(super) rhs: R,
pub(super) _marker: PhantomData<fn() -> (Lt, Rt, B)>,
}

impl<B, L, R, Lt, Rt> HandlerCallWithExtractors<Either<Lt, Rt>, B> for Or<L, R, Lt, Rt, B>
where
L: HandlerCallWithExtractors<Lt, B> + Send + 'static,
R: HandlerCallWithExtractors<Rt, B> + Send + 'static,
Rt: Send + 'static,
Lt: Send + 'static,
B: Send + 'static,
{
// this puts `futures_util` in our public API but thats fine in axum-extra
type Future = EitherFuture<
Map<L::Future, fn(<L::Future as Future>::Output) -> Response>,
Map<R::Future, fn(<R::Future as Future>::Output) -> Response>,
>;

fn call(
self,
extractors: Either<Lt, Rt>,
) -> <Self as HandlerCallWithExtractors<Either<Lt, Rt>, B>>::Future {
match extractors {
Either::Left(lt) => self
.lhs
.call(lt)
.map(IntoResponse::into_response as _)
.left_future(),
Either::Right(rt) => self
.rhs
.call(rt)
.map(IntoResponse::into_response as _)
.right_future(),
}
}
}

impl<B, L, R, Lt, Rt> Handler<(Lt, Rt), B> for Or<L, R, Lt, Rt, B>
where
L: HandlerCallWithExtractors<Lt, B> + Clone + Send + 'static,
R: HandlerCallWithExtractors<Rt, B> + Clone + Send + 'static,
Lt: FromRequest<B> + Send + 'static,
Rt: FromRequest<B> + Send + 'static,
Lt::Rejection: Send,
Rt::Rejection: Send,
B: Send + 'static,
{
// this puts `futures_util` in our public API but thats fine in axum-extra
type Future = BoxFuture<'static, Response>;

fn call(self, req: Request<B>) -> Self::Future {
Box::pin(async move {
let mut req = RequestParts::new(req);

if let Ok(lt) = req.extract::<Lt>().await {
return self.lhs.call(lt).await;
}

if let Ok(rt) = req.extract::<Rt>().await {
return self.rhs.call(rt).await;
}

StatusCode::NOT_FOUND.into_response()
})
}
}

impl<L, R, Lt, Rt, B> Copy for Or<L, R, Lt, Rt, B>
where
L: Copy,
R: Copy,
{
}

impl<L, R, Lt, Rt, B> Clone for Or<L, R, Lt, Rt, B>
where
L: Clone,
R: Clone,
{
fn clone(&self) -> Self {
Self {
lhs: self.lhs.clone(),
rhs: self.rhs.clone(),
_marker: self._marker,
}
}
}

#[cfg(test)]
mod tests {
use super::*;
use crate::test_helpers::*;
use axum::{
extract::{Path, Query},
routing::get,
Router,
};
use serde::Deserialize;

#[tokio::test]
async fn works() {
#[derive(Deserialize)]
struct Params {
a: String,
}

async fn one(Path(id): Path<u32>) -> String {
id.to_string()
}

async fn two(Query(params): Query<Params>) -> String {
params.a
}

async fn three() -> &'static str {
"fallback"
}

let app = Router::new().route("/:id", get(one.or(two).or(three)));

let client = TestClient::new(app);

let res = client.get("/123").send().await;
assert_eq!(res.text().await, "123");

let res = client.get("/foo?a=bar").send().await;
assert_eq!(res.text().await, "bar");

let res = client.get("/foo").send().await;
assert_eq!(res.text().await, "fallback");
}
}
Loading