Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for zstd compression #2021

Merged
merged 9 commits into from
Sep 7, 2020
1 change: 1 addition & 0 deletions .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ addons:
apt:
packages:
- libsnappy-dev
- libzstd-dev
- openjdk-8-jdk

cache:
Expand Down
9 changes: 5 additions & 4 deletions docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -122,11 +122,12 @@ multiprocessing is recommended.
Compression
***********

kafka-python supports gzip compression/decompression natively. To produce or
consume lz4 compressed messages, you should install python-lz4 (pip install lz4).
To enable snappy, install python-snappy (also requires snappy library).
See `Installation <install.html#optional-snappy-install>`_ for more information.
kafka-python supports multiple compression types:

- gzip : supported natively
- lz4 : requires `python-lz4 <https://pypi.org/project/lz4/>`_ installed
- snappy : requires the `python-snappy <https://pypi.org/project/python-snappy/>`_ package (which requires the snappy C library)
- zstd : requires the `python-zstandard <https://github.com/indygreg/python-zstandard>`_ package installed

Protocol
********
Expand Down
25 changes: 25 additions & 0 deletions kafka/codec.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,18 @@

_XERIAL_V1_HEADER = (-126, b'S', b'N', b'A', b'P', b'P', b'Y', 0, 1, 1)
_XERIAL_V1_FORMAT = 'bccccccBii'
ZSTD_MAX_OUTPUT_SIZE = 1024 * 1024

try:
import snappy
except ImportError:
snappy = None

try:
import zstandard as zstd
except ImportError:
zstd = None

try:
import lz4.frame as lz4

Expand Down Expand Up @@ -58,6 +64,10 @@ def has_snappy():
return snappy is not None


def has_zstd():
return zstd is not None


def has_lz4():
if lz4 is not None:
return True
Expand Down Expand Up @@ -299,3 +309,18 @@ def lz4_decode_old_kafka(payload):
payload[header_size:]
])
return lz4_decode(munged_payload)


def zstd_encode(payload):
if not zstd:
raise NotImplementedError("Zstd codec is not available")
return zstd.ZstdCompressor().compress(payload)


def zstd_decode(payload):
if not zstd:
raise NotImplementedError("Zstd codec is not available")
try:
return zstd.ZstdDecompressor().decompress(payload)
except zstd.ZstdError:
return zstd.ZstdDecompressor().decompress(payload, max_output_size=ZSTD_MAX_OUTPUT_SIZE)
gabriel-tincu marked this conversation as resolved.
Show resolved Hide resolved
8 changes: 6 additions & 2 deletions kafka/producer/kafka.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

import kafka.errors as Errors
from kafka.client_async import KafkaClient, selectors
from kafka.codec import has_gzip, has_snappy, has_lz4
from kafka.codec import has_gzip, has_snappy, has_lz4, has_zstd
from kafka.metrics import MetricConfig, Metrics
from kafka.partitioner.default import DefaultPartitioner
from kafka.producer.future import FutureRecordMetadata, FutureProduceResult
Expand Down Expand Up @@ -119,7 +119,7 @@ class KafkaProducer(object):
available guarantee.
If unset, defaults to acks=1.
compression_type (str): The compression type for all data generated by
the producer. Valid values are 'gzip', 'snappy', 'lz4', or None.
the producer. Valid values are 'gzip', 'snappy', 'lz4', 'zstd' or None.
Compression is of full batches of data, so the efficacy of batching
will also impact the compression ratio (more batching means better
compression). Default: None.
Expand Down Expand Up @@ -339,6 +339,7 @@ class KafkaProducer(object):
'gzip': (has_gzip, LegacyRecordBatchBuilder.CODEC_GZIP),
'snappy': (has_snappy, LegacyRecordBatchBuilder.CODEC_SNAPPY),
'lz4': (has_lz4, LegacyRecordBatchBuilder.CODEC_LZ4),
'zstd': (has_zstd, DefaultRecordBatchBuilder.CODEC_ZSTD),
None: (lambda: True, LegacyRecordBatchBuilder.CODEC_NONE),
}

Expand Down Expand Up @@ -388,6 +389,9 @@ def __init__(self, **configs):
if self.config['compression_type'] == 'lz4':
assert self.config['api_version'] >= (0, 8, 2), 'LZ4 Requires >= Kafka 0.8.2 Brokers'

