Skip to content

Commit

Permalink
Support running extractors from middleware::from_fn (#1088)
Browse files Browse the repository at this point in the history
  • Loading branch information
davidpdrsn authored Jun 15, 2022
1 parent 9ec32a7 commit 53cce05
Show file tree
Hide file tree
Showing 2 changed files with 154 additions and 61 deletions.
4 changes: 3 additions & 1 deletion axum/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,15 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

- **added:** Support resolving host name via `Forwarded` header in `Host`
extractor ([#1078])
- **added:** Support running extractors from `middleware::from_fn` functions ([#1088])
- **breaking:** Allow `Error: Into<Infallible>` for `Route::{layer, route_layer}` ([#924])
- **breaking:** Remove `extractor_middleware` which was previously deprecated.
Use `axum::middleware::from_extractor` instead ([#1077])

[#924]: https://github.com/tokio-rs/axum/pull/924
[#1078]: https://github.com/tokio-rs/axum/pull/1078
[#1077]: https://github.com/tokio-rs/axum/pull/1077
[#1078]: https://github.com/tokio-rs/axum/pull/1078
[#1088]: https://github.com/tokio-rs/axum/pull/1088

# 0.5.7 (08. June, 2022)

Expand Down
211 changes: 151 additions & 60 deletions axum/src/middleware/from_fn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,15 @@ use crate::{
response::{IntoResponse, Response},
BoxError,
};
use axum_core::extract::{FromRequest, RequestParts};
use futures_util::future::BoxFuture;
use http::Request;
use pin_project_lite::pin_project;
use std::{
any::type_name,
convert::Infallible,
fmt,
future::Future,
marker::PhantomData,
pin::Pin,
task::{Context, Poll},
};
Expand All @@ -23,8 +25,8 @@ use tower_service::Service;
/// `from_fn` requires the function given to
///
/// 1. Be an `async fn`.
/// 2. Take [`Request<B>`](http::Request) as the first argument.
/// 3. Take [`Next<B>`](Next) as the second argument.
/// 2. Take one or more [extractors] as the first arguments.
/// 3. Take [`Next<B>`](Next) as the final argument.
/// 4. Return something that implements [`IntoResponse`].
///
/// # Example
Expand Down Expand Up @@ -62,6 +64,37 @@ use tower_service::Service;
/// # let app: Router = app;
/// ```
///
/// # Running extractors
///
/// ```rust
/// use axum::{
/// Router,
/// extract::{TypedHeader, Query},
/// headers::authorization::{Authorization, Bearer},
/// http::Request,
/// middleware::{self, Next},
/// response::Response,
/// routing::get,
/// };
/// use std::collections::HashMap;
///
/// async fn my_middleware<B>(
/// TypedHeader(auth): TypedHeader<Authorization<Bearer>>,
/// Query(query_params): Query<HashMap<String, String>>,
/// req: Request<B>,
/// next: Next<B>,
/// ) -> Response {
/// // do something with `auth` and `query_params`...
///
/// next.run(req).await
/// }
///
/// let app = Router::new()
/// .route("/", get(|| async { /* ... */ }))
/// .route_layer(middleware::from_fn(my_middleware));
/// # let app: Router = app;
/// ```
///
/// # Passing state
///
/// State can be passed to the function like so:
Expand Down Expand Up @@ -114,11 +147,10 @@ use tower_service::Service;
/// struct State { /* ... */ }
///
/// async fn my_middleware<B>(
/// Extension(state): Extension<State>,
/// req: Request<B>,
/// next: Next<B>,
/// ) -> Response {
/// let state: &State = req.extensions().get().unwrap();
///
/// // ...
/// # ().into_response()
/// }
Expand All @@ -134,35 +166,55 @@ use tower_service::Service;
/// );
/// # let app: Router = app;
/// ```
pub fn from_fn<F>(f: F) -> FromFnLayer<F> {
FromFnLayer { f }
///
/// [extractors]: crate::extract::FromRequest
pub fn from_fn<F, T>(f: F) -> FromFnLayer<F, T> {
FromFnLayer {
f,
_extractor: PhantomData,
}
}

/// A [`tower::Layer`] from an async function.
///
/// [`tower::Layer`] is used to apply middleware to [`Router`](crate::Router)'s.
///
/// Created with [`from_fn`]. See that function for more details.
#[derive(Clone, Copy)]
pub struct FromFnLayer<F> {
pub struct FromFnLayer<F, T> {
f: F,
_extractor: PhantomData<fn() -> T>,
}

impl<S, F> Layer<S> for FromFnLayer<F>
impl<F, T> Clone for FromFnLayer<F, T>
where
F: Clone,
{
type Service = FromFn<F, S>;
fn clone(&self) -> Self {
Self {
f: self.f.clone(),
_extractor: self._extractor,
}
}
}

impl<F, T> Copy for FromFnLayer<F, T> where F: Copy {}

impl<S, F, T> Layer<S> for FromFnLayer<F, T>
where
F: Clone,
{
type Service = FromFn<F, S, T>;

fn layer(&self, inner: S) -> Self::Service {
FromFn {
f: self.f.clone(),
inner,
_extractor: PhantomData,
}
}
}

impl<F> fmt::Debug for FromFnLayer<F> {
impl<F, T> fmt::Debug for FromFnLayer<F, T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("FromFnLayer")
// Write out the type name, without quoting it as `&type_name::<F>()` would
Expand All @@ -174,50 +226,94 @@ impl<F> fmt::Debug for FromFnLayer<F> {
/// A middleware created from an async function.
///
/// Created with [`from_fn`]. See that function for more details.
#[derive(Clone, Copy)]
pub struct FromFn<F, S> {
pub struct FromFn<F, S, T> {
f: F,
inner: S,
_extractor: PhantomData<fn() -> T>,
}

impl<F, Fut, Out, S, ReqBody, ResBody> Service<Request<ReqBody>> for FromFn<F, S>
impl<F, S, T> Clone for FromFn<F, S, T>
where
F: FnMut(Request<ReqBody>, Next<ReqBody>) -> Fut,
Fut: Future<Output = Out>,
Out: IntoResponse,
S: Service<Request<ReqBody>, Response = Response<ResBody>, Error = Infallible>
+ Clone
+ Send
+ 'static,
S::Future: Send + 'static,
ResBody: HttpBody<Data = Bytes> + Send + 'static,
ResBody::Error: Into<BoxError>,
F: Clone,
S: Clone,
{
type Response = Response;
type Error = Infallible;
type Future = ResponseFuture<Fut>;

fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
fn clone(&self) -> Self {
Self {
f: self.f.clone(),
inner: self.inner.clone(),
_extractor: self._extractor,
}
}
}

impl<F, S, T> Copy for FromFn<F, S, T>
where
F: Copy,
S: Copy,
{
}

macro_rules! impl_service {
( $($ty:ident),* $(,)? ) => {
#[allow(non_snake_case)]
impl<F, Fut, Out, S, ReqBody, ResBody, $($ty,)*> Service<Request<ReqBody>> for FromFn<F, S, ($($ty,)*)>
where
F: FnMut($($ty),*, Next<ReqBody>) -> Fut + Clone + Send + 'static,
$( $ty: FromRequest<ReqBody> + Send, )*
Fut: Future<Output = Out> + Send + 'static,
Out: IntoResponse + 'static,
S: Service<Request<ReqBody>, Response = Response<ResBody>, Error = Infallible>
+ Clone
+ Send
+ 'static,
S::Future: Send + 'static,
ReqBody: Send + 'static,
ResBody: HttpBody<Data = Bytes> + Send + 'static,
ResBody::Error: Into<BoxError>,
{
type Response = Response;
type Error = Infallible;
type Future = ResponseFuture;

fn call(&mut self, req: Request<ReqBody>) -> Self::Future {
let not_ready_inner = self.inner.clone();
let ready_inner = std::mem::replace(&mut self.inner, not_ready_inner);
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
}

let inner = ServiceBuilder::new()
.boxed_clone()
.map_response_body(body::boxed)
.service(ready_inner);
let next = Next { inner };
fn call(&mut self, req: Request<ReqBody>) -> Self::Future {
let not_ready_inner = self.inner.clone();
let ready_inner = std::mem::replace(&mut self.inner, not_ready_inner);

ResponseFuture {
inner: (self.f)(req, next),
let mut f = self.f.clone();

let future = Box::pin(async move {
let mut parts = RequestParts::new(req);
$(
let $ty = match $ty::from_request(&mut parts).await {
Ok(value) => value,
Err(rejection) => return rejection.into_response(),
};
)*

let inner = ServiceBuilder::new()
.boxed_clone()
.map_response_body(body::boxed)
.service(ready_inner);
let next = Next { inner };

f($($ty),*, next).await.into_response()
});

ResponseFuture {
inner: future
}
}
}
}
};
}

impl<F, S> fmt::Debug for FromFn<F, S>
all_the_tuples!(impl_service);

impl<F, S, T> fmt::Debug for FromFn<F, S, T>
where
S: fmt::Debug,
{
Expand Down Expand Up @@ -252,27 +348,22 @@ impl<ReqBody> fmt::Debug for Next<ReqBody> {
}
}

pin_project! {
/// Response future for [`FromFn`].
pub struct ResponseFuture<F> {
#[pin]
inner: F,
}
/// Response future for [`FromFn`].
pub struct ResponseFuture {
inner: BoxFuture<'static, Response>,
}

impl<F, Out> Future for ResponseFuture<F>
where
F: Future<Output = Out>,
Out: IntoResponse,
{
impl Future for ResponseFuture {
type Output = Result<Response, Infallible>;

fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
self.project()
.inner
.poll(cx)
.map(IntoResponse::into_response)
.map(Ok)
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
self.inner.as_mut().poll(cx).map(Ok)
}
}

impl fmt::Debug for ResponseFuture {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("ResponseFuture").finish()
}
}

Expand Down

0 comments on commit 53cce05

Please sign in to comment.