Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support customizing rejections for #[derive(TypedPath)] #1012

Merged
merged 8 commits into from
May 17, 2022
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions axum-extra/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,12 @@ and this project adheres to [Semantic Versioning].

- **fixed:** `Option` and `Result` are now supported in typed path route handler parameters ([#1001])
- **fixed:** Support wildcards in typed paths ([#1003])
- **added:** Support using a custom rejection type for `#[derive(TypedPath)]`
instead of `PathRejection` ([#1012])

[#1001]: https://github.com/tokio-rs/axum/pull/1001
[#1003]: https://github.com/tokio-rs/axum/pull/1003
[#1012]: https://github.com/tokio-rs/axum/pull/1012

# 0.3.0 (27. April, 2022)

Expand Down
65 changes: 65 additions & 0 deletions axum-extra/src/routing/typed.rs
Original file line number Diff line number Diff line change
Expand Up @@ -140,12 +140,77 @@ use http::Uri;
/// );
/// ```
///
/// ## Customizing the rejection
///
/// By default the rejection used in the [`FromRequest`] implemetation will be [`PathRejection`].
///
/// That can be customized using `#[typed_path("...", rejection(YourType))]`:
///
/// ```
/// use serde::Deserialize;
/// use axum_extra::routing::TypedPath;
/// use axum::{
/// response::{IntoResponse, Response},
/// extract::rejection::PathRejection,
/// };
///
/// #[derive(TypedPath, Deserialize)]
/// #[typed_path("/users/:id", rejection(UsersMemberRejection))]
/// struct UsersMember {
/// id: String,
/// }
///
/// struct UsersMemberRejection;
///
/// // Your rejection type must implement `From<PathRejection>`.
/// //
/// // Here you can grab whatever details from the inner rejection
/// // that you need.
/// impl From<PathRejection> for UsersMemberRejection {
/// fn from(rejection: PathRejection) -> Self {
/// # UsersMemberRejection
/// // ...
/// }
/// }
///
/// // Your rejection must implement `IntoResponse`, like all rejections.
/// impl IntoResponse for UsersMemberRejection {
/// fn into_response(self) -> Response {
/// # ().into_response()
/// // ...
/// }
/// }
/// ```
///
/// The `From<PathRejection>` requirement only applies if your typed path is a struct with named
/// fields or a tuple struct. For unit structs your rejection type must implement `Default`:
///
/// ```
/// use axum_extra::routing::TypedPath;
/// use axum::response::{IntoResponse, Response};
///
/// #[derive(TypedPath)]
/// #[typed_path("/users", rejection(UsersCollectionRejection))]
/// struct UsersCollection;
///
/// #[derive(Default)]
/// struct UsersCollectionRejection;
///
/// impl IntoResponse for UsersCollectionRejection {
/// fn into_response(self) -> Response {
/// # ().into_response()
/// // ...
/// }
/// }
/// ```
///
/// [`FromRequest`]: axum::extract::FromRequest
/// [`RouterExt::typed_get`]: super::RouterExt::typed_get
/// [`RouterExt::typed_post`]: super::RouterExt::typed_post
/// [`Path`]: axum::extract::Path
/// [`Display`]: std::fmt::Display
/// [`Deserialize`]: serde::Deserialize
/// [`PathRejection`]: axum::extract::rejection::PathRejection
pub trait TypedPath: std::fmt::Display {
/// The path with optional captures such as `/users/:id`.
const PATH: &'static str;
Expand Down
3 changes: 3 additions & 0 deletions axum-macros/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,13 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- **fixed:** `Option` and `Result` are now supported in typed path route handler parameters ([#1001])
- **fixed:** Support wildcards in typed paths ([#1003])
- **added:** Support `#[derive(FromRequest)]` on enums using `#[from_request(via(OtherExtractor))]` ([#1009])
- **added:** Support using a custom rejection type for `#[derive(TypedPath)]`
instead of `PathRejection` ([#1012])

[#1001]: https://github.com/tokio-rs/axum/pull/1001
[#1003]: https://github.com/tokio-rs/axum/pull/1003
[#1009]: https://github.com/tokio-rs/axum/pull/1009
[#1012]: https://github.com/tokio-rs/axum/pull/1012

# 0.2.0 (31. March, 2022)

Expand Down
135 changes: 112 additions & 23 deletions axum-macros/src/typed_path.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use proc_macro2::{Span, TokenStream};
use quote::{format_ident, quote, quote_spanned};
use syn::{ItemStruct, LitStr};
use syn::{parse::Parse, ItemStruct, LitStr, Token};

pub(crate) fn expand(item_struct: ItemStruct) -> syn::Result<TokenStream> {
let ItemStruct {
Expand All @@ -18,52 +18,82 @@ pub(crate) fn expand(item_struct: ItemStruct) -> syn::Result<TokenStream> {
));
}

let Attrs { path } = parse_attrs(attrs)?;
let Attrs { path, rejection } = parse_attrs(attrs)?;

match fields {
syn::Fields::Named(_) => {
let segments = parse_path(&path)?;
Ok(expand_named_fields(ident, path, &segments))
Ok(expand_named_fields(ident, path, &segments, rejection))
}
syn::Fields::Unnamed(fields) => {
let segments = parse_path(&path)?;
expand_unnamed_fields(fields, ident, path, &segments)
expand_unnamed_fields(fields, ident, path, &segments, rejection)
}
syn::Fields::Unit => expand_unit_fields(ident, path),
syn::Fields::Unit => expand_unit_fields(ident, path, rejection),
}
}

mod kw {
syn::custom_keyword!(rejection);
}

struct Attrs {
path: LitStr,
rejection: Option<syn::Path>,
}

impl Parse for Attrs {
fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
let path = input.parse()?;

let _ = input.parse::<Token![,]>();

let lh = input.lookahead1();
let rejection = if lh.peek(kw::rejection) {
input.parse::<kw::rejection>()?;
let content;
syn::parenthesized!(content in input);
Some(content.parse()?)
} else if lh.is_empty() {
None
} else {
return Err(lh.error());
};
davidpdrsn marked this conversation as resolved.
Show resolved Hide resolved

Ok(Self { path, rejection })
}
}

fn parse_attrs(attrs: &[syn::Attribute]) -> syn::Result<Attrs> {
let mut path = None;
let mut out = None;

for attr in attrs {
if attr.path.is_ident("typed_path") {
if path.is_some() {
if out.is_some() {
return Err(syn::Error::new_spanned(
attr,
"`typed_path` specified more than once",
));
} else {
path = Some(attr.parse_args()?);
out = Some(attr.parse_args()?);
}
}
}

Ok(Attrs {
path: path.ok_or_else(|| {
syn::Error::new(
Span::call_site(),
"missing `#[typed_path(\"...\")]` attribute",
)
})?,
out.ok_or_else(|| {
syn::Error::new(
Span::call_site(),
"missing `#[typed_path(\"...\")]` attribute",
)
})
}

fn expand_named_fields(ident: &syn::Ident, path: LitStr, segments: &[Segment]) -> TokenStream {
fn expand_named_fields(
ident: &syn::Ident,
path: LitStr,
segments: &[Segment],
rejection: Option<syn::Path>,
) -> TokenStream {
let format_str = format_str_from_path(segments);
let captures = captures_from_path(segments);

Expand All @@ -88,17 +118,23 @@ fn expand_named_fields(ident: &syn::Ident, path: LitStr, segments: &[Segment]) -
}
};

let rejection_assoc_type = rejection_assoc_type(&rejection);
let map_err_rejection = map_err_rejection(&rejection);

let from_request_impl = quote! {
#[::axum::async_trait]
#[automatically_derived]
impl<B> ::axum::extract::FromRequest<B> for #ident
where
B: Send,
{
type Rejection = <::axum::extract::Path<Self> as ::axum::extract::FromRequest<B>>::Rejection;
type Rejection = #rejection_assoc_type;

async fn from_request(req: &mut ::axum::extract::RequestParts<B>) -> ::std::result::Result<Self, Self::Rejection> {
::axum::extract::Path::from_request(req).await.map(|path| path.0)
::axum::extract::Path::from_request(req)
.await
.map(|path| path.0)
#map_err_rejection
}
}
};
Expand All @@ -115,6 +151,7 @@ fn expand_unnamed_fields(
ident: &syn::Ident,
path: LitStr,
segments: &[Segment],
rejection: Option<syn::Path>,
) -> syn::Result<TokenStream> {
let num_captures = segments
.iter()
Expand Down Expand Up @@ -177,17 +214,23 @@ fn expand_unnamed_fields(
}
};

let rejection_assoc_type = rejection_assoc_type(&rejection);
let map_err_rejection = map_err_rejection(&rejection);

let from_request_impl = quote! {
#[::axum::async_trait]
#[automatically_derived]
impl<B> ::axum::extract::FromRequest<B> for #ident
where
B: Send,
{
type Rejection = <::axum::extract::Path<Self> as ::axum::extract::FromRequest<B>>::Rejection;
type Rejection = #rejection_assoc_type;

async fn from_request(req: &mut ::axum::extract::RequestParts<B>) -> ::std::result::Result<Self, Self::Rejection> {
::axum::extract::Path::from_request(req).await.map(|path| path.0)
::axum::extract::Path::from_request(req)
.await
.map(|path| path.0)
#map_err_rejection
}
}
};
Expand All @@ -207,7 +250,11 @@ fn simple_pluralize(count: usize, word: &str) -> String {
}
}

fn expand_unit_fields(ident: &syn::Ident, path: LitStr) -> syn::Result<TokenStream> {
fn expand_unit_fields(
ident: &syn::Ident,
path: LitStr,
rejection: Option<syn::Path>,
) -> syn::Result<TokenStream> {
for segment in parse_path(&path)? {
match segment {
Segment::Capture(_, span) => {
Expand Down Expand Up @@ -236,20 +283,35 @@ fn expand_unit_fields(ident: &syn::Ident, path: LitStr) -> syn::Result<TokenStre
}
};

let rejection_assoc_type = if let Some(rejection) = &rejection {
quote! { #rejection }
} else {
quote! { ::axum::http::StatusCode }
};
let create_rejection = if let Some(rejection) = &rejection {
quote! {
Err(<#rejection as ::std::default::Default>::default())
}
} else {
quote! {
Err(::axum::http::StatusCode::NOT_FOUND)
}
};

let from_request_impl = quote! {
#[::axum::async_trait]
#[automatically_derived]
impl<B> ::axum::extract::FromRequest<B> for #ident
where
B: Send,
{
type Rejection = ::axum::http::StatusCode;
type Rejection = #rejection_assoc_type;

async fn from_request(req: &mut ::axum::extract::RequestParts<B>) -> ::std::result::Result<Self, Self::Rejection> {
if req.uri().path() == <Self as ::axum_extra::routing::TypedPath>::PATH {
Ok(Self)
} else {
Err(::axum::http::StatusCode::NOT_FOUND)
#create_rejection
}
}
}
Expand Down Expand Up @@ -314,6 +376,33 @@ enum Segment {
Static(String),
}

fn path_rejection() -> TokenStream {
quote! {
<::axum::extract::Path<Self> as ::axum::extract::FromRequest<B>>::Rejection
}
}

fn rejection_assoc_type(rejection: &Option<syn::Path>) -> TokenStream {
match rejection {
Some(rejection) => quote! { #rejection },
None => path_rejection(),
}
}

fn map_err_rejection(rejection: &Option<syn::Path>) -> TokenStream {
rejection
.as_ref()
.map(|rejection| {
let path_rejection = path_rejection();
quote! {
.map_err(|rejection| {
<#rejection as ::std::convert::From<#path_rejection>>::from(rejection)
})
}
})
.unwrap_or_default()
}

#[test]
fn ui() {
#[rustversion::stable]
Expand Down
Loading