diff --git a/transmission/transmission.go b/transmission/transmission.go index 08a8203..e44fe74 100644 --- a/transmission/transmission.go +++ b/transmission/transmission.go @@ -388,6 +388,7 @@ func (b *batchAgg) fireBatch(events []*Event) { var req *http.Request reqBody, zipped := buildReqReader(encEvs, !b.disableCompression) req, err = http.NewRequest("POST", url.String(), reqBody) + req.ContentLength = int64(reqBody.Len()) req.Header.Set("Content-Type", contentType) if zipped { req.Header.Set("Content-Encoding", "zstd") @@ -630,14 +631,29 @@ func (b *batchAgg) enqueueErrResponses(err error, events []*Event, duration time var zstdBufferPool sync.Pool +type ReqReader interface { + io.ReadCloser + Len() int +} + type pooledReader struct { - *bytes.Reader + bytes.Reader buf []byte } +type SimpleReader struct { + bytes.Reader +} + +func (r SimpleReader) Close() error { + return nil +} + func (r *pooledReader) Close() error { + // Ensure further attempts to read will return io.EOF + r.Reset(nil) + // Then reset and give up ownership of the buffer. zstdBufferPool.Put(r.buf[:0]) - r.Reader = nil r.buf = nil return nil } @@ -661,9 +677,9 @@ func init() { } } -// buildReqReader returns an io.Reader and a boolean, indicating whether or not -// the io.Reader is compressed. -func buildReqReader(jsonEncoded []byte, compress bool) (io.ReadCloser, bool) { +// buildReqReader returns an io.ReadCloser and a boolean, indicating whether or not +// the underlying bytes.Reader is compressed. +func buildReqReader(jsonEncoded []byte, compress bool) (ReqReader, bool) { if compress { var buf []byte if found, ok := zstdBufferPool.Get().([]byte); ok { @@ -671,12 +687,15 @@ func buildReqReader(jsonEncoded []byte, compress bool) (io.ReadCloser, bool) { } buf = zstdEncoder.EncodeAll(jsonEncoded, buf) - return &pooledReader{ - Reader: bytes.NewReader(buf), - buf: buf, - }, true + reader := pooledReader{ + buf: buf, + } + reader.Reset(reader.buf) + return &reader, true } - return ioutil.NopCloser(bytes.NewReader(jsonEncoded)), false + var reader SimpleReader + reader.Reset(jsonEncoded) + return &reader, false } // nower to make testing easier diff --git a/transmission/transmission_test.go b/transmission/transmission_test.go index 6579063..a8f9b31 100644 --- a/transmission/transmission_test.go +++ b/transmission/transmission_test.go @@ -124,7 +124,13 @@ type FakeRoundTripper struct { func (f *FakeRoundTripper) RoundTrip(r *http.Request) (*http.Response, error) { f.req = r + if r.ContentLength == 0 { + panic("Expected a content length for all POST payloads.") + } bodyBytes, _ := ioutil.ReadAll(r.Body) + if r.ContentLength != int64(len(bodyBytes)) { + panic("Content length did not match number of read bytes.") + } f.reqBody = string(bodyBytes) // Honeycomb servers response to msgpack requests with msgpack responses, @@ -486,7 +492,13 @@ func (f *FancyFakeRoundTripper) RoundTrip(r *http.Request) (*http.Response, erro headerKeys := strings.Split(reqHeader, ",") expectedURL, _ := url.Parse(fmt.Sprintf("%s/1/batch/%s", headerKeys[0], headerKeys[2])) if r.Header.Get("X-Honeycomb-Team") == headerKeys[1] && r.URL.String() == expectedURL.String() { + if r.ContentLength == 0 { + panic("Expected a content length for all POST payloads.") + } bodyBytes, _ := ioutil.ReadAll(r.Body) + if r.ContentLength != int64(len(bodyBytes)) { + panic("Content length did not match number of read bytes.") + } f.reqBody = string(bodyBytes) // make sure body is legitimately compressed json @@ -1085,7 +1097,7 @@ func TestHoneycombSenderAddingResponsesBlocking(t *testing.T) { } -func TestBuildReqReaderNoGzip(t *testing.T) { +func TestBuildReqReaderCompress(t *testing.T) { payload := []byte(`{"hello": "world"}`) // Ensure that if compress is false, we get expected values @@ -1095,7 +1107,7 @@ func TestBuildReqReaderNoGzip(t *testing.T) { testOK(t, err) testEquals(t, readBuffer, payload) - // Ensure that if useGzip is true, we get compressed values + // Ensure that if compress is true, we get compressed values reader, compressed = buildReqReader([]byte(`{"hello": "world"}`), true) testEquals(t, compressed, true) readBuffer, err = ioutil.ReadAll(reader) @@ -1104,6 +1116,14 @@ func TestBuildReqReaderNoGzip(t *testing.T) { decompressed, err := zstd.Decompress(nil, readBuffer) testOK(t, err) testEquals(t, decompressed, payload) + + // Ensure that calling Close() on the compressed buffer, then + // attempting to Read() returns io.EOF but no crash. + // Needed to support https://go-review.googlesource.com/c/net/+/355491 + reader, _ = buildReqReader([]byte(`{"hello": "world"}`), true) + reader.Close() + _, err = reader.Read(nil) + testEquals(t, err, io.EOF) } func TestMsgpackArrayEncoding(t *testing.T) {