Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add OptionalPath extractor #1889

Merged
merged 4 commits into from
Apr 9, 2023
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions axum-extra/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,20 +19,19 @@ cookie = ["dep:cookie"]
cookie-private = ["cookie", "cookie?/private"]
cookie-signed = ["cookie", "cookie?/signed"]
cookie-key-expansion = ["cookie", "cookie?/key-expansion"]
erased-json = ["dep:serde_json", "dep:serde"]
form = ["dep:serde", "dep:serde_html_form"]
erased-json = ["dep:serde_json"]
form = ["dep:serde_html_form"]
json-lines = [
"dep:serde_json",
"dep:serde",
"dep:tokio-util",
"dep:tokio-stream",
"tokio-util?/io",
"tokio-stream?/io-util"
]
multipart = ["dep:multer"]
protobuf = ["dep:prost"]
query = ["dep:serde", "dep:serde_html_form"]
typed-routing = ["dep:axum-macros", "dep:serde", "dep:percent-encoding", "dep:serde_html_form", "dep:form_urlencoded"]
query = ["dep:serde_html_form"]
typed-routing = ["dep:axum-macros", "dep:percent-encoding", "dep:serde_html_form", "dep:form_urlencoded"]

[dependencies]
axum = { path = "../axum", version = "0.6.9", default-features = false }
Expand All @@ -42,6 +41,7 @@ http = "0.2"
http-body = "0.4.4"
mime = "0.3"
pin-project-lite = "0.2"
serde = "1.0"
davidpdrsn marked this conversation as resolved.
Show resolved Hide resolved
tokio = "1.19"
tower = { version = "0.4", default_features = false, features = ["util"] }
tower-http = { version = "0.4", features = ["map-response-body"] }
Expand All @@ -55,14 +55,14 @@ form_urlencoded = { version = "1.1.0", optional = true }
multer = { version = "2.0.0", optional = true }
percent-encoding = { version = "2.1", optional = true }
prost = { version = "0.11", optional = true }
serde = { version = "1.0", optional = true }
serde_html_form = { version = "0.2.0", optional = true }
serde_json = { version = "1.0.71", optional = true }
tokio-stream = { version = "0.1.9", optional = true }
tokio-util = { version = "0.7", optional = true }

[dev-dependencies]
axum = { path = "../axum", version = "0.6.0", features = ["headers"] }
axum-macros = { path = "../axum-macros", version = "0.3.7", features = ["__private"] }
http-body = "0.4.4"
hyper = "0.14"
reqwest = { version = "0.11", default-features = false, features = ["json", "stream", "multipart"] }
Expand Down
8 changes: 3 additions & 5 deletions axum-extra/src/extract/mod.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
//! Additional extractors.

mod cached;
mod optional_path;
mod with_rejection;

#[cfg(feature = "form")]
mod form;
Expand All @@ -14,9 +16,7 @@ mod query;
#[cfg(feature = "multipart")]
pub mod multipart;

mod with_rejection;

pub use self::cached::Cached;
pub use self::{cached::Cached, optional_path::OptionalPath, with_rejection::WithRejection};

#[cfg(feature = "cookie")]
pub use self::cookie::CookieJar;
Expand All @@ -39,5 +39,3 @@ pub use self::multipart::Multipart;
#[cfg(feature = "json-lines")]
#[doc(no_inline)]
pub use crate::json_lines::JsonLines;

pub use self::with_rejection::WithRejection;
102 changes: 102 additions & 0 deletions axum-extra/src/extract/optional_path.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
use axum::{
async_trait,
extract::{path::ErrorKind, rejection::PathRejection, FromRequestParts, Path},
RequestPartsExt,
};
use serde::de::DeserializeOwned;

