diff --git a/axum-extra/CHANGELOG.md b/axum-extra/CHANGELOG.md index 6614a7f767..4d8512d0e9 100644 --- a/axum-extra/CHANGELOG.md +++ b/axum-extra/CHANGELOG.md @@ -9,9 +9,12 @@ and this project adheres to [Semantic Versioning]. - **fixed:** `Option` and `Result` are now supported in typed path route handler parameters ([#1001]) - **fixed:** Support wildcards in typed paths ([#1003]) +- **added:** Support using a custom rejection type for `#[derive(TypedPath)]` + instead of `PathRejection` ([#1012]) [#1001]: https://github.com/tokio-rs/axum/pull/1001 [#1003]: https://github.com/tokio-rs/axum/pull/1003 +[#1012]: https://github.com/tokio-rs/axum/pull/1012 # 0.3.0 (27. April, 2022) diff --git a/axum-extra/src/routing/typed.rs b/axum-extra/src/routing/typed.rs index 745ed950c5..683d68431c 100644 --- a/axum-extra/src/routing/typed.rs +++ b/axum-extra/src/routing/typed.rs @@ -140,12 +140,77 @@ use http::Uri; /// ); /// ``` /// +/// ## Customizing the rejection +/// +/// By default the rejection used in the [`FromRequest`] implemetation will be [`PathRejection`]. +/// +/// That can be customized using `#[typed_path("...", rejection(YourType))]`: +/// +/// ``` +/// use serde::Deserialize; +/// use axum_extra::routing::TypedPath; +/// use axum::{ +/// response::{IntoResponse, Response}, +/// extract::rejection::PathRejection, +/// }; +/// +/// #[derive(TypedPath, Deserialize)] +/// #[typed_path("/users/:id", rejection(UsersMemberRejection))] +/// struct UsersMember { +/// id: String, +/// } +/// +/// struct UsersMemberRejection; +/// +/// // Your rejection type must implement `From`. +/// // +/// // Here you can grab whatever details from the inner rejection +/// // that you need. +/// impl From for UsersMemberRejection { +/// fn from(rejection: PathRejection) -> Self { +/// # UsersMemberRejection +/// // ... +/// } +/// } +/// +/// // Your rejection must implement `IntoResponse`, like all rejections. +/// impl IntoResponse for UsersMemberRejection { +/// fn into_response(self) -> Response { +/// # ().into_response() +/// // ... +/// } +/// } +/// ``` +/// +/// The `From` requirement only applies if your typed path is a struct with named +/// fields or a tuple struct. For unit structs your rejection type must implement `Default`: +/// +/// ``` +/// use axum_extra::routing::TypedPath; +/// use axum::response::{IntoResponse, Response}; +/// +/// #[derive(TypedPath)] +/// #[typed_path("/users", rejection(UsersCollectionRejection))] +/// struct UsersCollection; +/// +/// #[derive(Default)] +/// struct UsersCollectionRejection; +/// +/// impl IntoResponse for UsersCollectionRejection { +/// fn into_response(self) -> Response { +/// # ().into_response() +/// // ... +/// } +/// } +/// ``` +/// /// [`FromRequest`]: axum::extract::FromRequest /// [`RouterExt::typed_get`]: super::RouterExt::typed_get /// [`RouterExt::typed_post`]: super::RouterExt::typed_post /// [`Path`]: axum::extract::Path /// [`Display`]: std::fmt::Display /// [`Deserialize`]: serde::Deserialize +/// [`PathRejection`]: axum::extract::rejection::PathRejection pub trait TypedPath: std::fmt::Display { /// The path with optional captures such as `/users/:id`. const PATH: &'static str; diff --git a/axum-macros/CHANGELOG.md b/axum-macros/CHANGELOG.md index a983d92b8a..e952982816 100644 --- a/axum-macros/CHANGELOG.md +++ b/axum-macros/CHANGELOG.md @@ -10,10 +10,13 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - **fixed:** `Option` and `Result` are now supported in typed path route handler parameters ([#1001]) - **fixed:** Support wildcards in typed paths ([#1003]) - **added:** Support `#[derive(FromRequest)]` on enums using `#[from_request(via(OtherExtractor))]` ([#1009]) +- **added:** Support using a custom rejection type for `#[derive(TypedPath)]` + instead of `PathRejection` ([#1012]) [#1001]: https://github.com/tokio-rs/axum/pull/1001 [#1003]: https://github.com/tokio-rs/axum/pull/1003 [#1009]: https://github.com/tokio-rs/axum/pull/1009 +[#1012]: https://github.com/tokio-rs/axum/pull/1012 # 0.2.0 (31. March, 2022) diff --git a/axum-macros/src/typed_path.rs b/axum-macros/src/typed_path.rs index 7d91ac6880..3caeb0bd2c 100644 --- a/axum-macros/src/typed_path.rs +++ b/axum-macros/src/typed_path.rs @@ -1,6 +1,6 @@ use proc_macro2::{Span, TokenStream}; use quote::{format_ident, quote, quote_spanned}; -use syn::{ItemStruct, LitStr}; +use syn::{parse::Parse, ItemStruct, LitStr, Token}; pub(crate) fn expand(item_struct: ItemStruct) -> syn::Result { let ItemStruct { @@ -18,52 +18,79 @@ pub(crate) fn expand(item_struct: ItemStruct) -> syn::Result { )); } - let Attrs { path } = parse_attrs(attrs)?; + let Attrs { path, rejection } = parse_attrs(attrs)?; match fields { syn::Fields::Named(_) => { let segments = parse_path(&path)?; - Ok(expand_named_fields(ident, path, &segments)) + Ok(expand_named_fields(ident, path, &segments, rejection)) } syn::Fields::Unnamed(fields) => { let segments = parse_path(&path)?; - expand_unnamed_fields(fields, ident, path, &segments) + expand_unnamed_fields(fields, ident, path, &segments, rejection) } - syn::Fields::Unit => expand_unit_fields(ident, path), + syn::Fields::Unit => expand_unit_fields(ident, path, rejection), } } +mod kw { + syn::custom_keyword!(rejection); +} + struct Attrs { path: LitStr, + rejection: Option, +} + +impl Parse for Attrs { + fn parse(input: syn::parse::ParseStream) -> syn::Result { + let path = input.parse()?; + + let rejection = if input.is_empty() { + None + } else { + let _: Token![,] = input.parse()?; + let _: kw::rejection = input.parse()?; + + let content; + syn::parenthesized!(content in input); + Some(content.parse()?) + }; + + Ok(Self { path, rejection }) + } } fn parse_attrs(attrs: &[syn::Attribute]) -> syn::Result { - let mut path = None; + let mut out = None; for attr in attrs { if attr.path.is_ident("typed_path") { - if path.is_some() { + if out.is_some() { return Err(syn::Error::new_spanned( attr, "`typed_path` specified more than once", )); } else { - path = Some(attr.parse_args()?); + out = Some(attr.parse_args()?); } } } - Ok(Attrs { - path: path.ok_or_else(|| { - syn::Error::new( - Span::call_site(), - "missing `#[typed_path(\"...\")]` attribute", - ) - })?, + out.ok_or_else(|| { + syn::Error::new( + Span::call_site(), + "missing `#[typed_path(\"...\")]` attribute", + ) }) } -fn expand_named_fields(ident: &syn::Ident, path: LitStr, segments: &[Segment]) -> TokenStream { +fn expand_named_fields( + ident: &syn::Ident, + path: LitStr, + segments: &[Segment], + rejection: Option, +) -> TokenStream { let format_str = format_str_from_path(segments); let captures = captures_from_path(segments); @@ -88,6 +115,9 @@ fn expand_named_fields(ident: &syn::Ident, path: LitStr, segments: &[Segment]) - } }; + let rejection_assoc_type = rejection_assoc_type(&rejection); + let map_err_rejection = map_err_rejection(&rejection); + let from_request_impl = quote! { #[::axum::async_trait] #[automatically_derived] @@ -95,10 +125,13 @@ fn expand_named_fields(ident: &syn::Ident, path: LitStr, segments: &[Segment]) - where B: Send, { - type Rejection = <::axum::extract::Path as ::axum::extract::FromRequest>::Rejection; + type Rejection = #rejection_assoc_type; async fn from_request(req: &mut ::axum::extract::RequestParts) -> ::std::result::Result { - ::axum::extract::Path::from_request(req).await.map(|path| path.0) + ::axum::extract::Path::from_request(req) + .await + .map(|path| path.0) + #map_err_rejection } } }; @@ -115,6 +148,7 @@ fn expand_unnamed_fields( ident: &syn::Ident, path: LitStr, segments: &[Segment], + rejection: Option, ) -> syn::Result { let num_captures = segments .iter() @@ -177,6 +211,9 @@ fn expand_unnamed_fields( } }; + let rejection_assoc_type = rejection_assoc_type(&rejection); + let map_err_rejection = map_err_rejection(&rejection); + let from_request_impl = quote! { #[::axum::async_trait] #[automatically_derived] @@ -184,10 +221,13 @@ fn expand_unnamed_fields( where B: Send, { - type Rejection = <::axum::extract::Path as ::axum::extract::FromRequest>::Rejection; + type Rejection = #rejection_assoc_type; async fn from_request(req: &mut ::axum::extract::RequestParts) -> ::std::result::Result { - ::axum::extract::Path::from_request(req).await.map(|path| path.0) + ::axum::extract::Path::from_request(req) + .await + .map(|path| path.0) + #map_err_rejection } } }; @@ -207,7 +247,11 @@ fn simple_pluralize(count: usize, word: &str) -> String { } } -fn expand_unit_fields(ident: &syn::Ident, path: LitStr) -> syn::Result { +fn expand_unit_fields( + ident: &syn::Ident, + path: LitStr, + rejection: Option, +) -> syn::Result { for segment in parse_path(&path)? { match segment { Segment::Capture(_, span) => { @@ -236,6 +280,21 @@ fn expand_unit_fields(ident: &syn::Ident, path: LitStr) -> syn::Result::default()) + } + } else { + quote! { + Err(::axum::http::StatusCode::NOT_FOUND) + } + }; + let from_request_impl = quote! { #[::axum::async_trait] #[automatically_derived] @@ -243,13 +302,13 @@ fn expand_unit_fields(ident: &syn::Ident, path: LitStr) -> syn::Result) -> ::std::result::Result { if req.uri().path() == ::PATH { Ok(Self) } else { - Err(::axum::http::StatusCode::NOT_FOUND) + #create_rejection } } } @@ -314,6 +373,33 @@ enum Segment { Static(String), } +fn path_rejection() -> TokenStream { + quote! { + <::axum::extract::Path as ::axum::extract::FromRequest>::Rejection + } +} + +fn rejection_assoc_type(rejection: &Option) -> TokenStream { + match rejection { + Some(rejection) => quote! { #rejection }, + None => path_rejection(), + } +} + +fn map_err_rejection(rejection: &Option) -> TokenStream { + rejection + .as_ref() + .map(|rejection| { + let path_rejection = path_rejection(); + quote! { + .map_err(|rejection| { + <#rejection as ::std::convert::From<#path_rejection>>::from(rejection) + }) + } + }) + .unwrap_or_default() +} + #[test] fn ui() { #[rustversion::stable] diff --git a/axum-macros/tests/typed_path/pass/customize_rejection.rs b/axum-macros/tests/typed_path/pass/customize_rejection.rs new file mode 100644 index 0000000000..41aa7e614e --- /dev/null +++ b/axum-macros/tests/typed_path/pass/customize_rejection.rs @@ -0,0 +1,47 @@ +use axum::{ + extract::rejection::PathRejection, + response::{IntoResponse, Response}, +}; +use axum_extra::routing::{RouterExt, TypedPath}; +use serde::Deserialize; + +#[derive(TypedPath, Deserialize)] +#[typed_path("/:foo", rejection(MyRejection))] +struct MyPathNamed { + foo: String, +} + +#[derive(TypedPath, Deserialize)] +#[typed_path("/", rejection(MyRejection))] +struct MyPathUnit; + +#[derive(TypedPath, Deserialize)] +#[typed_path("/:foo", rejection(MyRejection))] +struct MyPathUnnamed(String); + +struct MyRejection; + +impl IntoResponse for MyRejection { + fn into_response(self) -> Response { + ().into_response() + } +} + +impl From for MyRejection { + fn from(_: PathRejection) -> Self { + Self + } +} + +impl Default for MyRejection { + fn default() -> Self { + Self + } +} + +fn main() { + axum::Router::::new() + .typed_get(|_: Result| async {}) + .typed_post(|_: Result| async {}) + .typed_put(|_: Result| async {}); +}