From 59fe66e8e2ff751954c38ad9f1a6ad1acde110f5 Mon Sep 17 00:00:00 2001 From: Rohith Date: Wed, 11 Jul 2018 15:26:50 +0100 Subject: [PATCH] - fixing up to use github.com/satori/go.uuid instead of and internal one, lose 20ns but hey :-) --- forwarding.go | 2 +- middleware.go | 8 +++++++- server.go | 1 + utils.go | 50 -------------------------------------------------- utils_test.go | 41 +++-------------------------------------- 5 files changed, 12 insertions(+), 90 deletions(-) diff --git a/forwarding.go b/forwarding.go index a5f14c0ad..51bcfa126 100644 --- a/forwarding.go +++ b/forwarding.go @@ -41,7 +41,7 @@ func (r *oauthProxy) proxyMiddleware(next http.Handler) http.Handler { // @step: add the proxy forwarding headers req.Header.Add("X-Forwarded-For", realIP(req)) - req.Header.Set("X-Forwarded-Host", req.URL.Host) + req.Header.Set("X-Forwarded-Host", req.Host) req.Header.Set("X-Forwarded-Proto", req.Header.Get("X-Forwarded-Proto")) // @step: add any custom headers to the request diff --git a/middleware.go b/middleware.go index ab957ab12..903278970 100644 --- a/middleware.go +++ b/middleware.go @@ -26,6 +26,7 @@ import ( "github.com/PuerkitoBio/purell" "github.com/gambol99/go-oidc/jose" "github.com/go-chi/chi/middleware" + uuid "github.com/satori/go.uuid" "github.com/unrolled/secure" "go.uber.org/zap" "go.uber.org/zap/zapcore" @@ -71,7 +72,12 @@ func (r *oauthProxy) requestIDMiddleware(header string) func(http.Handler) http. return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { if v := req.Header.Get(header); v == "" { - req.Header.Set(header, randomUUID()) + uid, err := uuid.NewV1() + if err != nil { + r.log.Error("failed to generatet correlation id for request", zap.Error(err)) + } else { + req.Header.Set(header, uid.String()) + } } next.ServeHTTP(w, req) diff --git a/server.go b/server.go index 84ff0c322..2c710dc56 100644 --- a/server.go +++ b/server.go @@ -164,6 +164,7 @@ func (r *oauthProxy) createReverseProxy() error { engine.Use(middleware.Recoverer) // @check if the request tracking id middleware is enabled if r.config.EnableRequestID { + r.log.Info("enabled the correlation request id middlware") engine.Use(r.requestIDMiddleware(r.config.RequestIDHeader)) } // @step: enable the entrypoint middleware diff --git a/utils.go b/utils.go index cd1836112..a66e61693 100644 --- a/utils.go +++ b/utils.go @@ -28,7 +28,6 @@ import ( "fmt" "io" "io/ioutil" - mrand "math/rand" "net" "net/http" "net/url" @@ -78,55 +77,6 @@ func getRequestHostURL(r *http.Request) string { return fmt.Sprintf("%s://%s", scheme, hostname) } -const ( - letterBytes = "abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyz0123456789" - letterIdxBits = 6 - letterIdxMask = 1<= 0; { - if remain == 0 { - cache, remain = randomSource.Int63(), letterIdxMax - } - if idx := int(cache & letterIdxMask); idx < len(letterBytes) { - b[i] = letterBytes[idx] - i-- - } - cache >>= letterIdxBits - remain-- - } - - return b -} - -// randomString returns a random string of x length -func randomString(length int) string { - return string(randomBytes(length)) -} - -// randomUUID returns a uuid from the random string -func randomUUID() string { - uuid := make([]byte, 36) - r := randomBytes(32) - i := 0 - for x := range []int{8, 4, 4, 4, 12} { - copy(uuid, r[i:i+x]) - if x != 12 { - copy(uuid, []byte("-")) - i = i + x - } - } - - return string(uuid) -} - // readConfigFile reads and parses the configuration file func readConfigFile(filename string, config *Config) error { content, err := ioutil.ReadFile(filename) diff --git a/utils_test.go b/utils_test.go index f61fb01e8..cfab77740 100644 --- a/utils_test.go +++ b/utils_test.go @@ -26,7 +26,7 @@ import ( "testing" "time" - "github.com/google/uuid" + uuid "github.com/satori/go.uuid" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -65,45 +65,10 @@ func TestDecodeKeyPairs(t *testing.T) { } } -func TestRandom(t *testing.T) { - s := randomBytes(6) - assert.NotEmpty(t, s) - assert.Equal(t, 6, len(s)) -} - -func TestRandomString(t *testing.T) { - s := randomString(6) - assert.NotEmpty(t, s) - assert.Equal(t, 6, len(s)) -} - -func TestRandomUUID(t *testing.T) { - s := randomUUID() - assert.NotEmpty(t, s) - assert.Equal(t, 36, len(s)) -} - -func BenchmarkRandomBytes36(b *testing.B) { - for n := 0; n < b.N; n++ { - randomString(36) - } -} - -func BenchmarkRandomString36(b *testing.B) { - for n := 0; n < b.N; n++ { - randomString(36) - } -} - func BenchmarkUUID(b *testing.B) { for n := 0; n < b.N; n++ { - uuid.New() - } -} - -func BenchmarkRandomUUID(b *testing.B) { - for n := 0; n < b.N; n++ { - randomUUID() + s, _ := uuid.NewV1() + s.String() } }