From 664cf8c1060bb9fe97de6e7f13bd65fbaf292402 Mon Sep 17 00:00:00 2001 From: Martti T Date: Sat, 6 Mar 2021 01:43:59 +0200 Subject: [PATCH] Refactor router for readability (#1796) * refactor router tests to table driven (this way it is easier to debug test cases with breakpoints) * refactor router variables to be more readable --- router.go | 305 ++++---- router_test.go | 1869 ++++++++++++++++++++++++++++++------------------ 2 files changed, 1328 insertions(+), 846 deletions(-) diff --git a/router.go b/router.go index f0e9e51f4..2dd09fae2 100644 --- a/router.go +++ b/router.go @@ -13,16 +13,16 @@ type ( echo *Echo } node struct { - kind kind - label byte - prefix string - parent *node - staticChildrens children - ppath string - pnames []string - methodHandler *methodHandler - paramChildren *node - anyChildren *node + kind kind + label byte + prefix string + parent *node + staticChildren children + ppath string + pnames []string + methodHandler *methodHandler + paramChild *node + anyChild *node } kind uint8 children []*node @@ -42,9 +42,9 @@ type ( ) const ( - skind kind = iota - pkind - akind + staticKind kind = iota + paramKind + anyKind paramLabel = byte(':') anyLabel = byte('*') @@ -73,137 +73,147 @@ func (r *Router) Add(method, path string, h HandlerFunc) { pnames := []string{} // Param names ppath := path // Pristine path - for i, l := 0, len(path); i < l; i++ { + for i, lcpIndex := 0, len(path); i < lcpIndex; i++ { if path[i] == ':' { j := i + 1 - r.insert(method, path[:i], nil, skind, "", nil) - for ; i < l && path[i] != '/'; i++ { + r.insert(method, path[:i], nil, staticKind, "", nil) + for ; i < lcpIndex && path[i] != '/'; i++ { } pnames = append(pnames, path[j:i]) path = path[:j] + path[i:] - i, l = j, len(path) + i, lcpIndex = j, len(path) - if i == l { - r.insert(method, path[:i], h, pkind, ppath, pnames) + if i == lcpIndex { + r.insert(method, path[:i], h, paramKind, ppath, pnames) } else { - r.insert(method, path[:i], nil, pkind, "", nil) + r.insert(method, path[:i], nil, paramKind, "", nil) } } else if path[i] == '*' { - r.insert(method, path[:i], nil, skind, "", nil) + r.insert(method, path[:i], nil, staticKind, "", nil) pnames = append(pnames, "*") - r.insert(method, path[:i+1], h, akind, ppath, pnames) + r.insert(method, path[:i+1], h, anyKind, ppath, pnames) } } - r.insert(method, path, h, skind, ppath, pnames) + r.insert(method, path, h, staticKind, ppath, pnames) } func (r *Router) insert(method, path string, h HandlerFunc, t kind, ppath string, pnames []string) { // Adjust max param - l := len(pnames) - if *r.echo.maxParam < l { - *r.echo.maxParam = l + paramLen := len(pnames) + if *r.echo.maxParam < paramLen { + *r.echo.maxParam = paramLen } - cn := r.tree // Current node as root - if cn == nil { + currentNode := r.tree // Current node as root + if currentNode == nil { panic("echo: invalid method") } search := path for { - sl := len(search) - pl := len(cn.prefix) - l := 0 - - // LCP - max := pl - if sl < max { - max = sl + searchLen := len(search) + prefixLen := len(currentNode.prefix) + lcpLen := 0 + + // LCP - Longest Common Prefix (https://en.wikipedia.org/wiki/LCP_array) + max := prefixLen + if searchLen < max { + max = searchLen } - for ; l < max && search[l] == cn.prefix[l]; l++ { + for ; lcpLen < max && search[lcpLen] == currentNode.prefix[lcpLen]; lcpLen++ { } - if l == 0 { + if lcpLen == 0 { // At root node - cn.label = search[0] - cn.prefix = search + currentNode.label = search[0] + currentNode.prefix = search if h != nil { - cn.kind = t - cn.addHandler(method, h) - cn.ppath = ppath - cn.pnames = pnames + currentNode.kind = t + currentNode.addHandler(method, h) + currentNode.ppath = ppath + currentNode.pnames = pnames } - } else if l < pl { + } else if lcpLen < prefixLen { // Split node - n := newNode(cn.kind, cn.prefix[l:], cn, cn.staticChildrens, cn.methodHandler, cn.ppath, cn.pnames, cn.paramChildren, cn.anyChildren) + n := newNode( + currentNode.kind, + currentNode.prefix[lcpLen:], + currentNode, + currentNode.staticChildren, + currentNode.methodHandler, + currentNode.ppath, + currentNode.pnames, + currentNode.paramChild, + currentNode.anyChild, + ) // Update parent path for all children to new node - for _, child := range cn.staticChildrens { + for _, child := range currentNode.staticChildren { child.parent = n } - if cn.paramChildren != nil { - cn.paramChildren.parent = n + if currentNode.paramChild != nil { + currentNode.paramChild.parent = n } - if cn.anyChildren != nil { - cn.anyChildren.parent = n + if currentNode.anyChild != nil { + currentNode.anyChild.parent = n } // Reset parent node - cn.kind = skind - cn.label = cn.prefix[0] - cn.prefix = cn.prefix[:l] - cn.staticChildrens = nil - cn.methodHandler = new(methodHandler) - cn.ppath = "" - cn.pnames = nil - cn.paramChildren = nil - cn.anyChildren = nil + currentNode.kind = staticKind + currentNode.label = currentNode.prefix[0] + currentNode.prefix = currentNode.prefix[:lcpLen] + currentNode.staticChildren = nil + currentNode.methodHandler = new(methodHandler) + currentNode.ppath = "" + currentNode.pnames = nil + currentNode.paramChild = nil + currentNode.anyChild = nil // Only Static children could reach here - cn.addStaticChild(n) + currentNode.addStaticChild(n) - if l == sl { + if lcpLen == searchLen { // At parent node - cn.kind = t - cn.addHandler(method, h) - cn.ppath = ppath - cn.pnames = pnames + currentNode.kind = t + currentNode.addHandler(method, h) + currentNode.ppath = ppath + currentNode.pnames = pnames } else { // Create child node - n = newNode(t, search[l:], cn, nil, new(methodHandler), ppath, pnames, nil, nil) + n = newNode(t, search[lcpLen:], currentNode, nil, new(methodHandler), ppath, pnames, nil, nil) n.addHandler(method, h) // Only Static children could reach here - cn.addStaticChild(n) + currentNode.addStaticChild(n) } - } else if l < sl { - search = search[l:] - c := cn.findChildWithLabel(search[0]) + } else if lcpLen < searchLen { + search = search[lcpLen:] + c := currentNode.findChildWithLabel(search[0]) if c != nil { // Go deeper - cn = c + currentNode = c continue } // Create child node - n := newNode(t, search, cn, nil, new(methodHandler), ppath, pnames, nil, nil) + n := newNode(t, search, currentNode, nil, new(methodHandler), ppath, pnames, nil, nil) n.addHandler(method, h) switch t { - case skind: - cn.addStaticChild(n) - case pkind: - cn.paramChildren = n - case akind: - cn.anyChildren = n + case staticKind: + currentNode.addStaticChild(n) + case paramKind: + currentNode.paramChild = n + case anyKind: + currentNode.anyChild = n } } else { // Node already exists if h != nil { - cn.addHandler(method, h) - cn.ppath = ppath - if len(cn.pnames) == 0 { // Issue #729 - cn.pnames = pnames + currentNode.addHandler(method, h) + currentNode.ppath = ppath + if len(currentNode.pnames) == 0 { // Issue #729 + currentNode.pnames = pnames } } } @@ -213,25 +223,25 @@ func (r *Router) insert(method, path string, h HandlerFunc, t kind, ppath string func newNode(t kind, pre string, p *node, sc children, mh *methodHandler, ppath string, pnames []string, paramChildren, anyChildren *node) *node { return &node{ - kind: t, - label: pre[0], - prefix: pre, - parent: p, - staticChildrens: sc, - ppath: ppath, - pnames: pnames, - methodHandler: mh, - paramChildren: paramChildren, - anyChildren: anyChildren, + kind: t, + label: pre[0], + prefix: pre, + parent: p, + staticChildren: sc, + ppath: ppath, + pnames: pnames, + methodHandler: mh, + paramChild: paramChildren, + anyChild: anyChildren, } } func (n *node) addStaticChild(c *node) { - n.staticChildrens = append(n.staticChildrens, c) + n.staticChildren = append(n.staticChildren, c) } func (n *node) findStaticChild(l byte) *node { - for _, c := range n.staticChildrens { + for _, c := range n.staticChildren { if c.label == l { return c } @@ -240,16 +250,16 @@ func (n *node) findStaticChild(l byte) *node { } func (n *node) findChildWithLabel(l byte) *node { - for _, c := range n.staticChildrens { + for _, c := range n.staticChildren { if c.label == l { return c } } if l == paramLabel { - return n.paramChildren + return n.paramChild } if l == anyLabel { - return n.anyChildren + return n.anyChild } return nil } @@ -330,13 +340,15 @@ func (n *node) checkMethodNotAllowed() HandlerFunc { func (r *Router) Find(method, path string, c Context) { ctx := c.(*context) ctx.path = path - cn := r.tree // Current node as root + currentNode := r.tree // Current node as root var ( + // search stores the remaining path to check for match. By each iteration we move from start of path to end of the path + // and search value gets shorter and shorter. search = path searchIndex = 0 - n int // Param counter - pvalues = ctx.pvalues // Use the internal slice so the interface can keep the illusion of a dynamic slice + paramIndex int // Param counter + paramValues = ctx.pvalues // Use the internal slice so the interface can keep the illusion of a dynamic slice ) // Backtracking is needed when a dead end (leaf node) is reached in the router tree. @@ -345,9 +357,9 @@ func (r *Router) Find(method, path string, c Context) { // For example if there is no static node match we should check parent next sibling by kind (param). // Backtracking itself does not check if there is a next sibling, this is done by the router logic. backtrackToNextNodeKind := func(fromKind kind) (nextNodeKind kind, valid bool) { - previous := cn - cn = previous.parent - valid = cn != nil + previous := currentNode + currentNode = previous.parent + valid = currentNode != nil // Next node type by priority // NOTE: With the current implementation we never backtrack from an `any` route, so `previous.kind` is @@ -355,51 +367,57 @@ func (r *Router) Find(method, path string, c Context) { // If this is changed then for any route next kind would be `static` and this statement should be changed nextNodeKind = previous.kind + 1 - if fromKind == skind { + if fromKind == staticKind { // when backtracking is done from static kind block we did not change search so nothing to restore return } // restore search to value it was before we move to current node we are backtracking from. - if previous.kind == skind { + if previous.kind == staticKind { searchIndex -= len(previous.prefix) } else { - n-- + paramIndex-- // for param/any node.prefix value is always `:` so we can not deduce searchIndex from that and must use pValue // for that index as it would also contain part of path we cut off before moving into node we are backtracking from - searchIndex -= len(pvalues[n]) + searchIndex -= len(paramValues[paramIndex]) } search = path[searchIndex:] return } - // Search order static > param > any + // Router tree is implemented by longest common prefix array (LCP array) https://en.wikipedia.org/wiki/LCP_array + // Tree search is implemented as for loop where one loop iteration is divided into 3 separate blocks + // Each of these blocks checks specific kind of node (static/param/any). Order of blocks reflex their priority in routing. + // Search order/priority is: static > param > any. + // + // Note: backtracking in tree is implemented by replacing/switching currentNode to previous node + // and hoping to (goto statement) next block by priority to check if it is the match. for { - pl := 0 // Prefix length - l := 0 // LCP length + prefixLen := 0 // Prefix length + lcpLen := 0 // LCP (longest common prefix) length - if cn.label != ':' { - sl := len(search) - pl = len(cn.prefix) + if currentNode.kind == staticKind { + searchLen := len(search) + prefixLen = len(currentNode.prefix) - // LCP - max := pl - if sl < max { - max = sl + // LCP - Longest Common Prefix (https://en.wikipedia.org/wiki/LCP_array) + max := prefixLen + if searchLen < max { + max = searchLen } - for ; l < max && search[l] == cn.prefix[l]; l++ { + for ; lcpLen < max && search[lcpLen] == currentNode.prefix[lcpLen]; lcpLen++ { } } - if l != pl { + if lcpLen != prefixLen { // No matching prefix, let's backtrack to the first possible alternative node of the decision path - nk, ok := backtrackToNextNodeKind(skind) + nk, ok := backtrackToNextNodeKind(staticKind) if !ok { return // No other possibilities on the decision path - } else if nk == pkind { + } else if nk == paramKind { goto Param // NOTE: this case (backtracking from static node to previous any node) can not happen by current any matching logic. Any node is end of search currently - //} else if nk == akind { + //} else if nk == anyKind { // goto Any } else { // Not found (this should never be possible for static node we are looking currently) @@ -408,31 +426,32 @@ func (r *Router) Find(method, path string, c Context) { } // The full prefix has matched, remove the prefix from the remaining search - search = search[l:] - searchIndex = searchIndex + l + search = search[lcpLen:] + searchIndex = searchIndex + lcpLen // Finish routing if no remaining search and we are on an leaf node - if search == "" && cn.ppath != "" { + if search == "" && currentNode.ppath != "" { break } // Static node if search != "" { - if child := cn.findStaticChild(search[0]); child != nil { - cn = child + if child := currentNode.findStaticChild(search[0]); child != nil { + currentNode = child continue } } Param: // Param node - if child := cn.paramChildren; search != "" && child != nil { - cn = child + if child := currentNode.paramChild; search != "" && child != nil { + currentNode = child + // FIXME: when param node does not have any children then param node should act similarly to any node - consider all remaining search as match i, l := 0, len(search) for ; i < l && search[i] != '/'; i++ { } - pvalues[n] = search[:i] - n++ + paramValues[paramIndex] = search[:i] + paramIndex++ search = search[i:] searchIndex = searchIndex + i continue @@ -440,20 +459,20 @@ func (r *Router) Find(method, path string, c Context) { Any: // Any node - if child := cn.anyChildren; child != nil { - // If any node is found, use remaining path for pvalues - cn = child - pvalues[len(cn.pnames)-1] = search + if child := currentNode.anyChild; child != nil { + // If any node is found, use remaining path for paramValues + currentNode = child + paramValues[len(currentNode.pnames)-1] = search break } // Let's backtrack to the first possible alternative node of the decision path - nk, ok := backtrackToNextNodeKind(akind) + nk, ok := backtrackToNextNodeKind(anyKind) if !ok { return // No other possibilities on the decision path - } else if nk == pkind { + } else if nk == paramKind { goto Param - } else if nk == akind { + } else if nk == anyKind { goto Any } else { // Not found @@ -461,12 +480,12 @@ func (r *Router) Find(method, path string, c Context) { } } - ctx.handler = cn.findHandler(method) - ctx.path = cn.ppath - ctx.pnames = cn.pnames + ctx.handler = currentNode.findHandler(method) + ctx.path = currentNode.ppath + ctx.pnames = currentNode.pnames if ctx.handler == nil { - ctx.handler = cn.checkMethodNotAllowed() + ctx.handler = currentNode.checkMethodNotAllowed() } return } diff --git a/router_test.go b/router_test.go index ba1890bd1..47e499402 100644 --- a/router_test.go +++ b/router_test.go @@ -640,42 +640,90 @@ var ( return nil } } + handlerFunc = func(c Context) error { + c.Set("path", c.Path()) + return nil + } ) +func checkUnusedParamValues(t *testing.T, c *context, expectParam map[string]string) { + for i, p := range c.pnames { + value := c.pvalues[i] + if value != "" { + if expectParam == nil { + t.Errorf("pValue '%v' is set for param name '%v' but we are not expecting it with expectParam", value, p) + } else { + if _, ok := expectParam[p]; !ok { + t.Errorf("pValue '%v' is set for param name '%v' but we are not expecting it with expectParam", value, p) + } + } + } + } +} + func TestRouterStatic(t *testing.T) { e := New() r := e.router path := "/folders/a/files/echo.gif" - r.Add(http.MethodGet, path, func(c Context) error { - c.Set("path", path) - return nil - }) + r.Add(http.MethodGet, path, handlerFunc) c := e.NewContext(nil, nil).(*context) + r.Find(http.MethodGet, path, c) c.handler(c) + assert.Equal(t, path, c.Get("path")) } func TestRouterParam(t *testing.T) { e := New() r := e.router - r.Add(http.MethodGet, "/users/:id", func(c Context) error { - return nil - }) - c := e.NewContext(nil, nil).(*context) - r.Find(http.MethodGet, "/users/1", c) - assert.Equal(t, "1", c.Param("id")) + + r.Add(http.MethodGet, "/users/:id", handlerFunc) + + var testCases = []struct { + name string + whenURL string + expectRoute interface{} + expectParam map[string]string + }{ + { + name: "route /users/1 to /users/:id", + whenURL: "/users/1", + expectRoute: "/users/:id", + expectParam: map[string]string{"id": "1"}, + }, + { // FIXME: this documents current implementation (slash at end is problematic) + name: "route /users/1/ to /users/:id", + whenURL: "/users/1/", + expectRoute: nil, // FIXME: should be "/users/:id", + expectParam: nil, // FIXME: should be map[string]string{"id": "1/"}, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + + c := e.NewContext(nil, nil).(*context) + r.Find(http.MethodGet, tc.whenURL, c) + + c.handler(c) + assert.Equal(t, tc.expectRoute, c.Get("path")) + for param, expectedValue := range tc.expectParam { + assert.Equal(t, expectedValue, c.Param(param)) + } + checkUnusedParamValues(t, c, tc.expectParam) + }) + } } func TestRouterTwoParam(t *testing.T) { e := New() r := e.router - r.Add(http.MethodGet, "/users/:uid/files/:fid", func(Context) error { - return nil - }) + r.Add(http.MethodGet, "/users/:uid/files/:fid", handlerFunc) c := e.NewContext(nil, nil).(*context) r.Find(http.MethodGet, "/users/1/files/1", c) + assert.Equal(t, "1", c.Param("uid")) assert.Equal(t, "1", c.Param("fid")) } @@ -685,18 +733,279 @@ func TestRouterParamWithSlash(t *testing.T) { e := New() r := e.router - r.Add(http.MethodGet, "/a/:b/c/d/:e", func(c Context) error { - return nil - }) + r.Add(http.MethodGet, "/a/:b/c/d/:e", handlerFunc) + r.Add(http.MethodGet, "/a/:b/c/:d/:f", handlerFunc) - r.Add(http.MethodGet, "/a/:b/c/:d/:f", func(c Context) error { - return nil - }) + c := e.NewContext(nil, nil).(*context) + r.Find(http.MethodGet, "/a/1/c/d/2/3", c) // `2/3` should mapped to path `/a/:b/c/d/:e` and into `:e` + + err := c.handler(c) + assert.Equal(t, nil, c.Get("path")) // FIXME: should be "/a/:b/c/d/:e" + assert.EqualError(t, err, "code=404, message=Not Found") // FIXME: should be .NoError() +} + +// Issue #1754 - router needs to backtrack multiple levels upwards in tree to find the matching route +// route evaluation order +// +// Routes: +// 1) /a/:b/c +// 2) /a/c/d +// 3) /a/c/df +// +// 4) /a/*/f +// 5) /:e/c/f +// +// 6) /* +// +// Searching route for "/a/c/f" should match "/a/*/f" +// When route `4) /a/*/f` is not added then request for "/a/c/f" should match "/:e/c/f" +// +// +----------+ +// +-----+ "/" root +--------------------+--------------------------+ +// | +----------+ | | +// | | | +// +-------v-------+ +---v---------+ +-------v---+ +// | "a/" (static) +---------------+ | ":" (param) | | "*" (any) | +// +-+----------+--+ | +-----------+-+ +-----------+ +// | | | | +// +---------------v+ +-- ---v------+ +------v----+ +-----v-----------+ +// | "c/d" (static) | | ":" (param) | | "*" (any) | | "/c/f" (static) | +// +---------+------+ +--------+----+ +----------++ +-----------------+ +// | | | +// | | | +// +---------v----+ +------v--------+ +------v--------+ +// | "f" (static) | | "/c" (static) | | "/f" (static) | +// +--------------+ +---------------+ +---------------+ +func TestRouteMultiLevelBacktracking(t *testing.T) { + var testCases = []struct { + name string + whenURL string + expectRoute interface{} + expectParam map[string]string + }{ + { + name: "route /a/c/df to /a/c/df", + whenURL: "/a/c/df", + expectRoute: "/a/c/df", + }, + { + name: "route /a/x/df to /a/:b/c", + whenURL: "/a/x/c", + expectRoute: "/a/:b/c", + expectParam: map[string]string{"b": "x"}, + }, + { + name: "route /a/x/f to /a/*/f", + whenURL: "/a/x/f", + expectRoute: "/a/*/f", + expectParam: map[string]string{"*": "x/f"}, // NOTE: `x` would be probably more suitable + }, + { + name: "route /b/c/f to /:e/c/f", + whenURL: "/b/c/f", + expectRoute: "/:e/c/f", + expectParam: map[string]string{"e": "b"}, + }, + { + name: "route /b/c/c to /*", + whenURL: "/b/c/c", + expectRoute: "/*", + expectParam: map[string]string{"*": "b/c/c"}, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + e := New() + r := e.router + + r.Add(http.MethodGet, "/a/:b/c", handlerHelper("case", 1)) + r.Add(http.MethodGet, "/a/c/d", handlerHelper("case", 2)) + r.Add(http.MethodGet, "/a/c/df", handlerHelper("case", 3)) + r.Add(http.MethodGet, "/a/*/f", handlerHelper("case", 4)) + r.Add(http.MethodGet, "/:e/c/f", handlerHelper("case", 5)) + r.Add(http.MethodGet, "/*", handlerHelper("case", 6)) + + c := e.NewContext(nil, nil).(*context) + r.Find(http.MethodGet, tc.whenURL, c) + + c.handler(c) + assert.Equal(t, tc.expectRoute, c.Get("path")) + for param, expectedValue := range tc.expectParam { + assert.Equal(t, expectedValue, c.Param(param)) + } + checkUnusedParamValues(t, c, tc.expectParam) + }) + } +} + +// Issue #1754 - router needs to backtrack multiple levels upwards in tree to find the matching route +// route evaluation order +// +// Request for "/a/c/f" should match "/:e/c/f" +// +// +-0,7--------+ +// | "/" (root) |----------------------------------+ +// +------------+ | +// | | | +// | | | +// +-1,6-----------+ | | +-8-----------+ +------v----+ +// | "a/" (static) +<--+ +--------->+ ":" (param) | | "*" (any) | +// +---------------+ +-------------+ +-----------+ +// | | | +// +-2--------v-----+ +v-3,5--------+ +-9------v--------+ +// | "c/d" (static) | | ":" (param) | | "/c/f" (static) | +// +----------------+ +-------------+ +-----------------+ +// | +// +-4--v----------+ +// | "/c" (static) | +// +---------------+ +func TestRouteMultiLevelBacktracking2(t *testing.T) { + e := New() + r := e.router + + r.Add(http.MethodGet, "/a/:b/c", handlerFunc) + r.Add(http.MethodGet, "/a/c/d", handlerFunc) + r.Add(http.MethodGet, "/a/c/df", handlerFunc) + r.Add(http.MethodGet, "/:e/c/f", handlerFunc) + r.Add(http.MethodGet, "/*", handlerFunc) + + var testCases = []struct { + name string + whenURL string + expectRoute string + expectParam map[string]string + }{ + { + name: "route /a/c/df to /a/c/df", + whenURL: "/a/c/df", + expectRoute: "/a/c/df", + }, + { + name: "route /a/x/df to /a/:b/c", + whenURL: "/a/x/c", + expectRoute: "/a/:b/c", + expectParam: map[string]string{"b": "x"}, + }, + { + name: "route /a/c/f to /:e/c/f", + whenURL: "/a/c/f", + expectRoute: "/:e/c/f", + expectParam: map[string]string{"e": "a"}, + }, + { + name: "route /b/c/f to /:e/c/f", + whenURL: "/b/c/f", + expectRoute: "/:e/c/f", + expectParam: map[string]string{"e": "b"}, + }, + { + name: "route /b/c/c to /*", + whenURL: "/b/c/c", + expectRoute: "/*", + expectParam: map[string]string{"*": "b/c/c"}, + }, + { // this traverses `/a/:b/c` and `/:e/c/f` branches and eventually backtracks to `/*` + name: "route /a/c/cf to /*", + whenURL: "/a/c/cf", + expectRoute: "/*", + expectParam: map[string]string{"*": "a/c/cf"}, + }, + { + name: "route /anyMatch to /*", + whenURL: "/anyMatch", + expectRoute: "/*", + expectParam: map[string]string{"*": "anyMatch"}, + }, + { + name: "route /anyMatch/withSlash to /*", + whenURL: "/anyMatch/withSlash", + expectRoute: "/*", + expectParam: map[string]string{"*": "anyMatch/withSlash"}, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + c := e.NewContext(nil, nil).(*context) + + r.Find(http.MethodGet, tc.whenURL, c) + + c.handler(c) + assert.Equal(t, tc.expectRoute, c.Get("path")) + for param, expectedValue := range tc.expectParam { + assert.Equal(t, expectedValue, c.Param(param)) + } + checkUnusedParamValues(t, c, tc.expectParam) + }) + } +} + +func TestRouterBacktrackingFromMultipleParamKinds(t *testing.T) { + e := New() + r := e.router + + r.Add(http.MethodGet, "/*", handlerFunc) // this can match only path that does not have slash in it + r.Add(http.MethodGet, "/:1/second", handlerFunc) + r.Add(http.MethodGet, "/:1/:2", handlerFunc) // this acts as match ANY for all routes that have at least one slash + r.Add(http.MethodGet, "/:1/:2/third", handlerFunc) + r.Add(http.MethodGet, "/:1/:2/:3/fourth", handlerFunc) + r.Add(http.MethodGet, "/:1/:2/:3/:4/fifth", handlerFunc) c := e.NewContext(nil, nil).(*context) - assert.NotPanics(t, func() { - r.Find(http.MethodGet, "/a/1/c/d/2/3", c) - }) + var testCases = []struct { + name string + whenURL string + expectRoute string + expectParam map[string]string + }{ + { + name: "route /first to /*", + whenURL: "/first", + expectRoute: "/*", + expectParam: map[string]string{"*": "first"}, + }, + { + name: "route /first/second to /:1/second", + whenURL: "/first/second", + expectRoute: "/:1/second", + expectParam: map[string]string{"1": "first"}, + }, + { + name: "route /first/second-new to /:1/:2", + whenURL: "/first/second-new", + expectRoute: "/:1/:2", + expectParam: map[string]string{ + "1": "first", + "2": "second-new", + }, + }, + { // FIXME: should match `/:1/:2` when backtracking in tree. this 1 level backtracking fails even with old implementation + name: "route /first/second/ to /:1/:2", + whenURL: "/first/second/", + expectRoute: "/*", // "/:1/:2", + expectParam: map[string]string{"*": "first/second/"}, // map[string]string{"1": "first", "2": "second/"}, + }, + { // FIXME: should match `/:1/:2`. same backtracking problem. when backtracking is at `/:1/:2` during backtracking this node should be match as it has executable handler + name: "route /first/second/third/fourth/fifth/nope to /:1/:2", + whenURL: "/first/second/third/fourth/fifth/nope", + expectRoute: "/*", // "/:1/:2", + expectParam: map[string]string{"*": "first/second/third/fourth/fifth/nope"}, // map[string]string{"1": "first", "2": "second/third/fourth/fifth/nope"}, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + r.Find(http.MethodGet, tc.whenURL, c) + + c.handler(c) + assert.Equal(t, tc.expectRoute, c.Get("path")) + for param, expectedValue := range tc.expectParam { + assert.Equal(t, expectedValue, c.Param(param)) + } + checkUnusedParamValues(t, c, tc.expectParam) + }) + } } // Issue #1509 @@ -713,16 +1022,37 @@ func TestRouterParamStaticConflict(t *testing.T) { g.GET("/status", handler) g.GET("/:name", handler) - c := e.NewContext(nil, nil).(*context) - r.Find(http.MethodGet, "/g/s", c) - c.handler(c) - assert.Equal(t, "s", c.Param("name")) - assert.Equal(t, "/g/:name", c.Get("path")) + var testCases = []struct { + whenURL string + expectRoute interface{} + expectParam map[string]string + }{ + { + whenURL: "/g/s", + expectRoute: "/g/:name", + expectParam: map[string]string{"name": "s"}, + }, + { + whenURL: "/g/status", + expectRoute: "/g/status", + expectParam: map[string]string{"name": ""}, + }, + } + for _, tc := range testCases { + t.Run(tc.whenURL, func(t *testing.T) { + c := e.NewContext(nil, nil).(*context) - c = e.NewContext(nil, nil).(*context) - r.Find(http.MethodGet, "/g/status", c) - c.handler(c) - assert.Equal(t, "/g/status", c.Get("path")) + r.Find(http.MethodGet, tc.whenURL, c) + err := c.handler(c) + + assert.NoError(t, err) + assert.Equal(t, tc.expectRoute, c.Get("path")) + for param, expectedValue := range tc.expectParam { + assert.Equal(t, expectedValue, c.Param(param)) + } + checkUnusedParamValues(t, c, tc.expectParam) + }) + } } func TestRouterMatchAny(t *testing.T) { @@ -730,28 +1060,46 @@ func TestRouterMatchAny(t *testing.T) { r := e.router // Routes - r.Add(http.MethodGet, "/", handlerHelper("case", 1)) - r.Add(http.MethodGet, "/*", handlerHelper("case", 2)) - r.Add(http.MethodGet, "/users/*", handlerHelper("case", 3)) - - c := e.NewContext(nil, nil).(*context) - r.Find(http.MethodGet, "/", c) - c.handler(c) - - assert.Equal(t, 1, c.Get("case")) - assert.Equal(t, "/", c.Get("path")) + r.Add(http.MethodGet, "/", handlerFunc) + r.Add(http.MethodGet, "/*", handlerFunc) + r.Add(http.MethodGet, "/users/*", handlerFunc) + + var testCases = []struct { + whenURL string + expectRoute interface{} + expectParam map[string]string + }{ + { + whenURL: "/", + expectRoute: "/", + expectParam: map[string]string{"*": ""}, + }, + { + whenURL: "/download", + expectRoute: "/*", + expectParam: map[string]string{"*": "download"}, + }, + { + whenURL: "/users/joe", + expectRoute: "/users/*", + expectParam: map[string]string{"*": "joe"}, + }, + } + for _, tc := range testCases { + t.Run(tc.whenURL, func(t *testing.T) { + c := e.NewContext(nil, nil).(*context) - r.Find(http.MethodGet, "/download", c) - c.handler(c) - assert.Equal(t, 2, c.Get("case")) - assert.Equal(t, "/*", c.Get("path")) - assert.Equal(t, "download", c.Param("*")) + r.Find(http.MethodGet, tc.whenURL, c) + err := c.handler(c) - r.Find(http.MethodGet, "/users/joe", c) - c.handler(c) - assert.Equal(t, 3, c.Get("case")) - assert.Equal(t, "/users/*", c.Get("path")) - assert.Equal(t, "joe", c.Param("*")) + assert.NoError(t, err) + assert.Equal(t, tc.expectRoute, c.Get("path")) + for param, expectedValue := range tc.expectParam { + assert.Equal(t, expectedValue, c.Param(param)) + } + checkUnusedParamValues(t, c, tc.expectParam) + }) + } } // NOTE: this is to document current implementation. Last added route with `*` asterisk is always the match and no @@ -796,155 +1144,53 @@ func TestRouterMatchAnyPrefixIssue(t *testing.T) { c.Set("path", c.Path()) return nil }) - c := e.NewContext(nil, nil).(*context) - r.Find(http.MethodGet, "/", c) - c.handler(c) - assert.Equal(t, "/*", c.Get("path")) - assert.Equal(t, "", c.Param("*")) - - r.Find(http.MethodGet, "/users", c) - c.handler(c) - assert.Equal(t, "/*", c.Get("path")) - assert.Equal(t, "users", c.Param("*")) - - r.Find(http.MethodGet, "/users/", c) - c.handler(c) - assert.Equal(t, "/users/*", c.Get("path")) - assert.Equal(t, "", c.Param("*")) - - r.Find(http.MethodGet, "/users_prefix", c) - c.handler(c) - assert.Equal(t, "/*", c.Get("path")) - assert.Equal(t, "users_prefix", c.Param("*")) - r.Find(http.MethodGet, "/users_prefix/", c) - c.handler(c) - assert.Equal(t, "/*", c.Get("path")) - assert.Equal(t, "users_prefix/", c.Param("*")) -} - -func TestRouteMultiLevelBacktracking(t *testing.T) { - e := New() - r := e.router - - r.Add(http.MethodGet, "/a/:b/c", handlerHelper("case", 1)) - r.Add(http.MethodGet, "/a/c/d", handlerHelper("case", 2)) - r.Add(http.MethodGet, "/:e/c/f", handlerHelper("case", 3)) - - c := e.NewContext(nil, nil).(*context) - r.Find(http.MethodGet, "/a/c/f", c) - - c.handler(c) - assert.Equal(t, 3, c.Get("case")) - assert.Equal(t, "/:e/c/f", c.Get("path")) -} - -// Issue # -func TestRouterBacktrackingFromParam(t *testing.T) { - e := New() - r := e.router - - r.Add(http.MethodGet, "/*", handlerHelper("case", 1)) - r.Add(http.MethodGet, "/users/:name/", handlerHelper("case", 2)) - - c := e.NewContext(nil, nil).(*context) - - r.Find(http.MethodGet, "/users/firstname/no-match", c) - c.handler(c) - assert.Equal(t, 1, c.Get("case")) - assert.Equal(t, "/*", c.Get("path")) - assert.Equal(t, "users/firstname/no-match", c.Param("*")) - - r.Find(http.MethodGet, "/users/firstname/", c) - c.handler(c) - assert.Equal(t, 2, c.Get("case")) - assert.Equal(t, "/users/:name/", c.Get("path")) - assert.Equal(t, "firstname", c.Param("name")) -} - -func TestRouterBacktrackingFromParamAny(t *testing.T) { - e := New() - r := e.router - - r.Add(http.MethodGet, "/*", handlerHelper("case", 1)) - r.Add(http.MethodGet, "/:name/lastname", handlerHelper("case", 2)) - - c := e.NewContext(nil, nil).(*context) - - r.Find(http.MethodGet, "/firstname/test", c) - c.handler(c) - assert.Equal(t, 1, c.Get("case")) - assert.Equal(t, "/*", c.Get("path")) - assert.Equal(t, "firstname/test", c.Param("*")) - - r.Find(http.MethodGet, "/firstname", c) - c.handler(c) - assert.Equal(t, 1, c.Get("case")) - assert.Equal(t, "/*", c.Get("path")) - assert.Equal(t, "firstname", c.Param("*")) - - r.Find(http.MethodGet, "/firstname/lastname", c) - c.handler(c) - assert.Equal(t, 2, c.Get("case")) - assert.Equal(t, "/:name/lastname", c.Get("path")) - assert.Equal(t, "firstname", c.Param("name")) -} - -func TestRouterBacktrackingFromParamAny2(t *testing.T) { - e := New() - r := e.router - - r.Add(http.MethodGet, "/*", handlerHelper("case", 1)) - r.Add(http.MethodGet, "/:name", handlerHelper("case", 2)) - r.Add(http.MethodGet, "/:name/lastname", handlerHelper("case", 3)) - - c := e.NewContext(nil, nil).(*context) - - r.Find(http.MethodGet, "/firstname/test", c) - c.handler(c) - assert.Equal(t, 1, c.Get("case")) - assert.Equal(t, "/*", c.Get("path")) - assert.Equal(t, "firstname/test", c.Param("*")) - - r.Find(http.MethodGet, "/firstname", c) - c.handler(c) - assert.Equal(t, 2, c.Get("case")) - assert.Equal(t, "/:name", c.Get("path")) - assert.Equal(t, "firstname", c.Param("name")) - - r.Find(http.MethodGet, "/firstname/lastname", c) - c.handler(c) - assert.Equal(t, 3, c.Get("case")) - assert.Equal(t, "/:name/lastname", c.Get("path")) - assert.Equal(t, "firstname", c.Param("name")) -} - -func TestRouterAnyCommonPath(t *testing.T) { - e := New() - r := e.router - - r.Add(http.MethodGet, "/ab*", handlerHelper("case", 1)) - r.Add(http.MethodGet, "/abcd", handlerHelper("case", 2)) - r.Add(http.MethodGet, "/abcd*", handlerHelper("case", 3)) - - c := e.NewContext(nil, nil).(*context) - - r.Find(http.MethodGet, "/abee", c) - c.handler(c) - assert.Equal(t, 1, c.Get("case")) - assert.Equal(t, "/ab*", c.Get("path")) - assert.Equal(t, "ee", c.Param("*")) + var testCases = []struct { + whenURL string + expectRoute interface{} + expectParam map[string]string + }{ + { + whenURL: "/", + expectRoute: "/*", + expectParam: map[string]string{"*": ""}, + }, + { + whenURL: "/users", + expectRoute: "/*", + expectParam: map[string]string{"*": "users"}, + }, + { + whenURL: "/users/", + expectRoute: "/users/*", + expectParam: map[string]string{"*": ""}, + }, + { + whenURL: "/users_prefix", + expectRoute: "/*", + expectParam: map[string]string{"*": "users_prefix"}, + }, + { + whenURL: "/users_prefix/", + expectRoute: "/*", + expectParam: map[string]string{"*": "users_prefix/"}, + }, + } + for _, tc := range testCases { + t.Run(tc.whenURL, func(t *testing.T) { + c := e.NewContext(nil, nil).(*context) - r.Find(http.MethodGet, "/abcd", c) - c.handler(c) - assert.Equal(t, "/abcd", c.Get("path")) - assert.Equal(t, 2, c.Get("case")) + r.Find(http.MethodGet, tc.whenURL, c) + err := c.handler(c) - r.Find(http.MethodGet, "/abcde", c) - c.handler(c) - assert.Equal(t, 3, c.Get("case")) - assert.Equal(t, "/abcd*", c.Get("path")) - assert.Equal(t, "e", c.Param("*")) + assert.NoError(t, err) + assert.Equal(t, tc.expectRoute, c.Get("path")) + for param, expectedValue := range tc.expectParam { + assert.Equal(t, expectedValue, c.Param(param)) + } + checkUnusedParamValues(t, c, tc.expectParam) + }) + } } // TestRouterMatchAnySlash shall verify finding the best route @@ -953,168 +1199,226 @@ func TestRouterMatchAnySlash(t *testing.T) { e := New() r := e.router - handler := func(c Context) error { - c.Set("path", c.Path()) - return nil - } - // Routes - r.Add(http.MethodGet, "/users", handler) - r.Add(http.MethodGet, "/users/*", handler) - r.Add(http.MethodGet, "/img/*", handler) - r.Add(http.MethodGet, "/img/load", handler) - r.Add(http.MethodGet, "/img/load/*", handler) - r.Add(http.MethodGet, "/assets/*", handler) - - c := e.NewContext(nil, nil).(*context) - r.Find(http.MethodGet, "/", c) - assert.Equal(t, "", c.Param("*")) - - // Test trailing slash request for simple any route (see #1526) - c = e.NewContext(nil, nil).(*context) - r.Find(http.MethodGet, "/users/", c) - c.handler(c) - assert.Equal(t, "/users/*", c.Get("path")) - assert.Equal(t, "", c.Param("*")) - - c = e.NewContext(nil, nil).(*context) - r.Find(http.MethodGet, "/users/joe", c) - c.handler(c) - assert.Equal(t, "/users/*", c.Get("path")) - assert.Equal(t, "joe", c.Param("*")) - - // Test trailing slash request for nested any route (see #1526) - c = e.NewContext(nil, nil).(*context) - r.Find(http.MethodGet, "/img/load", c) - c.handler(c) - assert.Equal(t, "/img/load", c.Get("path")) - assert.Equal(t, "", c.Param("*")) - - c = e.NewContext(nil, nil).(*context) - r.Find(http.MethodGet, "/img/load/", c) - c.handler(c) - assert.Equal(t, "/img/load/*", c.Get("path")) - assert.Equal(t, "", c.Param("*")) - - c = e.NewContext(nil, nil).(*context) - r.Find(http.MethodGet, "/img/load/ben", c) - c.handler(c) - assert.Equal(t, "/img/load/*", c.Get("path")) - assert.Equal(t, "ben", c.Param("*")) - - // Test /assets/* any route - // ... without trailing slash must not match - c = e.NewContext(nil, nil).(*context) - r.Find(http.MethodGet, "/assets", c) - c.handler(c) - assert.Equal(t, nil, c.Get("path")) - assert.Equal(t, "", c.Param("*")) + r.Add(http.MethodGet, "/users", handlerFunc) + r.Add(http.MethodGet, "/users/*", handlerFunc) + r.Add(http.MethodGet, "/img/*", handlerFunc) + r.Add(http.MethodGet, "/img/load", handlerFunc) + r.Add(http.MethodGet, "/img/load/*", handlerFunc) + r.Add(http.MethodGet, "/assets/*", handlerFunc) + + var testCases = []struct { + whenURL string + expectRoute interface{} + expectParam map[string]string + expectError error + }{ + { + whenURL: "/", + expectRoute: nil, + expectParam: map[string]string{"*": ""}, + expectError: ErrNotFound, + }, + { // Test trailing slash request for simple any route (see #1526) + whenURL: "/users/", + expectRoute: "/users/*", + expectParam: map[string]string{"*": ""}, + }, + { + whenURL: "/users/joe", + expectRoute: "/users/*", + expectParam: map[string]string{"*": "joe"}, + }, + // Test trailing slash request for nested any route (see #1526) + { + whenURL: "/img/load", + expectRoute: "/img/load", + expectParam: map[string]string{"*": ""}, + }, + { + whenURL: "/img/load/", + expectRoute: "/img/load/*", + expectParam: map[string]string{"*": ""}, + }, + { + whenURL: "/img/load/ben", + expectRoute: "/img/load/*", + expectParam: map[string]string{"*": "ben"}, + }, + // Test /assets/* any route + { // ... without trailing slash must not match + whenURL: "/assets", + expectRoute: nil, + expectParam: map[string]string{"*": ""}, + expectError: ErrNotFound, + }, + + { // ... with trailing slash must match + whenURL: "/assets/", + expectRoute: "/assets/*", + expectParam: map[string]string{"*": ""}, + }, + } + for _, tc := range testCases { + t.Run(tc.whenURL, func(t *testing.T) { + c := e.NewContext(nil, nil).(*context) - // ... with trailing slash must match - c = e.NewContext(nil, nil).(*context) - r.Find(http.MethodGet, "/assets/", c) - c.handler(c) - assert.Equal(t, "/assets/*", c.Get("path")) - assert.Equal(t, "", c.Param("*")) + r.Find(http.MethodGet, tc.whenURL, c) + err := c.handler(c) + if tc.expectError != nil { + assert.Equal(t, tc.expectError, err) + } else { + assert.NoError(t, err) + } + assert.Equal(t, tc.expectRoute, c.Get("path")) + for param, expectedValue := range tc.expectParam { + assert.Equal(t, expectedValue, c.Param(param)) + } + checkUnusedParamValues(t, c, tc.expectParam) + }) + } } func TestRouterMatchAnyMultiLevel(t *testing.T) { e := New() r := e.router - handler := func(c Context) error { - c.Set("path", c.Path()) - return nil - } // Routes - r.Add(http.MethodGet, "/api/users/jack", handler) - r.Add(http.MethodGet, "/api/users/jill", handler) - r.Add(http.MethodGet, "/api/users/*", handler) - r.Add(http.MethodGet, "/api/*", handler) - r.Add(http.MethodGet, "/other/*", handler) - r.Add(http.MethodGet, "/*", handler) - - c := e.NewContext(nil, nil).(*context) - r.Find(http.MethodGet, "/api/users/jack", c) - c.handler(c) - assert.Equal(t, "/api/users/jack", c.Get("path")) - assert.Equal(t, "", c.Param("*")) - - r.Find(http.MethodGet, "/api/users/jill", c) - c.handler(c) - assert.Equal(t, "/api/users/jill", c.Get("path")) - assert.Equal(t, "", c.Param("*")) - - r.Find(http.MethodGet, "/api/users/joe", c) - c.handler(c) - assert.Equal(t, "/api/users/*", c.Get("path")) - assert.Equal(t, "joe", c.Param("*")) - - r.Find(http.MethodGet, "/api/nousers/joe", c) - c.handler(c) - assert.Equal(t, "/api/*", c.Get("path")) - assert.Equal(t, "nousers/joe", c.Param("*")) + r.Add(http.MethodGet, "/api/users/jack", handlerFunc) + r.Add(http.MethodGet, "/api/users/jill", handlerFunc) + r.Add(http.MethodGet, "/api/users/*", handlerFunc) + r.Add(http.MethodGet, "/api/*", handlerFunc) + r.Add(http.MethodGet, "/other/*", handlerFunc) + r.Add(http.MethodGet, "/*", handlerFunc) + + var testCases = []struct { + whenURL string + expectRoute interface{} + expectParam map[string]string + expectError error + }{ + { + whenURL: "/api/users/jack", + expectRoute: "/api/users/jack", + expectParam: map[string]string{"*": ""}, + }, + { + whenURL: "/api/users/jill", + expectRoute: "/api/users/jill", + expectParam: map[string]string{"*": ""}, + }, + { + whenURL: "/api/users/joe", + expectRoute: "/api/users/*", + expectParam: map[string]string{"*": "joe"}, + }, + { + whenURL: "/api/nousers/joe", + expectRoute: "/api/*", + expectParam: map[string]string{"*": "nousers/joe"}, + }, + { + whenURL: "/api/none", + expectRoute: "/api/*", + expectParam: map[string]string{"*": "none"}, + }, + { + whenURL: "/api/none", + expectRoute: "/api/*", + expectParam: map[string]string{"*": "none"}, + }, + { + whenURL: "/noapi/users/jim", + expectRoute: "/*", + expectParam: map[string]string{"*": "noapi/users/jim"}, + }, + } + for _, tc := range testCases { + t.Run(tc.whenURL, func(t *testing.T) { + c := e.NewContext(nil, nil).(*context) - r.Find(http.MethodGet, "/api/none", c) - c.handler(c) - assert.Equal(t, "/api/*", c.Get("path")) - assert.Equal(t, "none", c.Param("*")) + r.Find(http.MethodGet, tc.whenURL, c) + err := c.handler(c) - r.Find(http.MethodGet, "/noapi/users/jim", c) - c.handler(c) - assert.Equal(t, "/*", c.Get("path")) - assert.Equal(t, "noapi/users/jim", c.Param("*")) + if tc.expectError != nil { + assert.Equal(t, tc.expectError, err) + } else { + assert.NoError(t, err) + } + assert.Equal(t, tc.expectRoute, c.Get("path")) + for param, expectedValue := range tc.expectParam { + assert.Equal(t, expectedValue, c.Param(param)) + } + checkUnusedParamValues(t, c, tc.expectParam) + }) + } } func TestRouterMatchAnyMultiLevelWithPost(t *testing.T) { e := New() r := e.router - handler := func(c Context) error { - c.Set("path", c.Path()) - return nil - } // Routes - e.POST("/api/auth/login", handler) - e.POST("/api/auth/forgotPassword", handler) - e.Any("/api/*", handler) - e.Any("/*", handler) - - // POST /api/auth/login shall choose login method - c := e.NewContext(nil, nil).(*context) - r.Find(http.MethodPost, "/api/auth/login", c) - c.handler(c) - assert.Equal(t, "/api/auth/login", c.Get("path")) - assert.Equal(t, "", c.Param("*")) - - // GET /api/auth/login shall choose any route - // c = e.NewContext(nil, nil).(*context) - // r.Find(http.MethodGet, "/api/auth/login", c) - // c.handler(c) - // assert.Equal(t, "/api/*", c.Get("path")) - // assert.Equal(t, "auth/login", c.Param("*")) - - // POST /api/auth/logout shall choose nearest any route - c = e.NewContext(nil, nil).(*context) - r.Find(http.MethodPost, "/api/auth/logout", c) - c.handler(c) - assert.Equal(t, "/api/*", c.Get("path")) - assert.Equal(t, "auth/logout", c.Param("*")) - - // POST to /api/other/test shall choose nearest any route - c = e.NewContext(nil, nil).(*context) - r.Find(http.MethodPost, "/api/other/test", c) - c.handler(c) - assert.Equal(t, "/api/*", c.Get("path")) - assert.Equal(t, "other/test", c.Param("*")) + e.POST("/api/auth/login", handlerFunc) + e.POST("/api/auth/forgotPassword", handlerFunc) + e.Any("/api/*", handlerFunc) + e.Any("/*", handlerFunc) + + var testCases = []struct { + whenMethod string + whenURL string + expectRoute interface{} + expectParam map[string]string + expectError error + }{ + { // POST /api/auth/login shall choose login method + whenURL: "/api/auth/login", + whenMethod: http.MethodPost, + expectRoute: "/api/auth/login", + expectParam: map[string]string{"*": ""}, + }, + { // POST /api/auth/logout shall choose nearest any route + whenURL: "/api/auth/logout", + whenMethod: http.MethodPost, + expectRoute: "/api/*", + expectParam: map[string]string{"*": "auth/logout"}, + }, + { // POST to /api/other/test shall choose nearest any route + whenURL: "/api/other/test", + whenMethod: http.MethodPost, + expectRoute: "/api/*", + expectParam: map[string]string{"*": "other/test"}, + }, + { // GET to /api/other/test shall choose nearest any route + whenURL: "/api/other/test", + whenMethod: http.MethodGet, + expectRoute: "/api/*", + expectParam: map[string]string{"*": "other/test"}, + }, + } + for _, tc := range testCases { + t.Run(tc.whenURL, func(t *testing.T) { + c := e.NewContext(nil, nil).(*context) - // GET to /api/other/test shall choose nearest any route - c = e.NewContext(nil, nil).(*context) - r.Find(http.MethodGet, "/api/other/test", c) - c.handler(c) - assert.Equal(t, "/api/*", c.Get("path")) - assert.Equal(t, "other/test", c.Param("*")) + method := http.MethodGet + if tc.whenMethod != "" { + method = tc.whenMethod + } + r.Find(method, tc.whenURL, c) + err := c.handler(c) + if tc.expectError != nil { + assert.Equal(t, tc.expectError, err) + } else { + assert.NoError(t, err) + } + assert.Equal(t, tc.expectRoute, c.Get("path")) + for param, expectedValue := range tc.expectParam { + assert.Equal(t, expectedValue, c.Param(param)) + } + checkUnusedParamValues(t, c, tc.expectParam) + }) + } } func TestRouterMicroParam(t *testing.T) { @@ -1150,29 +1454,56 @@ func TestRouterMultiRoute(t *testing.T) { r := e.router // Routes - r.Add(http.MethodGet, "/users", func(c Context) error { - c.Set("path", "/users") - return nil - }) - r.Add(http.MethodGet, "/users/:id", func(c Context) error { - return nil - }) - c := e.NewContext(nil, nil).(*context) - - // Route > /users - r.Find(http.MethodGet, "/users", c) - c.handler(c) - assert.Equal(t, "/users", c.Get("path")) + r.Add(http.MethodGet, "/users", handlerFunc) + r.Add(http.MethodGet, "/users/:id", handlerFunc) + + var testCases = []struct { + whenMethod string + whenURL string + expectRoute interface{} + expectParam map[string]string + expectError error + }{ + { + whenURL: "/users", + expectRoute: "/users", + expectParam: map[string]string{"*": ""}, + }, + { + whenURL: "/users/1", + expectRoute: "/users/:id", + expectParam: map[string]string{"id": "1"}, + }, + { + whenURL: "/user", + expectRoute: nil, + expectParam: map[string]string{"*": ""}, + expectError: ErrNotFound, + }, + } + for _, tc := range testCases { + t.Run(tc.whenURL, func(t *testing.T) { + c := e.NewContext(nil, nil).(*context) - // Route > /users/:id - r.Find(http.MethodGet, "/users/1", c) - assert.Equal(t, "1", c.Param("id")) + method := http.MethodGet + if tc.whenMethod != "" { + method = tc.whenMethod + } + r.Find(method, tc.whenURL, c) + err := c.handler(c) - // Route > /user - c = e.NewContext(nil, nil).(*context) - r.Find(http.MethodGet, "/user", c) - he := c.handler(c).(*HTTPError) - assert.Equal(t, http.StatusNotFound, he.Code) + if tc.expectError != nil { + assert.Equal(t, tc.expectError, err) + } else { + assert.NoError(t, err) + } + assert.Equal(t, tc.expectRoute, c.Get("path")) + for param, expectedValue := range tc.expectParam { + assert.Equal(t, expectedValue, c.Param(param)) + } + checkUnusedParamValues(t, c, tc.expectParam) + }) + } } func TestRouterPriority(t *testing.T) { @@ -1180,123 +1511,112 @@ func TestRouterPriority(t *testing.T) { r := e.router // Routes - r.Add(http.MethodGet, "/users", handlerHelper("a", 1)) - r.Add(http.MethodGet, "/users/new", handlerHelper("b", 2)) - r.Add(http.MethodGet, "/users/:id", handlerHelper("c", 3)) - r.Add(http.MethodGet, "/users/dew", handlerHelper("d", 4)) - r.Add(http.MethodGet, "/users/:id/files", handlerHelper("e", 5)) - r.Add(http.MethodGet, "/users/newsee", handlerHelper("f", 6)) - r.Add(http.MethodGet, "/users/*", handlerHelper("g", 7)) - r.Add(http.MethodGet, "/users/new/*", handlerHelper("h", 8)) - r.Add(http.MethodGet, "/*", handlerHelper("i", 9)) - - // Route > /users - c := e.NewContext(nil, nil).(*context) - r.Find(http.MethodGet, "/users", c) - c.handler(c) - assert.Equal(t, 1, c.Get("a")) - assert.Equal(t, "/users", c.Get("path")) - - // Route > /users/new - c = e.NewContext(nil, nil).(*context) - r.Find(http.MethodGet, "/users/new", c) - c.handler(c) - assert.Equal(t, 2, c.Get("b")) - assert.Equal(t, "/users/new", c.Get("path")) - - // Route > /users/:id - c = e.NewContext(nil, nil).(*context) - r.Find(http.MethodGet, "/users/1", c) - c.handler(c) - assert.Equal(t, 3, c.Get("c")) - assert.Equal(t, "/users/:id", c.Get("path")) - - // Route > /users/dew - c = e.NewContext(nil, nil).(*context) - r.Find(http.MethodGet, "/users/dew", c) - c.handler(c) - assert.Equal(t, 4, c.Get("d")) - assert.Equal(t, "/users/dew", c.Get("path")) - - // Route > /users/:id/files - c = e.NewContext(nil, nil).(*context) - r.Find(http.MethodGet, "/users/1/files", c) - c.handler(c) - assert.Equal(t, 5, c.Get("e")) - assert.Equal(t, "/users/:id/files", c.Get("path")) - - // Route > /users/:id - c = e.NewContext(nil, nil).(*context) - r.Find(http.MethodGet, "/users/news", c) - c.handler(c) - assert.Equal(t, 3, c.Get("c")) - assert.Equal(t, "/users/:id", c.Get("path")) - - // Route > /users/newsee - c = e.NewContext(nil, nil).(*context) - r.Find(http.MethodGet, "/users/newsee", c) - c.handler(c) - assert.Equal(t, 6, c.Get("f")) - assert.Equal(t, "/users/newsee", c.Get("path")) - - // Route > /users/newsee - r.Find(http.MethodGet, "/users/newsee", c) - c.handler(c) - assert.Equal(t, 6, c.Get("f")) - - // Route > /users/newsee - r.Find(http.MethodGet, "/users/newsee", c) - c.handler(c) - assert.Equal(t, 6, c.Get("f")) - - // Route > /users/* - c = e.NewContext(nil, nil).(*context) - r.Find(http.MethodGet, "/users/joe/books", c) - c.handler(c) - assert.Equal(t, 7, c.Get("g")) - assert.Equal(t, "/users/*", c.Get("path")) - assert.Equal(t, "joe/books", c.Param("*")) - - // Route > /users/new/* should be matched - c = e.NewContext(nil, nil).(*context) - r.Find(http.MethodGet, "/users/new/someone", c) - c.handler(c) - assert.Equal(t, 8, c.Get("h")) - assert.Equal(t, "/users/new/*", c.Get("path")) - assert.Equal(t, "someone", c.Param("*")) - - // Route > /users/* should be matched although /users/dew exists - c = e.NewContext(nil, nil).(*context) - r.Find(http.MethodGet, "/users/dew/someone", c) - c.handler(c) - assert.Equal(t, 7, c.Get("g")) - assert.Equal(t, "/users/*", c.Get("path")) - - assert.Equal(t, "dew/someone", c.Param("*")) - - // Route > /users/* should be matched although /users/dew exists - c = e.NewContext(nil, nil).(*context) - r.Find(http.MethodGet, "/users/notexists/someone", c) - c.handler(c) - assert.Equal(t, 7, c.Get("g")) - assert.Equal(t, "/users/*", c.Get("path")) - assert.Equal(t, "notexists/someone", c.Param("*")) + r.Add(http.MethodGet, "/users", handlerFunc) + r.Add(http.MethodGet, "/users/new", handlerFunc) + r.Add(http.MethodGet, "/users/:id", handlerFunc) + r.Add(http.MethodGet, "/users/dew", handlerFunc) + r.Add(http.MethodGet, "/users/:id/files", handlerFunc) + r.Add(http.MethodGet, "/users/newsee", handlerFunc) + r.Add(http.MethodGet, "/users/*", handlerFunc) + r.Add(http.MethodGet, "/users/new/*", handlerFunc) + r.Add(http.MethodGet, "/*", handlerFunc) + + var testCases = []struct { + whenMethod string + whenURL string + expectRoute interface{} + expectParam map[string]string + expectError error + }{ + { + whenURL: "/users", + expectRoute: "/users", + }, + { + whenURL: "/users/new", + expectRoute: "/users/new", + }, + { + whenURL: "/users/1", + expectRoute: "/users/:id", + expectParam: map[string]string{"id": "1"}, + }, + { + whenURL: "/users/dew", + expectRoute: "/users/dew", + }, + { + whenURL: "/users/1/files", + expectRoute: "/users/:id/files", + expectParam: map[string]string{"id": "1"}, + }, + { + whenURL: "/users/new", + expectRoute: "/users/new", + }, + { + whenURL: "/users/news", + expectRoute: "/users/:id", + expectParam: map[string]string{"id": "news"}, + }, + { + whenURL: "/users/newsee", + expectRoute: "/users/newsee", + }, + { + whenURL: "/users/joe/books", + expectRoute: "/users/*", + expectParam: map[string]string{"*": "joe/books"}, + }, + { + whenURL: "/users/new/someone", + expectRoute: "/users/new/*", + expectParam: map[string]string{"*": "someone"}, + }, + { + whenURL: "/users/dew/someone", + expectRoute: "/users/*", + expectParam: map[string]string{"*": "dew/someone"}, + }, + { // Route > /users/* should be matched although /users/dew exists + whenURL: "/users/notexists/someone", + expectRoute: "/users/*", + expectParam: map[string]string{"*": "notexists/someone"}, + }, + { + whenURL: "/nousers", + expectRoute: "/*", + expectParam: map[string]string{"*": "nousers"}, + }, + { + whenURL: "/nousers/new", + expectRoute: "/*", + expectParam: map[string]string{"*": "nousers/new"}, + }, + } + for _, tc := range testCases { + t.Run(tc.whenURL, func(t *testing.T) { + c := e.NewContext(nil, nil).(*context) - // Route > * - c = e.NewContext(nil, nil).(*context) - r.Find(http.MethodGet, "/nousers", c) - c.handler(c) - assert.Equal(t, 9, c.Get("i")) - assert.Equal(t, "/*", c.Get("path")) - assert.Equal(t, "nousers", c.Param("*")) + method := http.MethodGet + if tc.whenMethod != "" { + method = tc.whenMethod + } + r.Find(method, tc.whenURL, c) + err := c.handler(c) - // Route > * - c = e.NewContext(nil, nil).(*context) - r.Find(http.MethodGet, "/nousers/new", c) - c.handler(c) - assert.Equal(t, 9, c.Get("i")) - assert.Equal(t, "/*", c.Get("path")) - assert.Equal(t, "nousers/new", c.Param("*")) + if tc.expectError != nil { + assert.Equal(t, tc.expectError, err) + } else { + assert.NoError(t, err) + } + assert.Equal(t, tc.expectRoute, c.Get("path")) + for param, expectedValue := range tc.expectParam { + assert.Equal(t, expectedValue, c.Param(param)) + } + checkUnusedParamValues(t, c, tc.expectParam) + }) + } } func TestRouterIssue1348(t *testing.T) { @@ -1315,31 +1635,55 @@ func TestRouterIssue1348(t *testing.T) { func TestRouterPriorityNotFound(t *testing.T) { e := New() r := e.router - c := e.NewContext(nil, nil).(*context) // Add - r.Add(http.MethodGet, "/a/foo", func(c Context) error { - c.Set("a", 1) - return nil - }) - r.Add(http.MethodGet, "/a/bar", func(c Context) error { - c.Set("b", 2) - return nil - }) - - // Find - r.Find(http.MethodGet, "/a/foo", c) - c.handler(c) - assert.Equal(t, 1, c.Get("a")) + r.Add(http.MethodGet, "/a/foo", handlerFunc) + r.Add(http.MethodGet, "/a/bar", handlerFunc) + + var testCases = []struct { + whenMethod string + whenURL string + expectRoute interface{} + expectParam map[string]string + expectError error + }{ + { + whenURL: "/a/foo", + expectRoute: "/a/foo", + }, + { + whenURL: "/a/bar", + expectRoute: "/a/bar", + }, + { + whenURL: "/abc/def", + expectRoute: nil, + expectError: ErrNotFound, + }, + } + for _, tc := range testCases { + t.Run(tc.whenURL, func(t *testing.T) { + c := e.NewContext(nil, nil).(*context) - r.Find(http.MethodGet, "/a/bar", c) - c.handler(c) - assert.Equal(t, 2, c.Get("b")) + method := http.MethodGet + if tc.whenMethod != "" { + method = tc.whenMethod + } + r.Find(method, tc.whenURL, c) + err := c.handler(c) - c = e.NewContext(nil, nil).(*context) - r.Find(http.MethodGet, "/abc/def", c) - he := c.handler(c).(*HTTPError) - assert.Equal(t, http.StatusNotFound, he.Code) + if tc.expectError != nil { + assert.Equal(t, tc.expectError, err) + } else { + assert.NoError(t, err) + } + assert.Equal(t, tc.expectRoute, c.Get("path")) + for param, expectedValue := range tc.expectParam { + assert.Equal(t, expectedValue, c.Param(param)) + } + checkUnusedParamValues(t, c, tc.expectParam) + }) + } } func TestRouterParamNames(t *testing.T) { @@ -1347,34 +1691,58 @@ func TestRouterParamNames(t *testing.T) { r := e.router // Routes - r.Add(http.MethodGet, "/users", func(c Context) error { - c.Set("path", "/users") - return nil - }) - r.Add(http.MethodGet, "/users/:id", func(c Context) error { - return nil - }) - r.Add(http.MethodGet, "/users/:uid/files/:fid", func(c Context) error { - return nil - }) - c := e.NewContext(nil, nil).(*context) - - // Route > /users - r.Find(http.MethodGet, "/users", c) - c.handler(c) - assert.Equal(t, "/users", c.Get("path")) + r.Add(http.MethodGet, "/users", handlerFunc) + r.Add(http.MethodGet, "/users/:id", handlerFunc) + r.Add(http.MethodGet, "/users/:uid/files/:fid", handlerFunc) + + var testCases = []struct { + whenMethod string + whenURL string + expectRoute interface{} + expectParam map[string]string + expectError error + }{ + { + whenURL: "/users", + expectRoute: "/users", + }, + { + whenURL: "/users/1", + expectRoute: "/users/:id", + expectParam: map[string]string{"id": "1"}, + }, + { + whenURL: "/users/1/files/1", + expectRoute: "/users/:uid/files/:fid", + expectParam: map[string]string{ + "uid": "1", + "fid": "1", + }, + }, + } + for _, tc := range testCases { + t.Run(tc.whenURL, func(t *testing.T) { + c := e.NewContext(nil, nil).(*context) - // Route > /users/:id - r.Find(http.MethodGet, "/users/1", c) - assert.Equal(t, "id", c.pnames[0]) - assert.Equal(t, "1", c.Param("id")) + method := http.MethodGet + if tc.whenMethod != "" { + method = tc.whenMethod + } + r.Find(method, tc.whenURL, c) + err := c.handler(c) - // Route > /users/:uid/files/:fid - r.Find(http.MethodGet, "/users/1/files/1", c) - assert.Equal(t, "uid", c.pnames[0]) - assert.Equal(t, "1", c.Param("uid")) - assert.Equal(t, "fid", c.pnames[1]) - assert.Equal(t, "1", c.Param("fid")) + if tc.expectError != nil { + assert.Equal(t, tc.expectError, err) + } else { + assert.NoError(t, err) + } + assert.Equal(t, tc.expectRoute, c.Get("path")) + for param, expectedValue := range tc.expectParam { + assert.Equal(t, expectedValue, c.Param(param)) + } + checkUnusedParamValues(t, c, tc.expectParam) + }) + } } // Issue #623 and #1406 @@ -1389,47 +1757,69 @@ func TestRouterStaticDynamicConflict(t *testing.T) { r.Add(http.MethodGet, "/server", handlerHelper("c", 3)) r.Add(http.MethodGet, "/", handlerHelper("f", 6)) - c := e.NewContext(nil, nil) - r.Find(http.MethodGet, "/dictionary/skills", c) - c.Handler()(c) - assert.Equal(t, 1, c.Get("a")) - assert.Equal(t, "/dictionary/skills", c.Get("path")) - - c = e.NewContext(nil, nil) - r.Find(http.MethodGet, "/dictionary/skillsnot", c) - c.Handler()(c) - assert.Equal(t, 2, c.Get("b")) - assert.Equal(t, "/dictionary/:name", c.Get("path")) - - c = e.NewContext(nil, nil) - r.Find(http.MethodGet, "/dictionary/type", c) - c.Handler()(c) - assert.Equal(t, 2, c.Get("b")) - assert.Equal(t, "/dictionary/:name", c.Get("path")) - - c = e.NewContext(nil, nil) - r.Find(http.MethodGet, "/server", c) - c.Handler()(c) - assert.Equal(t, 3, c.Get("c")) - assert.Equal(t, "/server", c.Get("path")) - - c = e.NewContext(nil, nil) - r.Find(http.MethodGet, "/users/new", c) - c.Handler()(c) - assert.Equal(t, 4, c.Get("d")) - assert.Equal(t, "/users/new", c.Get("path")) + var testCases = []struct { + whenMethod string + whenURL string + expectRoute interface{} + expectParam map[string]string + expectError error + }{ + { + whenURL: "/dictionary/skills", + expectRoute: "/dictionary/skills", + expectParam: map[string]string{"*": ""}, + }, + { + whenURL: "/dictionary/skillsnot", + expectRoute: "/dictionary/:name", + expectParam: map[string]string{"name": "skillsnot"}, + }, + { + whenURL: "/dictionary/type", + expectRoute: "/dictionary/:name", + expectParam: map[string]string{"name": "type"}, + }, + { + whenURL: "/server", + expectRoute: "/server", + }, + { + whenURL: "/users/new", + expectRoute: "/users/new", + }, + { + whenURL: "/users/new2", + expectRoute: "/users/:name", + expectParam: map[string]string{"name": "new2"}, + }, + { + whenURL: "/", + expectRoute: "/", + }, + } + for _, tc := range testCases { + t.Run(tc.whenURL, func(t *testing.T) { + c := e.NewContext(nil, nil).(*context) - c = e.NewContext(nil, nil) - r.Find(http.MethodGet, "/users/new2", c) - c.Handler()(c) - assert.Equal(t, 5, c.Get("e")) - assert.Equal(t, "/users/:name", c.Get("path")) + method := http.MethodGet + if tc.whenMethod != "" { + method = tc.whenMethod + } + r.Find(method, tc.whenURL, c) + err := c.handler(c) - c = e.NewContext(nil, nil) - r.Find(http.MethodGet, "/", c) - c.Handler()(c) - assert.Equal(t, 6, c.Get("f")) - assert.Equal(t, "/", c.Get("path")) + if tc.expectError != nil { + assert.Equal(t, tc.expectError, err) + } else { + assert.NoError(t, err) + } + assert.Equal(t, tc.expectRoute, c.Get("path")) + for param, expectedValue := range tc.expectParam { + assert.Equal(t, expectedValue, c.Param(param)) + } + checkUnusedParamValues(t, c, tc.expectParam) + }) + } } // Issue #1348 @@ -1438,42 +1828,76 @@ func TestRouterParamBacktraceNotFound(t *testing.T) { r := e.router // Add - r.Add(http.MethodGet, "/:param1", func(c Context) error { - return nil - }) - r.Add(http.MethodGet, "/:param1/foo", func(c Context) error { - return nil - }) - r.Add(http.MethodGet, "/:param1/bar", func(c Context) error { - return nil - }) - r.Add(http.MethodGet, "/:param1/bar/:param2", func(c Context) error { - return nil - }) - - c := e.NewContext(nil, nil).(*context) - - //Find - r.Find(http.MethodGet, "/a", c) - assert.Equal(t, "a", c.Param("param1")) - - c = e.NewContext(nil, nil).(*context) - r.Find(http.MethodGet, "/a/foo", c) - assert.Equal(t, "a", c.Param("param1")) - - c = e.NewContext(nil, nil).(*context) - r.Find(http.MethodGet, "/a/bar", c) - assert.Equal(t, "a", c.Param("param1")) + r.Add(http.MethodGet, "/:param1", handlerFunc) + r.Add(http.MethodGet, "/:param1/foo", handlerFunc) + r.Add(http.MethodGet, "/:param1/bar", handlerFunc) + r.Add(http.MethodGet, "/:param1/bar/:param2", handlerFunc) + + var testCases = []struct { + name string + whenMethod string + whenURL string + expectRoute interface{} + expectParam map[string]string + expectError error + }{ + { + name: "route /a to /:param1", + whenURL: "/a", + expectRoute: "/:param1", + expectParam: map[string]string{"param1": "a"}, + }, + { + name: "route /a/foo to /:param1/foo", + whenURL: "/a/foo", + expectRoute: "/:param1/foo", + expectParam: map[string]string{"param1": "a"}, + }, + { + name: "route /a/bar to /:param1/bar", + whenURL: "/a/bar", + expectRoute: "/:param1/bar", + expectParam: map[string]string{"param1": "a"}, + }, + { + name: "route /a/bar/b to /:param1/bar/:param2", + whenURL: "/a/bar/b", + expectRoute: "/:param1/bar/:param2", + expectParam: map[string]string{ + "param1": "a", + "param2": "b", + }, + }, + { + name: "route /a/bbbbb should return 404", + whenURL: "/a/bbbbb", + expectRoute: nil, + expectError: ErrNotFound, + }, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + c := e.NewContext(nil, nil).(*context) - c = e.NewContext(nil, nil).(*context) - r.Find(http.MethodGet, "/a/bar/b", c) - assert.Equal(t, "a", c.Param("param1")) - assert.Equal(t, "b", c.Param("param2")) + method := http.MethodGet + if tc.whenMethod != "" { + method = tc.whenMethod + } + r.Find(method, tc.whenURL, c) + err := c.handler(c) - c = e.NewContext(nil, nil).(*context) - r.Find(http.MethodGet, "/a/bbbbb", c) - he := c.handler(c).(*HTTPError) - assert.Equal(t, http.StatusNotFound, he.Code) + if tc.expectError != nil { + assert.Equal(t, tc.expectError, err) + } else { + assert.NoError(t, err) + } + assert.Equal(t, tc.expectRoute, c.Get("path")) + for param, expectedValue := range tc.expectParam { + assert.Equal(t, expectedValue, c.Param(param)) + } + checkUnusedParamValues(t, c, tc.expectParam) + }) + } } func testRouterAPI(t *testing.T, api []*Route) { @@ -1487,13 +1911,15 @@ func testRouterAPI(t *testing.T, api []*Route) { } c := e.NewContext(nil, nil).(*context) for _, route := range api { - r.Find(route.Method, route.Path, c) - tokens := strings.Split(route.Path[1:], "/") - for _, token := range tokens { - if token[0] == ':' { - assert.Equal(t, c.Param(token[1:]), token) + t.Run(route.Path, func(t *testing.T) { + r.Find(route.Method, route.Path, c) + tokens := strings.Split(route.Path[1:], "/") + for _, token := range tokens { + if token[0] == ':' { + assert.Equal(t, c.Param(token[1:]), token) + } } - } + }) } } @@ -1552,79 +1978,93 @@ func TestRouterParam1466(t *testing.T) { e := New() r := e.router - r.Add(http.MethodPost, "/users/signup", func(c Context) error { - return nil - }) - r.Add(http.MethodPost, "/users/signup/bulk", func(c Context) error { - return nil - }) - r.Add(http.MethodPost, "/users/survey", func(c Context) error { - return nil - }) - r.Add(http.MethodGet, "/users/:username", func(c Context) error { - return nil - }) - r.Add(http.MethodGet, "/interests/:name/users", func(c Context) error { - return nil - }) - r.Add(http.MethodGet, "/skills/:name/users", func(c Context) error { - return nil - }) + r.Add(http.MethodPost, "/users/signup", handlerFunc) + r.Add(http.MethodPost, "/users/signup/bulk", handlerFunc) + r.Add(http.MethodPost, "/users/survey", handlerFunc) + r.Add(http.MethodGet, "/users/:username", handlerFunc) + r.Add(http.MethodGet, "/interests/:name/users", handlerFunc) + r.Add(http.MethodGet, "/skills/:name/users", handlerFunc) // Additional routes for Issue 1479 - r.Add(http.MethodGet, "/users/:username/likes/projects/ids", func(c Context) error { - return nil - }) - r.Add(http.MethodGet, "/users/:username/profile", func(c Context) error { - return nil - }) - r.Add(http.MethodGet, "/users/:username/uploads/:type", func(c Context) error { - return nil - }) - - c := e.NewContext(nil, nil).(*context) - - r.Find(http.MethodGet, "/users/ajitem", c) - assert.Equal(t, "ajitem", c.Param("username")) - - c = e.NewContext(nil, nil).(*context) - r.Find(http.MethodGet, "/users/sharewithme", c) - assert.Equal(t, "sharewithme", c.Param("username")) - - c = e.NewContext(nil, nil).(*context) - r.Find(http.MethodGet, "/users/signup", c) - assert.Equal(t, "", c.Param("username")) - // Additional assertions for #1479 - c = e.NewContext(nil, nil).(*context) - r.Find(http.MethodGet, "/users/sharewithme/likes/projects/ids", c) - assert.Equal(t, "sharewithme", c.Param("username")) - - c = e.NewContext(nil, nil).(*context) - r.Find(http.MethodGet, "/users/ajitem/likes/projects/ids", c) - assert.Equal(t, "ajitem", c.Param("username")) - - c = e.NewContext(nil, nil).(*context) - r.Find(http.MethodGet, "/users/sharewithme/profile", c) - assert.Equal(t, "sharewithme", c.Param("username")) - - c = e.NewContext(nil, nil).(*context) - r.Find(http.MethodGet, "/users/ajitem/profile", c) - assert.Equal(t, "ajitem", c.Param("username")) - - c = e.NewContext(nil, nil).(*context) - r.Find(http.MethodGet, "/users/sharewithme/uploads/self", c) - assert.Equal(t, "sharewithme", c.Param("username")) - assert.Equal(t, "self", c.Param("type")) - - c = e.NewContext(nil, nil).(*context) - r.Find(http.MethodGet, "/users/ajitem/uploads/self", c) - assert.Equal(t, "ajitem", c.Param("username")) - assert.Equal(t, "self", c.Param("type")) - - // Issue #1493 - check for routing loop - c = e.NewContext(nil, nil).(*context) - r.Find(http.MethodGet, "/users/tree/free", c) - assert.Equal(t, "", c.Param("id")) - assert.Equal(t, 0, c.response.Status) + r.Add(http.MethodGet, "/users/:username/likes/projects/ids", handlerFunc) + r.Add(http.MethodGet, "/users/:username/profile", handlerFunc) + r.Add(http.MethodGet, "/users/:username/uploads/:type", handlerFunc) + + var testCases = []struct { + whenURL string + expectRoute interface{} + expectParam map[string]string + }{ + { + whenURL: "/users/ajitem", + expectRoute: "/users/:username", + expectParam: map[string]string{"username": "ajitem"}, + }, + { + whenURL: "/users/sharewithme", + expectRoute: "/users/:username", + expectParam: map[string]string{"username": "sharewithme"}, + }, + { + whenURL: "/users/signup", + expectRoute: nil, // method not found as this route is for POST but request is for GET + expectParam: map[string]string{"username": ""}, + }, + // Additional assertions for #1479 + { + whenURL: "/users/sharewithme/likes/projects/ids", + expectRoute: "/users/:username/likes/projects/ids", + expectParam: map[string]string{"username": "sharewithme"}, + }, + { + whenURL: "/users/ajitem/likes/projects/ids", + expectRoute: "/users/:username/likes/projects/ids", + expectParam: map[string]string{"username": "ajitem"}, + }, + { + whenURL: "/users/sharewithme/profile", + expectRoute: "/users/:username/profile", + expectParam: map[string]string{"username": "sharewithme"}, + }, + { + whenURL: "/users/ajitem/profile", + expectRoute: "/users/:username/profile", + expectParam: map[string]string{"username": "ajitem"}, + }, + { + whenURL: "/users/sharewithme/uploads/self", + expectRoute: "/users/:username/uploads/:type", + expectParam: map[string]string{ + "username": "sharewithme", + "type": "self", + }, + }, + { + whenURL: "/users/ajitem/uploads/self", + expectRoute: "/users/:username/uploads/:type", + expectParam: map[string]string{ + "username": "ajitem", + "type": "self", + }, + }, + { + whenURL: "/users/tree/free", + expectRoute: nil, // not found + expectParam: map[string]string{"id": ""}, + }, + } + for _, tc := range testCases { + t.Run(tc.whenURL, func(t *testing.T) { + c := e.NewContext(nil, nil).(*context) + + r.Find(http.MethodGet, tc.whenURL, c) + c.handler(c) + assert.Equal(t, tc.expectRoute, c.Get("path")) + for param, expectedValue := range tc.expectParam { + assert.Equal(t, expectedValue, c.Param(param)) + } + checkUnusedParamValues(t, c, tc.expectParam) + }) + } } // Issue #1655 @@ -1669,33 +2109,56 @@ func TestRouterPanicWhenParamNoRootOnlyChildsFailsFind(t *testing.T) { e := New() r := e.router - r.Add(http.MethodGet, "/users/create", handlerHelper("create", 1)) - r.Add(http.MethodGet, "/users/:id/edit", func(c Context) error { - return nil - }) - r.Add(http.MethodGet, "/users/:id/active", func(c Context) error { - return nil - }) - - c := e.NewContext(nil, nil).(*context) - r.Find(http.MethodGet, "/users/alice/edit", c) - assert.Equal(t, "alice", c.Param("id")) + r.Add(http.MethodGet, "/users/create", handlerFunc) + r.Add(http.MethodGet, "/users/:id/edit", handlerFunc) + r.Add(http.MethodGet, "/users/:id/active", handlerFunc) + + var testCases = []struct { + whenURL string + expectRoute interface{} + expectParam map[string]string + expectStatus int + }{ + { + whenURL: "/users/alice/edit", + expectRoute: "/users/:id/edit", + expectParam: map[string]string{"id": "alice"}, + }, + { + whenURL: "/users/bob/active", + expectRoute: "/users/:id/active", + expectParam: map[string]string{"id": "bob"}, + }, + { + whenURL: "/users/create", + expectRoute: "/users/create", + expectParam: nil, + }, + //This panic before the fix for Issue #1653 + { + whenURL: "/users/createNotFound", + expectStatus: http.StatusNotFound, + }, + } + for _, tc := range testCases { + t.Run(tc.whenURL, func(t *testing.T) { + c := e.NewContext(nil, nil).(*context) - c = e.NewContext(nil, nil).(*context) - r.Find(http.MethodGet, "/users/bob/active", c) - assert.Equal(t, "bob", c.Param("id")) + r.Find(http.MethodGet, tc.whenURL, c) + err := c.handler(c) - c = e.NewContext(nil, nil).(*context) - r.Find(http.MethodGet, "/users/create", c) - c.Handler()(c) - assert.Equal(t, 1, c.Get("create")) - assert.Equal(t, "/users/create", c.Get("path")) - - //This panic before the fix for Issue #1653 - c = e.NewContext(nil, nil).(*context) - r.Find(http.MethodGet, "/users/createNotFound", c) - he := c.Handler()(c).(*HTTPError) - assert.Equal(t, http.StatusNotFound, he.Code) + if tc.expectStatus != 0 { + assert.Error(t, err) + he := err.(*HTTPError) + assert.Equal(t, tc.expectStatus, he.Code) + } + assert.Equal(t, tc.expectRoute, c.Get("path")) + for param, expectedValue := range tc.expectParam { + assert.Equal(t, expectedValue, c.Param(param)) + } + checkUnusedParamValues(t, c, tc.expectParam) + }) + } } func benchmarkRouterRoutes(b *testing.B, routes []*Route, routesToFind []*Route) { @@ -1765,14 +2228,14 @@ func (n *node) printTree(pfx string, tail bool) { p = prefix(tail, pfx, " ", "│ ") - children := n.staticChildrens + children := n.staticChildren l := len(children) - if n.paramChildren != nil { - n.paramChildren.printTree(p, n.anyChildren == nil && l == 0) + if n.paramChild != nil { + n.paramChild.printTree(p, n.anyChild == nil && l == 0) } - if n.anyChildren != nil { - n.anyChildren.printTree(p, l == 0) + if n.anyChild != nil { + n.anyChild.printTree(p, l == 0) } for i := 0; i < l-1; i++ { children[i].printTree(p, false)