From 17a3c89a6838e6aec3d4205b11e8a1904fb1afe3 Mon Sep 17 00:00:00 2001 From: David Pedersen Date: Fri, 3 Mar 2023 10:59:04 +0100 Subject: [PATCH] Fix routing issues when loading a `Router` via a dynamic library --- axum/CHANGELOG.md | 1 + axum/src/routing/mod.rs | 35 ++++++++++++++++++++--------------- 2 files changed, 21 insertions(+), 15 deletions(-) diff --git a/axum/CHANGELOG.md b/axum/CHANGELOG.md index 194c8b4d41b..521076bd8ca 100644 --- a/axum/CHANGELOG.md +++ b/axum/CHANGELOG.md @@ -8,6 +8,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 # Unreleased - **fixed:** Add `#[must_use]` to `WebSocketUpgrade::on_upgrade` ([#1801]) +- **fixed:** Fix routing issues when loading a `Router` via a dynamic library [#1801]: https://github.com/tokio-rs/axum/pull/1801 diff --git a/axum/src/routing/mod.rs b/axum/src/routing/mod.rs index e7b717ea41d..7f9af4540be 100644 --- a/axum/src/routing/mod.rs +++ b/axum/src/routing/mod.rs @@ -47,24 +47,12 @@ pub use self::method_routing::{ #[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)] pub(crate) struct RouteId(u32); -impl RouteId { - fn next() -> Self { - use std::sync::atomic::{AtomicU32, Ordering}; - // `AtomicU64` isn't supported on all platforms - static ID: AtomicU32 = AtomicU32::new(0); - let id = ID.fetch_add(1, Ordering::Relaxed); - if id == u32::MAX { - panic!("Over `u32::MAX` routes created. If you need this, please file an issue."); - } - Self(id) - } -} - /// The router type for composing handlers and services. pub struct Router { routes: HashMap>, node: Arc, fallback: Fallback, + prev_route_id: Option, } impl Clone for Router { @@ -73,6 +61,7 @@ impl Clone for Router { routes: self.routes.clone(), node: Arc::clone(&self.node), fallback: self.fallback.clone(), + prev_route_id: self.prev_route_id, } } } @@ -117,6 +106,7 @@ where routes: Default::default(), node: Default::default(), fallback: Fallback::Default(Route::new(NotFound)), + prev_route_id: Default::default(), } } @@ -134,7 +124,7 @@ where validate_path(path); - let id = RouteId::next(); + let id = self.next_route_id(); let endpoint = if let Some((route_id, Endpoint::MethodRouter(prev_method_router))) = self .node @@ -189,7 +179,7 @@ where panic!("Paths must start with a `/`"); } - let id = RouteId::next(); + let id = self.next_route_id(); self.set_node(path, id); self.routes.insert(id, endpoint); self @@ -286,6 +276,7 @@ where routes, node, fallback, + prev_route_id: _, } = other.into(); for (id, route) in routes { @@ -335,6 +326,7 @@ where routes, node: self.node, fallback, + prev_route_id: self.prev_route_id, } } @@ -368,6 +360,7 @@ where routes, node: self.node, fallback: self.fallback, + prev_route_id: self.prev_route_id, } } @@ -419,6 +412,7 @@ where routes, node: self.node, fallback, + prev_route_id: self.prev_route_id, } } @@ -506,6 +500,17 @@ where Endpoint::NestedRouter(router) => router.call_with_state(req, state), } } + + fn next_route_id(&mut self) -> RouteId { + let next_id = self.prev_route_id.map_or(RouteId(0), |RouteId(id)| { + let next = id + .checked_add(1) + .expect("Over `u32::MAX` routes created. If you need this, please file an issue."); + RouteId(next) + }); + self.prev_route_id = Some(next_id); + next_id + } } impl Router<(), B>