diff --git a/compress/benchmark_test.go b/compress/benchmark_test.go index 42c36ba8..28741844 100644 --- a/compress/benchmark_test.go +++ b/compress/benchmark_test.go @@ -4,14 +4,31 @@ import ( "testing" ) -// BenchmarkNew-24 55165 22851 ns/op 23884 B/op 2 allocs/op +/* +BenchmarkNew/zstd +BenchmarkNew/zstd-24 46011 21925 ns/op 28451 B/op 2 allocs/op +BenchmarkNew/zstd-cgo +BenchmarkNew/zstd-cgo-24 47054 25502 ns/op 960 B/op 2 allocs/op +*/ func BenchmarkNew(b *testing.B) { - b.ReportAllocs() - c, _ := New(CompressionAlgoZstd, CompressionLevelZstdBest) - defer func() { _ = c.Close() }() + b.Run("zstd", func(b *testing.B) { + b.ReportAllocs() + c, _ := New(CompressionAlgoZstd, CompressionLevelZstdBest) + defer func() { _ = c.Close() }() - for i := 0; i < b.N; i++ { - r, _ := c.Compress(loremIpsumDolor) - _, _ = c.Decompress(r) - } + for i := 0; i < b.N; i++ { + r, _ := c.Compress(loremIpsumDolor) + _, _ = c.Decompress(r) + } + }) + b.Run("zstd-cgo", func(b *testing.B) { + b.ReportAllocs() + c, _ := New(CompressionAlgoZstdCgo, CompressionLevelZstdCgoBest) + defer func() { _ = c.Close() }() + + for i := 0; i < b.N; i++ { + r, _ := c.Compress(loremIpsumDolor) + _, _ = c.Decompress(r) + } + }) } diff --git a/compress/compress.go b/compress/compress.go index a64a15aa..72d7ae08 100644 --- a/compress/compress.go +++ b/compress/compress.go @@ -3,6 +3,7 @@ package compress import ( "fmt" + zstdcgo "github.com/DataDog/zstd" "github.com/klauspost/compress/zstd" ) @@ -13,19 +14,19 @@ func (c CompressionAlgorithm) String() string { switch c { case CompressionAlgoZstd: return "zstd" + case CompressionAlgoZstdCgo: + return "zstd-cgo" default: return "" } } -func (c CompressionAlgorithm) isValid() bool { - return c == CompressionAlgoZstd -} - func NewCompressionAlgorithm(s string) (CompressionAlgorithm, error) { switch s { case "zstd": return CompressionAlgoZstd, nil + case "zstd-cgo": + return CompressionAlgoZstdCgo, nil default: return 0, fmt.Errorf("unknown compression algorithm: %s", s) } @@ -36,31 +37,19 @@ type CompressionLevel int func (c CompressionLevel) String() string { switch c { - case CompressionLevelZstdFastest: + case CompressionLevelZstdFastest, CompressionLevelZstdCgoFastest: return "fastest" - case CompressionLevelZstdDefault: + case CompressionLevelZstdDefault, CompressionLevelZstdCgoDefault: return "default" case CompressionLevelZstdBetter: return "better" - case CompressionLevelZstdBest: + case CompressionLevelZstdBest, CompressionLevelZstdCgoBest: return "best" default: return "" } } -func (c CompressionLevel) isValid() bool { - switch c { - case CompressionLevelZstdFastest, - CompressionLevelZstdDefault, - CompressionLevelZstdBetter, - CompressionLevelZstdBest: - return true - default: - return false - } -} - func NewCompressionLevel(s string) (CompressionLevel, error) { switch s { case "fastest": @@ -77,54 +66,89 @@ func NewCompressionLevel(s string) (CompressionLevel, error) { } var ( - CompressionAlgoZstd = CompressionAlgorithm(1) - + CompressionAlgoZstd = CompressionAlgorithm(1) CompressionLevelZstdFastest = CompressionLevel(zstd.SpeedFastest) CompressionLevelZstdDefault = CompressionLevel(zstd.SpeedDefault) // "pretty fast" compression CompressionLevelZstdBetter = CompressionLevel(zstd.SpeedBetterCompression) CompressionLevelZstdBest = CompressionLevel(zstd.SpeedBestCompression) + + CompressionAlgoZstdCgo = CompressionAlgorithm(2) + CompressionLevelZstdCgoFastest = CompressionLevel(zstdcgo.BestSpeed) // 1 + CompressionLevelZstdCgoDefault = CompressionLevel(zstdcgo.DefaultCompression) // 5 + CompressionLevelZstdCgoBest = CompressionLevel(zstdcgo.BestCompression) // 20 ) func New(algo CompressionAlgorithm, level CompressionLevel) (*Compressor, error) { - if !algo.isValid() { - return nil, fmt.Errorf("invalid compression algorithm: %d", algo) - } - if !level.isValid() { - return nil, fmt.Errorf("invalid compression level: %d", level) - } - - encoder, err := zstd.NewWriter(nil, zstd.WithEncoderLevel(zstd.EncoderLevel(level))) - if err != nil { - return nil, fmt.Errorf("cannot create zstd encoder: %w", err) - } - - decoder, err := zstd.NewReader(nil) - if err != nil { - return nil, fmt.Errorf("cannot create zstd decoder: %w", err) + switch algo { + case CompressionAlgoZstd: + switch level { + case CompressionLevelZstdFastest, + CompressionLevelZstdDefault, + CompressionLevelZstdBetter, + CompressionLevelZstdBest: + default: + return nil, fmt.Errorf("invalid compression level for %q: %d", algo, level) + } + + encoder, err := zstd.NewWriter(nil, zstd.WithEncoderLevel(zstd.EncoderLevel(level))) + if err != nil { + return nil, fmt.Errorf("cannot create zstd encoder: %w", err) + } + + decoder, err := zstd.NewReader(nil) + if err != nil { + return nil, fmt.Errorf("cannot create zstd decoder: %w", err) + } + + return &Compressor{compressorZstd: &compressorZstd{ + encoder: encoder, + decoder: decoder, + }}, nil + case CompressionAlgoZstdCgo: + var cgoLevel int + switch level { + case CompressionLevelZstdCgoFastest: + cgoLevel = zstdcgo.BestSpeed + case CompressionLevelZstdCgoDefault: + cgoLevel = zstdcgo.DefaultCompression + case CompressionLevelZstdCgoBest: + cgoLevel = zstdcgo.BestCompression + default: + return nil, fmt.Errorf("invalid compression level for %q: %d", algo, level) + } + + return &Compressor{ + compressorZstdCgo: &compressorZstdCgo{level: cgoLevel}, + }, nil + default: + return nil, fmt.Errorf("unknown compression algorithm: %d", algo) } - - return &Compressor{ - encoder: encoder, - decoder: decoder, - }, nil } type Compressor struct { - encoder *zstd.Encoder - decoder *zstd.Decoder + *compressorZstd + *compressorZstdCgo } func (c *Compressor) Compress(src []byte) ([]byte, error) { - return c.encoder.EncodeAll(src, nil), nil + if c.compressorZstdCgo != nil { + return c.compressorZstdCgo.Compress(src) + } + return c.compressorZstd.Compress(src) } func (c *Compressor) Decompress(src []byte) ([]byte, error) { - return c.decoder.DecodeAll(src, nil) + if c.compressorZstdCgo != nil { + return c.compressorZstdCgo.Decompress(src) + } + return c.compressorZstd.Decompress(src) } func (c *Compressor) Close() error { - c.decoder.Close() - return c.encoder.Close() + if c.compressorZstdCgo != nil { + return nil + } + return c.compressorZstd.Close() } // SerializeSettings serializes the compression settings. @@ -141,13 +165,27 @@ func DeserializeSettings(s string) (CompressionAlgorithm, CompressionLevel, erro } algo := CompressionAlgorithm(algoInt) - if !algo.isValid() { - return 0, 0, fmt.Errorf("invalid compression algorithm: %d", algoInt) - } - level := CompressionLevel(levelInt) - if !level.isValid() { - return 0, 0, fmt.Errorf("invalid compression level: %d", levelInt) + switch algo { + case CompressionAlgoZstd: + switch level { + case CompressionLevelZstdFastest, + CompressionLevelZstdDefault, + CompressionLevelZstdBetter, + CompressionLevelZstdBest: + default: + return 0, 0, fmt.Errorf("invalid compression level for %q: %d", algo, level) + } + case CompressionAlgoZstdCgo: + switch level { + case CompressionLevelZstdCgoFastest, + CompressionLevelZstdCgoDefault, + CompressionLevelZstdCgoBest: + default: + return 0, 0, fmt.Errorf("invalid compression level for %q: %d", algo, level) + } + default: + return 0, 0, fmt.Errorf("invalid compression algorithm: %d", algoInt) } return algo, level, nil diff --git a/compress/compress_test.go b/compress/compress_test.go index f18f7579..de3706f1 100644 --- a/compress/compress_test.go +++ b/compress/compress_test.go @@ -12,25 +12,36 @@ Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id est laborum.`) func TestCompress(t *testing.T) { - compressionLevels := []CompressionLevel{ - CompressionLevelZstdFastest, - CompressionLevelZstdDefault, - CompressionLevelZstdBetter, - CompressionLevelZstdBest, + type testCase struct { + algo CompressionAlgorithm + level CompressionLevel } - for _, level := range compressionLevels { - c, err := New(CompressionAlgoZstd, level) - require.NoError(t, err) + testCases := []testCase{ + {CompressionAlgoZstd, CompressionLevelZstdFastest}, + {CompressionAlgoZstd, CompressionLevelZstdDefault}, + {CompressionAlgoZstd, CompressionLevelZstdBetter}, + {CompressionAlgoZstd, CompressionLevelZstdBest}, - t.Cleanup(func() { _ = c.Close() }) + {CompressionAlgoZstdCgo, CompressionLevelZstdCgoFastest}, + {CompressionAlgoZstdCgo, CompressionLevelZstdCgoDefault}, + {CompressionAlgoZstdCgo, CompressionLevelZstdCgoBest}, + } + + for _, tc := range testCases { + t.Run(tc.algo.String()+"-"+tc.level.String(), func(t *testing.T) { + c, err := New(tc.algo, tc.level) + require.NoError(t, err) - compressed, err := c.Compress(loremIpsumDolor) - require.NoError(t, err) - require.Less(t, len(compressed), len(loremIpsumDolor)) + t.Cleanup(func() { _ = c.Close() }) - decompressed, err := c.Decompress(compressed) - require.NoError(t, err) - require.Equal(t, string(loremIpsumDolor), string(decompressed)) + compressed, err := c.Compress(loremIpsumDolor) + require.NoError(t, err) + require.Less(t, len(compressed), len(loremIpsumDolor)) + + decompressed, err := c.Decompress(compressed) + require.NoError(t, err) + require.Equal(t, string(loremIpsumDolor), string(decompressed)) + }) } } @@ -53,10 +64,16 @@ func TestSerialization(t *testing.T) { } func TestDeserializationError(t *testing.T) { - // valid algo is 1 - // valid level is 1-4 + // valid algo is 1, 2 + // valid level is 1-4 for algo 1 + // valid level is 1, 5, 20 for algo 2 testCases := []string{ - "0:0", "0:1", "1:0", "2:1", "1:5", + "0:0", "0:1", "0:2", "0:3", "0:4", "0:5", "0:20", + + "1:0", "1:5", "1:20", + + "2:0", "2:2", "2:3", "2:4", "2:6", "2:7", "2:8", "2:9", "2:10", "2:11", + "2:12", "2:13", "2:14", "2:15", "2:16", "2:17", "2:18", "2:19", "2:21", } for _, tc := range testCases { _, _, err := DeserializeSettings(tc) diff --git a/compress/zstd.go b/compress/zstd.go new file mode 100644 index 00000000..ad0276ed --- /dev/null +++ b/compress/zstd.go @@ -0,0 +1,21 @@ +package compress + +import "github.com/klauspost/compress/zstd" + +type compressorZstd struct { + encoder *zstd.Encoder + decoder *zstd.Decoder +} + +func (c *compressorZstd) Compress(src []byte) ([]byte, error) { + return c.encoder.EncodeAll(src, nil), nil +} + +func (c *compressorZstd) Decompress(src []byte) ([]byte, error) { + return c.decoder.DecodeAll(src, nil) +} + +func (c *compressorZstd) Close() error { + c.decoder.Close() + return c.encoder.Close() +} diff --git a/compress/zstdcgo.go b/compress/zstdcgo.go new file mode 100644 index 00000000..b32fad5f --- /dev/null +++ b/compress/zstdcgo.go @@ -0,0 +1,17 @@ +package compress + +import ( + "github.com/DataDog/zstd" +) + +type compressorZstdCgo struct { + level int +} + +func (c *compressorZstdCgo) Compress(src []byte) ([]byte, error) { + return zstd.CompressLevel(nil, src, c.level) +} + +func (c *compressorZstdCgo) Decompress(src []byte) ([]byte, error) { + return zstd.Decompress(nil, src) +} diff --git a/go.mod b/go.mod index 7b7ec7bb..b8e2f5b9 100644 --- a/go.mod +++ b/go.mod @@ -7,6 +7,7 @@ replace github.com/gocql/gocql => github.com/scylladb/gocql v1.14.2 require ( cloud.google.com/go/storage v1.43.0 github.com/Azure/azure-storage-blob-go v0.15.0 + github.com/DataDog/zstd v1.5.6 github.com/aws/aws-sdk-go v1.55.5 github.com/cenkalti/backoff/v4 v4.3.0 github.com/confluentinc/confluent-kafka-go/v2 v2.5.0 diff --git a/go.sum b/go.sum index cc0cd1d8..87591b09 100644 --- a/go.sum +++ b/go.sum @@ -43,6 +43,8 @@ github.com/Azure/go-autorest/logger v0.2.1/go.mod h1:T9E3cAhj2VqvPOtCYAvby9aBXkZ github.com/Azure/go-autorest/tracing v0.6.0 h1:TYi4+3m5t6K48TGI9AUdb+IzbnSxvnvUMfuitfgcfuo= github.com/Azure/go-autorest/tracing v0.6.0/go.mod h1:+vhtPC754Xsa23ID7GlGsrdKBpUA79WCAKPPZVC2DeU= github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= +github.com/DataDog/zstd v1.5.6 h1:LbEglqepa/ipmmQJUDnSsfvA8e8IStVcGaFWDuxvGOY= +github.com/DataDog/zstd v1.5.6/go.mod h1:g4AWEaM3yOg3HYfnJ3YIawPnVdXJh9QME85blwSAmyw= github.com/Masterminds/semver/v3 v3.2.1 h1:RN9w6+7QoMeJVGyfmbcgs28Br8cvmnucEXnY0rYXWg0= github.com/Masterminds/semver/v3 v3.2.1/go.mod h1:qvl/7zhW3nngYb5+80sSMF+FG2BjYrf8m9wsX0PNOMQ= github.com/Microsoft/go-winio v0.6.2 h1:F2VQgta7ecxGYO8k3ZZz3RS8fVIXVxONVUPlNERoyfY=