/// Extractor that extracts path arguments the same way as [`Path`], except if there aren't any.
///
/// This extractor can be used in place of `Path` when you have two routes that you want to handle
/// in mostly the same way, where one has a path parameter and the other one doesn't.
///
/// # Example
///
/// ```
/// use std::num::NonZeroU32;
/// use axum::{
/// response::IntoResponse,
/// routing::get,
/// Router,
/// };
/// use axum_extra::extract::OptionalPath;
///
/// async fn render_blog(OptionalPath(page): OptionalPath<NonZeroU32>) -> impl IntoResponse {
/// // Convert to u32, default to page 1 if not specified
/// let page = page.map_or(1, |param| param.get());
/// // ...
/// }
///
/// let app = Router::new()
/// .route("/blog", get(render_blog))
/// .route("/blog/:page", get(render_blog));
/// # let app: Router = app;
/// ```
#[derive(Debug)]
pub struct OptionalPath<T>(pub Option<T>);

#[async_trait]
impl<T, S> FromRequestParts<S> for OptionalPath<T>
where
T: DeserializeOwned + Send + 'static,
S: Send + Sync,
{
type Rejection = PathRejection;

async fn from_request_parts(
parts: &mut http::request::Parts,
_: &S,
) -> Result<Self, Self::Rejection> {
match parts.extract::<Path<T>>().await {
Ok(Path(params)) => Ok(Self(Some(params))),
Err(PathRejection::FailedToDeserializePathParams(e))
if matches!(e.kind(), ErrorKind::WrongNumberOfParameters { got: 0, .. }) =>
{
Ok(Self(None))
}
Err(e) => Err(e),
}
}
}

#[cfg(test)]
mod tests {
use std::num::NonZeroU32;

use axum::{routing::get, Router};

use super::OptionalPath;
use crate::test_helpers::TestClient;

#[crate::test]
async fn supports_128_bit_numbers() {
async fn handle(OptionalPath(param): OptionalPath<NonZeroU32>) -> String {
let num = param.map_or(0, |p| p.get());
format!("Success: {num}")
}

let app = Router::new()
.route("/", get(handle))
.route("/:num", get(handle));

let client = TestClient::new(app);

let res = client.get("/").send().await;
assert_eq!(res.text().await, "Success: 0");

let res = client.get("/1").send().await;
assert_eq!(res.text().await, "Success: 1");

let res = client.get("/0").send().await;
assert_eq!(
res.text().await,
"Invalid URL: invalid value: integer `0`, expected a nonzero u32"
);

let res = client.get("/NaN").send().await;
assert_eq!(
res.text().await,
"Invalid URL: Cannot parse `\"NaN\"` to a `u32`"
);
}
}
3 changes: 3 additions & 0 deletions axum-extra/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,9 @@ pub mod __private {
pub const PATH_SEGMENT: &AsciiSet = &PATH.add(b'/').add(b'%');
}

#[cfg(test)]
use axum_macros::__private_axum_test as test;

#[cfg(test)]
pub(crate) mod test_helpers {
#![allow(unused_imports)]
Expand Down
8 changes: 7 additions & 1 deletion axum/src/extract/path/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,8 @@ impl std::error::Error for PathDeserializationError {}

/// The kinds of errors that can happen we deserializing into a [`Path`].
///
/// This type is obtained through [`FailedToDeserializePathParams::into_kind`] and is useful for building
/// This type is obtained through [`FailedToDeserializePathParams::kind`] or
/// [`FailedToDeserializePathParams::into_kind`] and is useful for building
/// more precise error messages.
#[derive(Debug, PartialEq, Eq)]
#[non_exhaustive]
Expand Down Expand Up @@ -380,6 +381,11 @@ impl fmt::Display for ErrorKind {
pub struct FailedToDeserializePathParams(PathDeserializationError);

impl FailedToDeserializePathParams {
/// Get a reference to the underlying error kind.
pub fn kind(&self) -> &ErrorKind {
&self.0.kind
}

/// Convert this error into the underlying error kind.
pub fn into_kind(self) -> ErrorKind {
self.0.kind
Expand Down