Skip to content

Commit

Permalink
Fix bugs around merging routers with nested fallbacks
Browse files Browse the repository at this point in the history
  • Loading branch information
davidpdrsn committed Jul 15, 2023
1 parent b34715f commit 9c1d781
Show file tree
Hide file tree
Showing 3 changed files with 137 additions and 5 deletions.
12 changes: 9 additions & 3 deletions axum/src/routing/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ impl<S> fmt::Debug for Router<S> {
pub(crate) const NEST_TAIL_PARAM: &str = "__private__axum_nest_tail_param";
pub(crate) const NEST_TAIL_PARAM_CAPTURE: &str = "/*__private__axum_nest_tail_param";
pub(crate) const FALLBACK_PARAM: &str = "__private__axum_fallback";
pub(crate) const FALLBACK_PARAM_PATH: &str = "/*__private__axum_fallback";

impl<S> Router<S>
where
Expand Down Expand Up @@ -187,7 +188,7 @@ where
{
let Router {
path_router,
fallback_router: other_fallback,
fallback_router: mut other_fallback,
default_fallback,
catch_all_fallback,
} = other.into();
Expand All @@ -198,16 +199,21 @@ where
// both have the default fallback
// use the one from other
(true, true) => {
let fallback_router = std::mem::take(&mut self.fallback_router);
panic_on_err!(other_fallback.merge(fallback_router));
self.fallback_router = other_fallback;
}
// self has default fallback, other has a custom fallback
(true, false) => {
let fallback_router = std::mem::take(&mut self.fallback_router);
panic_on_err!(other_fallback.merge(fallback_router));
self.fallback_router = other_fallback;
self.default_fallback = false;
}
// self has a custom fallback, other has a default
// nothing to do
(false, true) => {}
(false, true) => {
panic_on_err!(self.fallback_router.merge(other_fallback));
}
// both have a custom fallback, not allowed
(false, false) => {
panic!("Cannot merge two `Router`s that both have a fallback")
Expand Down
11 changes: 9 additions & 2 deletions axum/src/routing/path_router.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use tower_service::Service;

use super::{
future::RouteFuture, not_found::NotFound, strip_prefix::StripPrefix, url_params, Endpoint,
MethodRouter, Route, RouteId, FALLBACK_PARAM, NEST_TAIL_PARAM,
MethodRouter, Route, RouteId, FALLBACK_PARAM_PATH, NEST_TAIL_PARAM,
};

pub(super) struct PathRouter<S, const IS_FALLBACK: bool> {
Expand All @@ -28,7 +28,7 @@ where

pub(super) fn set_fallback(&mut self, endpoint: Endpoint<S>) {
self.replace_endpoint("/", endpoint.clone());
self.replace_endpoint(&format!("/*{FALLBACK_PARAM}"), endpoint);
self.replace_endpoint(FALLBACK_PARAM_PATH, endpoint);
}
}

Expand Down Expand Up @@ -136,6 +136,13 @@ where
.route_id_to_path
.get(&id)
.expect("no path for route id. This is a bug in axum. Please file an issue");

if IS_FALLBACK && (&**path == "/" || &**path == FALLBACK_PARAM_PATH) {
// `Router::merge` makes sure that `/` and `/*__private__axum_fallback` on `other`
// will always be default fallbacks and thus we should ignore them
continue;
}

match route {
Endpoint::MethodRouter(method_router) => self.route(path, method_router)?,
Endpoint::Route(route) => self.route_service(path, route)?,
Expand Down
119 changes: 119 additions & 0 deletions axum/src/routing/tests/fallback.rs
Original file line number Diff line number Diff line change
Expand Up @@ -197,3 +197,122 @@ async fn doesnt_panic_if_used_with_nested_router() {
let res = client.get("/foobar").send().await;
assert_eq!(res.status(), StatusCode::OK);
}

#[crate::test]
async fn issue_2072() {
let nested_routes = Router::new().fallback(inner_fallback);

let app = Router::new()
.nest("/nested", nested_routes)
.merge(Router::new());

let client = TestClient::new(app);

let res = client.get("/nested/does-not-exist").send().await;
assert_eq!(res.status(), StatusCode::NOT_FOUND);
assert_eq!(res.text().await, "inner");

let res = client.get("/does-not-exist").send().await;
assert_eq!(res.status(), StatusCode::NOT_FOUND);
assert_eq!(res.text().await, "");
}

#[crate::test]
async fn issue_2072_outer_fallback_before_merge() {
let nested_routes = Router::new().fallback(inner_fallback);

let app = Router::new()
.nest("/nested", nested_routes)
.fallback(outer_fallback)
.merge(Router::new());

let client = TestClient::new(app);

let res = client.get("/nested/does-not-exist").send().await;
assert_eq!(res.status(), StatusCode::NOT_FOUND);
assert_eq!(res.text().await, "inner");

let res = client.get("/does-not-exist").send().await;
assert_eq!(res.status(), StatusCode::NOT_FOUND);
assert_eq!(res.text().await, "outer");
}

#[crate::test]
async fn issue_2072_outer_fallback_after_merge() {
let nested_routes = Router::new().fallback(inner_fallback);

let app = Router::new()
.nest("/nested", nested_routes)
.merge(Router::new())
.fallback(outer_fallback);

let client = TestClient::new(app);

let res = client.get("/nested/does-not-exist").send().await;
assert_eq!(res.status(), StatusCode::NOT_FOUND);
assert_eq!(res.text().await, "inner");

let res = client.get("/does-not-exist").send().await;
assert_eq!(res.status(), StatusCode::NOT_FOUND);
assert_eq!(res.text().await, "outer");
}

#[crate::test]
async fn merge_router_with_fallback_into_nested_router_with_fallback() {
let nested_routes = Router::new().fallback(inner_fallback);

let app = Router::new()
.nest("/nested", nested_routes)
.merge(Router::new().fallback(outer_fallback));

let client = TestClient::new(app);

let res = client.get("/nested/does-not-exist").send().await;
assert_eq!(res.status(), StatusCode::NOT_FOUND);
assert_eq!(res.text().await, "inner");

let res = client.get("/does-not-exist").send().await;
assert_eq!(res.status(), StatusCode::NOT_FOUND);
assert_eq!(res.text().await, "outer");
}

#[crate::test]
async fn merging_nested_router_with_fallback_into_router_with_fallback() {
let nested_routes = Router::new().fallback(inner_fallback);

let app = Router::new()
.fallback(outer_fallback)
.merge(Router::new().nest("/nested", nested_routes));

let client = TestClient::new(app);

let res = client.get("/nested/does-not-exist").send().await;
assert_eq!(res.status(), StatusCode::NOT_FOUND);
assert_eq!(res.text().await, "inner");

let res = client.get("/does-not-exist").send().await;
assert_eq!(res.status(), StatusCode::NOT_FOUND);
assert_eq!(res.text().await, "outer");
}

#[crate::test]
async fn merge_empty_into_router_with_fallback() {
let app = Router::new().fallback(outer_fallback).merge(Router::new());

let client = TestClient::new(app);

let res = client.get("/does-not-exist").send().await;
assert_eq!(res.status(), StatusCode::NOT_FOUND);
assert_eq!(res.text().await, "outer");
}

#[crate::test]
async fn merge_router_with_fallback_into_empty() {
let app = Router::new().merge(Router::new().fallback(outer_fallback));

let client = TestClient::new(app);

let res = client.get("/does-not-exist").send().await;
assert_eq!(res.status(), StatusCode::NOT_FOUND);
assert_eq!(res.text().await, "outer");
}

0 comments on commit 9c1d781

Please sign in to comment.