Skip to content

Commit

Permalink
Bug fixes for RouterExt:{route_with_tsr, route_service_with_tsr} (#…
Browse files Browse the repository at this point in the history
…1608)

* Bug fixes for `RouterExt:{route_with_tsr, route_service_with_tsr}`

* changelog link
  • Loading branch information
davidpdrsn authored Dec 2, 2022
1 parent 56d0dd9 commit 7386e5d
Show file tree
Hide file tree
Showing 2 changed files with 113 additions and 23 deletions.
7 changes: 6 additions & 1 deletion axum-extra/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,12 @@ and this project adheres to [Semantic Versioning].

# Unreleased

- None.
- **fixed:** Bug fixes for `RouterExt:{route_with_tsr, route_service_with_tsr}` ([#1608]):
- Redirects to the correct URI if the route contains path parameters
- Keeps query parameters when redirecting
- Better improved error message if adding route for `/`

[#1608]: https://github.com/tokio-rs/axum/pull/1608

# 0.4.1 (29. November, 2022)

Expand Down
129 changes: 107 additions & 22 deletions axum-extra/src/routing/mod.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
//! Additional types for defining routes.
use axum::{
handler::HandlerWithoutStateExt,
http::Request,
response::{IntoResponse, Redirect},
response::{IntoResponse, Redirect, Response},
routing::{any, MethodRouter},
Router,
};
use std::{convert::Infallible, future::ready, sync::Arc};
use http::{uri::PathAndQuery, StatusCode, Uri};
use std::{borrow::Cow, convert::Infallible};
use tower_service::Service;

mod resource;
Expand Down Expand Up @@ -265,18 +265,9 @@ where
where
Self: Sized,
{
validate_tsr_path(path);
self = self.route(path, method_router);

let redirect_service = {
let path: Arc<str> = path.into();
(move || ready(Redirect::permanent(&path))).into_service()
};

if let Some(path_without_trailing_slash) = path.strip_suffix('/') {
self.route_service(path_without_trailing_slash, redirect_service)
} else {
self.route_service(&format!("{}/", path), redirect_service)
}
add_tsr_redirect_route(self, path)
}

#[track_caller]
Expand All @@ -287,19 +278,65 @@ where
T::Future: Send + 'static,
Self: Sized,
{
validate_tsr_path(path);
self = self.route_service(path, service);
add_tsr_redirect_route(self, path)
}
}

let redirect = Redirect::permanent(path);
#[track_caller]
fn validate_tsr_path(path: &str) {
if path == "/" {
panic!("Cannot add a trailing slash redirect route for `/`")
}
}

if let Some(path_without_trailing_slash) = path.strip_suffix('/') {
self.route(
path_without_trailing_slash,
any(move || ready(redirect.clone())),
)
fn add_tsr_redirect_route<S, B>(router: Router<S, B>, path: &str) -> Router<S, B>
where
B: axum::body::HttpBody + Send + 'static,
S: Clone + Send + Sync + 'static,
{
async fn redirect_handler(uri: Uri) -> Response {
let new_uri = map_path(uri, |path| {
path.strip_suffix('/')
.map(Cow::Borrowed)
.unwrap_or_else(|| Cow::Owned(format!("{path}/")))
});

if let Some(new_uri) = new_uri {
Redirect::permanent(&new_uri.to_string()).into_response()
} else {
self.route(&format!("{}/", path), any(move || ready(redirect.clone())))
StatusCode::BAD_REQUEST.into_response()
}
}

if let Some(path_without_trailing_slash) = path.strip_suffix('/') {
router.route(path_without_trailing_slash, any(redirect_handler))
} else {
router.route(&format!("{}/", path), any(redirect_handler))
}
}

/// Map the path of a `Uri`.
///
/// Returns `None` if the `Uri` cannot be put back together with the new path.
fn map_path<F>(original_uri: Uri, f: F) -> Option<Uri>
where
F: FnOnce(&str) -> Cow<'_, str>,
{
let mut parts = original_uri.into_parts();
let path_and_query = parts.path_and_query.as_ref()?;

let new_path = f(path_and_query.path());

let new_path_and_query = if let Some(query) = &path_and_query.query() {
format!("{new_path}?{query}").parse::<PathAndQuery>().ok()?
} else {
new_path.parse::<PathAndQuery>().ok()?
};
parts.path_and_query = Some(new_path_and_query);

Uri::from_parts(parts).ok()
}

mod sealed {
Expand All @@ -311,7 +348,7 @@ mod sealed {
mod tests {
use super::*;
use crate::test_helpers::*;
use axum::{http::StatusCode, routing::get};
use axum::{extract::Path, http::StatusCode, routing::get};

#[tokio::test]
async fn test_tsr() {
Expand All @@ -335,4 +372,52 @@ mod tests {
assert_eq!(res.status(), StatusCode::PERMANENT_REDIRECT);
assert_eq!(res.headers()["location"], "/bar/");
}

#[tokio::test]
async fn tsr_with_params() {
let app = Router::new()
.route_with_tsr(
"/a/:a",
get(|Path(param): Path<String>| async move { param }),
)
.route_with_tsr(
"/b/:b/",
get(|Path(param): Path<String>| async move { param }),
);

let client = TestClient::new(app);

let res = client.get("/a/foo").send().await;
assert_eq!(res.status(), StatusCode::OK);
assert_eq!(res.text().await, "foo");

let res = client.get("/a/foo/").send().await;
assert_eq!(res.status(), StatusCode::PERMANENT_REDIRECT);
assert_eq!(res.headers()["location"], "/a/foo");

let res = client.get("/b/foo/").send().await;
assert_eq!(res.status(), StatusCode::OK);
assert_eq!(res.text().await, "foo");

let res = client.get("/b/foo").send().await;
assert_eq!(res.status(), StatusCode::PERMANENT_REDIRECT);
assert_eq!(res.headers()["location"], "/b/foo/");
}

#[tokio::test]
async fn tsr_maintains_query_params() {
let app = Router::new().route_with_tsr("/foo", get(|| async {}));

let client = TestClient::new(app);

let res = client.get("/foo/?a=a").send().await;
assert_eq!(res.status(), StatusCode::PERMANENT_REDIRECT);
assert_eq!(res.headers()["location"], "/foo?a=a");
}

#[test]
#[should_panic = "Cannot add a trailing slash redirect route for `/`"]
fn tsr_at_root() {
let _: Router = Router::new().route_with_tsr("/", get(|| async move {}));
}
}

0 comments on commit 7386e5d

Please sign in to comment.