Skip to content

Commit

Permalink
remove sapm handler and update tests to cover parser
Browse files Browse the repository at this point in the history
  • Loading branch information
charless-splunk committed Dec 6, 2019
1 parent 212de9c commit 6957d7c
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 152 deletions.
105 changes: 10 additions & 95 deletions sapmrequestparser/parser.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
package sapmhandler
package sapmrequestparser

import (
"bytes"
"compress/gzip"
"context"
"errors"
"io"
"io/ioutil"
Expand All @@ -16,50 +15,40 @@ import (

const (
// TraceEndpointV2 is the endpoint used for SAPM v2 traces. The SAPM protocol started with v2. There is no v1.
TraceEndpointV2 = "/v2/trace"
TraceEndpointV2 = "/v2/trace"
ContentTypeHeaderName = "Content-Type"
ContentTypeHeaderValue = "application/x-protobuf"

contentTypeHeader = "Content-Type"
xprotobuf = "application/x-protobuf"

acceptEncodingHeader = "Accept-Encoding"
contentEncodingHeader = "Content-Encoding"
gzipEncoding = "gzip"
AcceptEncodingHeaderName = "Accept-Encoding"
ContentEncodingHeaderName = "Content-Encoding"
GZipEncodingHeaderValue = "gzip"
)

var (
// ErrBadContentType indicates an incompatible content type was received
ErrBadContentType = errors.New("bad content type")

// ErrBadRequest indicates that the request couldn't be decoded
ErrBadRequest = errors.New("bad request")

gzipReaderPool = &sync.Pool{
New: func() interface{} {
// create a new gzip reader with a bytes reader and array of bytes containing only the gzip header
r, _ := gzip.NewReader(bytes.NewReader([]byte{31, 139, 8, 0, 0, 0, 0, 0, 0, 255, 0, 0, 0, 255, 255, 1, 0, 0, 255, 255, 0, 0, 0, 0, 0, 0, 0, 0}))
return r
},
}

gzipWriterPool = &sync.Pool{
New: func() interface{} {
return gzip.NewWriter(ioutil.Discard)
},
}
)

// ParseTraceV2Request processes an http request request into SAPM
func ParseTraceV2Request(req *http.Request) (*splunksapm.PostSpansRequest, error) {
// content type MUST be application/x-protobuf
if req.Header.Get(contentTypeHeader) != xprotobuf {
if req.Header.Get(ContentTypeHeaderName) != ContentTypeHeaderValue {
return nil, ErrBadContentType
}

var err error
var reader io.Reader

// content encoding SHOULD be gzip
if req.Header.Get(contentEncodingHeader) == gzipEncoding {
if req.Header.Get(ContentEncodingHeaderName) == GZipEncodingHeaderValue {
// get the gzip reader
reader = gzipReaderPool.Get().(*gzip.Reader)
defer gzipReaderPool.Put(reader)
Expand All @@ -86,82 +75,8 @@ func ParseTraceV2Request(req *http.Request) (*splunksapm.PostSpansRequest, error
// unmarshal request body
err = proto.Unmarshal(reqBytes, sapm)
if err != nil {
return sapm, err
return nil, err
}

return sapm, err
}

// NewTraceHandlerV2 returns an http.HandlerFunc for receiving SAPM requests and passing the SAPM to a receiving function
func NewTraceHandlerV2(receiver func(ctx context.Context, sapm *splunksapm.PostSpansRequest, err error) error) func(rw http.ResponseWriter, req *http.Request) {
return func(rw http.ResponseWriter, req *http.Request) {
sapm, err := ParseTraceV2Request(req)
// errors processing the request should return http.StatusBadRequest
if err != nil {
rw.WriteHeader(http.StatusBadRequest)
}

// pass the SAPM and error to the receiver function
err = receiver(req.Context(), sapm, err)

// handle errors from the receiver function
if err != nil {
// write a 500 error and return if the error isn't ErrBadRequest
if err == ErrBadRequest {
rw.WriteHeader(http.StatusBadRequest)
} else {
// return a 500 when an unknown error occurs in receiver
rw.WriteHeader(http.StatusInternalServerError)
}
return
}

// respBytes are bytes to write to the http.Response
var respBytes []byte

// build the response message
respBytes, err = proto.Marshal(&splunksapm.PostSpansResponse{})
if err != nil {
rw.WriteHeader(http.StatusInternalServerError)
return
}
rw.Header().Set(contentTypeHeader, xprotobuf)

// write the response if client does not accept gzip encoding
if req.Header.Get(acceptEncodingHeader) != gzipEncoding {
// write the response bytes
rw.Write(respBytes)
return
}

// gzip the response

// get the gzip writer
writer := gzipWriterPool.Get().(*gzip.Writer)
defer gzipWriterPool.Put(writer)

var gzipBuffer bytes.Buffer

// reset the writer with the gzip buffer
writer.Reset(&gzipBuffer)

// gzip the responseBytes
_, err = writer.Write(respBytes)
if err != nil {
rw.WriteHeader(http.StatusInternalServerError)
return
}

// close the gzip writer and write gzip footer
err = writer.Close()
if err != nil {
rw.WriteHeader(http.StatusInternalServerError)
return
}

// write the successfully gzipped payload
rw.Header().Set(contentEncodingHeader, gzipEncoding)
rw.Write(gzipBuffer.Bytes())
return
}
}
98 changes: 41 additions & 57 deletions sapmrequestparser/parser_test.go
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
package sapmhandler
package sapmrequestparser

import (
"bytes"
"compress/gzip"
"context"
"net/http"
"net/http/httptest"
"path"
"reflect"
"testing"
"testing/iotest"

Expand All @@ -16,132 +16,116 @@ import (

func TestNewV2TraceHandler(t *testing.T) {
var zipper *gzip.Writer
validProto, _ := proto.Marshal(&splunksapm.PostSpansRequest{})
validSapm := &splunksapm.PostSpansRequest{}
validProto, _ := proto.Marshal(validSapm)
uncompressedValidProtobufReq := httptest.NewRequest(http.MethodPost, path.Join("http://localhost", TraceEndpointV2), bytes.NewReader(validProto))
uncompressedValidProtobufReq.Header.Set(contentTypeHeader, xprotobuf)
uncompressedValidProtobufReq.Header.Set(ContentTypeHeaderName, ContentTypeHeaderValue)

var gzippedValidProtobufBuf bytes.Buffer
zipper = gzip.NewWriter(&gzippedValidProtobufBuf)
zipper.Write(validProto)
zipper.Close()
gzippedValidProtobufReq := httptest.NewRequest(http.MethodPost, path.Join("http://localhost", TraceEndpointV2), bytes.NewReader(gzippedValidProtobufBuf.Bytes()))
gzippedValidProtobufReq.Header.Set(contentTypeHeader, xprotobuf)
gzippedValidProtobufReq.Header.Set(contentEncodingHeader, gzipEncoding)
gzippedValidProtobufReq.Header.Set(acceptEncodingHeader, gzipEncoding)
gzippedValidProtobufReq.Header.Set(ContentTypeHeaderName, ContentTypeHeaderValue)
gzippedValidProtobufReq.Header.Set(ContentEncodingHeaderName, GZipEncodingHeaderValue)
gzippedValidProtobufReq.Header.Set(AcceptEncodingHeaderName, GZipEncodingHeaderValue)

badContentTypeReq := httptest.NewRequest(http.MethodPost, path.Join("http://localhost", TraceEndpointV2), bytes.NewReader([]byte{}))
badContentTypeReq.Header.Set(contentTypeHeader, "application/json")
badContentTypeReq.Header.Set(ContentTypeHeaderName, "application/json")

errReader := iotest.TimeoutReader(bytes.NewReader([]byte{}))
errReader.Read([]byte{}) // read once so that subsequent reads return an error

badBodyReq := httptest.NewRequest(http.MethodPost, path.Join("http://localhost", TraceEndpointV2), errReader)
badBodyReq.Header.Set(contentTypeHeader, xprotobuf)
badBodyReq.Header.Set(ContentTypeHeaderName, ContentTypeHeaderValue)

badGzipReq := httptest.NewRequest(http.MethodPost, path.Join("http://localhost", TraceEndpointV2), bytes.NewBuffer([]byte("hello world")))
badGzipReq.Header.Set(contentTypeHeader, xprotobuf)
badGzipReq.Header.Set(contentEncodingHeader, gzipEncoding)
badGzipReq.Header.Set(ContentTypeHeaderName, ContentTypeHeaderValue)
badGzipReq.Header.Set(ContentEncodingHeaderName, GZipEncodingHeaderValue)

var emptyGZipBuf bytes.Buffer
zipper = gzip.NewWriter(&emptyGZipBuf)
zipper.Write([]byte{})
zipper.Close()
emptyGZipReq := httptest.NewRequest(http.MethodPost, path.Join("http://localhost", TraceEndpointV2), bytes.NewReader(emptyGZipBuf.Bytes()))
emptyGZipReq.Header.Set(contentTypeHeader, xprotobuf)
emptyGZipReq.Header.Set(contentEncodingHeader, gzipEncoding)
emptyGZipReq.Header.Set(ContentTypeHeaderName, ContentTypeHeaderValue)
emptyGZipReq.Header.Set(ContentEncodingHeaderName, GZipEncodingHeaderValue)

var invalidProtubfBuf bytes.Buffer
zipper = gzip.NewWriter(&invalidProtubfBuf)
zipper.Write([]byte("invalid protbuf body"))
zipper.Close()
invalidProtobufReq := httptest.NewRequest(http.MethodPost, path.Join("http://localhost", TraceEndpointV2), bytes.NewReader(invalidProtubfBuf.Bytes()))
invalidProtobufReq.Header.Set(contentTypeHeader, xprotobuf)
invalidProtobufReq.Header.Set(contentEncodingHeader, gzipEncoding)
invalidProtobufReq.Header.Set(ContentTypeHeaderName, ContentTypeHeaderValue)
invalidProtobufReq.Header.Set(ContentEncodingHeaderName, GZipEncodingHeaderValue)

type want struct {
wantErr bool
statusCode int
wantErr bool
sapm *splunksapm.PostSpansRequest
}
tests := []struct {
name string
req *http.Request
want want
}{
{
name: "valid protobuf returns a 200 status code",
name: "valid protobuf returns and valid sapm",
req: uncompressedValidProtobufReq,
want: want{
statusCode: http.StatusOK,
wantErr: false,
sapm: validSapm,
wantErr: false,
},
},
{
name: "a bad request body returns error and 400 status code",
name: "a bad request body returns error and nil sapm",
req: badBodyReq,
want: want{
statusCode: http.StatusBadRequest,
wantErr: true,
sapm: nil,
wantErr: true,
},
},
{
name: "valid gzipped protobuf returns a 200 status code",
name: "valid gzipped protobuf returns and valid sapm",
req: gzippedValidProtobufReq,
want: want{
statusCode: http.StatusOK,
wantErr: false,
sapm: validSapm,
wantErr: false,
},
},
{
name: "invalid content type returns error and 400 status code",
name: "invalid content type returns error and nil sapm",
req: badContentTypeReq,
want: want{
statusCode: http.StatusBadRequest,
wantErr: true,
sapm: nil,
wantErr: true,
},
},
{
name: "invalid gzip data returns error and 400 status code",
name: "invalid gzip data returns error and nil sapm",
req: badGzipReq,
want: want{
statusCode: http.StatusBadRequest,
wantErr: true,
sapm: nil,
wantErr: true,
},
},
{
name: "invalid protobuf payload returns error and 400 status code",
name: "invalid protobuf payload returns error and nil sapm",
req: invalidProtobufReq,
want: want{
statusCode: http.StatusBadRequest,
wantErr: true,
sapm: nil,
wantErr: true,
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var returnedErr error
rw := httptest.NewRecorder()

receiver := func(ctx context.Context, sapm *splunksapm.PostSpansRequest, err error) error {
returnedErr = err
return err
}

handler := NewTraceHandlerV2(receiver)
handler(rw, tt.req)
if tt.want.wantErr != (returnedErr != nil) {
t.Errorf("NewTraceHandlerV2() returned err = %v, wantErr = %v", returnedErr, tt.want.wantErr)
sapm, err := ParseTraceV2Request(tt.req)
if tt.want.wantErr != (err != nil) {
t.Errorf("ParseTraceV2Request() returned err = %v, wantErr = %v", err, tt.want.wantErr)
return
}

if statusCode := rw.Code; tt.want.statusCode != statusCode {
t.Errorf("NewTraceHandlerV2() returned status code '%v', wanted '%v'", statusCode, tt.want.statusCode)
return
}

requestEncoding := tt.req.Header.Get(acceptEncodingHeader)
responseEncoding := rw.Header().Get(contentEncodingHeader)
if requestEncoding != responseEncoding {
t.Errorf("NewTraceHandlerV2() request encoding '%v' does not match response '%v'", requestEncoding, responseEncoding)
if !reflect.DeepEqual(sapm, tt.want.sapm) {
t.Errorf("ParseTraceV2Request() sapm returned = %v, wanted = %v", sapm, tt.want.sapm)
}
})
}
Expand Down

0 comments on commit 6957d7c

Please sign in to comment.