Skip to content

Commit

Permalink
fix: don't crash on stream aborts, always add content length (#156)
Browse files Browse the repository at this point in the history
  • Loading branch information
lizthegrey authored Jan 4, 2022
1 parent b46f1e7 commit 85568f3
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 12 deletions.
39 changes: 29 additions & 10 deletions transmission/transmission.go
Original file line number Diff line number Diff line change
Expand Up @@ -388,6 +388,7 @@ func (b *batchAgg) fireBatch(events []*Event) {
var req *http.Request
reqBody, zipped := buildReqReader(encEvs, !b.disableCompression)
req, err = http.NewRequest("POST", url.String(), reqBody)
req.ContentLength = int64(reqBody.Len())
req.Header.Set("Content-Type", contentType)
if zipped {
req.Header.Set("Content-Encoding", "zstd")
Expand Down Expand Up @@ -630,14 +631,29 @@ func (b *batchAgg) enqueueErrResponses(err error, events []*Event, duration time

var zstdBufferPool sync.Pool

type ReqReader interface {
io.ReadCloser
Len() int
}

type pooledReader struct {
*bytes.Reader
bytes.Reader
buf []byte
}

type SimpleReader struct {
bytes.Reader
}

func (r SimpleReader) Close() error {
return nil
}

func (r *pooledReader) Close() error {
// Ensure further attempts to read will return io.EOF
r.Reset(nil)
// Then reset and give up ownership of the buffer.
zstdBufferPool.Put(r.buf[:0])
r.Reader = nil
r.buf = nil
return nil
}
Expand All @@ -661,22 +677,25 @@ func init() {
}
}

// buildReqReader returns an io.Reader and a boolean, indicating whether or not
// the io.Reader is compressed.
func buildReqReader(jsonEncoded []byte, compress bool) (io.ReadCloser, bool) {
// buildReqReader returns an io.ReadCloser and a boolean, indicating whether or not
// the underlying bytes.Reader is compressed.
func buildReqReader(jsonEncoded []byte, compress bool) (ReqReader, bool) {
if compress {
var buf []byte
if found, ok := zstdBufferPool.Get().([]byte); ok {
buf = found[:0]
}

buf = zstdEncoder.EncodeAll(jsonEncoded, buf)
return &pooledReader{
Reader: bytes.NewReader(buf),
buf: buf,
}, true
reader := pooledReader{
buf: buf,
}
reader.Reset(reader.buf)
return &reader, true
}
return ioutil.NopCloser(bytes.NewReader(jsonEncoded)), false
var reader SimpleReader
reader.Reset(jsonEncoded)
return &reader, false
}

// nower to make testing easier
Expand Down
24 changes: 22 additions & 2 deletions transmission/transmission_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,13 @@ type FakeRoundTripper struct {

func (f *FakeRoundTripper) RoundTrip(r *http.Request) (*http.Response, error) {
f.req = r
if r.ContentLength == 0 {
panic("Expected a content length for all POST payloads.")
}
bodyBytes, _ := ioutil.ReadAll(r.Body)
if r.ContentLength != int64(len(bodyBytes)) {
panic("Content length did not match number of read bytes.")
}
f.reqBody = string(bodyBytes)

// Honeycomb servers response to msgpack requests with msgpack responses,
Expand Down Expand Up @@ -486,7 +492,13 @@ func (f *FancyFakeRoundTripper) RoundTrip(r *http.Request) (*http.Response, erro
headerKeys := strings.Split(reqHeader, ",")
expectedURL, _ := url.Parse(fmt.Sprintf("%s/1/batch/%s", headerKeys[0], headerKeys[2]))
if r.Header.Get("X-Honeycomb-Team") == headerKeys[1] && r.URL.String() == expectedURL.String() {
if r.ContentLength == 0 {
panic("Expected a content length for all POST payloads.")
}
bodyBytes, _ := ioutil.ReadAll(r.Body)
if r.ContentLength != int64(len(bodyBytes)) {
panic("Content length did not match number of read bytes.")
}
f.reqBody = string(bodyBytes)

// make sure body is legitimately compressed json
Expand Down Expand Up @@ -1085,7 +1097,7 @@ func TestHoneycombSenderAddingResponsesBlocking(t *testing.T) {

}

func TestBuildReqReaderNoGzip(t *testing.T) {
func TestBuildReqReaderCompress(t *testing.T) {
payload := []byte(`{"hello": "world"}`)

// Ensure that if compress is false, we get expected values
Expand All @@ -1095,7 +1107,7 @@ func TestBuildReqReaderNoGzip(t *testing.T) {
testOK(t, err)
testEquals(t, readBuffer, payload)

// Ensure that if useGzip is true, we get compressed values
// Ensure that if compress is true, we get compressed values
reader, compressed = buildReqReader([]byte(`{"hello": "world"}`), true)
testEquals(t, compressed, true)
readBuffer, err = ioutil.ReadAll(reader)
Expand All @@ -1104,6 +1116,14 @@ func TestBuildReqReaderNoGzip(t *testing.T) {
decompressed, err := zstd.Decompress(nil, readBuffer)
testOK(t, err)
testEquals(t, decompressed, payload)

// Ensure that calling Close() on the compressed buffer, then
// attempting to Read() returns io.EOF but no crash.
// Needed to support https://go-review.googlesource.com/c/net/+/355491
reader, _ = buildReqReader([]byte(`{"hello": "world"}`), true)
reader.Close()
_, err = reader.Read(nil)
testEquals(t, err, io.EOF)
}

func TestMsgpackArrayEncoding(t *testing.T) {
Expand Down

0 comments on commit 85568f3

Please sign in to comment.