Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

axum-extra: WithRejection #1262

Merged
merged 18 commits into from
Aug 17, 2022
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
Show all changes
18 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions axum-extra/src/extract/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ pub mod cookie;
#[cfg(feature = "query")]
mod query;

mod with_rejection;

pub use self::cached::Cached;

#[cfg(feature = "cookie")]
Expand All @@ -31,3 +33,5 @@ pub use self::query::Query;
#[cfg(feature = "json-lines")]
#[doc(no_inline)]
pub use crate::json_lines::JsonLines;

pub use self::with_rejection::WithRejection;
138 changes: 138 additions & 0 deletions axum-extra/src/extract/with_rejection.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
use axum::async_trait;
use axum::extract::{FromRequest, RequestParts};
use axum::response::IntoResponse;
use std::marker::PhantomData;
use std::ops::{Deref, DerefMut};

/// Extractor for customizing extractor rejections
///
/// `WithRejection` wraps another extractor and gives you the result. If the
/// extraction fails, the `Rejection` is transformed into `R` and returned as a
/// response
///
/// `E` is expected to implement [`FromRequest`]
///
/// `R` is expected to implement [`IntoResponse`] and [`From<E::Rejection>`]
///
///
/// # Example
///
/// ```rust
/// use axum::extract::rejection::JsonRejection;
/// use axum::response::{Response, IntoResponse};
/// use axum::Json;
/// use axum_extra::extract::WithRejection;
/// use serde::Deserialize;
///
/// struct MyRejection { /* ... */ }
///
/// impl From<JsonRejection> for MyRejection {
/// fn from(_:JsonRejection) -> MyRejection {
Altair-Bueno marked this conversation as resolved.
Show resolved Hide resolved
/// // ...
/// # todo!()
/// }
/// }
///
/// impl IntoResponse for MyRejection {
/// fn into_response(self) -> Response {
/// // ...
/// # todo!()
/// }
/// }
/// #[derive(Debug, Deserialize)]
/// struct Person { /* ... */ }
///
/// async fn handler(
/// // If the `Json` extractor ever fails, `MyRejection` will be sent to the
/// // client using the `IntoResponse` impl
/// WithRejection(Json(Person), _): WithRejection<Json<Person>, MyRejection>
/// ) { /* ... */ }
/// ```
Altair-Bueno marked this conversation as resolved.
Show resolved Hide resolved
///
/// For a full example see the [customize-extractor-error] example
///
/// [customize-extractor-error]: https://github.com/tokio-rs/axum/blob/main/examples/customize-extractor-error/src/main.rs
/// [`FromRequest`]: axum::extract::FromRequest
/// [`IntoResponse`]: axum::response::IntoResponse
/// [`From<E::Rejection>`]: std::convert::From
#[derive(Debug, Clone, Copy, Default)]
davidpdrsn marked this conversation as resolved.
Show resolved Hide resolved
pub struct WithRejection<E, R>(pub E, pub PhantomData<R>);
davidpdrsn marked this conversation as resolved.
Show resolved Hide resolved

impl<E, R> WithRejection<E, R> {
/// Returns the wrapped extractor
fn into_inner(self) -> E {
self.0
}
}

impl<E, R> Deref for WithRejection<E, R> {
type Target = E;

fn deref(&self) -> &Self::Target {
&self.0
}
}

impl<E, R> DerefMut for WithRejection<E, R> {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.0
}
}

#[async_trait]
impl<B, E, R> FromRequest<B> for WithRejection<E, R>
where
B: Send,
E: FromRequest<B>,
R: From<E::Rejection> + IntoResponse,
{
type Rejection = R;

async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> {
let extractor = req.extract::<E>().await?;
Ok(WithRejection(extractor, PhantomData))
}
}

#[cfg(test)]
mod tests {
use axum::http::Request;
use axum::response::Response;

use super::*;

#[tokio::test]
async fn extractor_rejection_is_transformed() {
struct TestExtractor;
struct TestRejection;

#[async_trait]
impl<B: Send> FromRequest<B> for TestExtractor {
type Rejection = ();

async fn from_request(_: &mut RequestParts<B>) -> Result<Self, Self::Rejection> {
Err(())
}
}

impl IntoResponse for TestRejection {
fn into_response(self) -> Response {
().into_response()
}
}

impl From<()> for TestRejection {
fn from(_: ()) -> Self {
TestRejection
}
}

let mut req = RequestParts::new(Request::new(()));

let result = req
.extract::<WithRejection<TestExtractor, TestRejection>>()
.await;

assert!(matches!(result, Err(TestRejection)))
}
}
1 change: 1 addition & 0 deletions examples/customize-extractor-error/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ publish = false

