diff --git a/axum-core/CHANGELOG.md b/axum-core/CHANGELOG.md index 9d8e745644..24d5962193 100644 --- a/axum-core/CHANGELOG.md +++ b/axum-core/CHANGELOG.md @@ -7,7 +7,22 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 # Unreleased -- None. +- **breaking:** Using `HeaderMap` as an extractor will no longer remove the headers and thus + they'll still be accessible to other extractors, such as `axum::extract::Json`. Instead + `HeaderMap` will clone the headers. You should prefer to use `TypedHeader` to extract only the + headers you need ([#698]) + + This includes these breaking changes: + - `RequestParts::take_headers` has been removed. + - `RequestParts::headers` returns `&HeaderMap`. + - `RequestParts::headers_mut` returns `&mut HeaderMap`. + - `HeadersAlreadyExtracted` has been removed. + - The `HeadersAlreadyExtracted` variant has been removed from these rejections: + - `RequestAlreadyExtracted` + - `RequestPartsAlreadyExtracted` + - `>::Error` has been changed to `std::convert::Infallible`. + +[#698]: https://github.com/tokio-rs/axum/pull/698 # 0.1.1 (06. December, 2021) diff --git a/axum-core/src/extract/mod.rs b/axum-core/src/extract/mod.rs index 6a4d2662f1..5d7271fc10 100644 --- a/axum-core/src/extract/mod.rs +++ b/axum-core/src/extract/mod.rs @@ -77,7 +77,7 @@ pub struct RequestParts { method: Method, uri: Uri, version: Version, - headers: Option, + headers: HeaderMap, extensions: Option, body: Option, } @@ -107,7 +107,7 @@ impl RequestParts { method, uri, version, - headers: Some(headers), + headers, extensions: Some(extensions), body: Some(body), } @@ -117,14 +117,11 @@ impl RequestParts { /// /// Fails if /// - /// - The full [`HeaderMap`] has been extracted, that is [`take_headers`] - /// have been called. /// - The full [`Extensions`] has been extracted, that is /// [`take_extensions`] have been called. /// - The request body has been extracted, that is [`take_body`] have been /// called. /// - /// [`take_headers`]: RequestParts::take_headers /// [`take_extensions`]: RequestParts::take_extensions /// [`take_body`]: RequestParts::take_body pub fn try_into_request(self) -> Result, RequestAlreadyExtracted> { @@ -132,7 +129,7 @@ impl RequestParts { method, uri, version, - mut headers, + headers, mut extensions, mut body, } = self; @@ -148,14 +145,7 @@ impl RequestParts { *req.method_mut() = method; *req.uri_mut() = uri; *req.version_mut() = version; - - if let Some(headers) = headers.take() { - *req.headers_mut() = headers; - } else { - return Err(RequestAlreadyExtracted::HeadersAlreadyExtracted( - HeadersAlreadyExtracted, - )); - } + *req.headers_mut() = headers; if let Some(extensions) = extensions.take() { *req.extensions_mut() = extensions; @@ -199,22 +189,13 @@ impl RequestParts { } /// Gets a reference to the request headers. - /// - /// Returns `None` if the headers has been taken by another extractor. - pub fn headers(&self) -> Option<&HeaderMap> { - self.headers.as_ref() + pub fn headers(&self) -> &HeaderMap { + &self.headers } /// Gets a mutable reference to the request headers. - /// - /// Returns `None` if the headers has been taken by another extractor. - pub fn headers_mut(&mut self) -> Option<&mut HeaderMap> { - self.headers.as_mut() - } - - /// Takes the headers out of the request, leaving a `None` in its place. - pub fn take_headers(&mut self) -> Option { - self.headers.take() + pub fn headers_mut(&mut self) -> &mut HeaderMap { + &mut self.headers } /// Gets a reference to the request extensions. diff --git a/axum-core/src/extract/rejection.rs b/axum-core/src/extract/rejection.rs index ad4f0eebed..b04aa9f9c5 100644 --- a/axum-core/src/extract/rejection.rs +++ b/axum-core/src/extract/rejection.rs @@ -8,13 +8,6 @@ define_rejection! { pub struct BodyAlreadyExtracted; } -define_rejection! { - #[status = INTERNAL_SERVER_ERROR] - #[body = "Headers taken by other extractor"] - /// Rejection used if the headers has been taken by another extractor. - pub struct HeadersAlreadyExtracted; -} - define_rejection! { #[status = INTERNAL_SERVER_ERROR] #[body = "Extensions taken by other extractor"] @@ -47,7 +40,6 @@ composite_rejection! { /// [`Request<_>`]: http::Request pub enum RequestAlreadyExtracted { BodyAlreadyExtracted, - HeadersAlreadyExtracted, ExtensionsAlreadyExtracted, } } @@ -79,7 +71,6 @@ composite_rejection! { /// /// Contains one variant for each way the [`http::request::Parts`] extractor can fail. pub enum RequestPartsAlreadyExtracted { - HeadersAlreadyExtracted, ExtensionsAlreadyExtracted, } } diff --git a/axum-core/src/extract/request_parts.rs b/axum-core/src/extract/request_parts.rs index 8116dc716f..9fc4bac95f 100644 --- a/axum-core/src/extract/request_parts.rs +++ b/axum-core/src/extract/request_parts.rs @@ -19,7 +19,7 @@ where method: req.method.clone(), version: req.version, uri: req.uri.clone(), - headers: None, + headers: HeaderMap::new(), extensions: None, body: None, }, @@ -65,15 +65,20 @@ where } } +/// Clone the headers from the request. +/// +/// Prefer using [`TypedHeader`] to extract only the headers you need. +/// +/// [`TypedHeader`]: https://docs.rs/axum/latest/axum/extract/struct.TypedHeader.html #[async_trait] impl FromRequest for HeaderMap where B: Send, { - type Rejection = HeadersAlreadyExtracted; + type Rejection = Infallible; async fn from_request(req: &mut RequestParts) -> Result { - req.take_headers().ok_or(HeadersAlreadyExtracted) + Ok(req.headers().clone()) } } @@ -143,7 +148,10 @@ where let method = unwrap_infallible(Method::from_request(req).await); let uri = unwrap_infallible(Uri::from_request(req).await); let version = unwrap_infallible(Version::from_request(req).await); - let headers = HeaderMap::from_request(req).await?; + let headers = match HeaderMap::from_request(req).await { + Ok(headers) => headers, + Err(err) => match err {}, + }; let extensions = Extensions::from_request(req).await?; let mut temp_request = Request::new(()); diff --git a/axum/CHANGELOG.md b/axum/CHANGELOG.md index 1fbd8876f0..4c29e8b14c 100644 --- a/axum/CHANGELOG.md +++ b/axum/CHANGELOG.md @@ -13,9 +13,28 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 overwriting old values. - **breaking:** Require `Output = ()` on `WebSocketStream::on_upgrade` ([#644]) - **breaking:** Make `TypedHeaderRejectionReason` `#[non_exhaustive]` ([#665]) +- **breaking:** Using `HeaderMap` as an extractor will no longer remove the headers and thus + they'll still be accessible to other extractors, such as `axum::extract::Json`. Instead + `HeaderMap` will clone the headers. You should prefer to use `TypedHeader` to extract only the + headers you need ([#698]) + + This includes these breaking changes: + - `RequestParts::take_headers` has been removed. + - `RequestParts::headers` returns `&HeaderMap`. + - `RequestParts::headers_mut` returns `&mut HeaderMap`. + - `HeadersAlreadyExtracted` has been removed. + - The `HeadersAlreadyExtracted` removed variant has been removed from these rejections: + - `RequestAlreadyExtracted` + - `RequestPartsAlreadyExtracted` + - `JsonRejection` + - `FormRejection` + - `ContentLengthLimitRejection` + - `WebSocketUpgradeRejection` + - `>::Error` has been changed to `std::convert::Infallible`. [#644]: https://github.com/tokio-rs/axum/pull/644 [#665]: https://github.com/tokio-rs/axum/pull/665 +[#698]: https://github.com/tokio-rs/axum/pull/698 # 0.4.3 (21. December, 2021) diff --git a/axum/src/docs/extract.md b/axum/src/docs/extract.md index feef409f27..3045397912 100644 --- a/axum/src/docs/extract.md +++ b/axum/src/docs/extract.md @@ -320,10 +320,6 @@ async fn handler(result: Result, JsonRejection>) -> impl IntoRespons StatusCode::INTERNAL_SERVER_ERROR, "Failed to buffer request body".to_string(), )), - JsonRejection::HeadersAlreadyExtracted(_) => Err(( - StatusCode::INTERNAL_SERVER_ERROR, - "Headers already extracted".to_string(), - )), // we must provide a catch-all case since `JsonRejection` is marked // `#[non_exhaustive]` _ => Err(( @@ -377,9 +373,7 @@ where type Rejection = (StatusCode, &'static str); async fn from_request(req: &mut RequestParts) -> Result { - let user_agent = req.headers().and_then(|headers| headers.get(USER_AGENT)); - - if let Some(user_agent) = user_agent { + if let Some(user_agent) = req.headers().get(USER_AGENT) { Ok(ExtractUserAgent(user_agent.clone())) } else { Err((StatusCode::BAD_REQUEST, "`User-Agent` header is missing")) diff --git a/axum/src/extract/content_length_limit.rs b/axum/src/extract/content_length_limit.rs index c1f2143c94..7a41d2543f 100644 --- a/axum/src/extract/content_length_limit.rs +++ b/axum/src/extract/content_length_limit.rs @@ -39,14 +39,7 @@ where type Rejection = ContentLengthLimitRejection; async fn from_request(req: &mut RequestParts) -> Result { - let content_length = req - .headers() - .ok_or_else(|| { - ContentLengthLimitRejection::HeadersAlreadyExtracted( - HeadersAlreadyExtracted::default(), - ) - })? - .get(http::header::CONTENT_LENGTH); + let content_length = req.headers().get(http::header::CONTENT_LENGTH); let content_length = content_length.and_then(|value| value.to_str().ok()?.parse::().ok()); diff --git a/axum/src/extract/extractor_middleware.rs b/axum/src/extract/extractor_middleware.rs index bede119761..aca7b6f397 100644 --- a/axum/src/extract/extractor_middleware.rs +++ b/axum/src/extract/extractor_middleware.rs @@ -59,7 +59,7 @@ use tower_service::Service; /// async fn from_request(req: &mut RequestParts) -> Result { /// let auth_header = req /// .headers() -/// .and_then(|headers| headers.get(http::header::AUTHORIZATION)) +/// .get(http::header::AUTHORIZATION) /// .and_then(|value| value.to_str().ok()); /// /// match auth_header { @@ -291,7 +291,6 @@ mod tests { async fn from_request(req: &mut RequestParts) -> Result { if let Some(auth) = req .headers() - .expect("headers already extracted") .get("authorization") .and_then(|v| v.to_str().ok()) { diff --git a/axum/src/extract/form.rs b/axum/src/extract/form.rs index 84d517df80..e53d722ad2 100644 --- a/axum/src/extract/form.rs +++ b/axum/src/extract/form.rs @@ -60,7 +60,7 @@ where .map_err(FailedToDeserializeQueryString::new::)?; Ok(Form(value)) } else { - if !has_content_type(req, &mime::APPLICATION_WWW_FORM_URLENCODED)? { + if !has_content_type(req, &mime::APPLICATION_WWW_FORM_URLENCODED) { return Err(InvalidFormContentType.into()); } diff --git a/axum/src/extract/mod.rs b/axum/src/extract/mod.rs index 68ba94e443..8e4c8e99c6 100644 --- a/axum/src/extract/mod.rs +++ b/axum/src/extract/mod.rs @@ -78,24 +78,20 @@ pub use self::typed_header::TypedHeader; pub(crate) fn has_content_type( req: &RequestParts, expected_content_type: &mime::Mime, -) -> Result { - let content_type = if let Some(content_type) = req - .headers() - .ok_or_else(HeadersAlreadyExtracted::default)? - .get(header::CONTENT_TYPE) - { +) -> bool { + let content_type = if let Some(content_type) = req.headers().get(header::CONTENT_TYPE) { content_type } else { - return Ok(false); + return false; }; let content_type = if let Ok(content_type) = content_type.to_str() { content_type } else { - return Ok(false); + return false; }; - Ok(content_type.starts_with(expected_content_type.as_ref())) + content_type.starts_with(expected_content_type.as_ref()) } pub(crate) fn take_body(req: &mut RequestParts) -> Result { diff --git a/axum/src/extract/multipart.rs b/axum/src/extract/multipart.rs index d9bad33ed1..26138149ca 100644 --- a/axum/src/extract/multipart.rs +++ b/axum/src/extract/multipart.rs @@ -58,7 +58,7 @@ where async fn from_request(req: &mut RequestParts) -> Result { let stream = BodyStream::from_request(req).await?; - let headers = req.headers().ok_or_else(HeadersAlreadyExtracted::default)?; + let headers = req.headers(); let boundary = parse_boundary(headers).ok_or(InvalidBoundary)?; let multipart = multer::Multipart::new(stream, boundary); Ok(Self { inner: multipart }) @@ -179,7 +179,6 @@ composite_rejection! { pub enum MultipartRejection { BodyAlreadyExtracted, InvalidBoundary, - HeadersAlreadyExtracted, } } diff --git a/axum/src/extract/rejection.rs b/axum/src/extract/rejection.rs index 2777a64fca..627244e342 100644 --- a/axum/src/extract/rejection.rs +++ b/axum/src/extract/rejection.rs @@ -124,7 +124,6 @@ composite_rejection! { InvalidFormContentType, FailedToDeserializeQueryString, BytesRejection, - HeadersAlreadyExtracted, } } @@ -139,7 +138,6 @@ composite_rejection! { InvalidJsonBody, MissingJsonContentType, BytesRejection, - HeadersAlreadyExtracted, } } @@ -195,8 +193,6 @@ pub enum ContentLengthLimitRejection { #[allow(missing_docs)] LengthRequired(LengthRequired), #[allow(missing_docs)] - HeadersAlreadyExtracted(HeadersAlreadyExtracted), - #[allow(missing_docs)] Inner(T), } @@ -208,7 +204,6 @@ where match self { Self::PayloadTooLarge(inner) => inner.into_response(), Self::LengthRequired(inner) => inner.into_response(), - Self::HeadersAlreadyExtracted(inner) => inner.into_response(), Self::Inner(inner) => inner.into_response(), } } @@ -222,7 +217,6 @@ where match self { Self::PayloadTooLarge(inner) => inner.fmt(f), Self::LengthRequired(inner) => inner.fmt(f), - Self::HeadersAlreadyExtracted(inner) => inner.fmt(f), Self::Inner(inner) => inner.fmt(f), } } @@ -236,7 +230,6 @@ where match self { Self::PayloadTooLarge(inner) => Some(inner), Self::LengthRequired(inner) => Some(inner), - Self::HeadersAlreadyExtracted(inner) => Some(inner), Self::Inner(inner) => Some(inner), } } diff --git a/axum/src/extract/typed_header.rs b/axum/src/extract/typed_header.rs index e5a6f99f8e..7bb52c5324 100644 --- a/axum/src/extract/typed_header.rs +++ b/axum/src/extract/typed_header.rs @@ -44,16 +44,7 @@ where type Rejection = TypedHeaderRejection; async fn from_request(req: &mut RequestParts) -> Result { - let headers = if let Some(headers) = req.headers() { - headers - } else { - return Err(TypedHeaderRejection { - name: T::name(), - reason: TypedHeaderRejectionReason::Missing, - }); - }; - - match headers.typed_try_get::() { + match req.headers().typed_try_get::() { Ok(Some(value)) => Ok(Self(value)), Ok(None) => Err(TypedHeaderRejection { name: T::name(), diff --git a/axum/src/extract/ws.rs b/axum/src/extract/ws.rs index 235fe6d65e..e07af70aa7 100644 --- a/axum/src/extract/ws.rs +++ b/axum/src/extract/ws.rs @@ -249,27 +249,24 @@ where return Err(MethodNotGet.into()); } - if !header_contains(req, header::CONNECTION, "upgrade")? { + if !header_contains(req, header::CONNECTION, "upgrade") { return Err(InvalidConnectionHeader.into()); } - if !header_eq(req, header::UPGRADE, "websocket")? { + if !header_eq(req, header::UPGRADE, "websocket") { return Err(InvalidUpgradeHeader.into()); } - if !header_eq(req, header::SEC_WEBSOCKET_VERSION, "13")? { + if !header_eq(req, header::SEC_WEBSOCKET_VERSION, "13") { return Err(InvalidWebSocketVersionHeader.into()); } - let sec_websocket_key = if let Some(key) = req - .headers_mut() - .ok_or_else(HeadersAlreadyExtracted::default)? - .remove(header::SEC_WEBSOCKET_KEY) - { - key - } else { - return Err(WebSocketKeyHeaderMissing.into()); - }; + let sec_websocket_key = + if let Some(key) = req.headers_mut().remove(header::SEC_WEBSOCKET_KEY) { + key + } else { + return Err(WebSocketKeyHeaderMissing.into()); + }; let on_upgrade = req .extensions_mut() @@ -277,11 +274,7 @@ where .remove::() .unwrap(); - let sec_websocket_protocol = req - .headers() - .ok_or_else(HeadersAlreadyExtracted::default)? - .get(header::SEC_WEBSOCKET_PROTOCOL) - .cloned(); + let sec_websocket_protocol = req.headers().get(header::SEC_WEBSOCKET_PROTOCOL).cloned(); Ok(Self { config: Default::default(), @@ -293,41 +286,25 @@ where } } -fn header_eq( - req: &RequestParts, - key: HeaderName, - value: &'static str, -) -> Result { - if let Some(header) = req - .headers() - .ok_or_else(HeadersAlreadyExtracted::default)? - .get(&key) - { - Ok(header.as_bytes().eq_ignore_ascii_case(value.as_bytes())) +fn header_eq(req: &RequestParts, key: HeaderName, value: &'static str) -> bool { + if let Some(header) = req.headers().get(&key) { + header.as_bytes().eq_ignore_ascii_case(value.as_bytes()) } else { - Ok(false) + false } } -fn header_contains( - req: &RequestParts, - key: HeaderName, - value: &'static str, -) -> Result { - let header = if let Some(header) = req - .headers() - .ok_or_else(HeadersAlreadyExtracted::default)? - .get(&key) - { +fn header_contains(req: &RequestParts, key: HeaderName, value: &'static str) -> bool { + let header = if let Some(header) = req.headers().get(&key) { header } else { - return Ok(false); + return false; }; if let Ok(header) = std::str::from_utf8(header.as_bytes()) { - Ok(header.to_ascii_lowercase().contains(value)) + header.to_ascii_lowercase().contains(value) } else { - Ok(false) + false } } @@ -585,7 +562,6 @@ pub mod rejection { InvalidUpgradeHeader, InvalidWebSocketVersionHeader, WebSocketKeyHeaderMissing, - HeadersAlreadyExtracted, ExtensionsAlreadyExtracted, } } diff --git a/axum/src/json.rs b/axum/src/json.rs index 103e2997ba..3b347250c6 100644 --- a/axum/src/json.rs +++ b/axum/src/json.rs @@ -96,7 +96,7 @@ where type Rejection = JsonRejection; async fn from_request(req: &mut RequestParts) -> Result { - if json_content_type(req)? { + if json_content_type(req) { let bytes = Bytes::from_request(req).await?; let value = serde_json::from_slice(&bytes).map_err(InvalidJsonBody::from_err)?; @@ -108,33 +108,29 @@ where } } -fn json_content_type(req: &RequestParts) -> Result { - let content_type = if let Some(content_type) = req - .headers() - .ok_or_else(HeadersAlreadyExtracted::default)? - .get(header::CONTENT_TYPE) - { +fn json_content_type(req: &RequestParts) -> bool { + let content_type = if let Some(content_type) = req.headers().get(header::CONTENT_TYPE) { content_type } else { - return Ok(false); + return false; }; let content_type = if let Ok(content_type) = content_type.to_str() { content_type } else { - return Ok(false); + return false; }; let mime = if let Ok(mime) = content_type.parse::() { mime } else { - return Ok(false); + return false; }; let is_json_content_type = mime.type_() == "application" && (mime.subtype() == "json" || mime.suffix().map_or(false, |name| name == "json")); - Ok(is_json_content_type) + is_json_content_type } impl Deref for Json { diff --git a/examples/customize-extractor-error/src/main.rs b/examples/customize-extractor-error/src/main.rs index 4e951001d7..2afba46ffa 100644 --- a/examples/customize-extractor-error/src/main.rs +++ b/examples/customize-extractor-error/src/main.rs @@ -73,9 +73,6 @@ where JsonRejection::MissingJsonContentType(err) => { (StatusCode::BAD_REQUEST, err.to_string().into()) } - JsonRejection::HeadersAlreadyExtracted(err) => { - (StatusCode::INTERNAL_SERVER_ERROR, err.to_string().into()) - } err => ( StatusCode::INTERNAL_SERVER_ERROR, format!("Unknown internal error: {}", err).into(),