diff --git a/router.go b/router.go index bad404a51d..79604eb8b1 100644 --- a/router.go +++ b/router.go @@ -305,16 +305,9 @@ func (app *App) register(methods []string, pathRaw string, group *Group, handler } // Duplicate Route Handling - app.mutex.Lock() if app.routeExists(method, pathRaw) { - // TODO does RemoveRoute also need these operations? - //Decrement global handler count - atomic.AddUint32(&app.handlersCount, ^uint32(len(handlers)-1)) //nolint:gosec // Not a concern - // Decrement global route position - atomic.AddUint32(&app.routesCount, ^uint32(0)) - app.deleteRoute(pathRaw, []string{method}) + app.deleteRoute(pathRaw, []string{method}, len(handlers)) } - app.mutex.Unlock() // is mounted app isMount := group != nil && group.app != app @@ -398,14 +391,22 @@ func (app *App) routeExists(method string, pathRaw string) bool { // RemoveRoute is used to remove a route from the stack // You should call RebuildTree after using this to ensure consistency of the tree -// TODO write tests for this that explicitally delete a route right after adding it -func (app *App) RemoveRoute(path string, methods ...string) { +func (app *App) RemoveRoute(path string, middlewareCount int, methods ...string) { + app.deleteRoute(path, methods, middlewareCount) +} + +func (app *App) deleteRoute(path string, methods []string, middlewareCount int) { app.mutex.Lock() defer app.mutex.Unlock() - app.deleteRoute(path, methods) -} -func (app *App) deleteRoute(path string, methods []string) { + if middlewareCount == 0 { + middlewareCount++ + } + + //Decrement global handler count + atomic.AddUint32(&app.handlersCount, ^uint32(middlewareCount-1)) //nolint:gosec // Not a concern + // Decrement global route position + atomic.AddUint32(&app.routesCount, ^uint32(0)) for _, method := range methods { // Uppercase HTTP methods method = utils.ToUpper(method) diff --git a/router_test.go b/router_test.go index 7b3ab87193..3e993474e5 100644 --- a/router_test.go +++ b/router_test.go @@ -387,6 +387,23 @@ func Test_App_Remove_Route(t *testing.T) { t.Parallel() app := New() + app.Get("/test", func(c Ctx) error { + return c.SendStatus(http.StatusOK) + }) + + app.RemoveRoute("/test", 0, http.MethodGet) + app.RebuildTree() + + verifyRequest(t, app, "/test", http.StatusNotFound) + require.Equal(t, uint32(0), app.handlersCount) + require.Equal(t, uint32(0), app.routesCount) + verifyRouteHandlerCounts(t, app, 0) +} + +func Test_App_Route_Registration_Prevent_Duplicate(t *testing.T) { + t.Parallel() + app := New() + registerTreeManipulationRoutes(app) registerTreeManipulationRoutes(app) @@ -406,10 +423,10 @@ func Test_App_Remove_Route(t *testing.T) { require.Equal(t, uint32(2), app.handlersCount) require.Equal(t, uint32(2), app.routesCount) - verifyRouteHandlerCounts(t, app) + verifyRouteHandlerCounts(t, app, 1) } -func Test_App_Remove_Route_With_Middleware(t *testing.T) { +func Test_Route_Registration_Prevent_Duplicate_With_Middleware(t *testing.T) { t.Parallel() app := New() @@ -436,10 +453,11 @@ func Test_App_Remove_Route_With_Middleware(t *testing.T) { require.Equal(t, uint32(3), app.handlersCount) require.Equal(t, uint32(2), app.routesCount) - verifyRouteHandlerCounts(t, app) + verifyRouteHandlerCounts(t, app, 1) } func registerTreeManipulationRoutes(app *App, middleware ...func(Ctx) error) { + app.Get("/test", func(c Ctx) error { app.Get("/dynamically-defined", func(c Ctx) error { return c.SendStatus(http.StatusOK) @@ -463,7 +481,7 @@ func verifyRequest(tb testing.TB, app *App, path string, expectedStatus int) *ht } // this is taken from listen.go's printRoutesMessage app method -func verifyRouteHandlerCounts(tb testing.TB, app *App) { +func verifyRouteHandlerCounts(tb testing.TB, app *App, expectedRoutesCount int) { tb.Helper() var routes []RouteMessage @@ -485,7 +503,7 @@ func verifyRouteHandlerCounts(tb testing.TB, app *App) { } for _, route := range routes { - require.Equal(tb, 1, strings.Count(route.handlers, " ")) + require.Equal(tb, expectedRoutesCount, strings.Count(route.handlers, " ")) } }