Skip to content

Commit

Permalink
Type annotations in aiokafka/codec.py (#984)
Browse files Browse the repository at this point in the history
  • Loading branch information
ods authored Feb 24, 2024
1 parent bb15ecf commit dd7dcb0
Show file tree
Hide file tree
Showing 4 changed files with 34 additions and 29 deletions.
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ SCALA_VERSION?=2.13
KAFKA_VERSION?=2.8.1
DOCKER_IMAGE=aiolibs/kafka:$(SCALA_VERSION)_$(KAFKA_VERSION)
DIFF_BRANCH=origin/master
FORMATTED_AREAS=aiokafka/util.py aiokafka/structs.py
FORMATTED_AREAS=aiokafka/util.py aiokafka/structs.py aiokafka/codec.py tests/test_codec.py

.PHONY: setup
setup:
Expand Down
44 changes: 24 additions & 20 deletions aiokafka/codec.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
from __future__ import annotations

import gzip
import io
import struct

from typing_extensions import Buffer

_XERIAL_V1_HEADER = (-126, b"S", b"N", b"A", b"P", b"P", b"Y", 0, 1, 1)
_XERIAL_V1_FORMAT = "bccccccBii"
ZSTD_MAX_OUTPUT_SIZE = 1024 * 1024
Expand All @@ -12,23 +16,23 @@
cramjam = None


def has_gzip():
def has_gzip() -> bool:
return True


def has_snappy():
def has_snappy() -> bool:
return cramjam is not None


def has_zstd():
def has_zstd() -> bool:
return cramjam is not None


def has_lz4():
def has_lz4() -> bool:
return cramjam is not None


def gzip_encode(payload, compresslevel=None):
def gzip_encode(payload: Buffer, compresslevel: int | None = None) -> bytes:
if not compresslevel:
compresslevel = 9

Expand All @@ -45,7 +49,7 @@ def gzip_encode(payload, compresslevel=None):
return buf.getvalue()


def gzip_decode(payload):
def gzip_decode(payload: Buffer) -> bytes:
buf = io.BytesIO(payload)

# Gzip context manager introduced in python 2.7
Expand All @@ -57,7 +61,9 @@ def gzip_decode(payload):
gzipper.close()


def snappy_encode(payload, xerial_compatible=True, xerial_blocksize=32 * 1024):
def snappy_encode(
payload: Buffer, xerial_compatible: bool = True, xerial_blocksize: int = 32 * 1024
) -> bytes:
"""Encodes the given data with snappy compression.
If xerial_compatible is set then the stream is encoded in a fashion
Expand Down Expand Up @@ -93,12 +99,9 @@ def snappy_encode(payload, xerial_compatible=True, xerial_blocksize=32 * 1024):
for fmt, dat in zip(_XERIAL_V1_FORMAT, _XERIAL_V1_HEADER):
out.write(struct.pack("!" + fmt, dat))

# Chunk through buffers to avoid creating intermediate slice copies
def chunker(payload, i, size):
return memoryview(payload)[i : size + i]

payload = memoryview(payload)
for chunk in (
chunker(payload, i, xerial_blocksize)
payload[i : i + xerial_blocksize]
for i in range(0, len(payload), xerial_blocksize)
):
block = cramjam.snappy.compress_raw(chunk)
Expand All @@ -109,7 +112,7 @@ def chunker(payload, i, size):
return out.getvalue()


def _detect_xerial_stream(payload):
def _detect_xerial_stream(payload: Buffer) -> bool:
"""Detects if the data given might have been encoded with the blocking mode
of the xerial snappy library.
Expand All @@ -131,20 +134,21 @@ def _detect_xerial_stream(payload):
1.
"""

payload = memoryview(payload)
if len(payload) > 16:
header = struct.unpack("!" + _XERIAL_V1_FORMAT, memoryview(payload)[:16])
header = struct.unpack("!" + _XERIAL_V1_FORMAT, payload[:16])
return header == _XERIAL_V1_HEADER
return False


def snappy_decode(payload):
def snappy_decode(payload: Buffer) -> bytes:
if not has_snappy():
raise NotImplementedError("Snappy codec is not available")

if _detect_xerial_stream(payload):
# TODO ? Should become a fileobj ?
out = io.BytesIO()
byt = payload[16:]
byt = memoryview(payload)[16:]
length = len(byt)
cursor = 0

Expand All @@ -162,7 +166,7 @@ def snappy_decode(payload):
return bytes(cramjam.snappy.decompress_raw(payload))


def lz4_encode(payload, level=9):
def lz4_encode(payload: Buffer, level: int = 9) -> bytes:
# level=9 is used by default by broker itself
# https://cwiki.apache.org/confluence/display/KAFKA/KIP-390%3A+Support+Compression+Level
if not has_lz4():
Expand All @@ -177,14 +181,14 @@ def lz4_encode(payload, level=9):
return bytes(compressor.finish())


def lz4_decode(payload):
def lz4_decode(payload: Buffer) -> bytes:
if not has_lz4():
raise NotImplementedError("LZ4 codec is not available")

return bytes(cramjam.lz4.decompress(payload))


def zstd_encode(payload, level=None):
def zstd_encode(payload: Buffer, level: int | None = None) -> bytes:
if not has_zstd():
raise NotImplementedError("Zstd codec is not available")

Expand All @@ -196,7 +200,7 @@ def zstd_encode(payload, level=None):
return bytes(cramjam.zstd.compress(payload, level=level))


def zstd_decode(payload):
def zstd_decode(payload: Buffer) -> bytes:
if not has_zstd():
raise NotImplementedError("Zstd codec is not available")

Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ dynamic = ["version"]
dependencies = [
"async-timeout",
"packaging",
"typing_extensions >=4.6.0",
]

[project.optional-dependencies]
Expand Down
16 changes: 8 additions & 8 deletions tests/test_codec.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,23 +20,23 @@
from ._testutil import random_string


def test_gzip():
def test_gzip() -> None:
for i in range(1000):
b1 = random_string(100)
b2 = gzip_decode(gzip_encode(b1))
assert b1 == b2


@pytest.mark.skipif(not has_snappy(), reason="Snappy not available")
def test_snappy():
def test_snappy() -> None:
for i in range(1000):
b1 = random_string(100)
b2 = snappy_decode(snappy_encode(b1))
assert b1 == b2


@pytest.mark.skipif(not has_snappy(), reason="Snappy not available")
def test_snappy_detect_xerial():
def test_snappy_detect_xerial() -> None:
_detect_xerial_stream = codecs._detect_xerial_stream

header = b"\x82SNAPPY\x00\x00\x00\x00\x01\x00\x00\x00\x01Some extra bytes"
Expand All @@ -55,7 +55,7 @@ def test_snappy_detect_xerial():


@pytest.mark.skipif(not has_snappy(), reason="Snappy not available")
def test_snappy_decode_xerial():
def test_snappy_decode_xerial() -> None:
header = b"\x82SNAPPY\x00\x00\x00\x00\x01\x00\x00\x00\x01"
random_snappy = snappy_encode(b"SNAPPY" * 50, xerial_compatible=False)
block_len = len(random_snappy)
Expand All @@ -73,7 +73,7 @@ def test_snappy_decode_xerial():


@pytest.mark.skipif(not has_snappy(), reason="Snappy not available")
def test_snappy_encode_xerial():
def test_snappy_encode_xerial() -> None:
to_ensure = (
b"\x82SNAPPY\x00\x00\x00\x00\x01\x00\x00\x00\x01"
b"\x00\x00\x00\x18\xac\x02\x14SNAPPY\xfe\x06\x00\xfe\x06\x00\xfe\x06\x00"
Expand All @@ -88,7 +88,7 @@ def test_snappy_encode_xerial():


@pytest.mark.skipif(not has_lz4(), reason="LZ4 not available")
def test_lz4():
def test_lz4() -> None:
for i in range(1000):
b1 = random_string(100)
b2 = lz4_decode(lz4_encode(b1))
Expand All @@ -97,7 +97,7 @@ def test_lz4():


@pytest.mark.skipif(not has_lz4(), reason="LZ4 not available")
def test_lz4_incremental():
def test_lz4_incremental() -> None:
for i in range(1000):
# lz4 max single block size is 4MB
# make sure we test with multiple-blocks
Expand All @@ -108,7 +108,7 @@ def test_lz4_incremental():


@pytest.mark.skipif(not has_zstd(), reason="Zstd not available")
def test_zstd():
def test_zstd() -> None:
for _ in range(1000):
b1 = random_string(100)
b2 = zstd_decode(zstd_encode(b1))
Expand Down

0 comments on commit dd7dcb0

Please sign in to comment.