diff --git a/glide.lock b/glide.lock index 533b25bd3..ef44d0cfa 100644 --- a/glide.lock +++ b/glide.lock @@ -77,8 +77,6 @@ imports: version: c0795c8afcf41dd1d786bebce68636c199b3bb45 - name: github.com/jteeuwen/go-bindata version: 6025e8de665b31fa74ab1a66f2cddd8c0abf887e -- name: github.com/julienschmidt/httprouter - version: 26a05976f9bf5c3aa992cc20e8588c359418ee58 - name: github.com/kardianos/osext version: 2bc1f35cddc0cc527b4bc3dce8578fc2a6c11384 - name: github.com/kisielk/errcheck diff --git a/glide.yaml b/glide.yaml index 9edf0f49c..327bbf7df 100644 --- a/glide.yaml +++ b/glide.yaml @@ -16,9 +16,6 @@ import: version: ^1.3.0 - package: github.com/uber-go/tally version: ^3.3.8 -# update to master since the last semver release is 2 years old -- package: github.com/julienschmidt/httprouter - version: master - package: github.com/kardianos/osext version: master - package: go.uber.org/thriftrw diff --git a/runtime/router.go b/runtime/router.go index 7af517a80..427559aea 100644 --- a/runtime/router.go +++ b/runtime/router.go @@ -25,11 +25,11 @@ import ( "fmt" "net/http" - "github.com/julienschmidt/httprouter" "github.com/opentracing/opentracing-go" "github.com/pborman/uuid" "github.com/pkg/errors" "github.com/uber-go/tally" + zrouter "github.com/uber/zanzibar/runtime/router" "go.uber.org/zap" "net/url" ) @@ -61,9 +61,9 @@ type HTTPRouter interface { // ParamsFromContext extracts the URL parameters that are embedded in the context by the Zanzibar HTTP router implementation. func ParamsFromContext(ctx context.Context) url.Values { - julienParams := httprouter.ParamsFromContext(ctx) + params := zrouter.ParamsFromContext(ctx) urlValues := make(url.Values) - for _, paramValue := range julienParams { + for _, paramValue := range params { urlValues.Add(paramValue.Key, paramValue.Value) } return urlValues @@ -131,7 +131,7 @@ func (endpoint *RouterEndpoint) HandleRequest( // httpRouter data structure to handle and register endpoints type httpRouter struct { gateway *Gateway - httpRouter *httprouter.Router + httpRouter *zrouter.Router notFoundEndpoint *RouterEndpoint methodNotAllowedEndpoint *RouterEndpoint panicCount tally.Counter @@ -167,8 +167,7 @@ func NewHTTPRouter(gateway *Gateway) HTTPRouter { requestUUIDHeaderKey: gateway.requestUUIDHeaderKey, } - router.httpRouter = &httprouter.Router{ - RedirectTrailingSlash: true, + router.httpRouter = &zrouter.Router{ HandleMethodNotAllowed: true, NotFound: http.HandlerFunc(router.handleNotFound), MethodNotAllowed: http.HandlerFunc(router.handleMethodNotAllowed), @@ -179,13 +178,6 @@ func NewHTTPRouter(gateway *Gateway) HTTPRouter { // Register register a handler function. func (router *httpRouter) Handle(method, prefix string, handler http.Handler) (err error) { - defer func() { - recoveredValue := recover() - if recoveredValue != nil { - err = fmt.Errorf("caught error when registering %s %s: %+v", method, prefix, recoveredValue) - } - }() - h := func(w http.ResponseWriter, r *http.Request) { reqUUID := r.Header.Get(router.requestUUIDHeaderKey) if reqUUID == "" { @@ -197,8 +189,7 @@ func (router *httpRouter) Handle(method, prefix string, handler http.Handler) (e handler.ServeHTTP(w, r) } - router.httpRouter.Handler(method, prefix, http.HandlerFunc(h)) - return err + return router.httpRouter.Handle(method, prefix, http.HandlerFunc(h)) } // ServeHTTP implements the http.Handle as a convenience to allow HTTPRouter to be invoked by the standard library HTTP server. diff --git a/runtime/router/router.go b/runtime/router/router.go new file mode 100644 index 000000000..54abfa2a7 --- /dev/null +++ b/runtime/router/router.go @@ -0,0 +1,155 @@ +// Copyright (c) 2019 Uber Technologies, Inc. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +package router + +import ( + "context" + "net/http" + "sort" + "strings" +) + +// Router dispatches http requests to a registered http.Handler. +// It implements a similar interface to the one in github.com/julienschmidt/httprouter, +// the main differences are: +// 1. this router does not treat "/a/:b" and "/a/b/c" as conflicts (https://github.com/julienschmidt/httprouter/issues/175) +// 2. this router does not treat "/a/:b" and "/a/:c" as different routes and therefore does not allow them to be registered at the same time (https://github.com/julienschmidt/httprouter/issues/6) +// 3. this router does not treat "/a" and "/a/" as different routes +// Also the `*` pattern is greedy, if a handler is register for `/a/*`, then no handler +// can be further registered for any path that starts with `/a/` +type Router struct { + tries map[string]*Trie + + // If enabled, the router checks if another method is allowed for the + // current route, if the current request can not be routed. + // If this is the case, the request is answered with 'Method Not Allowed' + // and HTTP status code 405. + // If no other Method is allowed, the request is delegated to the NotFound + // handler. + HandleMethodNotAllowed bool + + // Configurable http.Handler which is called when a request + // cannot be routed and HandleMethodNotAllowed is true. + // If it is not set, http.Error with http.StatusMethodNotAllowed is used. + // The "Allow" header with allowed request methods is set before the handler + // is called. + MethodNotAllowed http.Handler + + // Configurable http.Handler which is called when no matching route is + // found. If it is not set, http.NotFound is used. + NotFound http.Handler + + // Function to handle panics recovered from http handlers. + // It should be used to generate a error page and return the http error code + // 500 (Internal Server Error). + // The handler can be used to keep your server from crashing because of + // unrecovered panics. + PanicHandler func(http.ResponseWriter, *http.Request, interface{}) + + // TODO: (clu) maybe support OPTIONS +} + +type paramsKey string + +// urlParamsKey is the request context key under which URL params are stored. +const urlParamsKey = paramsKey("urlParamsKey") + +// ParamsFromContext pulls the URL parameters from a request context, +// or returns nil if none are present. +func ParamsFromContext(ctx context.Context) []Param { + p, _ := ctx.Value(urlParamsKey).([]Param) + return p +} + +// Handle registers a http.Handler for given method and path. +func (r *Router) Handle(method, path string, handler http.Handler) error { + if r.tries == nil { + r.tries = make(map[string]*Trie) + } + + trie, ok := r.tries[method] + if !ok { + trie = NewTrie() + r.tries[method] = trie + } + return trie.Set(path, handler) +} + +// ServeHTTP dispatches the request to a register handler to handle. +func (r *Router) ServeHTTP(w http.ResponseWriter, req *http.Request) { + if r.PanicHandler != nil { + defer func(w http.ResponseWriter, req *http.Request) { + if recovered := recover(); recovered != nil { + r.PanicHandler(w, req, recovered) + } + }(w, req) + } + + reqPath := req.URL.Path + if trie, ok := r.tries[req.Method]; ok { + if handler, params, err := trie.Get(reqPath); err == nil { + ctx := context.WithValue(req.Context(), urlParamsKey, params) + req = req.WithContext(ctx) + handler.ServeHTTP(w, req) + return + } + } + + if r.HandleMethodNotAllowed { + if allowed := r.allowed(reqPath, req.Method); allowed != "" { + w.Header().Set("Allow", allowed) + if r.MethodNotAllowed != nil { + r.MethodNotAllowed.ServeHTTP(w, req) + } else { + http.Error(w, + http.StatusText(http.StatusMethodNotAllowed), + http.StatusMethodNotAllowed, + ) + } + return + } + } + + if r.NotFound != nil { + r.NotFound.ServeHTTP(w, req) + } else { + http.NotFound(w, req) + } +} + +func (r *Router) allowed(path, reqMethod string) string { + var allow []string + + for method, trie := range r.tries { + if method == reqMethod || method == http.MethodOptions { + continue + } + + if _, _, err := trie.Get(path); err == nil { + allow = append(allow, method) + } + } + sort.Slice(allow, func(i, j int) bool { + return allow[i] < allow[j] + }) + + return strings.Join(allow, ", ") +} diff --git a/runtime/router/router_test.go b/runtime/router/router_test.go new file mode 100644 index 000000000..121a96dda --- /dev/null +++ b/runtime/router/router_test.go @@ -0,0 +1,159 @@ +// Copyright (c) 2019 Uber Technologies, Inc. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +package router + +import ( + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestHandle(t *testing.T) { + r := &Router{} + + handled := false + err := r.Handle("GET", "/*", + http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + handled = true + })) + assert.NoError(t, err, "unexpected error") + + req, _ := http.NewRequest("GET", "/foo", nil) + r.ServeHTTP(nil, req) + assert.True(t, handled) +} + +func TestParamsFromContext(t *testing.T) { + r := &Router{} + + handled := false + err := r.Handle("GET", "/:var", + http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + params := ParamsFromContext(req.Context()) + assert.Equal(t, 1, len(params)) + assert.Equal(t, "var", params[0].Key) + assert.Equal(t, "foo", params[0].Value) + handled = true + })) + assert.NoError(t, err, "unexpected error") + + req, _ := http.NewRequest("GET", "/foo", nil) + r.ServeHTTP(nil, req) + assert.True(t, handled) +} + +func TestPanicHandler(t *testing.T) { + handled := false + r := &Router{ + PanicHandler: func(writer http.ResponseWriter, req *http.Request, i interface{}) { + handled = true + }, + } + + err := r.Handle("GET", "/foo", + http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + panic("something went wrong") + })) + assert.NoError(t, err, "unexpected error") + + req, _ := http.NewRequest("GET", "/foo", nil) + r.ServeHTTP(nil, req) + assert.True(t, handled) +} + +func TestMethodNotAllowedDefault(t *testing.T) { + r := &Router{HandleMethodNotAllowed: true} + + handled := false + err := r.Handle("GET", "/foo", + http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + handled = true + })) + assert.NoError(t, err, "unexpected error") + err = r.Handle("PUT", "/bar", + http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + handled = true + })) + assert.NoError(t, err, "unexpected error") + + req, _ := http.NewRequest("PUT", "/foo", nil) + res := httptest.NewRecorder() + r.ServeHTTP(res, req) + assert.False(t, handled) + assert.Equal(t, http.StatusMethodNotAllowed, res.Result().StatusCode) + assert.Equal(t, "GET", res.Result().Header.Get("Allow")) +} + +func TestMethodNotAllowedCustom(t *testing.T) { + r := &Router{ + HandleMethodNotAllowed: true, + MethodNotAllowed: http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + w.Header().Set("life", "42") + w.WriteHeader(http.StatusMethodNotAllowed) + }), + } + + handled := false + err := r.Handle("GET", "/foo", + http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + handled = true + })) + assert.NoError(t, err, "unexpected error") + err = r.Handle("POST", "/foo", + http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + handled = true + })) + assert.NoError(t, err, "unexpected error") + + req, _ := http.NewRequest("PUT", "/foo", nil) + res := httptest.NewRecorder() + r.ServeHTTP(res, req) + assert.False(t, handled) + assert.Equal(t, "42", res.Result().Header.Get("life")) + assert.Equal(t, "GET, POST", res.Result().Header.Get("Allow")) +} + +func TestNotFoundDefault(t *testing.T) { + r := &Router{} + + req, _ := http.NewRequest("GET", "/foo", nil) + res := httptest.NewRecorder() + r.ServeHTTP(res, req) + assert.Equal(t, http.StatusNotFound, res.Result().StatusCode) +} + +func TestNotFoundCustom(t *testing.T) { + handled := false + r := &Router{ + NotFound: http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + w.WriteHeader(http.StatusNotFound) + handled = true + }), + } + + req, _ := http.NewRequest("GET", "/foo", nil) + res := httptest.NewRecorder() + r.ServeHTTP(res, req) + assert.True(t, handled) + assert.Equal(t, http.StatusNotFound, res.Result().StatusCode) +} diff --git a/runtime/router/trie.go b/runtime/router/trie.go new file mode 100644 index 000000000..4ecd22ede --- /dev/null +++ b/runtime/router/trie.go @@ -0,0 +1,310 @@ +// Copyright (c) 2019 Uber Technologies, Inc. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +package router + +import ( + "errors" + "fmt" + "net/http" + "strings" +) + +var ( + errPath = errors.New("bad path") + errExist = errors.New("path value already set") + errNotFound = errors.New("not found") +) + +type paramMismatch struct { + expected, actual string + existingPath string +} + +// Error returns the error string +func (e *paramMismatch) Error() string { + return fmt.Sprintf("param key mismatch: expected is %s but got %s", e.expected, e.actual) +} + +// Param is a url parameter where key is the url segment pattern (without :) and +// value is the actual segment of a matched url. +// e.g. url /foo/123 matches /foo/:id, the url param has key "id" and value "123" +type Param struct { + Key, Value string +} + +// Trie is a radix trie to store string value at given url path, +// a trie node corresponds to an arbitrary path substring. +type Trie struct { + root *tnode +} + +type tnode struct { + key string + value http.Handler + children []*tnode +} + +// NewTrie creates a new trie. +func NewTrie() *Trie { + return &Trie{ + root: &tnode{ + key: "", + }, + } +} + +// Set sets the value for given path, returns error if path already set. +// When a http.Handler is registered for a given path, a subsequent Get returns the registered +// handler if the url passed to Get call matches the set path. Match in this context could mean either +// equality (e.g. url is "/foo" and path is "/foo") or url matches path pattern, which has two forms: +// - path ends with "/*", e.g. url "/foo" and "/foo/bar" both matches path "/*" +// - path contains colon wildcard ("/:"), e.g. url "/a/b" and "/a/c" bot matches path "/a/:var" +func (t *Trie) Set(path string, value http.Handler) error { + if path == "" || strings.Contains(path, "//") { + return errPath + } + if !strings.HasPrefix(path, "/") { + path = "/" + path + } + // ignore trailing slash + path = strings.TrimSuffix(path, "/") + + // validate "*" + if strings.Contains(path, "*") && !strings.HasSuffix(path, "/*") { + return errors.New("/* must be the last path segment") + } + if strings.Count(path, "*") > 1 { + return errors.New("path can not contain more than one *") + } + + err := t.root.set(path, value, false, false) + if e, ok := err.(*paramMismatch); ok { + return fmt.Errorf("path %q has a different param key %q, it should be the same key %q as in existing path %q", path, e.actual, e.expected, e.existingPath) + } + return err +} + +// Get returns the http.Handler for given path, returns error if not found. +// It also returns the url params if given path contains any, e.g. if a handler is registered for +// "/:foo/bar", then calling Get with path "/xyz/bar" returns a param whose key is "foo" and value is "xyz". +func (t *Trie) Get(path string) (http.Handler, []Param, error) { + if path == "" || strings.Contains(path, "//") { + return nil, nil, errPath + } + if !strings.HasPrefix(path, "/") { + path = "/" + path + } + // ignore trailing slash + path = strings.TrimSuffix(path, "/") + return t.root.get(path, false, false, false) +} + +// set sets the handler for given path, creates new child node if necessary +// lastKeyCharSlash tracks whether the previous key char is a '/', used to decide it is a pattern or not +// when the current key char is ':'. lastPathCharSlash tracks whether the previous path char is a '/', +// used to decide it is a pattern or not when the current path char is ':'. +func (t *tnode) set(path string, value http.Handler, lastKeyCharSlash, lastPathCharSlash bool) error { + // find the longest common prefix + var shorterLength, i int + keyLength, pathLength := len(t.key), len(path) + if keyLength > pathLength { + shorterLength = pathLength + } else { + shorterLength = keyLength + } + for i < shorterLength && t.key[i] == path[i] { + i++ + } + + // Find the first character that differs between "path" and this node's key, if it exists. + // If we encounter a colon wildcard, ensure that the wildcard in path matches the wildcard + // in this node's key for that segment. The segment is a colon wildcard only when the colon + // is immediately after slash, e.g. "/:foo", "/x/:y". "/a:b" is not a colon wildcard segment. + var keyMatchIdx, pathMatchIdx int + for keyMatchIdx < keyLength && pathMatchIdx < pathLength { + if (t.key[keyMatchIdx] == ':' && lastKeyCharSlash) || + (path[pathMatchIdx] == ':' && lastPathCharSlash) { + keyStartIdx, pathStartIdx := keyMatchIdx, pathMatchIdx + same := t.key[keyMatchIdx] == path[pathMatchIdx] + for keyMatchIdx < keyLength && t.key[keyMatchIdx] != '/' { + keyMatchIdx++ + } + for pathMatchIdx < pathLength && path[pathMatchIdx] != '/' { + pathMatchIdx++ + } + if same && (keyMatchIdx-keyStartIdx) != (pathMatchIdx-pathStartIdx) { + return ¶mMismatch{ + t.key[keyStartIdx:keyMatchIdx], + path[pathStartIdx:pathMatchIdx], + t.key, + } + } + } else if t.key[keyMatchIdx] == path[pathMatchIdx] { + keyMatchIdx++ + pathMatchIdx++ + } else { + break + } + lastKeyCharSlash = t.key[keyMatchIdx-1] == '/' + lastPathCharSlash = path[pathMatchIdx-1] == '/' + } + + // If the node key is fully matched, we match the rest path with children nodes to see if a value + // already exists for the path. + if keyMatchIdx == keyLength { + for _, c := range t.children { + if _, _, err := c.get(path[pathMatchIdx:], lastKeyCharSlash, lastPathCharSlash, true); err == nil { + return errExist + } + } + } + + // node key is longer than longest common prefix + if i < keyLength { + // key/path suffix being "*" means a conflict + if path[i:] == "*" || t.key[i:] == "*" { + return errExist + } + + // split the node key, add new node with node key minus longest common prefix + split := &tnode{ + key: t.key[i:], + value: t.value, + children: t.children, + } + t.key = t.key[:i] + t.value = nil + t.children = []*tnode{split} + + // path is equal to longest common prefix + // set value on current node after split + if i == pathLength { + t.value = value + } else { + // path is longer than longest common prefix + // add new node with path minus longest common prefix + newNode := &tnode{ + key: path[i:], + value: value, + } + t.children = append(t.children, newNode) + } + } + + // node key is equal to longest common prefix + if i == keyLength { + // path is equal to longest common prefix + if i == pathLength { + // node is guaranteed to have zero value, + // otherwise it would have caused errExist earlier + t.value = value + } else { + // path is longer than node key, try to recurse on node children + for _, c := range t.children { + if c.key[0] == path[i] { + lastKeyCharSlash = i > 0 && t.key[i-1] == '/' + lastPathCharSlash = i > 0 && path[i-1] == '/' + err := c.set(path[i:], value, lastKeyCharSlash, lastPathCharSlash) + if e, ok := err.(*paramMismatch); ok { + e.existingPath = t.key + e.existingPath + return e + } + return err + } + } + // no children to recurse, add node with path minus longest common path + newNode := &tnode{ + key: path[i:], + value: value, + } + t.children = append(t.children, newNode) + } + } + + return nil +} + +func (t *tnode) get(path string, lastKeyCharSlash, lastPathCharSlash, colonAsPattern bool) (http.Handler, []Param, error) { + keyLength, pathLength := len(t.key), len(path) + var params []Param + + // find the longest matched prefix + var keyIdx, pathIdx int + for keyIdx < keyLength && pathIdx < pathLength { + if t.key[keyIdx] == ':' && lastKeyCharSlash { + // wildcard starts - match until next slash + keyStartIdx, pathStartIdx := keyIdx+1, pathIdx + for keyIdx < keyLength && t.key[keyIdx] != '/' { + keyIdx++ + } + for pathIdx < pathLength && path[pathIdx] != '/' { + pathIdx++ + } + params = append(params, Param{t.key[keyStartIdx:keyIdx], path[pathStartIdx:pathIdx]}) + } else if path[pathIdx] == ':' && lastPathCharSlash && colonAsPattern { + // necessary for conflict check used in set call + for keyIdx < keyLength && t.key[keyIdx] != '/' { + keyIdx++ + } + for pathIdx < pathLength && path[pathIdx] != '/' { + pathIdx++ + } + } else if t.key[keyIdx] == path[pathIdx] { + keyIdx++ + pathIdx++ + } else { + break + } + lastKeyCharSlash = t.key[keyIdx-1] == '/' + lastPathCharSlash = path[pathIdx-1] == '/' + } + + if keyIdx < keyLength { + // path matches up to node key's second to last character, + // the last char of node key is "*" and path is no shorter than longest matched prefix + if t.key[keyIdx:] == "*" && pathIdx < pathLength { + return t.value, params, nil + } + return nil, nil, errNotFound + } + + // ':' in path matches '*' in node key + if keyIdx > 0 && t.key[keyIdx-1] == '*' { + return t.value, params, nil + } + + // longest matched prefix matches up to node key length and path length + if pathIdx == pathLength { + if t.value != nil { + return t.value, params, nil + } + return nil, nil, errNotFound + } + + // longest matched prefix matches up to node key length but not path length + for _, c := range t.children { + if v, ps, err := c.get(path[pathIdx:], lastKeyCharSlash, lastPathCharSlash, colonAsPattern); err == nil { + return v, append(params, ps...), nil + } + } + + return nil, nil, errNotFound +} diff --git a/runtime/router/trie_test.go b/runtime/router/trie_test.go new file mode 100644 index 000000000..dee131aeb --- /dev/null +++ b/runtime/router/trie_test.go @@ -0,0 +1,302 @@ +// Copyright (c) 2019 Uber Technologies, Inc. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +package router + +import ( + "bytes" + "fmt" + "io" + "net/http" + "testing" + + "github.com/stretchr/testify/assert" +) + +const ( + get = "get" + set = "set" +) + +type ts struct { + op string + path string + value string + errMsg string + expectedValue string + expectedParams []Param +} + +type namedHandler struct { + id string +} + +func (n namedHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {} + +func runTrieTests(t *testing.T, trie *Trie, tests []ts) { + for _, test := range tests { + if test.op == set { + err := trie.Set(test.path, namedHandler{id: test.value}) + if test.errMsg == "" { + assert.NoError(t, err) + } else { + assert.EqualError(t, err, test.errMsg) + } + } + if test.op == get { + v, ps, err := trie.Get(test.path) + if test.errMsg == "" { + assert.NoError(t, err, test.path) + assert.Equal(t, test.expectedValue, v.(namedHandler).id) + assert.Equal(t, test.expectedParams, ps) + } else { + assert.EqualError(t, err, test.errMsg) + } + } + } + //printTrie(trie) +} + +func TestTrieLiteralPath(t *testing.T) { + tree := NewTrie() + tests := []ts{ + // test blank path + {op: set, path: "", value: "foo", errMsg: errPath.Error()}, + {op: get, path: "", errMsg: errPath.Error()}, + // test root path + {op: set, path: "/", value: "foo"}, + {op: get, path: "/", expectedValue: "foo"}, + // test set + {op: set, path: "/a/b", value: "bar"}, + {op: set, path: "/a/b/c", value: "bar"}, + // test set conflict + {op: set, path: "/a/b/c", value: "baz", errMsg: errExist.Error()}, + // test trailing slash when set + {op: set, path: "/a/b/c/", value: "baz", errMsg: errExist.Error()}, + // test not found + {op: get, path: "/a", errMsg: errNotFound.Error()}, + {op: get, path: "/a/b/d", errMsg: errNotFound.Error()}, + // test get + {op: get, path: "/a/b", expectedValue: "bar"}, + {op: get, path: "/a/b/c", expectedValue: "bar"}, + // test trailing slash when get + {op: get, path: "/a/b/c/", expectedValue: "bar"}, + // test missing starting slash + {op: set, path: "a", value: "foo"}, + {op: get, path: "a", expectedValue: "foo"}, + // test branching + {op: set, path: "/a/e/f", value: "quxx"}, + {op: get, path: "/a/e/f", expectedValue: "quxx"}, + // test segment overlap + {op: set, path: "/a/good", value: "good"}, + {op: set, path: "/a/goto", value: "goto"}, + {op: get, path: "/a/good", expectedValue: "good"}, + {op: get, path: "/a/goto", expectedValue: "goto"}, + } + + runTrieTests(t, tree, tests) +} + +func TestTriePathsWithPatten(t *testing.T) { + trie := NewTrie() + tests := []ts{ + // test setting "/*/a" is not allowed + {op: set, path: "/*/a", value: "foo", errMsg: "/* must be the last path segment"}, + // test setting path with multiple "*" is not allowed + {op: set, path: "/*/*", value: "foo", errMsg: "path can not contain more than one *"}, + // test "/a" does not collide with "/a/*" + {op: set, path: "/a", value: "foo"}, + // test "/a/*" match all paths starts with "/a/" + {op: set, path: "/a/*", value: "bar"}, + {op: get, path: "a", expectedValue: "foo"}, + {op: get, path: "/a/b", expectedValue: "bar"}, + {op: get, path: "/a/b/c/d", expectedValue: "bar"}, + {op: get, path: "/a/b/c/d", expectedValue: "bar"}, + // test paths starts with "/a/" collides with "/a/*" + {op: set, path: "/a/b/", value: "baz", errMsg: errExist.Error()}, + {op: set, path: "/a/b/c", value: "baz", errMsg: errExist.Error()}, + {op: set, path: "/a/:b", value: "baz", errMsg: errExist.Error()}, + // test "/*" collides with "/a" + {op: set, path: "/*", value: "baz", errMsg: errExist.Error()}, + // test "/:" collides with "/a" + {op: set, path: "/:x", value: "baz", errMsg: errExist.Error()}, + // test "/:/b" collides with "/a/*" + {op: set, path: "/:x/b", value: "baz", errMsg: errExist.Error()}, + } + runTrieTests(t, trie, tests) + + trie = NewTrie() + tests = []ts{ + // test ":a" is not treated as a pattern when queried as a url + {op: set, path: "/a", value: "foo"}, + {op: get, path: "/:a", errMsg: errNotFound.Error()}, + + {op: set, path: "/a:b", value: "bar"}, + {op: set, path: "/a:c", value: "baz"}, + {op: get, path: "/a:b", expectedValue: "bar"}, + {op: get, path: "/ac", errMsg: errNotFound.Error()}, + {op: get, path: "/a:", errMsg: errNotFound.Error()}, + } + runTrieTests(t, trie, tests) + + trie = NewTrie() + tests = []ts{ + {op: set, path: "/:a", value: "foo"}, + {op: get, path: "/:a", expectedValue: "foo", expectedParams: []Param{{"a", ":a"}}}, + } + runTrieTests(t, trie, tests) + + trie = NewTrie() + tests = []ts{ + // test "/a" does not collide with "/:a/b" + {op: set, path: "/:a/b", value: "foo"}, + {op: set, path: "/a", value: "bar"}, + {op: get, path: "/a", expectedValue: "bar"}, + {op: get, path: "/x/b/", expectedValue: "foo", expectedParams: []Param{{"a", "x"}}}, + {op: get, path: "/a/", expectedValue: "bar"}, + } + runTrieTests(t, trie, tests) + + trie = NewTrie() + tests = []ts{ + // test "/:a/b" does not collide with "/a" + {op: set, path: "/a", value: "bar"}, + {op: set, path: "/:a/b", value: "foo"}, + {op: get, path: "/a", expectedValue: "bar"}, + {op: get, path: "/x/b/", expectedValue: "foo", expectedParams: []Param{{"a", "x"}}}, + {op: get, path: "/a/", expectedValue: "bar"}, + } + runTrieTests(t, trie, tests) + + trie = NewTrie() + tests = []ts{ + // test "/b" collides with "/:" + {op: set, path: "/:a", value: "foo"}, + {op: set, path: "/b", errMsg: errExist.Error()}, + {op: get, path: "/a/", expectedValue: "foo", expectedParams: []Param{{"a", "a"}}}, + } + runTrieTests(t, trie, tests) + + trie = NewTrie() + tests = []ts{ + // test "/:" collides with "/b" + {op: set, path: "/b", value: "foo"}, + {op: set, path: "/:a", errMsg: errExist.Error()}, + {op: get, path: "/b/", expectedValue: "foo"}, + } + runTrieTests(t, trie, tests) + + trie = NewTrie() + tests = []ts{ + // more ":" tests + {op: set, path: "/a/b", value: "1"}, + {op: set, path: "/a/b/:cc/d", value: "2"}, + {op: set, path: "/a/b/:x/e", errMsg: "path \"/a/b/:x/e\" has a different param key \":x\", it should be the same key \":cc\" as in existing path \"/a/b/:cc/d\""}, + {op: set, path: "/a/b/c/x", value: "2.1"}, + {op: set, path: "/a/b/:cc/:d/e", value: "3"}, + {op: set, path: "/a/b/c/d/f", value: "4"}, + {op: set, path: "/a/:b/c/d", errMsg: errExist.Error()}, + {op: get, path: "/a/b/some/d", expectedValue: "2", expectedParams: []Param{{"cc", "some"}}}, + {op: get, path: "/a/b/c/x", expectedValue: "2.1"}, + {op: get, path: "/a/b/other/data/e", expectedValue: "3", + expectedParams: []Param{ + {"cc", "other"}, + {"d", "data"}, + }}, + {op: get, path: "/a/b/c/d/f", expectedValue: "4"}, + } + runTrieTests(t, trie, tests) + + trie = NewTrie() + tests = []ts{ + // more ":" tests + {op: set, path: "/a/b", value: "1"}, + {op: set, path: "/a/b/ccc/x", value: "2"}, + {op: set, path: "/a/b/c/dope/f", value: "3"}, + {op: set, path: "/a/b/ccc/:", errMsg: errExist.Error()}, + {op: set, path: "/a/b/c/:/:/", errMsg: errExist.Error()}, + {op: get, path: "/a/b/ccc", errMsg: errNotFound.Error()}, + {op: get, path: "/a/b/:", errMsg: errNotFound.Error()}, + } + runTrieTests(t, trie, tests) + + trie = NewTrie() + tests = []ts{ + // more ":" tests + {op: set, path: "/a/:b/c", value: "1"}, + {op: set, path: "/a/:b/d", value: "2"}, + {op: get, path: "/a/b/c", expectedValue: "1", expectedParams: []Param{{"b", "b"}}}, + {op: get, path: "/a/b/d", expectedValue: "2", expectedParams: []Param{{"b", "b"}}}, + } + runTrieTests(t, trie, tests) +} + +// simple test for coverage +func TestParamMismatch(t *testing.T) { + pm := paramMismatch{ + expected: "foo", + actual: "bar", + } + assert.Equal(t, "param key mismatch: expected is foo but got bar", pm.Error()) +} + +// utilities for debugging +func printTrie(t *Trie) { + buf := new(bytes.Buffer) + var levelsEnded []int + printNodes(buf, t.root.children, 0, levelsEnded) + fmt.Println(string(buf.Bytes())) +} + +func printNodes(w io.Writer, nodes []*tnode, level int, levelsEnded []int) { + for i, node := range nodes { + edge := "├──" + if i == len(nodes)-1 { + levelsEnded = append(levelsEnded, level) + edge = "└──" + } + printNode(w, node, level, levelsEnded, edge) + if len(node.children) > 0 { + printNodes(w, node.children, level+1, levelsEnded) + } + } +} + +func printNode(w io.Writer, node *tnode, level int, levelsEnded []int, edge string) { + for i := 0; i < level; i++ { + isEnded := false + for _, l := range levelsEnded { + if l == i { + isEnded = true + break + } + } + if isEnded { + _, _ = fmt.Fprint(w, " ") + } else { + _, _ = fmt.Fprint(w, "│ ") + } + } + if node.value != nil { + _, _ = fmt.Fprintf(w, "%s %v (%v)\n", edge, node.key, node.value) + } else { + _, _ = fmt.Fprintf(w, "%s %v\n", edge, node.key) + } +} diff --git a/runtime/router_test.go b/runtime/router_test.go index 122dbec06..b3c41e978 100644 --- a/runtime/router_test.go +++ b/runtime/router_test.go @@ -81,8 +81,8 @@ func (s *routerSuite) TestRouter() { }{ {"GET", "/notfound", nil, http.StatusNotFound, []byte("404 page not found\n")}, {"GET", "/noslash", nil, http.StatusOK, []byte("noslash\n")}, - {"GET", "/noslash/", nil, http.StatusMovedPermanently, []byte(`Moved Permanently.` + "\n\n")}, - {"GET", "/withslash", nil, http.StatusMovedPermanently, []byte(`Moved Permanently.` + "\n\n")}, + {"GET", "/noslash/", nil, http.StatusOK, []byte("noslash\n")}, + {"GET", "/withslash", nil, http.StatusOK, []byte("withslash\n")}, {"GET", "/withslash/", nil, http.StatusOK, []byte("withslash\n")}, {"GET", "/postonly", nil, http.StatusMethodNotAllowed, []byte("Method Not Allowed\n")}, {"GET", "/panicerror", nil, http.StatusInternalServerError, []byte("Internal Server Error\n")}, diff --git a/test/lib/test_backend/test_http_backend.go b/test/lib/test_backend/test_http_backend.go index 96144af1d..05e0007ef 100644 --- a/test/lib/test_backend/test_http_backend.go +++ b/test/lib/test_backend/test_http_backend.go @@ -25,8 +25,8 @@ import ( "strconv" "sync" - "github.com/julienschmidt/httprouter" "github.com/uber/zanzibar/runtime" + zrouter "github.com/uber/zanzibar/runtime/router" "go.uber.org/zap" ) @@ -38,7 +38,7 @@ type TestHTTPBackend struct { RealPort int32 RealAddr string WaitGroup *sync.WaitGroup - router *httprouter.Router + router *zrouter.Router } // BuildHTTPBackends returns a map of backends based on config @@ -83,7 +83,7 @@ func (backend *TestHTTPBackend) Bootstrap() error { func (backend *TestHTTPBackend) HandleFunc( method string, path string, handler http.HandlerFunc, ) { - backend.router.HandlerFunc(method, path, handler) + _ = backend.router.Handle(method, path, handler) } // Close ... @@ -103,7 +103,7 @@ func CreateHTTPBackend(port int32) *TestHTTPBackend { IP: "127.0.0.1", Port: port, WaitGroup: &sync.WaitGroup{}, - router: &httprouter.Router{ + router: &zrouter.Router{ HandleMethodNotAllowed: true, }, }