if self.config['compression_type'] == 'zstd':
assert self.config['api_version'] >= (2, 1, 0), 'Zstd Requires >= Kafka 2.1.0 Brokers'

# Check compression_type for library support
ct = self.config['compression_type']
if ct not in self._COMPRESSORS:
Expand Down
10 changes: 7 additions & 3 deletions kafka/protocol/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
import io
import time

from kafka.codec import (has_gzip, has_snappy, has_lz4,
gzip_decode, snappy_decode,
from kafka.codec import (has_gzip, has_snappy, has_lz4, has_zstd,
gzip_decode, snappy_decode, zstd_decode,
lz4_decode, lz4_decode_old_kafka)
from kafka.protocol.frame import KafkaBytes
from kafka.protocol.struct import Struct
Expand Down Expand Up @@ -35,6 +35,7 @@ class Message(Struct):
CODEC_GZIP = 0x01
CODEC_SNAPPY = 0x02
CODEC_LZ4 = 0x03
CODEC_ZSTD = 0x04
TIMESTAMP_TYPE_MASK = 0x08
HEADER_SIZE = 22 # crc(4), magic(1), attributes(1), timestamp(8), key+value size(4*2)

Expand Down Expand Up @@ -119,7 +120,7 @@ def is_compressed(self):

def decompress(self):
codec = self.attributes & self.CODEC_MASK
assert codec in (self.CODEC_GZIP, self.CODEC_SNAPPY, self.CODEC_LZ4)
assert codec in (self.CODEC_GZIP, self.CODEC_SNAPPY, self.CODEC_LZ4, self.CODEC_ZSTD)
if codec == self.CODEC_GZIP:
assert has_gzip(), 'Gzip decompression unsupported'
raw_bytes = gzip_decode(self.value)
Expand All @@ -132,6 +133,9 @@ def decompress(self):
raw_bytes = lz4_decode_old_kafka(self.value)
else:
raw_bytes = lz4_decode(self.value)
elif codec == self.CODEC_ZSTD:
assert has_zstd(), "ZSTD decompression unsupported"
raw_bytes = zstd_decode(self.value)
else:
raise Exception('This should be impossible')

Expand Down
11 changes: 9 additions & 2 deletions kafka/record/default_records.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,8 @@
)
from kafka.errors import CorruptRecordException, UnsupportedCodecError
from kafka.codec import (
gzip_encode, snappy_encode, lz4_encode,
gzip_decode, snappy_decode, lz4_decode
gzip_encode, snappy_encode, lz4_encode, zstd_encode,
gzip_decode, snappy_decode, lz4_decode, zstd_decode
)
import kafka.codec as codecs

Expand Down Expand Up @@ -97,6 +97,7 @@ class DefaultRecordBase(object):
CODEC_GZIP = 0x01
CODEC_SNAPPY = 0x02
CODEC_LZ4 = 0x03
CODEC_ZSTD = 0x04
TIMESTAMP_TYPE_MASK = 0x08
TRANSACTIONAL_MASK = 0x10
CONTROL_MASK = 0x20
Expand All @@ -111,6 +112,8 @@ def _assert_has_codec(self, compression_type):
checker, name = codecs.has_snappy, "snappy"
elif compression_type == self.CODEC_LZ4:
checker, name = codecs.has_lz4, "lz4"
elif compression_type == self.CODEC_ZSTD:
checker, name = codecs.has_zstd, "zstd"
if not checker():
raise UnsupportedCodecError(
"Libraries for {} compression codec not found".format(name))
Expand Down Expand Up @@ -185,6 +188,8 @@ def _maybe_uncompress(self):
uncompressed = snappy_decode(data.tobytes())
if compression_type == self.CODEC_LZ4:
uncompressed = lz4_decode(data.tobytes())
if compression_type == self.CODEC_ZSTD:
uncompressed = zstd_decode(data.tobytes())
self._buffer = bytearray(uncompressed)
self._pos = 0
self._decompressed = True
Expand Down Expand Up @@ -517,6 +522,8 @@ def _maybe_compress(self):
compressed = snappy_encode(data)
elif self._compression_type == self.CODEC_LZ4:
compressed = lz4_encode(data)
elif self._compression_type == self.CODEC_ZSTD:
compressed = zstd_encode(data)
compressed_size = len(compressed)
if len(data) <= compressed_size:
# We did not get any benefit from compression, lets send
Expand Down
2 changes: 1 addition & 1 deletion kafka/record/memory_records.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ class MemoryRecordsBuilder(object):

