Skip to content

Commit

Permalink
Add OptionalPath extractor (#1889)
Browse files Browse the repository at this point in the history
Co-authored-by: David Pedersen <david.pdrsn@gmail.com>
  • Loading branch information
jplatte and davidpdrsn authored Apr 9, 2023
1 parent 946d8c3 commit 43b2d52
Show file tree
Hide file tree
Showing 6 changed files with 124 additions and 13 deletions.
4 changes: 3 additions & 1 deletion axum-extra/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@ and this project adheres to [Semantic Versioning].

# Unreleased

- None.
- **added:** Add `OptionalPath` extractor ([#1889])

[#1889]: https://github.com/tokio-rs/axum/pull/1889

# 0.7.2 (22. March, 2023)

Expand Down
12 changes: 6 additions & 6 deletions axum-extra/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,20 +19,19 @@ cookie = ["dep:cookie"]
cookie-private = ["cookie", "cookie?/private"]
cookie-signed = ["cookie", "cookie?/signed"]
cookie-key-expansion = ["cookie", "cookie?/key-expansion"]
erased-json = ["dep:serde_json", "dep:serde"]
form = ["dep:serde", "dep:serde_html_form"]
erased-json = ["dep:serde_json"]
form = ["dep:serde_html_form"]
json-lines = [
"dep:serde_json",
"dep:serde",
"dep:tokio-util",
"dep:tokio-stream",
"tokio-util?/io",
"tokio-stream?/io-util"
]
multipart = ["dep:multer"]
protobuf = ["dep:prost"]
query = ["dep:serde", "dep:serde_html_form"]
typed-routing = ["dep:axum-macros", "dep:serde", "dep:percent-encoding", "dep:serde_html_form", "dep:form_urlencoded"]
query = ["dep:serde_html_form"]
typed-routing = ["dep:axum-macros", "dep:percent-encoding", "dep:serde_html_form", "dep:form_urlencoded"]

[dependencies]
axum = { path = "../axum", version = "0.6.9", default-features = false }
Expand All @@ -42,6 +41,7 @@ http = "0.2"
http-body = "0.4.4"
mime = "0.3"
pin-project-lite = "0.2"
serde = "1.0"
tokio = "1.19"
tower = { version = "0.4", default_features = false, features = ["util"] }
tower-http = { version = "0.4", features = ["map-response-body"] }
Expand All @@ -55,14 +55,14 @@ form_urlencoded = { version = "1.1.0", optional = true }
multer = { version = "2.0.0", optional = true }
percent-encoding = { version = "2.1", optional = true }
prost = { version = "0.11", optional = true }
serde = { version = "1.0", optional = true }
serde_html_form = { version = "0.2.0", optional = true }
serde_json = { version = "1.0.71", optional = true }
tokio-stream = { version = "0.1.9", optional = true }
tokio-util = { version = "0.7", optional = true }

[dev-dependencies]
axum = { path = "../axum", version = "0.6.0", features = ["headers"] }
axum-macros = { path = "../axum-macros", version = "0.3.7", features = ["__private"] }
http-body = "0.4.4"
hyper = "0.14"
reqwest = { version = "0.11", default-features = false, features = ["json", "stream", "multipart"] }
Expand Down
8 changes: 3 additions & 5 deletions axum-extra/src/extract/mod.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
//! Additional extractors.
mod cached;
mod optional_path;
mod with_rejection;

#[cfg(feature = "form")]
mod form;
Expand All @@ -14,9 +16,7 @@ mod query;
#[cfg(feature = "multipart")]
pub mod multipart;

mod with_rejection;

pub use self::cached::Cached;
pub use self::{cached::Cached, optional_path::OptionalPath, with_rejection::WithRejection};

#[cfg(feature = "cookie")]
pub use self::cookie::CookieJar;
Expand All @@ -39,5 +39,3 @@ pub use self::multipart::Multipart;
#[cfg(feature = "json-lines")]
#[doc(no_inline)]
pub use crate::json_lines::JsonLines;

pub use self::with_rejection::WithRejection;
102 changes: 102 additions & 0 deletions axum-extra/src/extract/optional_path.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
use axum::{
async_trait,
extract::{path::ErrorKind, rejection::PathRejection, FromRequestParts, Path},
RequestPartsExt,
};
use serde::de::DeserializeOwned;

/// Extractor that extracts path arguments the same way as [`Path`], except if there aren't any.
///
/// This extractor can be used in place of `Path` when you have two routes that you want to handle
/// in mostly the same way, where one has a path parameter and the other one doesn't.
///
/// # Example
///
/// ```
/// use std::num::NonZeroU32;
/// use axum::{
/// response::IntoResponse,
/// routing::get,
/// Router,
/// };
/// use axum_extra::extract::OptionalPath;
///
/// async fn render_blog(OptionalPath(page): OptionalPath<NonZeroU32>) -> impl IntoResponse {
/// // Convert to u32, default to page 1 if not specified
/// let page = page.map_or(1, |param| param.get());
/// // ...
/// }
///
/// let app = Router::new()
/// .route("/blog", get(render_blog))
/// .route("/blog/:page", get(render_blog));
/// # let app: Router = app;
/// ```
#[derive(Debug)]
pub struct OptionalPath<T>(pub Option<T>);

#[async_trait]
impl<T, S> FromRequestParts<S> for OptionalPath<T>
where
T: DeserializeOwned + Send + 'static,
S: Send + Sync,
{
type Rejection = PathRejection;

async fn from_request_parts(
parts: &mut http::request::Parts,
_: &S,
) -> Result<Self, Self::Rejection> {
match parts.extract::<Path<T>>().await {
Ok(Path(params)) => Ok(Self(Some(params))),
Err(PathRejection::FailedToDeserializePathParams(e))
if matches!(e.kind(), ErrorKind::WrongNumberOfParameters { got: 0, .. }) =>
{
Ok(Self(None))
}
Err(e) => Err(e),
}
}
}

#[cfg(test)]
mod tests {
use std::num::NonZeroU32;

use axum::{routing::get, Router};

use super::OptionalPath;
use crate::test_helpers::TestClient;

#[crate::test]
async fn supports_128_bit_numbers() {
async fn handle(OptionalPath(param): OptionalPath<NonZeroU32>) -> String {
let num = param.map_or(0, |p| p.get());
format!("Success: {num}")
}

let app = Router::new()
.route("/", get(handle))
.route("/:num", get(handle));

let client = TestClient::new(app);

let res = client.get("/").send().await;
assert_eq!(res.text().await, "Success: 0");

let res = client.get("/1").send().await;
assert_eq!(res.text().await, "Success: 1");

let res = client.get("/0").send().await;
assert_eq!(
res.text().await,
"Invalid URL: invalid value: integer `0`, expected a nonzero u32"
);

let res = client.get("/NaN").send().await;
assert_eq!(
res.text().await,
"Invalid URL: Cannot parse `\"NaN\"` to a `u32`"
);
}
}
3 changes: 3 additions & 0 deletions axum-extra/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,9 @@ pub mod __private {
pub const PATH_SEGMENT: &AsciiSet = &PATH.add(b'/').add(b'%');
}

#[cfg(test)]
use axum_macros::__private_axum_test as test;

#[cfg(test)]
pub(crate) mod test_helpers {
#![allow(unused_imports)]
Expand Down
8 changes: 7 additions & 1 deletion axum/src/extract/path/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,8 @@ impl std::error::Error for PathDeserializationError {}

/// The kinds of errors that can happen we deserializing into a [`Path`].
///
/// This type is obtained through [`FailedToDeserializePathParams::into_kind`] and is useful for building
/// This type is obtained through [`FailedToDeserializePathParams::kind`] or
/// [`FailedToDeserializePathParams::into_kind`] and is useful for building
/// more precise error messages.
#[derive(Debug, PartialEq, Eq)]
#[non_exhaustive]
Expand Down Expand Up @@ -380,6 +381,11 @@ impl fmt::Display for ErrorKind {
pub struct FailedToDeserializePathParams(PathDeserializationError);

impl FailedToDeserializePathParams {
/// Get a reference to the underlying error kind.
pub fn kind(&self) -> &ErrorKind {
&self.0.kind
}

/// Convert this error into the underlying error kind.
pub fn into_kind(self) -> ErrorKind {
self.0.kind
Expand Down

0 comments on commit 43b2d52

Please sign in to comment.