Skip to content

Commit

Permalink
Implement IntoResponse for MultipartError (#1861)
Browse files Browse the repository at this point in the history
  • Loading branch information
davidpdrsn authored Mar 21, 2023
1 parent 8e1eb89 commit 03e8bc7
Show file tree
Hide file tree
Showing 5 changed files with 165 additions and 5 deletions.
2 changes: 1 addition & 1 deletion axum-extra/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ and this project adheres to [Semantic Versioning].

# Unreleased

- None.
- **added:** Implement `IntoResponse` for `MultipartError` ([#1861])

# 0.7.1 (13. March, 2023)

Expand Down
1 change: 1 addition & 0 deletions axum-extra/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ axum = { path = "../axum", version = "0.6.9", default-features = false }
bytes = "1.1.0"
futures-util = { version = "0.3", default-features = false, features = ["alloc"] }
http = "0.2"
http-body = "0.4.4"
mime = "0.3"
pin-project-lite = "0.2"
tokio = "1.19"
Expand Down
82 changes: 80 additions & 2 deletions axum-extra/src/extract/multipart.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,10 @@ use axum::{
use futures_util::stream::Stream;
use http::{
header::{HeaderMap, CONTENT_TYPE},
Request,
Request, StatusCode,
};
use std::{
error::Error,
fmt,
pin::Pin,
task::{Context, Poll},
Expand Down Expand Up @@ -246,6 +247,57 @@ impl MultipartError {
fn from_multer(multer: multer::Error) -> Self {
Self { source: multer }
}

/// Get the response body text used for this rejection.
pub fn body_text(&self) -> String {
self.source.to_string()
}

/// Get the status code used for this rejection.
pub fn status(&self) -> http::StatusCode {
status_code_from_multer_error(&self.source)
}
}

fn status_code_from_multer_error(err: &multer::Error) -> StatusCode {
match err {
multer::Error::UnknownField { .. }
| multer::Error::IncompleteFieldData { .. }
| multer::Error::IncompleteHeaders
| multer::Error::ReadHeaderFailed(..)
| multer::Error::DecodeHeaderName { .. }
| multer::Error::DecodeContentType(..)
| multer::Error::NoBoundary
| multer::Error::DecodeHeaderValue { .. }
| multer::Error::NoMultipart
| multer::Error::IncompleteStream => StatusCode::BAD_REQUEST,
multer::Error::FieldSizeExceeded { .. } | multer::Error::StreamSizeExceeded { .. } => {
StatusCode::PAYLOAD_TOO_LARGE
}
multer::Error::StreamReadFailed(err) => {
if let Some(err) = err.downcast_ref::<multer::Error>() {
return status_code_from_multer_error(err);
}

if err
.downcast_ref::<axum::Error>()
.and_then(|err| err.source())
.and_then(|err| err.downcast_ref::<http_body::LengthLimitError>())
.is_some()
{
return StatusCode::PAYLOAD_TOO_LARGE;
}

StatusCode::INTERNAL_SERVER_ERROR
}
_ => StatusCode::INTERNAL_SERVER_ERROR,
}
}

impl IntoResponse for MultipartError {
fn into_response(self) -> Response {
(self.status(), self.body_text()).into_response()
}
}

impl fmt::Display for MultipartError {
Expand Down Expand Up @@ -357,7 +409,9 @@ impl std::error::Error for InvalidBoundary {}
mod tests {
use super::*;
use crate::test_helpers::*;
use axum::{body::Body, response::IntoResponse, routing::post, Router};
use axum::{
body::Body, extract::DefaultBodyLimit, response::IntoResponse, routing::post, Router,
};

#[tokio::test]
async fn content_type_with_encoding() {
Expand Down Expand Up @@ -395,4 +449,28 @@ mod tests {
async fn handler(_: Multipart) {}
let _app: Router<(), http_body::Limited<Body>> = Router::new().route("/", post(handler));
}

#[tokio::test]
async fn body_too_large() {
const BYTES: &[u8] = "<!doctype html><title>🦀</title>".as_bytes();

async fn handle(mut multipart: Multipart) -> Result<(), MultipartError> {
while let Some(field) = multipart.next_field().await? {
field.bytes().await?;
}
Ok(())
}

let app = Router::new()
.route("/", post(handle))
.layer(DefaultBodyLimit::max(BYTES.len() - 1));

let client = TestClient::new(app);

let form =
reqwest::multipart::Form::new().part("file", reqwest::multipart::Part::bytes(BYTES));

let res = client.post("/").multipart(form).send().await;
assert_eq!(res.status(), StatusCode::PAYLOAD_TOO_LARGE);
}
}
4 changes: 3 additions & 1 deletion axum/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

# Unreleased

- None.
- **added:** Implement `IntoResponse` for `MultipartError` ([#1861])

[#1861]: https://github.com/tokio-rs/axum/pull/1861

# 0.6.11 (13. March, 2023)

Expand Down
81 changes: 80 additions & 1 deletion axum/src/extract/multipart.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,12 @@ use super::{BodyStream, FromRequest};
use crate::body::{Bytes, HttpBody};
use crate::BoxError;
use async_trait::async_trait;
use axum_core::response::{IntoResponse, Response};
use axum_core::RequestExt;
use futures_util::stream::Stream;
use http::header::{HeaderMap, CONTENT_TYPE};
use http::Request;
use http::{Request, StatusCode};
use std::error::Error;
use std::{
fmt,
pin::Pin,
Expand Down Expand Up @@ -209,6 +211,51 @@ impl MultipartError {
fn from_multer(multer: multer::Error) -> Self {
Self { source: multer }
}

/// Get the response body text used for this rejection.
pub fn body_text(&self) -> String {
self.source.to_string()
}

/// Get the status code used for this rejection.
pub fn status(&self) -> http::StatusCode {
status_code_from_multer_error(&self.source)
}
}

fn status_code_from_multer_error(err: &multer::Error) -> StatusCode {
match err {
multer::Error::UnknownField { .. }
| multer::Error::IncompleteFieldData { .. }
| multer::Error::IncompleteHeaders
| multer::Error::ReadHeaderFailed(..)
| multer::Error::DecodeHeaderName { .. }
| multer::Error::DecodeContentType(..)
| multer::Error::NoBoundary
| multer::Error::DecodeHeaderValue { .. }
| multer::Error::NoMultipart
| multer::Error::IncompleteStream => StatusCode::BAD_REQUEST,
multer::Error::FieldSizeExceeded { .. } | multer::Error::StreamSizeExceeded { .. } => {
StatusCode::PAYLOAD_TOO_LARGE
}
multer::Error::StreamReadFailed(err) => {
if let Some(err) = err.downcast_ref::<multer::Error>() {
return status_code_from_multer_error(err);
}

if err
.downcast_ref::<crate::Error>()
.and_then(|err| err.source())
.and_then(|err| err.downcast_ref::<http_body::LengthLimitError>())
.is_some()
{
return StatusCode::PAYLOAD_TOO_LARGE;
}

StatusCode::INTERNAL_SERVER_ERROR
}
_ => StatusCode::INTERNAL_SERVER_ERROR,
}
}

impl fmt::Display for MultipartError {
Expand All @@ -223,6 +270,12 @@ impl std::error::Error for MultipartError {
}
}

impl IntoResponse for MultipartError {
fn into_response(self) -> Response {
(self.status(), self.body_text()).into_response()
}
}

fn parse_boundary(headers: &HeaderMap) -> Option<String> {
let content_type = headers.get(CONTENT_TYPE)?.to_str().ok()?;
multer::parse_boundary(content_type).ok()
Expand All @@ -247,6 +300,8 @@ define_rejection! {

#[cfg(test)]
mod tests {
use axum_core::extract::DefaultBodyLimit;

use super::*;
use crate::{body::Body, response::IntoResponse, routing::post, test_helpers::*, Router};

Expand Down Expand Up @@ -286,4 +341,28 @@ mod tests {
async fn handler(_: Multipart) {}
let _app: Router<(), http_body::Limited<Body>> = Router::new().route("/", post(handler));
}

#[crate::test]
async fn body_too_large() {
const BYTES: &[u8] = "<!doctype html><title>🦀</title>".as_bytes();

async fn handle(mut multipart: Multipart) -> Result<(), MultipartError> {
while let Some(field) = multipart.next_field().await? {
field.bytes().await?;
}
Ok(())
}

let app = Router::new()
.route("/", post(handle))
.layer(DefaultBodyLimit::max(BYTES.len() - 1));

let client = TestClient::new(app);

let form =
reqwest::multipart::Form::new().part("file", reqwest::multipart::Part::bytes(BYTES));

let res = client.post("/").multipart(form).send().await;
assert_eq!(res.status(), StatusCode::PAYLOAD_TOO_LARGE);
}
}

0 comments on commit 03e8bc7

Please sign in to comment.