def __init__(self, magic, compression_type, batch_size):
assert magic in [0, 1, 2], "Not supported magic"
assert compression_type in [0, 1, 2, 3], "Not valid compression type"
assert compression_type in [0, 1, 2, 3, 4], "Not valid compression type"
if magic >= 2:
self._builder = DefaultRecordBatchBuilder(
magic=magic, compression_type=compression_type,
Expand Down
11 changes: 10 additions & 1 deletion test/test_codec.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,12 @@
from kafka.vendor.six.moves import range

from kafka.codec import (
has_snappy, has_lz4,
has_snappy, has_lz4, has_zstd,
gzip_encode, gzip_decode,
snappy_encode, snappy_decode,
lz4_encode, lz4_decode,
lz4_encode_old_kafka, lz4_decode_old_kafka,
zstd_encode, zstd_decode,
)

from test.testutil import random_string
Expand Down Expand Up @@ -113,3 +114,11 @@ def test_lz4_incremental():
b2 = lz4_decode(lz4_encode(b1))
assert len(b1) == len(b2)
assert b1 == b2


@pytest.mark.skipif(not has_zstd(), reason="Zstd not available")
def test_zstd():
for _ in range(1000):
b1 = random_string(100).encode('utf-8')
b2 = zstd_decode(zstd_encode(b1))
assert b1 == b2
20 changes: 10 additions & 10 deletions test/test_producer.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,16 +23,16 @@ def test_buffer_pool():


@pytest.mark.skipif(not env_kafka_version(), reason="No KAFKA_VERSION set")
@pytest.mark.parametrize("compression", [None, 'gzip', 'snappy', 'lz4'])
@pytest.mark.parametrize("compression", [None, 'gzip', 'snappy', 'lz4', 'zstd'])
def test_end_to_end(kafka_broker, compression):

if compression == 'lz4':
# LZ4 requires 0.8.2
if env_kafka_version() < (0, 8, 2):
return
# python-lz4 crashes on older versions of pypy
pytest.skip('LZ4 requires 0.8.2')
gabriel-tincu marked this conversation as resolved.
Show resolved Hide resolved
elif platform.python_implementation() == 'PyPy':
return
pytest.skip('python-lz4 crashes on older versions of pypy')

if compression == 'zstd' and env_kafka_version() < (2, 1, 0):
pytest.skip('zstd requires kafka 2.1.0 or newer')

connect_str = ':'.join([kafka_broker.host, str(kafka_broker.port)])
producer = KafkaProducer(bootstrap_servers=connect_str,
Expand Down Expand Up @@ -81,8 +81,10 @@ def test_kafka_producer_gc_cleanup():


@pytest.mark.skipif(not env_kafka_version(), reason="No KAFKA_VERSION set")
@pytest.mark.parametrize("compression", [None, 'gzip', 'snappy', 'lz4'])
@pytest.mark.parametrize("compression", [None, 'gzip', 'snappy', 'lz4', 'zstd'])
def test_kafka_producer_proper_record_metadata(kafka_broker, compression):
if compression == 'zstd' and env_kafka_version() < (2, 1, 0):
pytest.skip('zstd requires 2.1.0 or more')
connect_str = ':'.join([kafka_broker.host, str(kafka_broker.port)])
producer = KafkaProducer(bootstrap_servers=connect_str,
retries=5,
Expand Down Expand Up @@ -124,10 +126,8 @@ def test_kafka_producer_proper_record_metadata(kafka_broker, compression):
if headers:
assert record.serialized_header_size == 22

# generated timestamp case is skipped for broker 0.9 and below
if magic == 0:
return

pytest.skip('generated timestamp case is skipped for broker 0.9 and below')
send_time = time.time() * 1000
future = producer.send(
topic,
Expand Down
1 change: 1 addition & 0 deletions tox.ini
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ deps =
pytest-mock
mock
python-snappy
zstandard
lz4
xxhash
crc32c
Expand Down