Skip to content

Commit

Permalink
Support customizing rejections for #[derive(TypedPath)] (#1012)
Browse files Browse the repository at this point in the history
* Support customizing rejections for `#[derive(TypedPath)]`

* changelog

* clean up
  • Loading branch information
davidpdrsn authored May 17, 2022
1 parent b215a87 commit 5948cde
Show file tree
Hide file tree
Showing 5 changed files with 227 additions and 23 deletions.
3 changes: 3 additions & 0 deletions axum-extra/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,12 @@ and this project adheres to [Semantic Versioning].

- **fixed:** `Option` and `Result` are now supported in typed path route handler parameters ([#1001])
- **fixed:** Support wildcards in typed paths ([#1003])
- **added:** Support using a custom rejection type for `#[derive(TypedPath)]`
instead of `PathRejection` ([#1012])

[#1001]: https://github.com/tokio-rs/axum/pull/1001
[#1003]: https://github.com/tokio-rs/axum/pull/1003
[#1012]: https://github.com/tokio-rs/axum/pull/1012

# 0.3.0 (27. April, 2022)

Expand Down
65 changes: 65 additions & 0 deletions axum-extra/src/routing/typed.rs
Original file line number Diff line number Diff line change
Expand Up @@ -140,12 +140,77 @@ use http::Uri;
/// );
/// ```
///
/// ## Customizing the rejection
///
/// By default the rejection used in the [`FromRequest`] implemetation will be [`PathRejection`].
///
/// That can be customized using `#[typed_path("...", rejection(YourType))]`:
///
/// ```
/// use serde::Deserialize;
/// use axum_extra::routing::TypedPath;
/// use axum::{
/// response::{IntoResponse, Response},
/// extract::rejection::PathRejection,
/// };
///
/// #[derive(TypedPath, Deserialize)]
/// #[typed_path("/users/:id", rejection(UsersMemberRejection))]
/// struct UsersMember {
/// id: String,
/// }
///
/// struct UsersMemberRejection;
///
/// // Your rejection type must implement `From<PathRejection>`.
/// //
/// // Here you can grab whatever details from the inner rejection
/// // that you need.
/// impl From<PathRejection> for UsersMemberRejection {
/// fn from(rejection: PathRejection) -> Self {
/// # UsersMemberRejection
/// // ...
/// }
/// }
///
/// // Your rejection must implement `IntoResponse`, like all rejections.
/// impl IntoResponse for UsersMemberRejection {
/// fn into_response(self) -> Response {
/// # ().into_response()
/// // ...
/// }
/// }
/// ```
///
/// The `From<PathRejection>` requirement only applies if your typed path is a struct with named
/// fields or a tuple struct. For unit structs your rejection type must implement `Default`:
///
/// ```
/// use axum_extra::routing::TypedPath;
/// use axum::response::{IntoResponse, Response};
///
/// #[derive(TypedPath)]
/// #[typed_path("/users", rejection(UsersCollectionRejection))]
/// struct UsersCollection;
///
/// #[derive(Default)]
/// struct UsersCollectionRejection;
///
/// impl IntoResponse for UsersCollectionRejection {
/// fn into_response(self) -> Response {
/// # ().into_response()
/// // ...
/// }
/// }
/// ```
///
/// [`FromRequest`]: axum::extract::FromRequest
/// [`RouterExt::typed_get`]: super::RouterExt::typed_get
/// [`RouterExt::typed_post`]: super::RouterExt::typed_post
/// [`Path`]: axum::extract::Path
/// [`Display`]: std::fmt::Display
/// [`Deserialize`]: serde::Deserialize
/// [`PathRejection`]: axum::extract::rejection::PathRejection
pub trait TypedPath: std::fmt::Display {
/// The path with optional captures such as `/users/:id`.
const PATH: &'static str;
Expand Down
3 changes: 3 additions & 0 deletions axum-macros/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,13 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- **fixed:** `Option` and `Result` are now supported in typed path route handler parameters ([#1001])
- **fixed:** Support wildcards in typed paths ([#1003])
- **added:** Support `#[derive(FromRequest)]` on enums using `#[from_request(via(OtherExtractor))]` ([#1009])
- **added:** Support using a custom rejection type for `#[derive(TypedPath)]`
instead of `PathRejection` ([#1012])

[#1001]: https://github.com/tokio-rs/axum/pull/1001
[#1003]: https://github.com/tokio-rs/axum/pull/1003
[#1009]: https://github.com/tokio-rs/axum/pull/1009
[#1012]: https://github.com/tokio-rs/axum/pull/1012

# 0.2.0 (31. March, 2022)

Expand Down
132 changes: 109 additions & 23 deletions axum-macros/src/typed_path.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use proc_macro2::{Span, TokenStream};
use quote::{format_ident, quote, quote_spanned};
use syn::{ItemStruct, LitStr};
use syn::{parse::Parse, ItemStruct, LitStr, Token};

pub(crate) fn expand(item_struct: ItemStruct) -> syn::Result<TokenStream> {
let ItemStruct {
Expand All @@ -18,52 +18,79 @@ pub(crate) fn expand(item_struct: ItemStruct) -> syn::Result<TokenStream> {
));
}

let Attrs { path } = parse_attrs(attrs)?;
let Attrs { path, rejection } = parse_attrs(attrs)?;

match fields {
syn::Fields::Named(_) => {
let segments = parse_path(&path)?;
Ok(expand_named_fields(ident, path, &segments))
Ok(expand_named_fields(ident, path, &segments, rejection))
}
syn::Fields::Unnamed(fields) => {
let segments = parse_path(&path)?;
expand_unnamed_fields(fields, ident, path, &segments)
expand_unnamed_fields(fields, ident, path, &segments, rejection)
}
syn::Fields::Unit => expand_unit_fields(ident, path),
syn::Fields::Unit => expand_unit_fields(ident, path, rejection),
}
}

mod kw {
syn::custom_keyword!(rejection);
}

struct Attrs {
path: LitStr,
rejection: Option<syn::Path>,
}

impl Parse for Attrs {
fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
let path = input.parse()?;

let rejection = if input.is_empty() {
None
} else {
let _: Token![,] = input.parse()?;
let _: kw::rejection = input.parse()?;

let content;
syn::parenthesized!(content in input);
Some(content.parse()?)
};

Ok(Self { path, rejection })
}
}

fn parse_attrs(attrs: &[syn::Attribute]) -> syn::Result<Attrs> {
let mut path = None;
let mut out = None;

for attr in attrs {
if attr.path.is_ident("typed_path") {
if path.is_some() {
if out.is_some() {
return Err(syn::Error::new_spanned(
attr,
"`typed_path` specified more than once",
));
} else {
path = Some(attr.parse_args()?);
out = Some(attr.parse_args()?);
}
}
}

Ok(Attrs {
path: path.ok_or_else(|| {
syn::Error::new(
Span::call_site(),
"missing `#[typed_path(\"...\")]` attribute",
)
})?,
out.ok_or_else(|| {
syn::Error::new(
Span::call_site(),
"missing `#[typed_path(\"...\")]` attribute",
)
})
}

fn expand_named_fields(ident: &syn::Ident, path: LitStr, segments: &[Segment]) -> TokenStream {
fn expand_named_fields(
ident: &syn::Ident,
path: LitStr,
segments: &[Segment],
rejection: Option<syn::Path>,
) -> TokenStream {
let format_str = format_str_from_path(segments);
let captures = captures_from_path(segments);

Expand All @@ -88,17 +115,23 @@ fn expand_named_fields(ident: &syn::Ident, path: LitStr, segments: &[Segment]) -
}
};

let rejection_assoc_type = rejection_assoc_type(&rejection);
let map_err_rejection = map_err_rejection(&rejection);

let from_request_impl = quote! {
#[::axum::async_trait]
#[automatically_derived]
impl<B> ::axum::extract::FromRequest<B> for #ident
where
B: Send,
{
type Rejection = <::axum::extract::Path<Self> as ::axum::extract::FromRequest<B>>::Rejection;
type Rejection = #rejection_assoc_type;

async fn from_request(req: &mut ::axum::extract::RequestParts<B>) -> ::std::result::Result<Self, Self::Rejection> {
::axum::extract::Path::from_request(req).await.map(|path| path.0)
::axum::extract::Path::from_request(req)
.await
.map(|path| path.0)
#map_err_rejection
}
}
};
Expand All @@ -115,6 +148,7 @@ fn expand_unnamed_fields(
ident: &syn::Ident,
path: LitStr,
segments: &[Segment],
rejection: Option<syn::Path>,
) -> syn::Result<TokenStream> {
let num_captures = segments
.iter()
Expand Down Expand Up @@ -177,17 +211,23 @@ fn expand_unnamed_fields(
}
};

let rejection_assoc_type = rejection_assoc_type(&rejection);
let map_err_rejection = map_err_rejection(&rejection);

let from_request_impl = quote! {
#[::axum::async_trait]
#[automatically_derived]
impl<B> ::axum::extract::FromRequest<B> for #ident
where
B: Send,
{
type Rejection = <::axum::extract::Path<Self> as ::axum::extract::FromRequest<B>>::Rejection;
type Rejection = #rejection_assoc_type;

async fn from_request(req: &mut ::axum::extract::RequestParts<B>) -> ::std::result::Result<Self, Self::Rejection> {
::axum::extract::Path::from_request(req).await.map(|path| path.0)
::axum::extract::Path::from_request(req)
.await
.map(|path| path.0)
#map_err_rejection
}
}
};
Expand All @@ -207,7 +247,11 @@ fn simple_pluralize(count: usize, word: &str) -> String {
}
}

fn expand_unit_fields(ident: &syn::Ident, path: LitStr) -> syn::Result<TokenStream> {
fn expand_unit_fields(
ident: &syn::Ident,
path: LitStr,
rejection: Option<syn::Path>,
) -> syn::Result<TokenStream> {
for segment in parse_path(&path)? {
match segment {
Segment::Capture(_, span) => {
Expand Down Expand Up @@ -236,20 +280,35 @@ fn expand_unit_fields(ident: &syn::Ident, path: LitStr) -> syn::Result<TokenStre
}
};

let rejection_assoc_type = if let Some(rejection) = &rejection {
quote! { #rejection }
} else {
quote! { ::axum::http::StatusCode }
};
let create_rejection = if let Some(rejection) = &rejection {
quote! {
Err(<#rejection as ::std::default::Default>::default())
}
} else {
quote! {
Err(::axum::http::StatusCode::NOT_FOUND)
}
};

let from_request_impl = quote! {
#[::axum::async_trait]
#[automatically_derived]
impl<B> ::axum::extract::FromRequest<B> for #ident
where
B: Send,
{
type Rejection = ::axum::http::StatusCode;
type Rejection = #rejection_assoc_type;

async fn from_request(req: &mut ::axum::extract::RequestParts<B>) -> ::std::result::Result<Self, Self::Rejection> {
if req.uri().path() == <Self as ::axum_extra::routing::TypedPath>::PATH {
Ok(Self)
} else {
Err(::axum::http::StatusCode::NOT_FOUND)
#create_rejection
}
}
}
Expand Down Expand Up @@ -314,6 +373,33 @@ enum Segment {
Static(String),
}

fn path_rejection() -> TokenStream {
quote! {
<::axum::extract::Path<Self> as ::axum::extract::FromRequest<B>>::Rejection
}
}

fn rejection_assoc_type(rejection: &Option<syn::Path>) -> TokenStream {
match rejection {
Some(rejection) => quote! { #rejection },
None => path_rejection(),
}
}

fn map_err_rejection(rejection: &Option<syn::Path>) -> TokenStream {
rejection
.as_ref()
.map(|rejection| {
let path_rejection = path_rejection();
quote! {
.map_err(|rejection| {
<#rejection as ::std::convert::From<#path_rejection>>::from(rejection)
})
}
})
.unwrap_or_default()
}

#[test]
fn ui() {
#[rustversion::stable]
Expand Down
Loading

0 comments on commit 5948cde

Please sign in to comment.