Skip to content

Commit

Permalink
feat: measure external latency (#779)
Browse files Browse the repository at this point in the history
  • Loading branch information
zepatrik authored Apr 19, 2024
1 parent f29ab1d commit a7a3c8a
Show file tree
Hide file tree
Showing 9 changed files with 252 additions and 43 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/format.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ jobs:
- uses: actions/checkout@v3
- uses: actions/setup-go@v3
with:
go-version: "1.21"
go-version: "1.22"
- run: make format
- name: Indicate formatting issues
run: git diff HEAD --exit-code --color
2 changes: 1 addition & 1 deletion .github/workflows/licenses.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ jobs:
- uses: actions/checkout@v2
- uses: actions/setup-go@v2
with:
go-version: "1.21"
go-version: "1.22"
- uses: actions/setup-node@v2
with:
node-version: "18"
Expand Down
4 changes: 2 additions & 2 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ jobs:
- uses: actions/checkout@v2
- uses: actions/setup-go@v2
with:
go-version: "1.21"
go-version: "1.22"
- run: |
go test -tags sqlite -failfast -short -timeout=20m $(go list ./... | grep -v sqlcon | grep -v watcherx | grep -v pkgerx | grep -v configx)
shell: bash
Expand Down Expand Up @@ -55,7 +55,7 @@ jobs:
uses: actions/checkout@v2
- uses: actions/setup-go@v2
with:
go-version: "1.21"
go-version: "1.22"
- name: golangci-lint
uses: golangci/golangci-lint-action@v3
with:
Expand Down
4 changes: 3 additions & 1 deletion go.mod
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
module github.com/ory/x

go 1.21
go 1.22

toolchain go1.22.2

require (
code.dny.dev/ssrf v0.2.0
Expand Down
29 changes: 29 additions & 0 deletions httpx/external_latency.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
// Copyright © 2024 Ory Corp
// SPDX-License-Identifier: Apache-2.0

package httpx

import (
"net/http"
"time"

"github.com/ory/x/reqlog"
)

// MeasureExternalLatencyTransport is an http.RoundTripper that measures the latency of all requests as external latency.
type MeasureExternalLatencyTransport struct {
Transport http.RoundTripper
}

var _ http.RoundTripper = (*MeasureExternalLatencyTransport)(nil)

func (m *MeasureExternalLatencyTransport) RoundTrip(req *http.Request) (*http.Response, error) {
upstreamHostPath := req.URL.Scheme + "://" + req.URL.Host + req.URL.Path
defer reqlog.StartMeasureExternalCall(req.Context(), "http_request", upstreamHostPath, time.Now())

t := m.Transport
if t == nil {
t = http.DefaultTransport
}
return t.RoundTrip(req)
}
94 changes: 57 additions & 37 deletions proxy/proxy_full_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -86,20 +86,20 @@ func (rt *testingRoundTripper) RoundTrip(req *http.Request) (*http.Response, err
}

func TestFullIntegration(t *testing.T) {
upstream, upstreamHandler := httpx.NewChanHandler(0)
upstream, upstreamHandler := httpx.NewChanHandler(1)
upstreamServer := httptest.NewTLSServer(upstream)
defer upstreamServer.Close()

// create the proxy
hostMapper := make(chan func(*http.Request) (*HostConfig, error))
reqMiddleware := make(chan ReqMiddleware)
respMiddleware := make(chan RespMiddleware)
hostMapper := make(chan func(*http.Request) (*HostConfig, error), 1)
reqMiddleware := make(chan ReqMiddleware, 1)
respMiddleware := make(chan RespMiddleware, 1)

type CustomErrorReq func(*http.Request, error)
type CustomErrorResp func(*http.Response, error) error

onErrorReq := make(chan CustomErrorReq)
onErrorResp := make(chan CustomErrorResp)
onErrorReq := make(chan CustomErrorReq, 1)
onErrorResp := make(chan CustomErrorResp, 1)

proxy := httptest.NewTLSServer(New(
func(ctx context.Context, r *http.Request) (context.Context, *HostConfig, error) {
Expand All @@ -122,17 +122,20 @@ func TestFullIntegration(t *testing.T) {
return f(resp, config, body)
}),
WithOnError(func(request *http.Request, err error) {
f := <-onErrorReq
if f == nil {
return
select {
case f := <-onErrorReq:
f(request, err)
default:
t.Errorf("unexpected error: %+v", err)
}
f(request, err)
}, func(response *http.Response, err error) error {
f := <-onErrorResp
if f == nil {
return nil
select {
case f := <-onErrorResp:
return f(response, err)
default:
t.Errorf("unexpected error: %+v", err)
return err
}
return f(response, err)
})))

cl := proxy.Client()
Expand Down Expand Up @@ -315,8 +318,7 @@ func TestFullIntegration(t *testing.T) {
req.Host = "auth.example.com"
return req
},
assertResponse: func(t *testing.T, r *http.Response) {
},
assertResponse: func(t *testing.T, r *http.Response) {},
respMiddleware: func(resp *http.Response, config *HostConfig, body []byte) ([]byte, error) {
return nil, errors.New("some response middleware error")
},
Expand Down Expand Up @@ -495,37 +497,55 @@ func TestFullIntegration(t *testing.T) {
},
} {
t.Run("case="+tc.desc, func(t *testing.T) {
go func() {
hostMapper <- func(r *http.Request) (*HostConfig, error) {
host := r.Host
hc, err := tc.hostMapper(host)
if err == nil {
hc.UpstreamHost = urlx.ParseOrPanic(upstreamServer.URL).Host
hc.UpstreamScheme = urlx.ParseOrPanic(upstreamServer.URL).Scheme
hc.TargetHost = hc.UpstreamHost
hc.TargetScheme = hc.UpstreamScheme
}
return hc, err
hostMapper <- func(r *http.Request) (*HostConfig, error) {
host := r.Host
hc, err := tc.hostMapper(host)
if err == nil {
hc.UpstreamHost = urlx.ParseOrPanic(upstreamServer.URL).Host
hc.UpstreamScheme = urlx.ParseOrPanic(upstreamServer.URL).Scheme
hc.TargetHost = hc.UpstreamHost
hc.TargetScheme = hc.UpstreamScheme
}
return hc, err
}
if tc.onErrReq != nil {
onErrorReq <- tc.onErrReq
}
if tc.onErrResp != nil {
onErrorResp <- tc.onErrResp
}

if tc.onErrReq == nil {
// we will only send a request if there is no request error
reqMiddleware <- tc.reqMiddleware
respMiddleware <- tc.respMiddleware
upstreamHandler <- func(w http.ResponseWriter, r *http.Request) {
t := &remoteT{t: t, w: w, r: r}
tc.handler(assert.New(t), t, r)
}
respMiddleware <- tc.respMiddleware
}()

go func() {
onErrorReq <- tc.onErrReq
}()

go func() {
onErrorResp <- tc.onErrResp
}()
}

resp, err := cl.Do(tc.request(t))
require.NoError(t, err)
tc.assertResponse(t, resp)

select {
case <-hostMapper:
t.Fatal("host mapper not consumed")
case <-reqMiddleware:
t.Fatal("req middleware not consumed")
case <-respMiddleware:
t.Fatal("resp middleware not consumed")
case <-onErrorReq:
t.Fatal("req error not consumed")
case <-onErrorResp:
t.Fatal("resp error not consumed")
default:
if len(upstreamHandler) != 0 {
t.Fatal("upstream handler not consumed")
}
return
}
})
}
}
Expand Down
79 changes: 79 additions & 0 deletions reqlog/external_latency.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
// Copyright © 2024 Ory Corp
// SPDX-License-Identifier: Apache-2.0

package reqlog

import (
"context"
"sync"
"time"
)

// WithEnableExternalLatencyMeasurement returns a context that measures external latencies.
func WithEnableExternalLatencyMeasurement(ctx context.Context) context.Context {
container := contextContainer{
latencies: make([]externalLatency, 0),
}
return context.WithValue(ctx, externalLatencyKey, &container)
}

// StartMeasureExternalCall starts measuring the duration of an external call.
// The returned function has to be called to record the duration.
func StartMeasureExternalCall(ctx context.Context, cause, detail string, start time.Time) {
container, ok := ctx.Value(externalLatencyKey).(*contextContainer)
if !ok {
return
}
if _, ok := ctx.Value(disableExternalLatencyMeasurement).(bool); ok {
return
}

container.Lock()
defer container.Unlock()
container.latencies = append(container.latencies, externalLatency{
Took: time.Since(start),
Cause: cause,
Detail: detail,
})
}

// totalExternalLatency returns the total duration of all external calls.
func totalExternalLatency(ctx context.Context) (total time.Duration) {
if _, ok := ctx.Value(disableExternalLatencyMeasurement).(bool); ok {
return 0
}
container, ok := ctx.Value(externalLatencyKey).(*contextContainer)
if !ok {
return 0
}

container.Lock()
defer container.Unlock()
for _, l := range container.latencies {
total += l.Took
}
return total
}

// WithDisableExternalLatencyMeasurement returns a context that does not measure external latencies.
// Use this when you want to disable external latency measurements for a specific request.
func WithDisableExternalLatencyMeasurement(ctx context.Context) context.Context {
return context.WithValue(ctx, disableExternalLatencyMeasurement, true)
}

type (
externalLatency = struct {
Took time.Duration
Cause, Detail string
}
contextContainer = struct {
latencies []externalLatency
sync.Mutex
}
contextKey int
)

const (
externalLatencyKey contextKey = 1
disableExternalLatencyMeasurement contextKey = 2
)
71 changes: 71 additions & 0 deletions reqlog/external_latency_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
// Copyright © 2024 Ory Corp
// SPDX-License-Identifier: Apache-2.0

package reqlog

import (
"encoding/json"
"io"
"net/http"
"net/http/httptest"
"sync"
"testing"
"time"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/tidwall/gjson"
"golang.org/x/sync/errgroup"
)

func TestExternalLatencyMiddleware(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
NewMiddleware().ServeHTTP(w, r, func(w http.ResponseWriter, r *http.Request) {
var wg sync.WaitGroup

wg.Add(3)
for i := range 3 {
ctx := r.Context()
if i%3 == 0 {
ctx = WithDisableExternalLatencyMeasurement(ctx)
}
go func() {
defer StartMeasureExternalCall(ctx, "", "", time.Now())
time.Sleep(100 * time.Millisecond)
wg.Done()
}()
}
wg.Wait()
total := totalExternalLatency(r.Context())
_ = json.NewEncoder(w).Encode(map[string]any{
"total": total,
})
})
}))
defer ts.Close()

bodies := make([][]byte, 100)
eg := errgroup.Group{}
for i := range bodies {
eg.Go(func() error {
res, err := http.Get(ts.URL)
if err != nil {
return err
}
defer res.Body.Close()
bodies[i], err = io.ReadAll(res.Body)
if err != nil {
return err
}
return nil
})
}

require.NoError(t, eg.Wait())

for _, body := range bodies {
actualTotal := gjson.GetBytes(body, "total").Int()
assert.GreaterOrEqual(t, actualTotal, int64(200*time.Millisecond), string(body))
assert.Less(t, actualTotal, int64(300*time.Millisecond), string(body))
}
}
10 changes: 9 additions & 1 deletion reqlog/middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,7 @@ func (m *Middleware) ServeHTTP(rw http.ResponseWriter, r *http.Request, next htt
nrw = negroni.NewResponseWriter(rw)
}

r = r.WithContext(WithEnableExternalLatencyMeasurement(r.Context()))
next(nrw, r)

latency := m.clock.Since(start)
Expand All @@ -161,11 +162,18 @@ func DefaultBefore(entry *logrusx.Logger, req *http.Request, remoteAddr string)

// DefaultAfter is the default func assigned to *Middleware.After
func DefaultAfter(entry *logrusx.Logger, req *http.Request, res negroni.ResponseWriter, latency time.Duration, name string) *logrusx.Logger {
return entry.WithRequest(req).WithField("http_response", map[string]interface{}{
e := entry.WithRequest(req).WithField("http_response", map[string]any{
"status": res.Status(),
"size": res.Size(),
"text_status": http.StatusText(res.Status()),
"took": latency,
"headers": entry.HTTPHeadersRedacted(res.Header()),
})
if el := totalExternalLatency(req.Context()); el > 0 {
e = e.WithFields(map[string]any{
"took_internal": latency - el,
"took_external": el,
})
}
return e
}

0 comments on commit a7a3c8a

Please sign in to comment.