From 60434e01882710c0fd19ccf5a0b840f25da39624 Mon Sep 17 00:00:00 2001 From: AnikHasibul Date: Mon, 15 Jul 2019 19:57:18 +0600 Subject: [PATCH] =?UTF-8?q?=F0=9F=90=9B=20FIX:=20multiple=20sub-router=20g?= =?UTF-8?q?roup=20chaining=20bug?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- router.go | 17 +++++++++-------- router_test.go | 49 +++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 58 insertions(+), 8 deletions(-) diff --git a/router.go b/router.go index a04b317..d0d2611 100644 --- a/router.go +++ b/router.go @@ -220,20 +220,21 @@ func (r *Router) Handle(method, path string, handle fasthttp.RequestHandler) { if r.beginPath != "/" { path = r.beginPath + path } - var route *Router + + // Call to the parent recursively until main router to register paths in it if r.parent != nil { - route = r.parent - } else { - route = r + r.parent.Handle(method, path, handle) + return } - if route.trees == nil { - route.trees = make(map[string]*node) + + if r.trees == nil { + r.trees = make(map[string]*node) } - root := route.trees[method] + root := r.trees[method] if root == nil { root = new(node) - route.trees[method] = root + r.trees[method] = root } root.addRoute(path, handle) diff --git a/router_test.go b/router_test.go index 5d984b3..f3eef56 100644 --- a/router_test.go +++ b/router_test.go @@ -314,6 +314,8 @@ func TestRouterGroup(t *testing.T) { r2 := r1.Group("/boo") r3 := r1.Group("/goo") r4 := r1.Group("/moo") + r5 := r4.Group("/foo") + r6 := r5.Group("/foo") fooHit := false r1.POST("/foo", func(ctx *fasthttp.RequestCtx) { fooHit = true @@ -333,6 +335,14 @@ func TestRouterGroup(t *testing.T) { barHit = true ctx.SetStatusCode(fasthttp.StatusOK) }) + r5.POST("/bar", func(ctx *fasthttp.RequestCtx) { + barHit = true + ctx.SetStatusCode(fasthttp.StatusOK) + }) + r6.POST("/bar", func(ctx *fasthttp.RequestCtx) { + barHit = true + ctx.SetStatusCode(fasthttp.StatusOK) + }) s := &fasthttp.Server{ Handler: r1.Handler, } @@ -422,6 +432,45 @@ func TestRouterGroup(t *testing.T) { t.FailNow() } + rw.r.WriteString("POST /moo/foo/bar HTTP/1.1\r\n\r\n") + go func() { + ch <- s.ServeConn(rw) + }() + select { + case err := <-ch: + if err != nil { + t.Fatalf("return error %s", err) + } + case <-time.After(100 * time.Millisecond): + t.Fatalf("timeout") + } + if err := resp.Read(br); err != nil { + t.Fatalf("Unexpected error when reading response: %s", err) + } + if !(resp.Header.StatusCode() == fasthttp.StatusOK && barHit) { + t.Errorf("Chained routing failed with subrouter grouping.") + t.FailNow() + } + rw.r.WriteString("POST /moo/foo/foo/bar HTTP/1.1\r\n\r\n") + go func() { + ch <- s.ServeConn(rw) + }() + select { + case err := <-ch: + if err != nil { + t.Fatalf("return error %s", err) + } + case <-time.After(100 * time.Millisecond): + t.Fatalf("timeout") + } + if err := resp.Read(br); err != nil { + t.Fatalf("Unexpected error when reading response: %s", err) + } + if !(resp.Header.StatusCode() == fasthttp.StatusOK && barHit) { + t.Errorf("Chained routing failed with subrouter grouping.") + t.FailNow() + } + rw.r.WriteString("POST /qax HTTP/1.1\r\n\r\n") go func() { ch <- s.ServeConn(rw)