diff --git a/compress.go b/compress.go new file mode 100644 index 000000000..94b716e4b --- /dev/null +++ b/compress.go @@ -0,0 +1,75 @@ +package sarama + +import ( + "bytes" + "compress/gzip" + "fmt" + "sync" + + "github.com/eapache/go-xerial-snappy" + "github.com/pierrec/lz4" +) + +var ( + lz4WriterPool = sync.Pool{ + New: func() interface{} { + return lz4.NewWriter(nil) + }, + } + + gzipWriterPool = sync.Pool{ + New: func() interface{} { + return gzip.NewWriter(nil) + }, + } +) + +func compress(cc CompressionCodec, level int, data []byte) ([]byte, error) { + switch cc { + case CompressionNone: + return data, nil + case CompressionGZIP: + var ( + err error + buf bytes.Buffer + writer *gzip.Writer + ) + if level != CompressionLevelDefault { + writer, err = gzip.NewWriterLevel(&buf, level) + if err != nil { + return nil, err + } + } else { + writer = gzipWriterPool.Get().(*gzip.Writer) + defer gzipWriterPool.Put(writer) + writer.Reset(&buf) + } + if _, err := writer.Write(data); err != nil { + return nil, err + } + if err := writer.Close(); err != nil { + return nil, err + } + return buf.Bytes(), nil + case CompressionSnappy: + return snappy.Encode(data), nil + case CompressionLZ4: + writer := lz4WriterPool.Get().(*lz4.Writer) + defer lz4WriterPool.Put(writer) + + var buf bytes.Buffer + writer.Reset(&buf) + + if _, err := writer.Write(data); err != nil { + return nil, err + } + if err := writer.Close(); err != nil { + return nil, err + } + return buf.Bytes(), nil + case CompressionZSTD: + return zstdCompressLevel(nil, data, level) + default: + return nil, PacketEncodingError{fmt.Sprintf("unsupported compression codec (%d)", cc)} + } +} diff --git a/decompress.go b/decompress.go new file mode 100644 index 000000000..eaccbfc26 --- /dev/null +++ b/decompress.go @@ -0,0 +1,63 @@ +package sarama + +import ( + "bytes" + "compress/gzip" + "fmt" + "io/ioutil" + "sync" + + "github.com/eapache/go-xerial-snappy" + "github.com/pierrec/lz4" +) + +var ( + lz4ReaderPool = sync.Pool{ + New: func() interface{} { + return lz4.NewReader(nil) + }, + } + + gzipReaderPool sync.Pool +) + +func decompress(cc CompressionCodec, data []byte) ([]byte, error) { + switch cc { + case CompressionNone: + return data, nil + case CompressionGZIP: + var ( + err error + reader *gzip.Reader + readerIntf = gzipReaderPool.Get() + ) + if readerIntf != nil { + reader = readerIntf.(*gzip.Reader) + } else { + reader, err = gzip.NewReader(bytes.NewReader(data)) + if err != nil { + return nil, err + } + } + + defer gzipReaderPool.Put(reader) + + if err := reader.Reset(bytes.NewReader(data)); err != nil { + return nil, err + } + + return ioutil.ReadAll(reader) + case CompressionSnappy: + return snappy.Decode(data) + case CompressionLZ4: + reader := lz4ReaderPool.Get().(*lz4.Reader) + defer lz4ReaderPool.Put(reader) + + reader.Reset(bytes.NewReader(data)) + return ioutil.ReadAll(reader) + case CompressionZSTD: + return zstdDecompress(nil, data) + default: + return nil, PacketDecodingError{fmt.Sprintf("invalid compression specified (%d)", cc)} + } +} diff --git a/message.go b/message.go index 44d5cc91b..51d3309c0 100644 --- a/message.go +++ b/message.go @@ -1,14 +1,8 @@ package sarama import ( - "bytes" - "compress/gzip" "fmt" - "io/ioutil" "time" - - "github.com/eapache/go-xerial-snappy" - "github.com/pierrec/lz4" ) // CompressionCodec represents the various compression codecs recognized by Kafka in messages. @@ -77,53 +71,12 @@ func (m *Message) encode(pe packetEncoder) error { payload = m.compressedCache m.compressedCache = nil } else if m.Value != nil { - switch m.Codec { - case CompressionNone: - payload = m.Value - case CompressionGZIP: - var buf bytes.Buffer - var writer *gzip.Writer - if m.CompressionLevel != CompressionLevelDefault { - writer, err = gzip.NewWriterLevel(&buf, m.CompressionLevel) - if err != nil { - return err - } - } else { - writer = gzip.NewWriter(&buf) - } - if _, err = writer.Write(m.Value); err != nil { - return err - } - if err = writer.Close(); err != nil { - return err - } - m.compressedCache = buf.Bytes() - payload = m.compressedCache - case CompressionSnappy: - tmp := snappy.Encode(m.Value) - m.compressedCache = tmp - payload = m.compressedCache - case CompressionLZ4: - var buf bytes.Buffer - writer := lz4.NewWriter(&buf) - if _, err = writer.Write(m.Value); err != nil { - return err - } - if err = writer.Close(); err != nil { - return err - } - m.compressedCache = buf.Bytes() - payload = m.compressedCache - case CompressionZSTD: - c, err := zstdCompressLevel(nil, m.Value, m.CompressionLevel) - if err != nil { - return err - } - m.compressedCache = c - payload = m.compressedCache - default: - return PacketEncodingError{fmt.Sprintf("unsupported compression codec (%d)", m.Codec)} + + payload, err = compress(m.Codec, m.CompressionLevel, m.Value) + if err != nil { + return err } + m.compressedCache = payload // Keep in mind the compressed payload size for metric gathering m.compressedSize = len(payload) } @@ -179,53 +132,18 @@ func (m *Message) decode(pd packetDecoder) (err error) { switch m.Codec { case CompressionNone: // nothing to do - case CompressionGZIP: + default: if m.Value == nil { break } - reader, err := gzip.NewReader(bytes.NewReader(m.Value)) + + m.Value, err = decompress(m.Codec, m.Value) if err != nil { return err } - if m.Value, err = ioutil.ReadAll(reader); err != nil { - return err - } if err := m.decodeSet(); err != nil { return err } - case CompressionSnappy: - if m.Value == nil { - break - } - if m.Value, err = snappy.Decode(m.Value); err != nil { - return err - } - if err := m.decodeSet(); err != nil { - return err - } - case CompressionLZ4: - if m.Value == nil { - break - } - reader := lz4.NewReader(bytes.NewReader(m.Value)) - if m.Value, err = ioutil.ReadAll(reader); err != nil { - return err - } - if err := m.decodeSet(); err != nil { - return err - } - case CompressionZSTD: - if m.Value == nil { - break - } - if m.Value, err = zstdDecompress(nil, m.Value); err != nil { - return err - } - if err := m.decodeSet(); err != nil { - return err - } - default: - return PacketDecodingError{fmt.Sprintf("invalid compression specified (%d)", m.Codec)} } return pd.pop() diff --git a/record_batch.go b/record_batch.go index 5444557f1..e0f183f7a 100644 --- a/record_batch.go +++ b/record_batch.go @@ -1,14 +1,8 @@ package sarama import ( - "bytes" - "compress/gzip" "fmt" - "io/ioutil" "time" - - "github.com/eapache/go-xerial-snappy" - "github.com/pierrec/lz4" ) const recordBatchOverhead = 49 @@ -174,31 +168,9 @@ func (b *RecordBatch) decode(pd packetDecoder) (err error) { return err } - switch b.Codec { - case CompressionNone: - case CompressionGZIP: - reader, err := gzip.NewReader(bytes.NewReader(recBuffer)) - if err != nil { - return err - } - if recBuffer, err = ioutil.ReadAll(reader); err != nil { - return err - } - case CompressionSnappy: - if recBuffer, err = snappy.Decode(recBuffer); err != nil { - return err - } - case CompressionLZ4: - reader := lz4.NewReader(bytes.NewReader(recBuffer)) - if recBuffer, err = ioutil.ReadAll(reader); err != nil { - return err - } - case CompressionZSTD: - if recBuffer, err = zstdDecompress(nil, recBuffer); err != nil { - return err - } - default: - return PacketDecodingError{fmt.Sprintf("invalid compression specified (%d)", b.Codec)} + recBuffer, err = decompress(b.Codec, recBuffer) + if err != nil { + return err } b.recordsLen = len(recBuffer) @@ -219,50 +191,8 @@ func (b *RecordBatch) encodeRecords(pe packetEncoder) error { } b.recordsLen = len(raw) - switch b.Codec { - case CompressionNone: - b.compressedRecords = raw - case CompressionGZIP: - var buf bytes.Buffer - var writer *gzip.Writer - if b.CompressionLevel != CompressionLevelDefault { - writer, err = gzip.NewWriterLevel(&buf, b.CompressionLevel) - if err != nil { - return err - } - } else { - writer = gzip.NewWriter(&buf) - } - if _, err := writer.Write(raw); err != nil { - return err - } - if err := writer.Close(); err != nil { - return err - } - b.compressedRecords = buf.Bytes() - case CompressionSnappy: - b.compressedRecords = snappy.Encode(raw) - case CompressionLZ4: - var buf bytes.Buffer - writer := lz4.NewWriter(&buf) - if _, err := writer.Write(raw); err != nil { - return err - } - if err := writer.Close(); err != nil { - return err - } - b.compressedRecords = buf.Bytes() - case CompressionZSTD: - c, err := zstdCompressLevel(nil, raw, b.CompressionLevel) - if err != nil { - return err - } - b.compressedRecords = c - default: - return PacketEncodingError{fmt.Sprintf("unsupported compression codec (%d)", b.Codec)} - } - - return nil + b.compressedRecords, err = compress(b.Codec, b.CompressionLevel, raw) + return err } func (b *RecordBatch) computeAttributes() int16 {