Skip to content

Commit

Permalink
Refactoring http trace (#296)
Browse files Browse the repository at this point in the history
* refact: improvements http trace

* segment: new tests

* add helper functions for handler

* chore: rename functions/update comments

by code review: #296
  • Loading branch information
nicolascb authored Apr 29, 2021
1 parent a6b81de commit 015dc2e
Show file tree
Hide file tree
Showing 2 changed files with 187 additions and 35 deletions.
90 changes: 55 additions & 35 deletions xray/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -117,49 +117,18 @@ func Handler(sn SegmentNamer, h http.Handler) http.Handler {
}

func httpTrace(seg *Segment, h http.Handler, w http.ResponseWriter, r *http.Request, traceHeader *header.Header) {
seg.Lock()

scheme := "https://"
if r.TLS == nil {
scheme = "http://"
}
seg.GetHTTP().GetRequest().Method = r.Method
seg.GetHTTP().GetRequest().URL = scheme + r.Host + r.URL.Path
seg.GetHTTP().GetRequest().ClientIP, seg.GetHTTP().GetRequest().XForwardedFor = clientIP(r)
seg.GetHTTP().GetRequest().UserAgent = r.UserAgent()

// Don't use the segment's header here as we only want to
// send back the root and possibly sampled values.
var respHeader bytes.Buffer
respHeader.WriteString("Root=")
respHeader.WriteString(seg.TraceID)

if traceHeader.SamplingDecision == header.Requested {
respHeader.WriteString(";Sampled=")
respHeader.WriteString(strconv.Itoa(btoi(seg.Sampled)))
}

w.Header().Set(TraceIDHeaderKey, respHeader.String())
seg.Unlock()
httpCaptureRequest(seg, r)
traceIDHeaderValue := generateTraceIDHeaderValue(seg, traceHeader)
w.Header().Set(TraceIDHeaderKey, traceIDHeaderValue)

capturer := &responseCapturer{w, 200, 0}
resp := capturer.wrappedResponseWriter()
h.ServeHTTP(resp, r)

seg.Lock()
seg.GetHTTP().GetResponse().Status = capturer.status
seg.GetHTTP().GetResponse().ContentLength, _ = strconv.Atoi(capturer.Header().Get("Content-Length"))

if capturer.status >= 400 && capturer.status < 500 {
seg.Error = true
}
if capturer.status == 429 {
seg.Throttle = true
}
if capturer.status >= 500 && capturer.status < 600 {
seg.Fault = true
}
seg.Unlock()
httpCaptureResponse(seg, capturer.status)
}

func clientIP(r *http.Request) (string, bool) {
Expand All @@ -180,3 +149,54 @@ func btoi(b bool) int {
}
return 0
}

// generateTraceIDHeaderValue generates value for _x_amzn_trace_id header key
func generateTraceIDHeaderValue(seg *Segment, traceHeader *header.Header) string {
seg.Lock()
defer seg.Unlock()

var respHeader bytes.Buffer
respHeader.WriteString("Root=")
respHeader.WriteString(seg.TraceID)

if traceHeader.SamplingDecision == header.Requested {
respHeader.WriteString(";Sampled=")
respHeader.WriteString(strconv.Itoa(btoi(seg.Sampled)))
}

return respHeader.String()
}

// httpCaptureResponse fill response by http status code
func httpCaptureResponse(seg *Segment, statusCode int) {
seg.Lock()
defer seg.Unlock()

seg.GetHTTP().GetResponse().Status = statusCode

if statusCode >= 400 && statusCode < 500 {
seg.Error = true
}
if statusCode == 429 {
seg.Throttle = true
}
if statusCode >= 500 && statusCode < 600 {
seg.Fault = true
}
}

// httpCaptureRequest fill request data by http.Request
func httpCaptureRequest(seg *Segment, r *http.Request) {
seg.Lock()
defer seg.Unlock()

scheme := "https://"
if r.TLS == nil {
scheme = "http://"
}

seg.GetHTTP().GetRequest().Method = r.Method
seg.GetHTTP().GetRequest().URL = scheme + r.Host + r.URL.Path
seg.GetHTTP().GetRequest().ClientIP, seg.GetHTTP().GetRequest().XForwardedFor = clientIP(r)
seg.GetHTTP().GetRequest().UserAgent = r.UserAgent()
}
132 changes: 132 additions & 0 deletions xray/handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,11 @@ import (
"net/http"
"net/http/httptest"
"os"
"reflect"
"strings"
"testing"

"github.com/aws/aws-xray-sdk-go/header"
"github.com/stretchr/testify/assert"
)

Expand Down Expand Up @@ -171,3 +173,133 @@ func BenchmarkHandler(b *testing.B) {
ts.Close()
}
}

func TestGenerateTraceIDHeaderValue(t *testing.T) {
type args struct {
seg *Segment
traceHeader *header.Header
}
tests := []struct {
name string
args func(t *testing.T) args
want1 string
}{
{
name: "TraceID with sampling decision",
args: func(*testing.T) args {
return args{
seg: &Segment{
TraceID: "x-traceid",
Sampled: true,
},
traceHeader: &header.Header{
SamplingDecision: header.Requested,
},
}
},
want1: "Root=x-traceid;Sampled=1",
},
{
name: "TraceID without Sampled",
args: func(*testing.T) args {
return args{
seg: &Segment{
TraceID: "x-traceid",
Sampled: true,
},
traceHeader: &header.Header{},
}
},
want1: "Root=x-traceid",
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
tArgs := tt.args(t)
got1 := generateTraceIDHeaderValue(tArgs.seg, tArgs.traceHeader)
if !reflect.DeepEqual(got1, tt.want1) {
t.Errorf("Segment.TraceHeaderID got1 = %v, want1: %v", got1, tt.want1)
}
})
}
}

func TestHTTPCaptureResponse(t *testing.T) {
type args struct {
seg *Segment
statusCode int
}
tests := []struct {
name string
inspect func(r *Segment, t *testing.T) //inspects receiver after test run

args func(t *testing.T) args
}{
{
name: "StatudCode 400 >= 400 and < 500 is a error",
inspect: func(s *Segment, t *testing.T) {
if !s.Error {
t.Errorf("Segment error, got = false, want1: true")
}
},
args: func(*testing.T) args {
return args{
seg: &Segment{},
statusCode: 401,
}
},
},
{
name: "StatudCode 429 set error/throttle",
inspect: func(s *Segment, t *testing.T) {
if !s.Error {
t.Errorf("Segment error, got = false, want1: true")
}

if !s.Throttle {
t.Errorf("Segment.Throttle error, got = false, want1: true")
}

},
args: func(*testing.T) args {
return args{
seg: &Segment{},
statusCode: 429,
}
},
},
{
name: "StatusCode 500 is a fault error",
inspect: func(s *Segment, t *testing.T) {
if !s.Fault {
t.Errorf("Segment.Fault error, got = false, want1: true")
}

},
args: func(*testing.T) args {
return args{
seg: &Segment{},
statusCode: 500,
}
},
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
tArgs := tt.args(t)

httpCaptureResponse(tArgs.seg, tArgs.statusCode)

if tt.inspect != nil {
tt.inspect(tArgs.seg, t)
}

if tArgs.seg.GetHTTP().GetResponse().Status != tArgs.statusCode {
t.Errorf("Status code error, got = %d, want1: %d", tArgs.seg.GetHTTP().GetResponse().Status, tArgs.statusCode)
}

})
}
}

0 comments on commit 015dc2e

Please sign in to comment.