[dependencies]
axum = { path = "../../axum" }
axum-extra = { path = "../../axum-extra" }
serde = { version = "1.0", features = ["derive"] }
serde_json = "1.0"
tokio = { version = "1.0", features = ["full"] }
Expand Down
95 changes: 48 additions & 47 deletions examples/customize-extractor-error/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,13 @@
//! ```

use axum::{
async_trait,
davidpdrsn marked this conversation as resolved.
Show resolved Hide resolved
extract::{rejection::JsonRejection, FromRequest, RequestParts},
http::StatusCode,
routing::post,
BoxError, Router,
extract::rejection::JsonRejection, http::StatusCode, response::IntoResponse, routing::post,
Json, Router,
};
use serde::{de::DeserializeOwned, Deserialize};
use serde_json::{json, Value};
use std::{borrow::Cow, net::SocketAddr};
use axum_extra::extract::WithRejection;
use serde::Deserialize;
use serde_json::json;
use std::net::SocketAddr;
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};

#[tokio::main]
Expand All @@ -38,7 +36,12 @@ async fn main() {
.unwrap();
}

async fn handler(Json(user): Json<User>) {
async fn handler(
// `WithRejection` will extract `Json<User>` from the request. If the
// extraction fails, a `MyRejection` will be created from `JsonResponse` and
// returned to the client
WithRejection(Json(user), _): WithRejection<Json<User>, MyRejection>,
) {
dbg!(&user);
}

Expand All @@ -49,46 +52,44 @@ struct User {
username: String,
}

// We define our own `Json` extractor that customizes the error from `axum::Json`
struct Json<T>(T);

#[async_trait]
impl<B, T> FromRequest<B> for Json<T>
where
// these trait bounds are copied from `impl FromRequest for axum::Json`
T: DeserializeOwned,
B: axum::body::HttpBody + Send,
B::Data: Send,
B::Error: Into<BoxError>,
{
type Rejection = (StatusCode, axum::Json<Value>);
// Define your own custom rejection
#[derive(Debug)]
struct MyRejection {
body: String,
status: StatusCode,
}

async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> {
match axum::Json::<T>::from_request(req).await {
Ok(value) => Ok(Self(value.0)),
Err(rejection) => {
// convert the error from `axum::Json` into whatever we want
let (status, body): (_, Cow<'_, str>) = match rejection {
JsonRejection::JsonDataError(err) => (
StatusCode::BAD_REQUEST,
format!("Invalid JSON request: {}", err).into(),
),
JsonRejection::MissingJsonContentType(err) => {
(StatusCode::BAD_REQUEST, err.to_string().into())
}
err => (
StatusCode::INTERNAL_SERVER_ERROR,
format!("Unknown internal error: {}", err).into(),
),
};
// `IntoResponse` is required for your custom rejection type
impl IntoResponse for MyRejection {
fn into_response(self) -> axum::response::Response {
let Self { body, status } = self;
(
status,
// we use `axum::Json` here to generate a JSON response
// body but you can use whatever response you want
axum::Json(json!({ "error": body })),
)
.into_response()
}
}

Err((
status,
// we use `axum::Json` here to generate a JSON response
// body but you can use whatever response you want
axum::Json(json!({ "error": body })),
))
// Implement `From` for any Rejection type you want
impl From<JsonRejection> for MyRejection {
fn from(rejection: JsonRejection) -> Self {
// convert the error from `axum::Json` into whatever we want
let (status, body) = match rejection {
JsonRejection::JsonDataError(err) => (
StatusCode::BAD_REQUEST,
format!("Invalid JSON request: {}", err),
),
JsonRejection::MissingJsonContentType(err) => {
(StatusCode::BAD_REQUEST, err.to_string())
}
}
err => (
StatusCode::INTERNAL_SERVER_ERROR,
format!("Unknown internal error: {}", err),
),
};
Self { body, status }
}
}