Skip to content

Commit

Permalink
feat(internal): Add zstd to internal content_coding (#13423)
Browse files Browse the repository at this point in the history
  • Loading branch information
Zeyad Kenawi authored Jun 22, 2023
1 parent 14b0750 commit 577db89
Show file tree
Hide file tree
Showing 10 changed files with 382 additions and 123 deletions.
189 changes: 138 additions & 51 deletions internal/content_coding.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,25 @@ import (

"github.com/klauspost/compress/gzip"
"github.com/klauspost/compress/zlib"
"github.com/klauspost/compress/zstd"
"github.com/klauspost/pgzip"
)

const DefaultMaxDecompressionSize = 500 * 1024 * 1024 //500MB
const defaultMaxDecompressionSize int64 = 500 * 1024 * 1024 //500MB

// DecodingOption provide methods to change the decoding from the standard
// configuration.
type DecodingOption func(*decoderConfig)

type decoderConfig struct {
maxDecompressionSize int64
}

func WithMaxDecompressionSize(maxDecompressionSize int64) DecodingOption {
return func(cfg *decoderConfig) {
cfg.maxDecompressionSize = maxDecompressionSize
}
}

type encoderConfig struct {
level int
Expand Down Expand Up @@ -92,10 +107,12 @@ func NewContentEncoder(encoding string, options ...EncodingOption) (ContentEncod
switch encoding {
case "gzip":
return NewGzipEncoder(options...)
case "zlib":
return NewZlibEncoder(options...)
case "identity", "":
return NewIdentityEncoder(options...)
case "zlib":
return NewZlibEncoder(options...)
case "zstd":
return NewZstdEncoder(options...)
default:
return nil, errors.New("invalid value for content_encoding")
}
Expand All @@ -111,32 +128,34 @@ func (a *AutoDecoder) SetEncoding(encoding string) {
a.encoding = encoding
}

func (a *AutoDecoder) Decode(data []byte, maxDecompressionSize int64) ([]byte, error) {
func (a *AutoDecoder) Decode(data []byte) ([]byte, error) {
if a.encoding == "gzip" {
return a.gzip.Decode(data, maxDecompressionSize)
return a.gzip.Decode(data)
}
return a.identity.Decode(data, maxDecompressionSize)
return a.identity.Decode(data)
}

func NewAutoContentDecoder() *AutoDecoder {
func NewAutoContentDecoder(options ...DecodingOption) *AutoDecoder {
var a AutoDecoder

a.identity = NewIdentityDecoder()
a.gzip = NewGzipDecoder()
a.identity = NewIdentityDecoder(options...)
a.gzip = NewGzipDecoder(options...)
return &a
}

// NewContentDecoder returns a ContentDecoder for the encoding type.
func NewContentDecoder(encoding string) (ContentDecoder, error) {
func NewContentDecoder(encoding string, options ...DecodingOption) (ContentDecoder, error) {
switch encoding {
case "auto":
return NewAutoContentDecoder(options...), nil
case "gzip":
return NewGzipDecoder(), nil
case "zlib":
return NewZlibDecoder(), nil
return NewGzipDecoder(options...), nil
case "identity", "":
return NewIdentityDecoder(), nil
case "auto":
return NewAutoContentDecoder(), nil
return NewIdentityDecoder(options...), nil
case "zlib":
return NewZlibDecoder(options...), nil
case "zstd":
return NewZstdDecoder(options...)
default:
return nil, errors.New("invalid value for content_encoding")
}
Expand Down Expand Up @@ -165,7 +184,7 @@ func NewGzipEncoder(options ...EncodingOption) (*GzipEncoder, error) {
case gzip.NoCompression, gzip.DefaultCompression, gzip.BestSpeed, gzip.BestCompression:
// Do nothing as those are valid levels
default:
return nil, fmt.Errorf("invalid compression level, only 0, 1 and 9 are supported")
return nil, errors.New("invalid compression level, only 0, 1 and 9 are supported")
}

var buf bytes.Buffer
Expand Down Expand Up @@ -238,7 +257,7 @@ func NewZlibEncoder(options ...EncodingOption) (*ZlibEncoder, error) {
case zlib.NoCompression, zlib.DefaultCompression, zlib.BestSpeed, zlib.BestCompression:
// Do nothing as those are valid levels
default:
return nil, fmt.Errorf("invalid compression level, only 0, 1 and 9 are supported")
return nil, errors.New("invalid compression level, only 0, 1 and 9 are supported")
}

var buf bytes.Buffer
Expand All @@ -264,6 +283,41 @@ func (e *ZlibEncoder) Encode(data []byte) ([]byte, error) {
return e.buf.Bytes(), nil
}

type ZstdEncoder struct {
encoder *zstd.Encoder
}

func NewZstdEncoder(options ...EncodingOption) (*ZstdEncoder, error) {
cfg := encoderConfig{level: 3}
for _, o := range options {
o(&cfg)
}

// Map the levels
var level zstd.EncoderLevel
switch cfg.level {
case 1:
level = zstd.SpeedFastest
case 3:
level = zstd.SpeedDefault
case 7:
level = zstd.SpeedBetterCompression
case 11:
level = zstd.SpeedBestCompression
default:
return nil, errors.New("invalid compression level, only 1, 3, 7 and 11 are supported")
}

e, err := zstd.NewWriter(nil, zstd.WithEncoderLevel(level))
return &ZstdEncoder{
encoder: e,
}, err
}

func (e *ZstdEncoder) Encode(data []byte) ([]byte, error) {
return e.encoder.EncodeAll(data, make([]byte, 0, len(data))), nil
}

// IdentityEncoder is a null encoder that applies no transformation.
type IdentityEncoder struct{}

Expand All @@ -282,49 +336,56 @@ func (*IdentityEncoder) Encode(data []byte) ([]byte, error) {
// ContentDecoder removes a wrapper encoding from byte buffers.
type ContentDecoder interface {
SetEncoding(string)
Decode([]byte, int64) ([]byte, error)
Decode([]byte) ([]byte, error)
}

// GzipDecoder decompresses buffers with gzip compression.
type GzipDecoder struct {
preader *pgzip.Reader
reader *gzip.Reader
buf *bytes.Buffer
preader *pgzip.Reader
reader *gzip.Reader
buf *bytes.Buffer
maxDecompressionSize int64
}

func NewGzipDecoder() *GzipDecoder {
func NewGzipDecoder(options ...DecodingOption) *GzipDecoder {
cfg := decoderConfig{maxDecompressionSize: defaultMaxDecompressionSize}
for _, o := range options {
o(&cfg)
}

return &GzipDecoder{
preader: new(pgzip.Reader),
reader: new(gzip.Reader),
buf: new(bytes.Buffer),
preader: new(pgzip.Reader),
reader: new(gzip.Reader),
buf: new(bytes.Buffer),
maxDecompressionSize: cfg.maxDecompressionSize,
}
}

func (*GzipDecoder) SetEncoding(string) {}

func (d *GzipDecoder) Decode(data []byte, maxDecompressionSize int64) ([]byte, error) {
func (d *GzipDecoder) Decode(data []byte) ([]byte, error) {
// Parallel Gzip is only faster for larger data chunks. According to the
// project's documentation the trade-off size is at about 1MB, so we switch
// to parallel Gzip if the data is larger and run the built-in version
// otherwise.
if len(data) > 1024*1024 {
return d.decodeBig(data, maxDecompressionSize)
return d.decodeBig(data)
}
return d.decodeSmall(data, maxDecompressionSize)
return d.decodeSmall(data)
}

func (d *GzipDecoder) decodeSmall(data []byte, maxDecompressionSize int64) ([]byte, error) {
func (d *GzipDecoder) decodeSmall(data []byte) ([]byte, error) {
err := d.reader.Reset(bytes.NewBuffer(data))
if err != nil {
return nil, err
}
d.buf.Reset()

n, err := io.CopyN(d.buf, d.reader, maxDecompressionSize)
n, err := io.CopyN(d.buf, d.reader, d.maxDecompressionSize)
if err != nil && !errors.Is(err, io.EOF) {
return nil, err
} else if n == maxDecompressionSize {
return nil, fmt.Errorf("size of decoded data exceeds allowed size %d", maxDecompressionSize)
} else if n == d.maxDecompressionSize {
return nil, fmt.Errorf("size of decoded data exceeds allowed size %d", d.maxDecompressionSize)
}

err = d.reader.Close()
Expand All @@ -334,18 +395,18 @@ func (d *GzipDecoder) decodeSmall(data []byte, maxDecompressionSize int64) ([]by
return d.buf.Bytes(), nil
}

func (d *GzipDecoder) decodeBig(data []byte, maxDecompressionSize int64) ([]byte, error) {
func (d *GzipDecoder) decodeBig(data []byte) ([]byte, error) {
err := d.preader.Reset(bytes.NewBuffer(data))
if err != nil {
return nil, err
}
d.buf.Reset()

n, err := io.CopyN(d.buf, d.preader, maxDecompressionSize)
n, err := io.CopyN(d.buf, d.preader, d.maxDecompressionSize)
if err != nil && !errors.Is(err, io.EOF) {
return nil, err
} else if n == maxDecompressionSize {
return nil, fmt.Errorf("size of decoded data exceeds allowed size %d", maxDecompressionSize)
} else if n == d.maxDecompressionSize {
return nil, fmt.Errorf("size of decoded data exceeds allowed size %d", d.maxDecompressionSize)
}

err = d.preader.Close()
Expand All @@ -356,18 +417,25 @@ func (d *GzipDecoder) decodeBig(data []byte, maxDecompressionSize int64) ([]byte
}

type ZlibDecoder struct {
buf *bytes.Buffer
buf *bytes.Buffer
maxDecompressionSize int64
}

func NewZlibDecoder() *ZlibDecoder {
func NewZlibDecoder(options ...DecodingOption) *ZlibDecoder {
cfg := decoderConfig{maxDecompressionSize: defaultMaxDecompressionSize}
for _, o := range options {
o(&cfg)
}

return &ZlibDecoder{
buf: new(bytes.Buffer),
buf: new(bytes.Buffer),
maxDecompressionSize: cfg.maxDecompressionSize,
}
}

func (*ZlibDecoder) SetEncoding(string) {}

func (d *ZlibDecoder) Decode(data []byte, maxDecompressionSize int64) ([]byte, error) {
func (d *ZlibDecoder) Decode(data []byte) ([]byte, error) {
d.buf.Reset()

b := bytes.NewBuffer(data)
Expand All @@ -376,11 +444,11 @@ func (d *ZlibDecoder) Decode(data []byte, maxDecompressionSize int64) ([]byte, e
return nil, err
}

n, err := io.CopyN(d.buf, r, maxDecompressionSize)
n, err := io.CopyN(d.buf, r, d.maxDecompressionSize)
if err != nil && !errors.Is(err, io.EOF) {
return nil, err
} else if n == maxDecompressionSize {
return nil, fmt.Errorf("size of decoded data exceeds allowed size %d", maxDecompressionSize)
} else if n == d.maxDecompressionSize {
return nil, fmt.Errorf("size of decoded data exceeds allowed size %d", d.maxDecompressionSize)
}

err = r.Close()
Expand All @@ -390,19 +458,38 @@ func (d *ZlibDecoder) Decode(data []byte, maxDecompressionSize int64) ([]byte, e
return d.buf.Bytes(), nil
}

type ZstdDecoder struct {
decoder *zstd.Decoder
}

func NewZstdDecoder(options ...DecodingOption) (*ZstdDecoder, error) {
cfg := decoderConfig{maxDecompressionSize: defaultMaxDecompressionSize}
for _, o := range options {
o(&cfg)
}

d, err := zstd.NewReader(nil, zstd.WithDecoderConcurrency(0), zstd.WithDecoderMaxWindow(uint64(cfg.maxDecompressionSize)))
return &ZstdDecoder{
decoder: d,
}, err
}

func (*ZstdDecoder) SetEncoding(string) {}

func (d *ZstdDecoder) Decode(data []byte) ([]byte, error) {
return d.decoder.DecodeAll(data, nil)
}

// IdentityDecoder is a null decoder that returns the input.
type IdentityDecoder struct{}
type IdentityDecoder struct {
}

func NewIdentityDecoder() *IdentityDecoder {
func NewIdentityDecoder(_ ...DecodingOption) *IdentityDecoder {
return &IdentityDecoder{}
}

func (*IdentityDecoder) SetEncoding(string) {}

func (*IdentityDecoder) Decode(data []byte, maxDecompressionSize int64) ([]byte, error) {
size := int64(len(data))
if size > maxDecompressionSize {
return nil, fmt.Errorf("size of decoded data: %d exceeds allowed size %d", size, maxDecompressionSize)
}
func (*IdentityDecoder) Decode(data []byte) ([]byte, error) {
return data, nil
}
Loading

0 comments on commit 577db89

Please sign in to comment.