Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Change HeaderMap extractor to clone the headers #698

Merged
merged 9 commits into from
Jan 11, 2022
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 8 additions & 27 deletions axum-core/src/extract/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ pub struct RequestParts<B> {
method: Method,
uri: Uri,
version: Version,
headers: Option<HeaderMap>,
headers: HeaderMap,
extensions: Option<Extensions>,
body: Option<B>,
}
Expand Down Expand Up @@ -107,7 +107,7 @@ impl<B> RequestParts<B> {
method,
uri,
version,
headers: Some(headers),
headers,
extensions: Some(extensions),
body: Some(body),
}
Expand All @@ -117,22 +117,19 @@ impl<B> RequestParts<B> {
///
/// Fails if
///
/// - The full [`HeaderMap`] has been extracted, that is [`take_headers`]
/// have been called.
/// - The full [`Extensions`] has been extracted, that is
/// [`take_extensions`] have been called.
/// - The request body has been extracted, that is [`take_body`] have been
/// called.
///
/// [`take_headers`]: RequestParts::take_headers
/// [`take_extensions`]: RequestParts::take_extensions
/// [`take_body`]: RequestParts::take_body
pub fn try_into_request(self) -> Result<Request<B>, RequestAlreadyExtracted> {
let Self {
method,
uri,
version,
mut headers,
headers,
mut extensions,
mut body,
} = self;
Expand All @@ -148,14 +145,7 @@ impl<B> RequestParts<B> {
*req.method_mut() = method;
*req.uri_mut() = uri;
*req.version_mut() = version;

if let Some(headers) = headers.take() {
*req.headers_mut() = headers;
} else {
return Err(RequestAlreadyExtracted::HeadersAlreadyExtracted(
HeadersAlreadyExtracted,
));
}
*req.headers_mut() = headers;

if let Some(extensions) = extensions.take() {
*req.extensions_mut() = extensions;
Expand Down Expand Up @@ -199,22 +189,13 @@ impl<B> RequestParts<B> {
}

/// Gets a reference to the request headers.
///
/// Returns `None` if the headers has been taken by another extractor.
pub fn headers(&self) -> Option<&HeaderMap> {
self.headers.as_ref()
pub fn headers(&self) -> &HeaderMap {
&self.headers
}

/// Gets a mutable reference to the request headers.
///
/// Returns `None` if the headers has been taken by another extractor.
pub fn headers_mut(&mut self) -> Option<&mut HeaderMap> {
self.headers.as_mut()
}

/// Takes the headers out of the request, leaving a `None` in its place.
pub fn take_headers(&mut self) -> Option<HeaderMap> {
self.headers.take()
pub fn headers_mut(&mut self) -> &mut HeaderMap {
&mut self.headers
}

/// Gets a reference to the request extensions.
Expand Down
9 changes: 0 additions & 9 deletions axum-core/src/extract/rejection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,6 @@ define_rejection! {
pub struct BodyAlreadyExtracted;
}

define_rejection! {
#[status = INTERNAL_SERVER_ERROR]
#[body = "Headers taken by other extractor"]
/// Rejection used if the headers has been taken by another extractor.
pub struct HeadersAlreadyExtracted;
}

define_rejection! {
#[status = INTERNAL_SERVER_ERROR]
#[body = "Extensions taken by other extractor"]
Expand Down Expand Up @@ -47,7 +40,6 @@ composite_rejection! {
/// [`Request<_>`]: http::Request
pub enum RequestAlreadyExtracted {
BodyAlreadyExtracted,
HeadersAlreadyExtracted,
ExtensionsAlreadyExtracted,
}
}
Expand Down Expand Up @@ -79,7 +71,6 @@ composite_rejection! {
///
/// Contains one variant for each way the [`http::request::Parts`] extractor can fail.
pub enum RequestPartsAlreadyExtracted {
HeadersAlreadyExtracted,
ExtensionsAlreadyExtracted,
}
}
11 changes: 7 additions & 4 deletions axum-core/src/extract/request_parts.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ where
method: req.method.clone(),
version: req.version,
uri: req.uri.clone(),
headers: None,
headers: HeaderMap::new(),
extensions: None,
body: None,
},
Expand Down Expand Up @@ -70,10 +70,10 @@ impl<B> FromRequest<B> for HeaderMap
where
B: Send,
{
type Rejection = HeadersAlreadyExtracted;
type Rejection = Infallible;

async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> {
req.take_headers().ok_or(HeadersAlreadyExtracted)
Ok(req.headers().clone())
}
}

Expand Down Expand Up @@ -143,7 +143,10 @@ where
let method = unwrap_infallible(Method::from_request(req).await);
let uri = unwrap_infallible(Uri::from_request(req).await);
let version = unwrap_infallible(Version::from_request(req).await);
let headers = HeaderMap::from_request(req).await?;
let headers = match HeaderMap::from_request(req).await {
Ok(headers) => headers,
Err(err) => match err {},
};
let extensions = Extensions::from_request(req).await?;

let mut temp_request = Request::new(());
Expand Down
6 changes: 1 addition & 5 deletions axum/src/docs/extract.md
Original file line number Diff line number Diff line change
Expand Up @@ -320,10 +320,6 @@ async fn handler(result: Result<Json<Value>, JsonRejection>) -> impl IntoRespons
StatusCode::INTERNAL_SERVER_ERROR,
"Failed to buffer request body".to_string(),
)),
JsonRejection::HeadersAlreadyExtracted(_) => Err((
StatusCode::INTERNAL_SERVER_ERROR,
"Headers already extracted".to_string(),
)),
// we must provide a catch-all case since `JsonRejection` is marked
// `#[non_exhaustive]`
_ => Err((
Expand Down Expand Up @@ -377,7 +373,7 @@ where
type Rejection = (StatusCode, &'static str);

async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> {
let user_agent = req.headers().and_then(|headers| headers.get(USER_AGENT));
let user_agent = req.headers().get(USER_AGENT);
davidpdrsn marked this conversation as resolved.
Show resolved Hide resolved

if let Some(user_agent) = user_agent {
Ok(ExtractUserAgent(user_agent.clone()))
Expand Down
9 changes: 1 addition & 8 deletions axum/src/extract/content_length_limit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,14 +39,7 @@ where
type Rejection = ContentLengthLimitRejection<T::Rejection>;

async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> {
let content_length = req
.headers()
.ok_or_else(|| {
ContentLengthLimitRejection::HeadersAlreadyExtracted(
HeadersAlreadyExtracted::default(),
)
})?
.get(http::header::CONTENT_LENGTH);
let content_length = req.headers().get(http::header::CONTENT_LENGTH);

let content_length =
content_length.and_then(|value| value.to_str().ok()?.parse::<u64>().ok());
Expand Down
3 changes: 1 addition & 2 deletions axum/src/extract/extractor_middleware.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ use tower_service::Service;
/// async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> {
/// let auth_header = req
/// .headers()
/// .and_then(|headers| headers.get(http::header::AUTHORIZATION))
/// .get(http::header::AUTHORIZATION)
/// .and_then(|value| value.to_str().ok());
///
/// match auth_header {
Expand Down Expand Up @@ -291,7 +291,6 @@ mod tests {
async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> {
if let Some(auth) = req
.headers()
.expect("headers already extracted")
.get("authorization")
.and_then(|v| v.to_str().ok())
{
Expand Down
2 changes: 1 addition & 1 deletion axum/src/extract/form.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ where
.map_err(FailedToDeserializeQueryString::new::<T, _>)?;
Ok(Form(value))
} else {
if !has_content_type(req, &mime::APPLICATION_WWW_FORM_URLENCODED)? {
if !has_content_type(req, &mime::APPLICATION_WWW_FORM_URLENCODED) {
return Err(InvalidFormContentType.into());
}

Expand Down
14 changes: 5 additions & 9 deletions axum/src/extract/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -78,24 +78,20 @@ pub use self::typed_header::TypedHeader;
pub(crate) fn has_content_type<B>(
req: &RequestParts<B>,
expected_content_type: &mime::Mime,
) -> Result<bool, HeadersAlreadyExtracted> {
let content_type = if let Some(content_type) = req
.headers()
.ok_or_else(HeadersAlreadyExtracted::default)?
.get(header::CONTENT_TYPE)
{
) -> bool {
let content_type = if let Some(content_type) = req.headers().get(header::CONTENT_TYPE) {
content_type
} else {
return Ok(false);
return false;
};

let content_type = if let Ok(content_type) = content_type.to_str() {
content_type
} else {
return Ok(false);
return false;
};

Ok(content_type.starts_with(expected_content_type.as_ref()))
content_type.starts_with(expected_content_type.as_ref())
}

pub(crate) fn take_body<B>(req: &mut RequestParts<B>) -> Result<B, BodyAlreadyExtracted> {
Expand Down
3 changes: 1 addition & 2 deletions axum/src/extract/multipart.rs
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ where

async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> {
let stream = BodyStream::from_request(req).await?;
let headers = req.headers().ok_or_else(HeadersAlreadyExtracted::default)?;
let headers = req.headers();
let boundary = parse_boundary(headers).ok_or(InvalidBoundary)?;
let multipart = multer::Multipart::new(stream, boundary);
Ok(Self { inner: multipart })
Expand Down Expand Up @@ -179,7 +179,6 @@ composite_rejection! {
pub enum MultipartRejection {
BodyAlreadyExtracted,
InvalidBoundary,
HeadersAlreadyExtracted,
}
}

Expand Down
7 changes: 0 additions & 7 deletions axum/src/extract/rejection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,6 @@ composite_rejection! {
InvalidFormContentType,
FailedToDeserializeQueryString,
BytesRejection,
HeadersAlreadyExtracted,
}
}

Expand All @@ -139,7 +138,6 @@ composite_rejection! {
InvalidJsonBody,
MissingJsonContentType,
BytesRejection,
HeadersAlreadyExtracted,
}
}

Expand Down Expand Up @@ -195,8 +193,6 @@ pub enum ContentLengthLimitRejection<T> {
#[allow(missing_docs)]
LengthRequired(LengthRequired),
#[allow(missing_docs)]
HeadersAlreadyExtracted(HeadersAlreadyExtracted),
#[allow(missing_docs)]
Inner(T),
}

Expand All @@ -208,7 +204,6 @@ where
match self {
Self::PayloadTooLarge(inner) => inner.into_response(),
Self::LengthRequired(inner) => inner.into_response(),
Self::HeadersAlreadyExtracted(inner) => inner.into_response(),
Self::Inner(inner) => inner.into_response(),
}
}
Expand All @@ -222,7 +217,6 @@ where
match self {
Self::PayloadTooLarge(inner) => inner.fmt(f),
Self::LengthRequired(inner) => inner.fmt(f),
Self::HeadersAlreadyExtracted(inner) => inner.fmt(f),
Self::Inner(inner) => inner.fmt(f),
}
}
Expand All @@ -236,7 +230,6 @@ where
match self {
Self::PayloadTooLarge(inner) => Some(inner),
Self::LengthRequired(inner) => Some(inner),
Self::HeadersAlreadyExtracted(inner) => Some(inner),
Self::Inner(inner) => Some(inner),
}
}
Expand Down
11 changes: 1 addition & 10 deletions axum/src/extract/typed_header.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,16 +44,7 @@ where
type Rejection = TypedHeaderRejection;

async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> {
let headers = if let Some(headers) = req.headers() {
headers
} else {
return Err(TypedHeaderRejection {
name: T::name(),
reason: TypedHeaderRejectionReason::Missing,
});
};

match headers.typed_try_get::<T>() {
match req.headers().typed_try_get::<T>() {
Ok(Some(value)) => Ok(Self(value)),
Ok(None) => Err(TypedHeaderRejection {
name: T::name(),
Expand Down
Loading