Skip to content

Commit

Permalink
Add RemoveRouteByName method, add more tests, cleanup deleteRoute method
Browse files Browse the repository at this point in the history
  • Loading branch information
ckoch786 committed Dec 8, 2024
1 parent 590152e commit d411cb1
Show file tree
Hide file tree
Showing 2 changed files with 181 additions and 77 deletions.
56 changes: 37 additions & 19 deletions router.go
Original file line number Diff line number Diff line change
Expand Up @@ -306,7 +306,8 @@ func (app *App) register(methods []string, pathRaw string, group *Group, handler

// Duplicate Route Handling
if app.routeExists(method, pathRaw) {
app.deleteRoute(pathRaw, []string{method}, len(handlers))
matchPathFunc := func(r *Route) bool { return r.Path == pathRaw }
app.deleteRoute([]string{method}, matchPathFunc)
}

// is mounted app
Expand Down Expand Up @@ -364,7 +365,6 @@ func (app *App) register(methods []string, pathRaw string, group *Group, handler
Method: method,
Handlers: handlers,
}
// Should this use a mutex lock?
// Increment global handler count
atomic.AddUint32(&app.handlersCount, uint32(len(handlers))) //nolint:gosec // Not a concern

Expand All @@ -384,29 +384,41 @@ func (app *App) register(methods []string, pathRaw string, group *Group, handler
}

func (app *App) routeExists(method string, pathRaw string) bool {
pathToCheck := pathRaw
if !app.config.CaseSensitive {
pathToCheck = utils.ToLower(pathToCheck)
}

return slices.ContainsFunc(app.stack[app.methodInt(method)], func(r *Route) bool {
return r.path == pathRaw
routePath := r.path
if !app.config.CaseSensitive {
routePath = utils.ToLower(routePath)
}

return routePath == pathToCheck
})
}

// RemoveRoute is used to remove a route from the stack
// You should call RebuildTree after using this to ensure consistency of the tree
func (app *App) RemoveRoute(path string, middlewareCount int, methods ...string) {
app.deleteRoute(path, methods, middlewareCount)
// RemoveRoute is used to remove a route from the stack by path.
// This only needs to be called to remove a route, route registration prevents duplicate routes.
// You should call RebuildTree after using this to ensure consistency of the tree.
func (app *App) RemoveRoute(path string, methods ...string) {
pathMatchFunc := func(r *Route) bool { return r.Path == path }
app.deleteRoute(methods, pathMatchFunc)
}

func (app *App) deleteRoute(path string, methods []string, middlewareCount int) {
// RemoveRouteByName is used to remove a route from the stack by name.
// This only needs to be called to remove a route, route registration prevents duplicate routes.
// You should call RebuildTree after using this to ensure consistency of the tree.
func (app *App) RemoveRouteByName(name string, methods ...string) {
matchFunc := func(r *Route) bool { return r.Name == name }
app.deleteRoute(methods, matchFunc)
}

func (app *App) deleteRoute(methods []string, matchFunc func(r *Route) bool) {
app.mutex.Lock()
defer app.mutex.Unlock()

if middlewareCount == 0 {
middlewareCount++
}

//Decrement global handler count
atomic.AddUint32(&app.handlersCount, ^uint32(middlewareCount-1)) //nolint:gosec // Not a concern
// Decrement global route position
atomic.AddUint32(&app.routesCount, ^uint32(0))
for _, method := range methods {
// Uppercase HTTP methods
method = utils.ToUpper(method)
Expand All @@ -418,13 +430,19 @@ func (app *App) deleteRoute(path string, methods []string, middlewareCount int)
}

// Find the index of the route to remove
index := slices.IndexFunc(app.stack[m], func(r *Route) bool {
return r.Path == path
})
index := slices.IndexFunc(app.stack[m], matchFunc)
if index == -1 {
continue // Route not found
}

route := app.stack[m][index]

handlerCount := len(route.Handlers)
//Decrement global handler count
atomic.AddUint32(&app.handlersCount, ^uint32(handlerCount-1)) // nolint:gosec // Not a concern
// Decrement global route position
atomic.AddUint32(&app.routesCount, ^uint32(0))

// Remove route from tree stack
app.stack[m] = slices.Delete(app.stack[m], index, index+1)
}
Expand Down
202 changes: 144 additions & 58 deletions router_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import (
"reflect"
"runtime"
"strings"
"sync"
"testing"

"github.com/gofiber/utils/v2"
Expand Down Expand Up @@ -372,6 +373,65 @@ func Test_Router_NotFound_HTML_Inject(t *testing.T) {
require.Equal(t, "Cannot DELETE /does/not/exist<script>alert('foo');</script>", string(c.Response.Body()))
}

func sendStatusOK(c Ctx) error {
return c.SendStatus(http.StatusOK)
}

func registerTreeManipulationRoutes(app *App, middleware ...func(Ctx) error) {

app.Get("/test", func(c Ctx) error {
app.Get("/dynamically-defined", sendStatusOK)

app.RebuildTree()

return c.SendStatus(http.StatusOK)
}, middleware...)

}

func verifyRequest(tb testing.TB, app *App, path string, expectedStatus int) *http.Response {
tb.Helper()

resp, err := app.Test(httptest.NewRequest(MethodGet, path, nil))
require.NoError(tb, err, "app.Test(req)")
require.Equal(tb, expectedStatus, resp.StatusCode, "Status code")

return resp
}

func verifyRouteHandlerCounts(tb testing.TB, app *App, expectedRoutesCount int) {
tb.Helper()

// this is taken from listen.go's printRoutesMessage app method
var routes []RouteMessage
for _, routeStack := range app.stack {

for _, route := range routeStack {
routeMsg := RouteMessage{
name: route.Name,
method: route.Method,
path: route.Path,
}

for _, handler := range route.Handlers {
routeMsg.handlers += runtime.FuncForPC(reflect.ValueOf(handler).Pointer()).Name() + " "
}

routes = append(routes, routeMsg)
}
}

for _, route := range routes {
require.Equal(tb, expectedRoutesCount, strings.Count(route.handlers, " "))
}
}

func verifyThereAreNoRoutes(tb testing.TB, app *App) {
require.Equal(tb, uint32(0), app.handlersCount)
require.Equal(tb, uint32(0), app.routesCount)
verifyRouteHandlerCounts(tb, app, 0)
}

func Test_App_Rebuild_Tree(t *testing.T) {
t.Parallel()
app := New()
Expand All @@ -383,21 +443,98 @@ func Test_App_Rebuild_Tree(t *testing.T) {
verifyRequest(t, app, "/dynamically-defined", http.StatusOK)
}

func Test_App_Remove_Route_By_Name(t *testing.T) {
t.Parallel()
app := New()

app.Get("/api/test", sendStatusOK).Name("test")

app.RemoveRouteByName("test", http.MethodGet)
app.RebuildTree()

verifyRequest(t, app, "/test", http.StatusNotFound)
verifyThereAreNoRoutes(t, app)
}

func Test_App_Remove_Route_By_Name_Non_Existing_Route(t *testing.T) {
t.Parallel()
app := New()

app.RemoveRouteByName("test", http.MethodGet)
app.RebuildTree()

verifyThereAreNoRoutes(t, app)
}

func Test_App_Remove_Route_Nested(t *testing.T) {
t.Parallel()
app := New()

api := app.Group("/api")

v1 := api.Group("/v1")
v1.Get("/test", sendStatusOK)

app.RemoveRoute("/api/v1/test", http.MethodGet)

verifyThereAreNoRoutes(t, app)
}

func Test_App_Remove_Route_Parameterized(t *testing.T) {
t.Parallel()
app := New()

app.Get("/test/:id", sendStatusOK)
app.RemoveRoute("/test/:id", http.MethodGet)

verifyThereAreNoRoutes(t, app)

}

func Test_App_Remove_Route(t *testing.T) {
t.Parallel()
app := New()

app.Get("/test", func(c Ctx) error {
return c.SendStatus(http.StatusOK)
})
app.Get("/test", sendStatusOK)

app.RemoveRoute("/test", 0, http.MethodGet)
app.RemoveRoute("/test", http.MethodGet)
app.RebuildTree()

verifyRequest(t, app, "/test", http.StatusNotFound)
require.Equal(t, uint32(0), app.handlersCount)
require.Equal(t, uint32(0), app.routesCount)
verifyRouteHandlerCounts(t, app, 0)
}

func Test_App_Remove_Route_Non_Existing_Route(t *testing.T) {
t.Parallel()
app := New()

app.RemoveRoute("/test", http.MethodGet, http.MethodHead)
app.RebuildTree()

verifyThereAreNoRoutes(t, app)
}

func Test_App_Remove_Route_Concurrent(t *testing.T) {
t.Parallel()
app := New()

// Add test route
app.Get("/test", sendStatusOK)

// Concurrently remove and add routes
var wg sync.WaitGroup
for i := 0; i < 10; i++ {
wg.Add(1)
go func() {
defer wg.Done()
app.RemoveRoute("/test", http.MethodGet)
app.Get("/test", sendStatusOK)
}()
}
wg.Wait()

// Verify final state
app.RebuildTree()
verifyRequest(t, app, "/test", http.StatusOK)
}

func Test_App_Route_Registration_Prevent_Duplicate(t *testing.T) {
Expand Down Expand Up @@ -456,57 +593,6 @@ func Test_Route_Registration_Prevent_Duplicate_With_Middleware(t *testing.T) {
verifyRouteHandlerCounts(t, app, 1)
}

func registerTreeManipulationRoutes(app *App, middleware ...func(Ctx) error) {

app.Get("/test", func(c Ctx) error {
app.Get("/dynamically-defined", func(c Ctx) error {
return c.SendStatus(http.StatusOK)
})

app.RebuildTree()

return c.SendStatus(http.StatusOK)
}, middleware...)

}

func verifyRequest(tb testing.TB, app *App, path string, expectedStatus int) *http.Response {
tb.Helper()

resp, err := app.Test(httptest.NewRequest(MethodGet, path, nil))
require.NoError(tb, err, "app.Test(req)")
require.Equal(tb, expectedStatus, resp.StatusCode, "Status code")

return resp
}

// this is taken from listen.go's printRoutesMessage app method
func verifyRouteHandlerCounts(tb testing.TB, app *App, expectedRoutesCount int) {
tb.Helper()

var routes []RouteMessage
for _, routeStack := range app.stack {

for _, route := range routeStack {
routeMsg := RouteMessage{
name: route.Name,
method: route.Method,
path: route.Path,
}

for _, handler := range route.Handlers {
routeMsg.handlers += runtime.FuncForPC(reflect.ValueOf(handler).Pointer()).Name() + " "
}

routes = append(routes, routeMsg)
}
}

for _, route := range routes {
require.Equal(tb, expectedRoutesCount, strings.Count(route.handlers, " "))
}
}

//////////////////////////////////////////////
///////////////// BENCHMARKS /////////////////
//////////////////////////////////////////////
Expand Down

0 comments on commit d411cb1

Please sign in to comment.