diff --git a/glide.lock b/glide.lock
index 533b25bd3..ef44d0cfa 100644
--- a/glide.lock
+++ b/glide.lock
@@ -77,8 +77,6 @@ imports:
version: c0795c8afcf41dd1d786bebce68636c199b3bb45
- name: github.com/jteeuwen/go-bindata
version: 6025e8de665b31fa74ab1a66f2cddd8c0abf887e
-- name: github.com/julienschmidt/httprouter
- version: 26a05976f9bf5c3aa992cc20e8588c359418ee58
- name: github.com/kardianos/osext
version: 2bc1f35cddc0cc527b4bc3dce8578fc2a6c11384
- name: github.com/kisielk/errcheck
diff --git a/glide.yaml b/glide.yaml
index 9edf0f49c..327bbf7df 100644
--- a/glide.yaml
+++ b/glide.yaml
@@ -16,9 +16,6 @@ import:
version: ^1.3.0
- package: github.com/uber-go/tally
version: ^3.3.8
-# update to master since the last semver release is 2 years old
-- package: github.com/julienschmidt/httprouter
- version: master
- package: github.com/kardianos/osext
version: master
- package: go.uber.org/thriftrw
diff --git a/runtime/router.go b/runtime/router.go
index 7af517a80..427559aea 100644
--- a/runtime/router.go
+++ b/runtime/router.go
@@ -25,11 +25,11 @@ import (
"fmt"
"net/http"
- "github.com/julienschmidt/httprouter"
"github.com/opentracing/opentracing-go"
"github.com/pborman/uuid"
"github.com/pkg/errors"
"github.com/uber-go/tally"
+ zrouter "github.com/uber/zanzibar/runtime/router"
"go.uber.org/zap"
"net/url"
)
@@ -61,9 +61,9 @@ type HTTPRouter interface {
// ParamsFromContext extracts the URL parameters that are embedded in the context by the Zanzibar HTTP router implementation.
func ParamsFromContext(ctx context.Context) url.Values {
- julienParams := httprouter.ParamsFromContext(ctx)
+ params := zrouter.ParamsFromContext(ctx)
urlValues := make(url.Values)
- for _, paramValue := range julienParams {
+ for _, paramValue := range params {
urlValues.Add(paramValue.Key, paramValue.Value)
}
return urlValues
@@ -131,7 +131,7 @@ func (endpoint *RouterEndpoint) HandleRequest(
// httpRouter data structure to handle and register endpoints
type httpRouter struct {
gateway *Gateway
- httpRouter *httprouter.Router
+ httpRouter *zrouter.Router
notFoundEndpoint *RouterEndpoint
methodNotAllowedEndpoint *RouterEndpoint
panicCount tally.Counter
@@ -167,8 +167,7 @@ func NewHTTPRouter(gateway *Gateway) HTTPRouter {
requestUUIDHeaderKey: gateway.requestUUIDHeaderKey,
}
- router.httpRouter = &httprouter.Router{
- RedirectTrailingSlash: true,
+ router.httpRouter = &zrouter.Router{
HandleMethodNotAllowed: true,
NotFound: http.HandlerFunc(router.handleNotFound),
MethodNotAllowed: http.HandlerFunc(router.handleMethodNotAllowed),
@@ -179,13 +178,6 @@ func NewHTTPRouter(gateway *Gateway) HTTPRouter {
// Register register a handler function.
func (router *httpRouter) Handle(method, prefix string, handler http.Handler) (err error) {
- defer func() {
- recoveredValue := recover()
- if recoveredValue != nil {
- err = fmt.Errorf("caught error when registering %s %s: %+v", method, prefix, recoveredValue)
- }
- }()
-
h := func(w http.ResponseWriter, r *http.Request) {
reqUUID := r.Header.Get(router.requestUUIDHeaderKey)
if reqUUID == "" {
@@ -197,8 +189,7 @@ func (router *httpRouter) Handle(method, prefix string, handler http.Handler) (e
handler.ServeHTTP(w, r)
}
- router.httpRouter.Handler(method, prefix, http.HandlerFunc(h))
- return err
+ return router.httpRouter.Handle(method, prefix, http.HandlerFunc(h))
}
// ServeHTTP implements the http.Handle as a convenience to allow HTTPRouter to be invoked by the standard library HTTP server.
diff --git a/runtime/router/router.go b/runtime/router/router.go
new file mode 100644
index 000000000..54abfa2a7
--- /dev/null
+++ b/runtime/router/router.go
@@ -0,0 +1,155 @@
+// Copyright (c) 2019 Uber Technologies, Inc.
+//
+// Permission is hereby granted, free of charge, to any person obtaining a copy
+// of this software and associated documentation files (the "Software"), to deal
+// in the Software without restriction, including without limitation the rights
+// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+// copies of the Software, and to permit persons to whom the Software is
+// furnished to do so, subject to the following conditions:
+//
+// The above copyright notice and this permission notice shall be included in
+// all copies or substantial portions of the Software.
+//
+// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+// THE SOFTWARE.
+
+package router
+
+import (
+ "context"
+ "net/http"
+ "sort"
+ "strings"
+)
+
+// Router dispatches http requests to a registered http.Handler.
+// It implements a similar interface to the one in github.com/julienschmidt/httprouter,
+// the main differences are:
+// 1. this router does not treat "/a/:b" and "/a/b/c" as conflicts (https://github.com/julienschmidt/httprouter/issues/175)
+// 2. this router does not treat "/a/:b" and "/a/:c" as different routes and therefore does not allow them to be registered at the same time (https://github.com/julienschmidt/httprouter/issues/6)
+// 3. this router does not treat "/a" and "/a/" as different routes
+// Also the `*` pattern is greedy, if a handler is register for `/a/*`, then no handler
+// can be further registered for any path that starts with `/a/`
+type Router struct {
+ tries map[string]*Trie
+
+ // If enabled, the router checks if another method is allowed for the
+ // current route, if the current request can not be routed.
+ // If this is the case, the request is answered with 'Method Not Allowed'
+ // and HTTP status code 405.
+ // If no other Method is allowed, the request is delegated to the NotFound
+ // handler.
+ HandleMethodNotAllowed bool
+
+ // Configurable http.Handler which is called when a request
+ // cannot be routed and HandleMethodNotAllowed is true.
+ // If it is not set, http.Error with http.StatusMethodNotAllowed is used.
+ // The "Allow" header with allowed request methods is set before the handler
+ // is called.
+ MethodNotAllowed http.Handler
+
+ // Configurable http.Handler which is called when no matching route is
+ // found. If it is not set, http.NotFound is used.
+ NotFound http.Handler
+
+ // Function to handle panics recovered from http handlers.
+ // It should be used to generate a error page and return the http error code
+ // 500 (Internal Server Error).
+ // The handler can be used to keep your server from crashing because of
+ // unrecovered panics.
+ PanicHandler func(http.ResponseWriter, *http.Request, interface{})
+
+ // TODO: (clu) maybe support OPTIONS
+}
+
+type paramsKey string
+
+// urlParamsKey is the request context key under which URL params are stored.
+const urlParamsKey = paramsKey("urlParamsKey")
+
+// ParamsFromContext pulls the URL parameters from a request context,
+// or returns nil if none are present.
+func ParamsFromContext(ctx context.Context) []Param {
+ p, _ := ctx.Value(urlParamsKey).([]Param)
+ return p
+}
+
+// Handle registers a http.Handler for given method and path.
+func (r *Router) Handle(method, path string, handler http.Handler) error {
+ if r.tries == nil {
+ r.tries = make(map[string]*Trie)
+ }
+
+ trie, ok := r.tries[method]
+ if !ok {
+ trie = NewTrie()
+ r.tries[method] = trie
+ }
+ return trie.Set(path, handler)
+}
+
+// ServeHTTP dispatches the request to a register handler to handle.
+func (r *Router) ServeHTTP(w http.ResponseWriter, req *http.Request) {
+ if r.PanicHandler != nil {
+ defer func(w http.ResponseWriter, req *http.Request) {
+ if recovered := recover(); recovered != nil {
+ r.PanicHandler(w, req, recovered)
+ }
+ }(w, req)
+ }
+
+ reqPath := req.URL.Path
+ if trie, ok := r.tries[req.Method]; ok {
+ if handler, params, err := trie.Get(reqPath); err == nil {
+ ctx := context.WithValue(req.Context(), urlParamsKey, params)
+ req = req.WithContext(ctx)
+ handler.ServeHTTP(w, req)
+ return
+ }
+ }
+
+ if r.HandleMethodNotAllowed {
+ if allowed := r.allowed(reqPath, req.Method); allowed != "" {
+ w.Header().Set("Allow", allowed)
+ if r.MethodNotAllowed != nil {
+ r.MethodNotAllowed.ServeHTTP(w, req)
+ } else {
+ http.Error(w,
+ http.StatusText(http.StatusMethodNotAllowed),
+ http.StatusMethodNotAllowed,
+ )
+ }
+ return
+ }
+ }
+
+ if r.NotFound != nil {
+ r.NotFound.ServeHTTP(w, req)
+ } else {
+ http.NotFound(w, req)
+ }
+}
+
+func (r *Router) allowed(path, reqMethod string) string {
+ var allow []string
+
+ for method, trie := range r.tries {
+ if method == reqMethod || method == http.MethodOptions {
+ continue
+ }
+
+ if _, _, err := trie.Get(path); err == nil {
+ allow = append(allow, method)
+ }
+ }
+ sort.Slice(allow, func(i, j int) bool {
+ return allow[i] < allow[j]
+ })
+
+ return strings.Join(allow, ", ")
+}
diff --git a/runtime/router/router_test.go b/runtime/router/router_test.go
new file mode 100644
index 000000000..121a96dda
--- /dev/null
+++ b/runtime/router/router_test.go
@@ -0,0 +1,159 @@
+// Copyright (c) 2019 Uber Technologies, Inc.
+//
+// Permission is hereby granted, free of charge, to any person obtaining a copy
+// of this software and associated documentation files (the "Software"), to deal
+// in the Software without restriction, including without limitation the rights
+// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+// copies of the Software, and to permit persons to whom the Software is
+// furnished to do so, subject to the following conditions:
+//
+// The above copyright notice and this permission notice shall be included in
+// all copies or substantial portions of the Software.
+//
+// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+// THE SOFTWARE.
+
+package router
+
+import (
+ "net/http"
+ "net/http/httptest"
+ "testing"
+
+ "github.com/stretchr/testify/assert"
+)
+
+func TestHandle(t *testing.T) {
+ r := &Router{}
+
+ handled := false
+ err := r.Handle("GET", "/*",
+ http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
+ handled = true
+ }))
+ assert.NoError(t, err, "unexpected error")
+
+ req, _ := http.NewRequest("GET", "/foo", nil)
+ r.ServeHTTP(nil, req)
+ assert.True(t, handled)
+}
+
+func TestParamsFromContext(t *testing.T) {
+ r := &Router{}
+
+ handled := false
+ err := r.Handle("GET", "/:var",
+ http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
+ params := ParamsFromContext(req.Context())
+ assert.Equal(t, 1, len(params))
+ assert.Equal(t, "var", params[0].Key)
+ assert.Equal(t, "foo", params[0].Value)
+ handled = true
+ }))
+ assert.NoError(t, err, "unexpected error")
+
+ req, _ := http.NewRequest("GET", "/foo", nil)
+ r.ServeHTTP(nil, req)
+ assert.True(t, handled)
+}
+
+func TestPanicHandler(t *testing.T) {
+ handled := false
+ r := &Router{
+ PanicHandler: func(writer http.ResponseWriter, req *http.Request, i interface{}) {
+ handled = true
+ },
+ }
+
+ err := r.Handle("GET", "/foo",
+ http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
+ panic("something went wrong")
+ }))
+ assert.NoError(t, err, "unexpected error")
+
+ req, _ := http.NewRequest("GET", "/foo", nil)
+ r.ServeHTTP(nil, req)
+ assert.True(t, handled)
+}
+
+func TestMethodNotAllowedDefault(t *testing.T) {
+ r := &Router{HandleMethodNotAllowed: true}
+
+ handled := false
+ err := r.Handle("GET", "/foo",
+ http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
+ handled = true
+ }))
+ assert.NoError(t, err, "unexpected error")
+ err = r.Handle("PUT", "/bar",
+ http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
+ handled = true
+ }))
+ assert.NoError(t, err, "unexpected error")
+
+ req, _ := http.NewRequest("PUT", "/foo", nil)
+ res := httptest.NewRecorder()
+ r.ServeHTTP(res, req)
+ assert.False(t, handled)
+ assert.Equal(t, http.StatusMethodNotAllowed, res.Result().StatusCode)
+ assert.Equal(t, "GET", res.Result().Header.Get("Allow"))
+}
+
+func TestMethodNotAllowedCustom(t *testing.T) {
+ r := &Router{
+ HandleMethodNotAllowed: true,
+ MethodNotAllowed: http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
+ w.Header().Set("life", "42")
+ w.WriteHeader(http.StatusMethodNotAllowed)
+ }),
+ }
+
+ handled := false
+ err := r.Handle("GET", "/foo",
+ http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
+ handled = true
+ }))
+ assert.NoError(t, err, "unexpected error")
+ err = r.Handle("POST", "/foo",
+ http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
+ handled = true
+ }))
+ assert.NoError(t, err, "unexpected error")
+
+ req, _ := http.NewRequest("PUT", "/foo", nil)
+ res := httptest.NewRecorder()
+ r.ServeHTTP(res, req)
+ assert.False(t, handled)
+ assert.Equal(t, "42", res.Result().Header.Get("life"))
+ assert.Equal(t, "GET, POST", res.Result().Header.Get("Allow"))
+}
+
+func TestNotFoundDefault(t *testing.T) {
+ r := &Router{}
+
+ req, _ := http.NewRequest("GET", "/foo", nil)
+ res := httptest.NewRecorder()
+ r.ServeHTTP(res, req)
+ assert.Equal(t, http.StatusNotFound, res.Result().StatusCode)
+}
+
+func TestNotFoundCustom(t *testing.T) {
+ handled := false
+ r := &Router{
+ NotFound: http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
+ w.WriteHeader(http.StatusNotFound)
+ handled = true
+ }),
+ }
+
+ req, _ := http.NewRequest("GET", "/foo", nil)
+ res := httptest.NewRecorder()
+ r.ServeHTTP(res, req)
+ assert.True(t, handled)
+ assert.Equal(t, http.StatusNotFound, res.Result().StatusCode)
+}
diff --git a/runtime/router/trie.go b/runtime/router/trie.go
new file mode 100644
index 000000000..4ecd22ede
--- /dev/null
+++ b/runtime/router/trie.go
@@ -0,0 +1,310 @@
+// Copyright (c) 2019 Uber Technologies, Inc.
+//
+// Permission is hereby granted, free of charge, to any person obtaining a copy
+// of this software and associated documentation files (the "Software"), to deal
+// in the Software without restriction, including without limitation the rights
+// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+// copies of the Software, and to permit persons to whom the Software is
+// furnished to do so, subject to the following conditions:
+//
+// The above copyright notice and this permission notice shall be included in
+// all copies or substantial portions of the Software.
+//
+// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+// THE SOFTWARE.
+
+package router
+
+import (
+ "errors"
+ "fmt"
+ "net/http"
+ "strings"
+)
+
+var (
+ errPath = errors.New("bad path")
+ errExist = errors.New("path value already set")
+ errNotFound = errors.New("not found")
+)
+
+type paramMismatch struct {
+ expected, actual string
+ existingPath string
+}
+
+// Error returns the error string
+func (e *paramMismatch) Error() string {
+ return fmt.Sprintf("param key mismatch: expected is %s but got %s", e.expected, e.actual)
+}
+
+// Param is a url parameter where key is the url segment pattern (without :) and
+// value is the actual segment of a matched url.
+// e.g. url /foo/123 matches /foo/:id, the url param has key "id" and value "123"
+type Param struct {
+ Key, Value string
+}
+
+// Trie is a radix trie to store string value at given url path,
+// a trie node corresponds to an arbitrary path substring.
+type Trie struct {
+ root *tnode
+}
+
+type tnode struct {
+ key string
+ value http.Handler
+ children []*tnode
+}
+
+// NewTrie creates a new trie.
+func NewTrie() *Trie {
+ return &Trie{
+ root: &tnode{
+ key: "",
+ },
+ }
+}
+
+// Set sets the value for given path, returns error if path already set.
+// When a http.Handler is registered for a given path, a subsequent Get returns the registered
+// handler if the url passed to Get call matches the set path. Match in this context could mean either
+// equality (e.g. url is "/foo" and path is "/foo") or url matches path pattern, which has two forms:
+// - path ends with "/*", e.g. url "/foo" and "/foo/bar" both matches path "/*"
+// - path contains colon wildcard ("/:"), e.g. url "/a/b" and "/a/c" bot matches path "/a/:var"
+func (t *Trie) Set(path string, value http.Handler) error {
+ if path == "" || strings.Contains(path, "//") {
+ return errPath
+ }
+ if !strings.HasPrefix(path, "/") {
+ path = "/" + path
+ }
+ // ignore trailing slash
+ path = strings.TrimSuffix(path, "/")
+
+ // validate "*"
+ if strings.Contains(path, "*") && !strings.HasSuffix(path, "/*") {
+ return errors.New("/* must be the last path segment")
+ }
+ if strings.Count(path, "*") > 1 {
+ return errors.New("path can not contain more than one *")
+ }
+
+ err := t.root.set(path, value, false, false)
+ if e, ok := err.(*paramMismatch); ok {
+ return fmt.Errorf("path %q has a different param key %q, it should be the same key %q as in existing path %q", path, e.actual, e.expected, e.existingPath)
+ }
+ return err
+}
+
+// Get returns the http.Handler for given path, returns error if not found.
+// It also returns the url params if given path contains any, e.g. if a handler is registered for
+// "/:foo/bar", then calling Get with path "/xyz/bar" returns a param whose key is "foo" and value is "xyz".
+func (t *Trie) Get(path string) (http.Handler, []Param, error) {
+ if path == "" || strings.Contains(path, "//") {
+ return nil, nil, errPath
+ }
+ if !strings.HasPrefix(path, "/") {
+ path = "/" + path
+ }
+ // ignore trailing slash
+ path = strings.TrimSuffix(path, "/")
+ return t.root.get(path, false, false, false)
+}
+
+// set sets the handler for given path, creates new child node if necessary
+// lastKeyCharSlash tracks whether the previous key char is a '/', used to decide it is a pattern or not
+// when the current key char is ':'. lastPathCharSlash tracks whether the previous path char is a '/',
+// used to decide it is a pattern or not when the current path char is ':'.
+func (t *tnode) set(path string, value http.Handler, lastKeyCharSlash, lastPathCharSlash bool) error {
+ // find the longest common prefix
+ var shorterLength, i int
+ keyLength, pathLength := len(t.key), len(path)
+ if keyLength > pathLength {
+ shorterLength = pathLength
+ } else {
+ shorterLength = keyLength
+ }
+ for i < shorterLength && t.key[i] == path[i] {
+ i++
+ }
+
+ // Find the first character that differs between "path" and this node's key, if it exists.
+ // If we encounter a colon wildcard, ensure that the wildcard in path matches the wildcard
+ // in this node's key for that segment. The segment is a colon wildcard only when the colon
+ // is immediately after slash, e.g. "/:foo", "/x/:y". "/a:b" is not a colon wildcard segment.
+ var keyMatchIdx, pathMatchIdx int
+ for keyMatchIdx < keyLength && pathMatchIdx < pathLength {
+ if (t.key[keyMatchIdx] == ':' && lastKeyCharSlash) ||
+ (path[pathMatchIdx] == ':' && lastPathCharSlash) {
+ keyStartIdx, pathStartIdx := keyMatchIdx, pathMatchIdx
+ same := t.key[keyMatchIdx] == path[pathMatchIdx]
+ for keyMatchIdx < keyLength && t.key[keyMatchIdx] != '/' {
+ keyMatchIdx++
+ }
+ for pathMatchIdx < pathLength && path[pathMatchIdx] != '/' {
+ pathMatchIdx++
+ }
+ if same && (keyMatchIdx-keyStartIdx) != (pathMatchIdx-pathStartIdx) {
+ return ¶mMismatch{
+ t.key[keyStartIdx:keyMatchIdx],
+ path[pathStartIdx:pathMatchIdx],
+ t.key,
+ }
+ }
+ } else if t.key[keyMatchIdx] == path[pathMatchIdx] {
+ keyMatchIdx++
+ pathMatchIdx++
+ } else {
+ break
+ }
+ lastKeyCharSlash = t.key[keyMatchIdx-1] == '/'
+ lastPathCharSlash = path[pathMatchIdx-1] == '/'
+ }
+
+ // If the node key is fully matched, we match the rest path with children nodes to see if a value
+ // already exists for the path.
+ if keyMatchIdx == keyLength {
+ for _, c := range t.children {
+ if _, _, err := c.get(path[pathMatchIdx:], lastKeyCharSlash, lastPathCharSlash, true); err == nil {
+ return errExist
+ }
+ }
+ }
+
+ // node key is longer than longest common prefix
+ if i < keyLength {
+ // key/path suffix being "*" means a conflict
+ if path[i:] == "*" || t.key[i:] == "*" {
+ return errExist
+ }
+
+ // split the node key, add new node with node key minus longest common prefix
+ split := &tnode{
+ key: t.key[i:],
+ value: t.value,
+ children: t.children,
+ }
+ t.key = t.key[:i]
+ t.value = nil
+ t.children = []*tnode{split}
+
+ // path is equal to longest common prefix
+ // set value on current node after split
+ if i == pathLength {
+ t.value = value
+ } else {
+ // path is longer than longest common prefix
+ // add new node with path minus longest common prefix
+ newNode := &tnode{
+ key: path[i:],
+ value: value,
+ }
+ t.children = append(t.children, newNode)
+ }
+ }
+
+ // node key is equal to longest common prefix
+ if i == keyLength {
+ // path is equal to longest common prefix
+ if i == pathLength {
+ // node is guaranteed to have zero value,
+ // otherwise it would have caused errExist earlier
+ t.value = value
+ } else {
+ // path is longer than node key, try to recurse on node children
+ for _, c := range t.children {
+ if c.key[0] == path[i] {
+ lastKeyCharSlash = i > 0 && t.key[i-1] == '/'
+ lastPathCharSlash = i > 0 && path[i-1] == '/'
+ err := c.set(path[i:], value, lastKeyCharSlash, lastPathCharSlash)
+ if e, ok := err.(*paramMismatch); ok {
+ e.existingPath = t.key + e.existingPath
+ return e
+ }
+ return err
+ }
+ }
+ // no children to recurse, add node with path minus longest common path
+ newNode := &tnode{
+ key: path[i:],
+ value: value,
+ }
+ t.children = append(t.children, newNode)
+ }
+ }
+
+ return nil
+}
+
+func (t *tnode) get(path string, lastKeyCharSlash, lastPathCharSlash, colonAsPattern bool) (http.Handler, []Param, error) {
+ keyLength, pathLength := len(t.key), len(path)
+ var params []Param
+
+ // find the longest matched prefix
+ var keyIdx, pathIdx int
+ for keyIdx < keyLength && pathIdx < pathLength {
+ if t.key[keyIdx] == ':' && lastKeyCharSlash {
+ // wildcard starts - match until next slash
+ keyStartIdx, pathStartIdx := keyIdx+1, pathIdx
+ for keyIdx < keyLength && t.key[keyIdx] != '/' {
+ keyIdx++
+ }
+ for pathIdx < pathLength && path[pathIdx] != '/' {
+ pathIdx++
+ }
+ params = append(params, Param{t.key[keyStartIdx:keyIdx], path[pathStartIdx:pathIdx]})
+ } else if path[pathIdx] == ':' && lastPathCharSlash && colonAsPattern {
+ // necessary for conflict check used in set call
+ for keyIdx < keyLength && t.key[keyIdx] != '/' {
+ keyIdx++
+ }
+ for pathIdx < pathLength && path[pathIdx] != '/' {
+ pathIdx++
+ }
+ } else if t.key[keyIdx] == path[pathIdx] {
+ keyIdx++
+ pathIdx++
+ } else {
+ break
+ }
+ lastKeyCharSlash = t.key[keyIdx-1] == '/'
+ lastPathCharSlash = path[pathIdx-1] == '/'
+ }
+
+ if keyIdx < keyLength {
+ // path matches up to node key's second to last character,
+ // the last char of node key is "*" and path is no shorter than longest matched prefix
+ if t.key[keyIdx:] == "*" && pathIdx < pathLength {
+ return t.value, params, nil
+ }
+ return nil, nil, errNotFound
+ }
+
+ // ':' in path matches '*' in node key
+ if keyIdx > 0 && t.key[keyIdx-1] == '*' {
+ return t.value, params, nil
+ }
+
+ // longest matched prefix matches up to node key length and path length
+ if pathIdx == pathLength {
+ if t.value != nil {
+ return t.value, params, nil
+ }
+ return nil, nil, errNotFound
+ }
+
+ // longest matched prefix matches up to node key length but not path length
+ for _, c := range t.children {
+ if v, ps, err := c.get(path[pathIdx:], lastKeyCharSlash, lastPathCharSlash, colonAsPattern); err == nil {
+ return v, append(params, ps...), nil
+ }
+ }
+
+ return nil, nil, errNotFound
+}
diff --git a/runtime/router/trie_test.go b/runtime/router/trie_test.go
new file mode 100644
index 000000000..dee131aeb
--- /dev/null
+++ b/runtime/router/trie_test.go
@@ -0,0 +1,302 @@
+// Copyright (c) 2019 Uber Technologies, Inc.
+//
+// Permission is hereby granted, free of charge, to any person obtaining a copy
+// of this software and associated documentation files (the "Software"), to deal
+// in the Software without restriction, including without limitation the rights
+// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+// copies of the Software, and to permit persons to whom the Software is
+// furnished to do so, subject to the following conditions:
+//
+// The above copyright notice and this permission notice shall be included in
+// all copies or substantial portions of the Software.
+//
+// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+// THE SOFTWARE.
+
+package router
+
+import (
+ "bytes"
+ "fmt"
+ "io"
+ "net/http"
+ "testing"
+
+ "github.com/stretchr/testify/assert"
+)
+
+const (
+ get = "get"
+ set = "set"
+)
+
+type ts struct {
+ op string
+ path string
+ value string
+ errMsg string
+ expectedValue string
+ expectedParams []Param
+}
+
+type namedHandler struct {
+ id string
+}
+
+func (n namedHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {}
+
+func runTrieTests(t *testing.T, trie *Trie, tests []ts) {
+ for _, test := range tests {
+ if test.op == set {
+ err := trie.Set(test.path, namedHandler{id: test.value})
+ if test.errMsg == "" {
+ assert.NoError(t, err)
+ } else {
+ assert.EqualError(t, err, test.errMsg)
+ }
+ }
+ if test.op == get {
+ v, ps, err := trie.Get(test.path)
+ if test.errMsg == "" {
+ assert.NoError(t, err, test.path)
+ assert.Equal(t, test.expectedValue, v.(namedHandler).id)
+ assert.Equal(t, test.expectedParams, ps)
+ } else {
+ assert.EqualError(t, err, test.errMsg)
+ }
+ }
+ }
+ //printTrie(trie)
+}
+
+func TestTrieLiteralPath(t *testing.T) {
+ tree := NewTrie()
+ tests := []ts{
+ // test blank path
+ {op: set, path: "", value: "foo", errMsg: errPath.Error()},
+ {op: get, path: "", errMsg: errPath.Error()},
+ // test root path
+ {op: set, path: "/", value: "foo"},
+ {op: get, path: "/", expectedValue: "foo"},
+ // test set
+ {op: set, path: "/a/b", value: "bar"},
+ {op: set, path: "/a/b/c", value: "bar"},
+ // test set conflict
+ {op: set, path: "/a/b/c", value: "baz", errMsg: errExist.Error()},
+ // test trailing slash when set
+ {op: set, path: "/a/b/c/", value: "baz", errMsg: errExist.Error()},
+ // test not found
+ {op: get, path: "/a", errMsg: errNotFound.Error()},
+ {op: get, path: "/a/b/d", errMsg: errNotFound.Error()},
+ // test get
+ {op: get, path: "/a/b", expectedValue: "bar"},
+ {op: get, path: "/a/b/c", expectedValue: "bar"},
+ // test trailing slash when get
+ {op: get, path: "/a/b/c/", expectedValue: "bar"},
+ // test missing starting slash
+ {op: set, path: "a", value: "foo"},
+ {op: get, path: "a", expectedValue: "foo"},
+ // test branching
+ {op: set, path: "/a/e/f", value: "quxx"},
+ {op: get, path: "/a/e/f", expectedValue: "quxx"},
+ // test segment overlap
+ {op: set, path: "/a/good", value: "good"},
+ {op: set, path: "/a/goto", value: "goto"},
+ {op: get, path: "/a/good", expectedValue: "good"},
+ {op: get, path: "/a/goto", expectedValue: "goto"},
+ }
+
+ runTrieTests(t, tree, tests)
+}
+
+func TestTriePathsWithPatten(t *testing.T) {
+ trie := NewTrie()
+ tests := []ts{
+ // test setting "/*/a" is not allowed
+ {op: set, path: "/*/a", value: "foo", errMsg: "/* must be the last path segment"},
+ // test setting path with multiple "*" is not allowed
+ {op: set, path: "/*/*", value: "foo", errMsg: "path can not contain more than one *"},
+ // test "/a" does not collide with "/a/*"
+ {op: set, path: "/a", value: "foo"},
+ // test "/a/*" match all paths starts with "/a/"
+ {op: set, path: "/a/*", value: "bar"},
+ {op: get, path: "a", expectedValue: "foo"},
+ {op: get, path: "/a/b", expectedValue: "bar"},
+ {op: get, path: "/a/b/c/d", expectedValue: "bar"},
+ {op: get, path: "/a/b/c/d", expectedValue: "bar"},
+ // test paths starts with "/a/" collides with "/a/*"
+ {op: set, path: "/a/b/", value: "baz", errMsg: errExist.Error()},
+ {op: set, path: "/a/b/c", value: "baz", errMsg: errExist.Error()},
+ {op: set, path: "/a/:b", value: "baz", errMsg: errExist.Error()},
+ // test "/*" collides with "/a"
+ {op: set, path: "/*", value: "baz", errMsg: errExist.Error()},
+ // test "/:" collides with "/a"
+ {op: set, path: "/:x", value: "baz", errMsg: errExist.Error()},
+ // test "/:/b" collides with "/a/*"
+ {op: set, path: "/:x/b", value: "baz", errMsg: errExist.Error()},
+ }
+ runTrieTests(t, trie, tests)
+
+ trie = NewTrie()
+ tests = []ts{
+ // test ":a" is not treated as a pattern when queried as a url
+ {op: set, path: "/a", value: "foo"},
+ {op: get, path: "/:a", errMsg: errNotFound.Error()},
+
+ {op: set, path: "/a:b", value: "bar"},
+ {op: set, path: "/a:c", value: "baz"},
+ {op: get, path: "/a:b", expectedValue: "bar"},
+ {op: get, path: "/ac", errMsg: errNotFound.Error()},
+ {op: get, path: "/a:", errMsg: errNotFound.Error()},
+ }
+ runTrieTests(t, trie, tests)
+
+ trie = NewTrie()
+ tests = []ts{
+ {op: set, path: "/:a", value: "foo"},
+ {op: get, path: "/:a", expectedValue: "foo", expectedParams: []Param{{"a", ":a"}}},
+ }
+ runTrieTests(t, trie, tests)
+
+ trie = NewTrie()
+ tests = []ts{
+ // test "/a" does not collide with "/:a/b"
+ {op: set, path: "/:a/b", value: "foo"},
+ {op: set, path: "/a", value: "bar"},
+ {op: get, path: "/a", expectedValue: "bar"},
+ {op: get, path: "/x/b/", expectedValue: "foo", expectedParams: []Param{{"a", "x"}}},
+ {op: get, path: "/a/", expectedValue: "bar"},
+ }
+ runTrieTests(t, trie, tests)
+
+ trie = NewTrie()
+ tests = []ts{
+ // test "/:a/b" does not collide with "/a"
+ {op: set, path: "/a", value: "bar"},
+ {op: set, path: "/:a/b", value: "foo"},
+ {op: get, path: "/a", expectedValue: "bar"},
+ {op: get, path: "/x/b/", expectedValue: "foo", expectedParams: []Param{{"a", "x"}}},
+ {op: get, path: "/a/", expectedValue: "bar"},
+ }
+ runTrieTests(t, trie, tests)
+
+ trie = NewTrie()
+ tests = []ts{
+ // test "/b" collides with "/:"
+ {op: set, path: "/:a", value: "foo"},
+ {op: set, path: "/b", errMsg: errExist.Error()},
+ {op: get, path: "/a/", expectedValue: "foo", expectedParams: []Param{{"a", "a"}}},
+ }
+ runTrieTests(t, trie, tests)
+
+ trie = NewTrie()
+ tests = []ts{
+ // test "/:" collides with "/b"
+ {op: set, path: "/b", value: "foo"},
+ {op: set, path: "/:a", errMsg: errExist.Error()},
+ {op: get, path: "/b/", expectedValue: "foo"},
+ }
+ runTrieTests(t, trie, tests)
+
+ trie = NewTrie()
+ tests = []ts{
+ // more ":" tests
+ {op: set, path: "/a/b", value: "1"},
+ {op: set, path: "/a/b/:cc/d", value: "2"},
+ {op: set, path: "/a/b/:x/e", errMsg: "path \"/a/b/:x/e\" has a different param key \":x\", it should be the same key \":cc\" as in existing path \"/a/b/:cc/d\""},
+ {op: set, path: "/a/b/c/x", value: "2.1"},
+ {op: set, path: "/a/b/:cc/:d/e", value: "3"},
+ {op: set, path: "/a/b/c/d/f", value: "4"},
+ {op: set, path: "/a/:b/c/d", errMsg: errExist.Error()},
+ {op: get, path: "/a/b/some/d", expectedValue: "2", expectedParams: []Param{{"cc", "some"}}},
+ {op: get, path: "/a/b/c/x", expectedValue: "2.1"},
+ {op: get, path: "/a/b/other/data/e", expectedValue: "3",
+ expectedParams: []Param{
+ {"cc", "other"},
+ {"d", "data"},
+ }},
+ {op: get, path: "/a/b/c/d/f", expectedValue: "4"},
+ }
+ runTrieTests(t, trie, tests)
+
+ trie = NewTrie()
+ tests = []ts{
+ // more ":" tests
+ {op: set, path: "/a/b", value: "1"},
+ {op: set, path: "/a/b/ccc/x", value: "2"},
+ {op: set, path: "/a/b/c/dope/f", value: "3"},
+ {op: set, path: "/a/b/ccc/:", errMsg: errExist.Error()},
+ {op: set, path: "/a/b/c/:/:/", errMsg: errExist.Error()},
+ {op: get, path: "/a/b/ccc", errMsg: errNotFound.Error()},
+ {op: get, path: "/a/b/:", errMsg: errNotFound.Error()},
+ }
+ runTrieTests(t, trie, tests)
+
+ trie = NewTrie()
+ tests = []ts{
+ // more ":" tests
+ {op: set, path: "/a/:b/c", value: "1"},
+ {op: set, path: "/a/:b/d", value: "2"},
+ {op: get, path: "/a/b/c", expectedValue: "1", expectedParams: []Param{{"b", "b"}}},
+ {op: get, path: "/a/b/d", expectedValue: "2", expectedParams: []Param{{"b", "b"}}},
+ }
+ runTrieTests(t, trie, tests)
+}
+
+// simple test for coverage
+func TestParamMismatch(t *testing.T) {
+ pm := paramMismatch{
+ expected: "foo",
+ actual: "bar",
+ }
+ assert.Equal(t, "param key mismatch: expected is foo but got bar", pm.Error())
+}
+
+// utilities for debugging
+func printTrie(t *Trie) {
+ buf := new(bytes.Buffer)
+ var levelsEnded []int
+ printNodes(buf, t.root.children, 0, levelsEnded)
+ fmt.Println(string(buf.Bytes()))
+}
+
+func printNodes(w io.Writer, nodes []*tnode, level int, levelsEnded []int) {
+ for i, node := range nodes {
+ edge := "├──"
+ if i == len(nodes)-1 {
+ levelsEnded = append(levelsEnded, level)
+ edge = "└──"
+ }
+ printNode(w, node, level, levelsEnded, edge)
+ if len(node.children) > 0 {
+ printNodes(w, node.children, level+1, levelsEnded)
+ }
+ }
+}
+
+func printNode(w io.Writer, node *tnode, level int, levelsEnded []int, edge string) {
+ for i := 0; i < level; i++ {
+ isEnded := false
+ for _, l := range levelsEnded {
+ if l == i {
+ isEnded = true
+ break
+ }
+ }
+ if isEnded {
+ _, _ = fmt.Fprint(w, " ")
+ } else {
+ _, _ = fmt.Fprint(w, "│ ")
+ }
+ }
+ if node.value != nil {
+ _, _ = fmt.Fprintf(w, "%s %v (%v)\n", edge, node.key, node.value)
+ } else {
+ _, _ = fmt.Fprintf(w, "%s %v\n", edge, node.key)
+ }
+}
diff --git a/runtime/router_test.go b/runtime/router_test.go
index 122dbec06..b3c41e978 100644
--- a/runtime/router_test.go
+++ b/runtime/router_test.go
@@ -81,8 +81,8 @@ func (s *routerSuite) TestRouter() {
}{
{"GET", "/notfound", nil, http.StatusNotFound, []byte("404 page not found\n")},
{"GET", "/noslash", nil, http.StatusOK, []byte("noslash\n")},
- {"GET", "/noslash/", nil, http.StatusMovedPermanently, []byte(`Moved Permanently.` + "\n\n")},
- {"GET", "/withslash", nil, http.StatusMovedPermanently, []byte(`Moved Permanently.` + "\n\n")},
+ {"GET", "/noslash/", nil, http.StatusOK, []byte("noslash\n")},
+ {"GET", "/withslash", nil, http.StatusOK, []byte("withslash\n")},
{"GET", "/withslash/", nil, http.StatusOK, []byte("withslash\n")},
{"GET", "/postonly", nil, http.StatusMethodNotAllowed, []byte("Method Not Allowed\n")},
{"GET", "/panicerror", nil, http.StatusInternalServerError, []byte("Internal Server Error\n")},
diff --git a/test/lib/test_backend/test_http_backend.go b/test/lib/test_backend/test_http_backend.go
index 96144af1d..05e0007ef 100644
--- a/test/lib/test_backend/test_http_backend.go
+++ b/test/lib/test_backend/test_http_backend.go
@@ -25,8 +25,8 @@ import (
"strconv"
"sync"
- "github.com/julienschmidt/httprouter"
"github.com/uber/zanzibar/runtime"
+ zrouter "github.com/uber/zanzibar/runtime/router"
"go.uber.org/zap"
)
@@ -38,7 +38,7 @@ type TestHTTPBackend struct {
RealPort int32
RealAddr string
WaitGroup *sync.WaitGroup
- router *httprouter.Router
+ router *zrouter.Router
}
// BuildHTTPBackends returns a map of backends based on config
@@ -83,7 +83,7 @@ func (backend *TestHTTPBackend) Bootstrap() error {
func (backend *TestHTTPBackend) HandleFunc(
method string, path string, handler http.HandlerFunc,
) {
- backend.router.HandlerFunc(method, path, handler)
+ _ = backend.router.Handle(method, path, handler)
}
// Close ...
@@ -103,7 +103,7 @@ func CreateHTTPBackend(port int32) *TestHTTPBackend {
IP: "127.0.0.1",
Port: port,
WaitGroup: &sync.WaitGroup{},
- router: &httprouter.Router{
+ router: &zrouter.Router{
HandleMethodNotAllowed: true,
},
}