Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove RequestParts::take_extensions #699

Merged
merged 8 commits into from
Jan 13, 2022
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 12 additions & 35 deletions axum-core/src/extract/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ pub struct RequestParts<B> {
uri: Uri,
version: Version,
headers: HeaderMap,
extensions: Option<Extensions>,
extensions: Extensions,
body: Option<B>,
}

Expand Down Expand Up @@ -108,52 +108,38 @@ impl<B> RequestParts<B> {
uri,
version,
headers,
extensions: Some(extensions),
extensions,
body: Some(body),
}
}

/// Convert this `RequestParts` back into a [`Request`].
///
/// Fails if
/// Fails if The request body has been extracted, that is [`take_body`] have
davidpdrsn marked this conversation as resolved.
Show resolved Hide resolved
/// been called.
///
/// - The full [`Extensions`] has been extracted, that is
/// [`take_extensions`] have been called.
/// - The request body has been extracted, that is [`take_body`] have been
/// called.
///
/// [`take_extensions`]: RequestParts::take_extensions
/// [`take_body`]: RequestParts::take_body
pub fn try_into_request(self) -> Result<Request<B>, RequestAlreadyExtracted> {
pub fn try_into_request(self) -> Result<Request<B>, BodyAlreadyExtracted> {
let Self {
method,
uri,
version,
headers,
mut extensions,
extensions,
mut body,
} = self;

let mut req = if let Some(body) = body.take() {
Request::new(body)
} else {
return Err(RequestAlreadyExtracted::BodyAlreadyExtracted(
BodyAlreadyExtracted,
));
return Err(BodyAlreadyExtracted);
};

*req.method_mut() = method;
*req.uri_mut() = uri;
*req.version_mut() = version;
*req.headers_mut() = headers;

if let Some(extensions) = extensions.take() {
*req.extensions_mut() = extensions;
} else {
return Err(RequestAlreadyExtracted::ExtensionsAlreadyExtracted(
ExtensionsAlreadyExtracted,
));
}
*req.extensions_mut() = extensions;

Ok(req)
}
Expand Down Expand Up @@ -199,22 +185,13 @@ impl<B> RequestParts<B> {
}

/// Gets a reference to the request extensions.
///
/// Returns `None` if the extensions has been taken by another extractor.
pub fn extensions(&self) -> Option<&Extensions> {
self.extensions.as_ref()
pub fn extensions(&self) -> &Extensions {
&self.extensions
}

/// Gets a mutable reference to the request extensions.
///
/// Returns `None` if the extensions has been taken by another extractor.
pub fn extensions_mut(&mut self) -> Option<&mut Extensions> {
self.extensions.as_mut()
}

/// Takes the extensions out of the request, leaving a `None` in its place.
pub fn take_extensions(&mut self) -> Option<Extensions> {
self.extensions.take()
pub fn extensions_mut(&mut self) -> &mut Extensions {
&mut self.extensions
}

/// Gets a reference to the request body.
Expand Down
30 changes: 1 addition & 29 deletions axum-core/src/extract/rejection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,10 @@ define_rejection! {
#[body = "Cannot have two request body extractors for a single handler"]
/// Rejection type used if you try and extract the request body more than
/// once.
#[derive(Default)]
pub struct BodyAlreadyExtracted;
}

define_rejection! {
#[status = INTERNAL_SERVER_ERROR]
#[body = "Extensions taken by other extractor"]
/// Rejection used if the request extension has been taken by another
/// extractor.
pub struct ExtensionsAlreadyExtracted;
}

define_rejection! {
#[status = BAD_REQUEST]
#[body = "Failed to buffer the request body"]
Expand All @@ -32,18 +25,6 @@ define_rejection! {
pub struct InvalidUtf8(Error);
}

composite_rejection! {
/// Rejection used for [`Request<_>`].
///
/// Contains one variant for each way the [`Request<_>`] extractor can fail.
///
/// [`Request<_>`]: http::Request
pub enum RequestAlreadyExtracted {
BodyAlreadyExtracted,
ExtensionsAlreadyExtracted,
}
}

composite_rejection! {
/// Rejection used for [`Bytes`](bytes::Bytes).
///
Expand All @@ -65,12 +46,3 @@ composite_rejection! {
InvalidUtf8,
}
}

composite_rejection! {
/// Rejection used for [`http::request::Parts`].
///
/// Contains one variant for each way the [`http::request::Parts`] extractor can fail.
pub enum RequestPartsAlreadyExtracted {
ExtensionsAlreadyExtracted,
}
}
25 changes: 5 additions & 20 deletions axum-core/src/extract/request_parts.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ impl<B> FromRequest<B> for Request<B>
where
B: Send,
{
type Rejection = RequestAlreadyExtracted;
type Rejection = BodyAlreadyExtracted;

async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> {
let req = std::mem::replace(
Expand All @@ -20,7 +20,7 @@ where
version: req.version,
uri: req.uri.clone(),
headers: HeaderMap::new(),
extensions: None,
extensions: Extensions::default(),
body: None,
},
);
Expand Down Expand Up @@ -82,18 +82,6 @@ where
}
}

#[async_trait]
impl<B> FromRequest<B> for Extensions
where
B: Send,
{
type Rejection = ExtensionsAlreadyExtracted;

async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> {
req.take_extensions().ok_or(ExtensionsAlreadyExtracted)
}
}

#[async_trait]
impl<B> FromRequest<B> for Bytes
where
Expand Down Expand Up @@ -142,17 +130,14 @@ impl<B> FromRequest<B> for http::request::Parts
where
B: Send,
{
type Rejection = RequestPartsAlreadyExtracted;
type Rejection = Infallible;

async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> {
let method = unwrap_infallible(Method::from_request(req).await);
let uri = unwrap_infallible(Uri::from_request(req).await);
let version = unwrap_infallible(Version::from_request(req).await);
let headers = match HeaderMap::from_request(req).await {
Ok(headers) => headers,
Err(err) => match err {},
};
let extensions = Extensions::from_request(req).await?;
let headers = unwrap_infallible(HeaderMap::from_request(req).await);
let extensions = std::mem::take(req.extensions_mut());

let mut temp_request = Request::new(());
*temp_request.method_mut() = method;
Expand Down
6 changes: 0 additions & 6 deletions axum-core/src/macros.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,6 @@ macro_rules! define_rejection {
}

impl std::error::Error for $name {}

impl Default for $name {
fn default() -> Self {
Self
}
}
davidpdrsn marked this conversation as resolved.
Show resolved Hide resolved
};

(
Expand Down
76 changes: 5 additions & 71 deletions axum-extra/src/extract/cached.rs
Original file line number Diff line number Diff line change
@@ -1,15 +1,8 @@
use axum::{
async_trait,
extract::{
rejection::{ExtensionRejection, ExtensionsAlreadyExtracted},
Extension, FromRequest, RequestParts,
},
response::{IntoResponse, Response},
};
use std::{
fmt,
ops::{Deref, DerefMut},
extract::{Extension, FromRequest, RequestParts},
};
use std::ops::{Deref, DerefMut};

/// Cache results of other extractors.
///
Expand Down Expand Up @@ -100,25 +93,14 @@ where
B: Send,
T: FromRequest<B> + Clone + Send + Sync + 'static,
{
type Rejection = CachedRejection<T::Rejection>;
type Rejection = T::Rejection;

async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> {
match Extension::<CachedEntry<T>>::from_request(req).await {
Ok(Extension(CachedEntry(value))) => Ok(Self(value)),
Err(ExtensionRejection::ExtensionsAlreadyExtracted(err)) => {
Err(CachedRejection::ExtensionsAlreadyExtracted(err))
}
Err(_) => {
let value = T::from_request(req).await.map_err(CachedRejection::Inner)?;

req.extensions_mut()
.ok_or_else(|| {
CachedRejection::ExtensionsAlreadyExtracted(
ExtensionsAlreadyExtracted::default(),
)
})?
.insert(CachedEntry(value.clone()));

let value = T::from_request(req).await?;
req.extensions_mut().insert(CachedEntry(value.clone()));
Ok(Self(value))
}
}
Expand All @@ -139,54 +121,6 @@ impl<T> DerefMut for Cached<T> {
}
}

/// Rejection used for [`Cached`].
///
/// Contains one variant for each way the [`Cached`] extractor can fail.
#[derive(Debug)]
#[non_exhaustive]
pub enum CachedRejection<R> {
#[allow(missing_docs)]
ExtensionsAlreadyExtracted(ExtensionsAlreadyExtracted),
#[allow(missing_docs)]
Inner(R),
}

impl<R> IntoResponse for CachedRejection<R>
where
R: IntoResponse,
{
fn into_response(self) -> Response {
match self {
Self::ExtensionsAlreadyExtracted(inner) => inner.into_response(),
Self::Inner(inner) => inner.into_response(),
}
}
}

impl<R> fmt::Display for CachedRejection<R>
where
R: fmt::Display,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::ExtensionsAlreadyExtracted(inner) => write!(f, "{}", inner),
Self::Inner(inner) => write!(f, "{}", inner),
}
}
}

impl<R> std::error::Error for CachedRejection<R>
where
R: std::error::Error + 'static,
{
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
match self {
Self::ExtensionsAlreadyExtracted(inner) => Some(inner),
Self::Inner(inner) => Some(inner),
}
}
}

#[cfg(test)]
mod tests {
use super::*;
Expand Down
6 changes: 0 additions & 6 deletions axum-extra/src/extract/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,3 @@
mod cached;

pub use self::cached::Cached;

pub mod rejection {
//! Rejection response types.

pub use super::cached::CachedRejection;
}
1 change: 0 additions & 1 deletion axum/src/extract/extension.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,6 @@ where
async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> {
let value = req
.extensions()
.ok_or_else(ExtensionsAlreadyExtracted::default)?
.get::<T>()
.ok_or_else(|| {
MissingExtension::from_err(format!(
Expand Down
7 changes: 2 additions & 5 deletions axum/src/extract/matched_path.rs
Original file line number Diff line number Diff line change
Expand Up @@ -70,11 +70,8 @@ where
type Rejection = MatchedPathRejection;

async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> {
let extensions = req.extensions().ok_or_else(|| {
MatchedPathRejection::ExtensionsAlreadyExtracted(ExtensionsAlreadyExtracted::default())
})?;

let matched_path = extensions
let matched_path = req
.extensions()
.get::<Self>()
.ok_or(MatchedPathRejection::MatchedPathMissing(MatchedPathMissing))?
.clone();
Expand Down
12 changes: 5 additions & 7 deletions axum/src/extract/path/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@

mod de;

use super::rejection::ExtensionsAlreadyExtracted;
use crate::{
body::{boxed, Full},
extract::{rejection::*, FromRequest, RequestParts},
Expand Down Expand Up @@ -164,11 +163,7 @@ where
type Rejection = PathRejection;

async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> {
let ext = req
.extensions_mut()
.ok_or_else::<Self::Rejection, _>(|| ExtensionsAlreadyExtracted::default().into())?;

let params = match ext.get::<Option<UrlParams>>() {
let params = match req.extensions_mut().get::<Option<UrlParams>>() {
Some(Some(UrlParams(Ok(params)))) => Cow::Borrowed(params),
Some(Some(UrlParams(Err(InvalidUtf8InPathParam { key })))) => {
let err = PathDeserializationError {
Expand Down Expand Up @@ -519,6 +514,9 @@ mod tests {

let res = client.get("/foo").send().await;
assert_eq!(res.status(), StatusCode::INTERNAL_SERVER_ERROR);
assert_eq!(res.text().await, "Extensions taken by other extractor");
assert_eq!(
res.text().await,
"No paths parameters found for matched route. Are you also extracting `Request<_>`?"
);
}
}
Loading