diff --git a/Makefile b/Makefile index ca394ed2..9882c87a 100644 --- a/Makefile +++ b/Makefile @@ -12,8 +12,10 @@ FORMATTED_AREAS=\ aiokafka/structs.py \ aiokafka/util.py \ aiokafka/protocol/ \ + aiokafka/record/ \ tests/test_codec.py \ - tests/test_helpers.py + tests/test_helpers.py \ + tests/record/ .PHONY: setup setup: diff --git a/aiokafka/record/_crc32c.py b/aiokafka/record/_crc32c.py index 0c1209af..22158249 100644 --- a/aiokafka/record/_crc32c.py +++ b/aiokafka/record/_crc32c.py @@ -23,6 +23,7 @@ """ import array +from typing import Iterable # fmt: off CRC_TABLE = ( @@ -97,7 +98,7 @@ _MASK = 0xFFFFFFFF -def crc_update(crc, data): +def crc_update(crc: int, data: Iterable[int]) -> int: """Update CRC-32C checksum with data. Args: crc: 32-bit checksum to update as long. @@ -116,7 +117,7 @@ def crc_update(crc, data): return crc ^ _MASK -def crc_finalize(crc): +def crc_finalize(crc: int) -> int: """Finalize CRC-32C checksum. This function should be called as last step of crc calculation. Args: @@ -127,7 +128,7 @@ def crc_finalize(crc): return crc & _MASK -def crc(data): +def crc(data: Iterable[int]) -> int: """Compute CRC-32C checksum of the data. Args: data: byte array, string or iterable over bytes. diff --git a/aiokafka/record/_crecords/cutil.pyi b/aiokafka/record/_crecords/cutil.pyi new file mode 100644 index 00000000..d8615b05 --- /dev/null +++ b/aiokafka/record/_crecords/cutil.pyi @@ -0,0 +1,8 @@ +from typing import Callable + +from typing_extensions import Buffer + +def crc32c_cython(data: Buffer) -> int: ... +def decode_varint_cython(buffer: bytearray, pos: int) -> tuple[int, int]: ... +def encode_varint_cython(value: int, write: Callable[[int], None]) -> int: ... +def size_of_varint_cython(value: int) -> int: ... diff --git a/aiokafka/record/_crecords/default_records.pxd b/aiokafka/record/_crecords/default_records.pxd index 7da6fc98..d16007ab 100644 --- a/aiokafka/record/_crecords/default_records.pxd +++ b/aiokafka/record/_crecords/default_records.pxd @@ -43,8 +43,8 @@ cdef class DefaultRecord: cdef: readonly int64_t offset - int64_t timestamp - char timestamp_type + readonly int64_t timestamp + readonly char timestamp_type readonly object key readonly object value readonly object headers diff --git a/aiokafka/record/_crecords/default_records.pyi b/aiokafka/record/_crecords/default_records.pyi new file mode 100644 index 00000000..0910f9de --- /dev/null +++ b/aiokafka/record/_crecords/default_records.pyi @@ -0,0 +1,148 @@ +from typing import ClassVar, final + +from typing_extensions import Literal, Self + +from aiokafka.record._protocols import ( + DefaultRecordBatchBuilderProtocol, + DefaultRecordBatchProtocol, + DefaultRecordMetadataProtocol, + DefaultRecordProtocol, +) +from aiokafka.record._types import ( + CodecGzipT, + CodecLz4T, + CodecMaskT, + CodecNoneT, + CodecSnappyT, + CodecZstdT, + DefaultCompressionTypeT, +) + +@final +class DefaultRecord(DefaultRecordProtocol): + def __init__( + self, + offset: int, + timestamp: int, + timestamp_type: int, + key: bytes | None, + value: bytes | None, + headers: list[tuple[str, bytes | None]], + ) -> None: ... + @property + def offset(self) -> int: ... + @property + def timestamp(self) -> int: ... + @property + def timestamp_type(self) -> int: ... + @property + def key(self) -> bytes | None: ... + @property + def value(self) -> bytes | None: ... + @property + def headers(self) -> list[tuple[str, bytes | None]]: ... + @property + def checksum(self) -> None: ... + +@final +class DefaultRecordBatch(DefaultRecordBatchProtocol): + CODEC_MASK: ClassVar[CodecMaskT] + CODEC_NONE: ClassVar[CodecNoneT] + CODEC_GZIP: ClassVar[CodecGzipT] + CODEC_SNAPPY: ClassVar[CodecSnappyT] + CODEC_LZ4: ClassVar[CodecLz4T] + CODEC_ZSTD: ClassVar[CodecZstdT] + + def __init__(self, buffer: bytes): ... + @property + def compression_type(self) -> int: ... + @property + def is_transactional(self) -> bool: ... + @property + def is_control_batch(self) -> bool: ... + @property + def next_offset(self) -> int: ... + def __iter__(self) -> Self: ... + def __next__(self) -> DefaultRecord: ... + def validate_crc(self) -> bool: ... + @property + def base_offset(self) -> int: ... + @property + def magic(self) -> int: ... + @property + def crc(self) -> int: ... + @property + def attributes(self) -> int: ... + @property + def last_offset_delta(self) -> int: ... + @property + def first_timestamp(self) -> int: ... + @property + def max_timestamp(self) -> int: ... + @property + def producer_id(self) -> int: ... + @property + def producer_epoch(self) -> int: ... + @property + def base_sequence(self) -> int: ... + @property + def timestamp_type(self) -> Literal[0, 1]: ... + +@final +class DefaultRecordBatchBuilder(DefaultRecordBatchBuilderProtocol): + producer_id: int + producer_epoch: int + base_sequence: int + def __init__( + self, + magic: int, + compression_type: DefaultCompressionTypeT, + is_transactional: int, + producer_id: int, + producer_epoch: int, + base_sequence: int, + batch_size: int, + ) -> None: ... + def set_producer_state( + self, producer_id: int, producer_epoch: int, base_sequence: int + ) -> None: ... + def append( + self, + offset: int, + timestamp: int | None, + key: bytes | None, + value: bytes | None, + headers: list[tuple[str, bytes | None]], + ) -> DefaultRecordMetadata: ... + def build(self) -> bytearray: ... + def size(self) -> int: ... + def size_in_bytes( + self, + offset: int, + timestamp: int, + key: bytes | None, + value: bytes | None, + headers: list[tuple[str, bytes | None]], + ) -> int: ... + @classmethod + def size_of( + cls, + key: bytes | None, + value: bytes | None, + headers: list[tuple[str, bytes | None]], + ) -> int: ... + @classmethod + def estimate_size_in_bytes( + cls, + key: bytes | None, + value: bytes | None, + headers: list[tuple[str, bytes | None]], + ) -> int: ... + +@final +class DefaultRecordMetadata(DefaultRecordMetadataProtocol): + offset: int + size: int + timestamp: int + crc: None + def __init__(self, offset: int, size: int, timestamp: int): ... diff --git a/aiokafka/record/_crecords/default_records.pyx b/aiokafka/record/_crecords/default_records.pyx index ba49411d..40a945e2 100644 --- a/aiokafka/record/_crecords/default_records.pyx +++ b/aiokafka/record/_crecords/default_records.pyx @@ -448,20 +448,6 @@ cdef class DefaultRecord: record.headers = headers return record - @property - def timestamp(self): - if self.timestamp != -1: - return self.timestamp - else: - return None - - @property - def timestamp_type(self): - if self.timestamp != -1: - return self.timestamp_type - else: - return None - def __repr__(self): return ( "DefaultRecord(offset={!r}, timestamp={!r}, timestamp_type={!r}," diff --git a/aiokafka/record/_crecords/legacy_records.pyi b/aiokafka/record/_crecords/legacy_records.pyi new file mode 100644 index 00000000..b160d99b --- /dev/null +++ b/aiokafka/record/_crecords/legacy_records.pyi @@ -0,0 +1,95 @@ +from typing import Any, ClassVar, Generator, final + +from typing_extensions import Buffer, Literal, Never + +from aiokafka.record._protocols import ( + LegacyRecordBatchBuilderProtocol, + LegacyRecordBatchProtocol, + LegacyRecordMetadataProtocol, + LegacyRecordProtocol, +) +from aiokafka.record._types import ( + CodecGzipT, + CodecLz4T, + CodecMaskT, + CodecSnappyT, + LegacyCompressionTypeT, +) + +@final +class LegacyRecord(LegacyRecordProtocol): + def __init__( + self, + offset: int, + timestamp: int, + attributes: int, + key: bytes | None, + value: bytes | None, + crc: int, + ) -> None: ... + @property + def offset(self) -> int: ... + @property + def key(self) -> bytes | None: ... + @property + def value(self) -> bytes | None: ... + @property + def headers(self) -> list[Never]: ... + @property + def timestamp(self) -> int | None: ... + @property + def timestamp_type(self) -> Literal[0, 1] | None: ... + @property + def checksum(self) -> int: ... + +@final +class LegacyRecordBatch(LegacyRecordBatchProtocol): + RECORD_OVERHEAD_V0: ClassVar[int] + RECORD_OVERHEAD_V1: ClassVar[int] + CODEC_MASK: ClassVar[CodecMaskT] + CODEC_GZIP: ClassVar[CodecGzipT] + CODEC_SNAPPY: ClassVar[CodecSnappyT] + CODEC_LZ4: ClassVar[CodecLz4T] + + is_control_batch: bool + is_transactional: bool + producer_id: int | None + def __init__(self, buffer: Buffer, magic: int) -> None: ... + @property + def next_offset(self) -> int: ... + def validate_crc(self) -> bool: ... + def __iter__(self) -> Generator[LegacyRecord, None, None]: ... + +@final +class LegacyRecordBatchBuilder(LegacyRecordBatchBuilderProtocol): + CODEC_MASK: ClassVar[CodecMaskT] + CODEC_GZIP: ClassVar[CodecGzipT] + CODEC_SNAPPY: ClassVar[CodecSnappyT] + CODEC_LZ4: ClassVar[CodecLz4T] + + def __init__( + self, magic: int, compression_type: LegacyCompressionTypeT, batch_size: int + ) -> None: ... + def append( + self, + offset: int, + timestamp: int | None, + key: bytes | None, + value: bytes | None, + headers: Any = None, + ) -> LegacyRecordMetadata: ... + def size(self) -> int: ... + def size_in_bytes( + self, offset: Any, timestamp: Any, key: Buffer | None, value: Buffer | None + ) -> int: ... + @staticmethod + def record_overhead(magic: int) -> int: ... + def build(self) -> bytearray: ... + +@final +class LegacyRecordMetadata(LegacyRecordMetadataProtocol): + offset: int + crc: int + size: int + timestamp: int + def __init__(self, offset: int, crc: int, size: int, timestamp: int) -> None: ... diff --git a/aiokafka/record/_crecords/memory_records.pyi b/aiokafka/record/_crecords/memory_records.pyi new file mode 100644 index 00000000..d8dfa409 --- /dev/null +++ b/aiokafka/record/_crecords/memory_records.pyi @@ -0,0 +1,13 @@ +from typing import final + +from aiokafka.record._protocols import MemoryRecordsProtocol + +from .default_records import DefaultRecordBatch +from .legacy_records import LegacyRecordBatch + +@final +class MemoryRecords(MemoryRecordsProtocol): + def __init__(self, bytes_data: bytes) -> None: ... + def size_in_bytes(self) -> int: ... + def has_next(self) -> bool: ... + def next_batch(self) -> DefaultRecordBatch | LegacyRecordBatch | None: ... diff --git a/aiokafka/record/_protocols.py b/aiokafka/record/_protocols.py new file mode 100644 index 00000000..176932b1 --- /dev/null +++ b/aiokafka/record/_protocols.py @@ -0,0 +1,276 @@ +from __future__ import annotations + +from typing import ( + Any, + ClassVar, + Iterable, + Iterator, + List, + Optional, + Protocol, + Tuple, + Union, + runtime_checkable, +) + +from typing_extensions import Literal, Never + +from ._types import ( + CodecGzipT, + CodecLz4T, + CodecMaskT, + CodecNoneT, + CodecSnappyT, + CodecZstdT, + DefaultCompressionTypeT, + LegacyCompressionTypeT, +) + + +class DefaultRecordBatchBuilderProtocol(Protocol): + def __init__( + self, + magic: int, + compression_type: DefaultCompressionTypeT, + is_transactional: int, + producer_id: int, + producer_epoch: int, + base_sequence: int, + batch_size: int, + ): ... + def append( + self, + offset: int, + timestamp: Optional[int], + key: Optional[bytes], + value: Optional[bytes], + headers: List[Tuple[str, Optional[bytes]]], + ) -> Optional[DefaultRecordMetadataProtocol]: ... + def build(self) -> bytearray: ... + def size(self) -> int: ... + def size_in_bytes( + self, + offset: int, + timestamp: int, + key: Optional[bytes], + value: Optional[bytes], + headers: List[Tuple[str, Optional[bytes]]], + ) -> int: ... + @classmethod + def size_of( + cls, + key: Optional[bytes], + value: Optional[bytes], + headers: List[Tuple[str, Optional[bytes]]], + ) -> int: ... + @classmethod + def estimate_size_in_bytes( + cls, + key: Optional[bytes], + value: Optional[bytes], + headers: List[Tuple[str, Optional[bytes]]], + ) -> int: ... + def set_producer_state( + self, producer_id: int, producer_epoch: int, base_sequence: int + ) -> None: ... + @property + def producer_id(self) -> int: ... + @property + def producer_epoch(self) -> int: ... + @property + def base_sequence(self) -> int: ... + + +class DefaultRecordMetadataProtocol(Protocol): + def __init__(self, offset: int, size: int, timestamp: int) -> None: ... + @property + def offset(self) -> int: ... + @property + def crc(self) -> None: ... + @property + def size(self) -> int: ... + @property + def timestamp(self) -> int: ... + + +class DefaultRecordBatchProtocol(Iterator["DefaultRecordProtocol"], Protocol): + CODEC_MASK: ClassVar[CodecMaskT] + CODEC_NONE: ClassVar[CodecNoneT] + CODEC_GZIP: ClassVar[CodecGzipT] + CODEC_SNAPPY: ClassVar[CodecSnappyT] + CODEC_LZ4: ClassVar[CodecLz4T] + CODEC_ZSTD: ClassVar[CodecZstdT] + + def __init__(self, buffer: Union[bytes, bytearray, memoryview]) -> None: ... + @property + def base_offset(self) -> int: ... + @property + def magic(self) -> int: ... + @property + def crc(self) -> int: ... + @property + def attributes(self) -> int: ... + @property + def compression_type(self) -> int: ... + @property + def timestamp_type(self) -> int: ... + @property + def is_transactional(self) -> bool: ... + @property + def is_control_batch(self) -> bool: ... + @property + def last_offset_delta(self) -> int: ... + @property + def first_timestamp(self) -> int: ... + @property + def max_timestamp(self) -> int: ... + @property + def producer_id(self) -> int: ... + @property + def producer_epoch(self) -> int: ... + @property + def base_sequence(self) -> int: ... + @property + def next_offset(self) -> int: ... + def validate_crc(self) -> bool: ... + + +@runtime_checkable +class DefaultRecordProtocol(Protocol): + def __init__( + self, + offset: int, + timestamp: int, + timestamp_type: int, + key: Optional[bytes], + value: Optional[bytes], + headers: List[Tuple[str, Optional[bytes]]], + ) -> None: ... + @property + def offset(self) -> int: ... + @property + def timestamp(self) -> int: + """Epoch milliseconds""" + + @property + def timestamp_type(self) -> int: + """CREATE_TIME(0) or APPEND_TIME(1)""" + + @property + def key(self) -> Optional[bytes]: + """Bytes key or None""" + + @property + def value(self) -> Optional[bytes]: + """Bytes value or None""" + + @property + def headers(self) -> List[Tuple[str, Optional[bytes]]]: ... + @property + def checksum(self) -> None: ... + + +class LegacyRecordBatchBuilderProtocol(Protocol): + def __init__( + self, + magic: Literal[0, 1], + compression_type: LegacyCompressionTypeT, + batch_size: int, + ) -> None: ... + def append( + self, + offset: int, + timestamp: Optional[int], + key: Optional[bytes], + value: Optional[bytes], + headers: Any = None, + ) -> Optional[LegacyRecordMetadataProtocol]: ... + def build(self) -> bytearray: + """Compress batch to be ready for send""" + + def size(self) -> int: + """Return current size of data written to buffer""" + + def size_in_bytes( + self, + offset: int, + timestamp: int, + key: Optional[bytes], + value: Optional[bytes], + ) -> int: + """Actual size of message to add""" + + @classmethod + def record_overhead(cls, magic: int) -> int: ... + + +class LegacyRecordMetadataProtocol(Protocol): + def __init__(self, offset: int, crc: int, size: int, timestamp: int) -> None: ... + @property + def offset(self) -> int: ... + @property + def crc(self) -> int: ... + @property + def size(self) -> int: ... + @property + def timestamp(self) -> int: ... + + +class LegacyRecordBatchProtocol(Iterable["LegacyRecordProtocol"], Protocol): + CODEC_MASK: ClassVar[CodecMaskT] + CODEC_GZIP: ClassVar[CodecGzipT] + CODEC_SNAPPY: ClassVar[CodecSnappyT] + CODEC_LZ4: ClassVar[CodecLz4T] + + is_control_batch: bool + is_transactional: bool + producer_id: Optional[int] + + def __init__(self, buffer: Union[bytes, bytearray, memoryview], magic: int): ... + @property + def next_offset(self) -> int: ... + def validate_crc(self) -> bool: ... + + +@runtime_checkable +class LegacyRecordProtocol(Protocol): + def __init__( + self, + offset: int, + timestamp: Optional[int], + timestamp_type: Optional[Literal[0, 1]], + key: Optional[bytes], + value: Optional[bytes], + crc: int, + ) -> None: ... + @property + def offset(self) -> int: ... + @property + def timestamp(self) -> Optional[int]: + """Epoch milliseconds""" + + @property + def timestamp_type(self) -> Optional[Literal[0, 1]]: + """CREATE_TIME(0) or APPEND_TIME(1)""" + + @property + def key(self) -> Optional[bytes]: + """Bytes key or None""" + + @property + def value(self) -> Optional[bytes]: + """Bytes value or None""" + + @property + def headers(self) -> List[Never]: ... + @property + def checksum(self) -> int: ... + + +class MemoryRecordsProtocol(Protocol): + def __init__(self, bytes_data: bytes) -> None: ... + def size_in_bytes(self) -> int: ... + def has_next(self) -> bool: ... + def next_batch( + self, + ) -> Optional[Union[DefaultRecordBatchProtocol, LegacyRecordBatchProtocol]]: ... diff --git a/aiokafka/record/_types.py b/aiokafka/record/_types.py new file mode 100644 index 00000000..9a3c121f --- /dev/null +++ b/aiokafka/record/_types.py @@ -0,0 +1,14 @@ +from typing import Union + +from typing_extensions import Literal + +CodecNoneT = Literal[0x00] +CodecGzipT = Literal[0x01] +CodecSnappyT = Literal[0x02] +CodecLz4T = Literal[0x03] +CodecZstdT = Literal[0x04] +CodecMaskT = Literal[0x07] +DefaultCompressionTypeT = Union[ + CodecGzipT, CodecLz4T, CodecNoneT, CodecSnappyT, CodecZstdT +] +LegacyCompressionTypeT = Union[CodecGzipT, CodecLz4T, CodecSnappyT, CodecNoneT] diff --git a/aiokafka/record/control_record.py b/aiokafka/record/control_record.py index cd09369e..ed92105e 100644 --- a/aiokafka/record/control_record.py +++ b/aiokafka/record/control_record.py @@ -1,33 +1,30 @@ import struct +from dataclasses import dataclass + +from typing_extensions import Self _SCHEMA = struct.Struct(">HH") +@dataclass(frozen=True) class ControlRecord: - def __init__(self, version, type_): - self._version = version - self._type = type_ - - @property - def version(self): - return self._version + __slots__ = ("version", "type_") - @property - def type_(self): - return self._type + version: int + type_: int - def __eq__(self, other): + def __eq__(self, other: object) -> bool: if isinstance(other, ControlRecord): - return other._version == self._version and other._type == self._type + return other.version == self.version and other.type_ == self.type_ return False @classmethod - def parse(cls, data: bytes): + def parse(cls, data: bytes) -> Self: version, type_ = _SCHEMA.unpack_from(data) return cls(version, type_) - def __repr__(self): - return f"ControlRecord(version={self._version}, type_={self._type})" + def __repr__(self) -> str: + return f"ControlRecord(version={self.version}, type_={self.type_})" ABORT_MARKER = ControlRecord(0, 0) diff --git a/aiokafka/record/default_records.py b/aiokafka/record/default_records.py index aed9d3a2..8b0a596d 100644 --- a/aiokafka/record/default_records.py +++ b/aiokafka/record/default_records.py @@ -56,6 +56,10 @@ import struct import time +from dataclasses import dataclass +from typing import Any, Callable, List, Optional, Sized, Tuple, Type, Union, final + +from typing_extensions import Self, TypeIs, assert_never import aiokafka.codec as codecs from aiokafka.codec import ( @@ -71,6 +75,20 @@ from aiokafka.errors import CorruptRecordException, UnsupportedCodecError from aiokafka.util import NO_EXTENSIONS +from ._protocols import ( + DefaultRecordBatchBuilderProtocol, + DefaultRecordBatchProtocol, + DefaultRecordMetadataProtocol, + DefaultRecordProtocol, +) +from ._types import ( + CodecGzipT, + CodecLz4T, + CodecMaskT, + CodecNoneT, + CodecSnappyT, + CodecZstdT, +) from .util import calc_crc32c, decode_varint, encode_varint, size_of_varint @@ -97,12 +115,12 @@ class DefaultRecordBase: CRC_OFFSET = struct.calcsize(">qiib") AFTER_LEN_OFFSET = struct.calcsize(">qi") - CODEC_MASK = 0x07 - CODEC_NONE = 0x00 - CODEC_GZIP = 0x01 - CODEC_SNAPPY = 0x02 - CODEC_LZ4 = 0x03 - CODEC_ZSTD = 0x04 + CODEC_MASK: CodecMaskT = 0x07 + CODEC_NONE: CodecNoneT = 0x00 + CODEC_GZIP: CodecGzipT = 0x01 + CODEC_SNAPPY: CodecSnappyT = 0x02 + CODEC_LZ4: CodecLz4T = 0x03 + CODEC_ZSTD: CodecZstdT = 0x04 TIMESTAMP_TYPE_MASK = 0x08 TRANSACTIONAL_MASK = 0x10 CONTROL_MASK = 0x20 @@ -112,7 +130,9 @@ class DefaultRecordBase: NO_PARTITION_LEADER_EPOCH = -1 - def _assert_has_codec(self, compression_type): + def _assert_has_codec( + self, compression_type: int + ) -> TypeIs[Union[CodecGzipT, CodecSnappyT, CodecLz4T, CodecZstdT]]: if compression_type == self.CODEC_GZIP: checker, name = codecs.has_gzip, "gzip" elif compression_type == self.CODEC_SNAPPY: @@ -129,82 +149,86 @@ def _assert_has_codec(self, compression_type): raise UnsupportedCodecError( f"Libraries for {name} compression codec not found" ) + return True -class _DefaultRecordBatchPy(DefaultRecordBase): - def __init__(self, buffer): +@final +class _DefaultRecordBatchPy(DefaultRecordBase, DefaultRecordBatchProtocol): + def __init__(self, buffer: Union[bytes, bytearray, memoryview]) -> None: self._buffer = bytearray(buffer) - self._header_data = self.HEADER_STRUCT.unpack_from(self._buffer) + self._header_data: Tuple[ + int, int, int, int, int, int, int, int, int, int, int, int, int + ] = self.HEADER_STRUCT.unpack_from(self._buffer) self._pos = self.HEADER_STRUCT.size self._num_records = self._header_data[12] self._next_record_index = 0 self._decompressed = False @property - def base_offset(self): + def base_offset(self) -> int: return self._header_data[0] @property - def magic(self): + def magic(self) -> int: return self._header_data[3] @property - def crc(self): + def crc(self) -> int: return self._header_data[4] @property - def attributes(self): + def attributes(self) -> int: return self._header_data[5] @property - def compression_type(self): + def compression_type(self) -> int: return self.attributes & self.CODEC_MASK @property - def timestamp_type(self): + def timestamp_type(self) -> int: return int(bool(self.attributes & self.TIMESTAMP_TYPE_MASK)) @property - def is_transactional(self): + def is_transactional(self) -> bool: return bool(self.attributes & self.TRANSACTIONAL_MASK) @property - def is_control_batch(self): + def is_control_batch(self) -> bool: return bool(self.attributes & self.CONTROL_MASK) @property - def last_offset_delta(self): + def last_offset_delta(self) -> int: return self._header_data[6] @property - def first_timestamp(self): + def first_timestamp(self) -> int: return self._header_data[7] @property - def max_timestamp(self): + def max_timestamp(self) -> int: return self._header_data[8] @property - def producer_id(self): + def producer_id(self) -> int: return self._header_data[9] @property - def producer_epoch(self): + def producer_epoch(self) -> int: return self._header_data[10] @property - def base_sequence(self): + def base_sequence(self) -> int: return self._header_data[11] @property - def next_offset(self): + def next_offset(self) -> int: return self.base_offset + self.last_offset_delta + 1 - def _maybe_uncompress(self): + def _maybe_uncompress(self) -> None: if not self._decompressed: compression_type = self.compression_type if compression_type != self.CODEC_NONE: - self._assert_has_codec(compression_type) + assert self._assert_has_codec(compression_type) data = memoryview(self._buffer)[self._pos :] if compression_type == self.CODEC_GZIP: uncompressed = gzip_decode(data) @@ -212,13 +236,17 @@ def _maybe_uncompress(self): uncompressed = snappy_decode(data.tobytes()) elif compression_type == self.CODEC_LZ4: uncompressed = lz4_decode(data.tobytes()) - if compression_type == self.CODEC_ZSTD: + elif compression_type == self.CODEC_ZSTD: uncompressed = zstd_decode(data.tobytes()) + else: + assert_never(compression_type) self._buffer = bytearray(uncompressed) self._pos = 0 self._decompressed = True - def _read_msg(self, decode_varint=decode_varint): + def _read_msg( + self, decode_varint: Callable[[bytearray, int], Tuple[int, int]] = decode_varint + ) -> "_DefaultRecordPy": # Record => # Length => Varint # Attributes => Int8 @@ -264,7 +292,7 @@ def _read_msg(self, decode_varint=decode_varint): raise CorruptRecordException( f"Found invalid number of record headers {header_count}" ) - headers = [] + headers: List[Tuple[str, Optional[bytes]]] = [] while header_count: # Header key is of type String, that can't be None h_key_len, pos = decode_varint(buffer, pos) @@ -294,15 +322,15 @@ def _read_msg(self, decode_varint=decode_varint): ) self._pos = pos - return DefaultRecord( + return _DefaultRecordPy( offset, timestamp, self.timestamp_type, key, value, headers ) - def __iter__(self): + def __iter__(self) -> Self: self._maybe_uncompress() return self - def __next__(self): + def __next__(self) -> "_DefaultRecordPy": if self._next_record_index >= self._num_records: if self._pos != len(self._buffer): raise CorruptRecordException( @@ -320,9 +348,7 @@ def __next__(self): self._next_record_index += 1 return msg - next = __next__ - - def validate_crc(self): + def validate_crc(self) -> bool: assert self._decompressed is False, "Validate should be called before iteration" crc = self.crc @@ -331,78 +357,47 @@ def validate_crc(self): return crc == verify_crc -class _DefaultRecordPy: - __slots__ = ( - "_offset", - "_timestamp", - "_timestamp_type", - "_key", - "_value", - "_headers", - ) - - def __init__(self, offset, timestamp, timestamp_type, key, value, headers): - self._offset = offset - self._timestamp = timestamp - self._timestamp_type = timestamp_type - self._key = key - self._value = value - self._headers = headers - - @property - def offset(self): - return self._offset - - @property - def timestamp(self): - """Epoch milliseconds""" - return self._timestamp +@final +@dataclass(frozen=True) +class _DefaultRecordPy(DefaultRecordProtocol): + __slots__ = ("offset", "timestamp", "timestamp_type", "key", "value", "headers") - @property - def timestamp_type(self): - """CREATE_TIME(0) or APPEND_TIME(1)""" - return self._timestamp_type + offset: int + timestamp: int + timestamp_type: int + key: Optional[bytes] + value: Optional[bytes] + headers: List[Tuple[str, Optional[bytes]]] @property - def key(self): - """Bytes key or None""" - return self._key - - @property - def value(self): - """Bytes value or None""" - return self._value - - @property - def headers(self): - return self._headers - - @property - def checksum(self): + def checksum(self) -> None: return None - def __repr__(self): + def __repr__(self) -> str: return ( - f"DefaultRecord(offset={self._offset!r}, timestamp={self._timestamp!r}," - f" timestamp_type={self._timestamp_type!r}, key={self._key!r}," - f" value={self._value!r}, headers={self._headers!r})" + f"DefaultRecord(offset={self.offset!r}, timestamp={self.timestamp!r}," + f" timestamp_type={self.timestamp_type!r}, key={self.key!r}," + f" value={self.value!r}, headers={self.headers!r})" ) -class _DefaultRecordBatchBuilderPy(DefaultRecordBase): +@final +class _DefaultRecordBatchBuilderPy( + DefaultRecordBase, DefaultRecordBatchBuilderProtocol +): # excluding key, value and headers: # 5 bytes length + 10 bytes timestamp + 5 bytes offset + 1 byte attributes MAX_RECORD_OVERHEAD = 21 def __init__( self, - magic, - compression_type, - is_transactional, - producer_id, - producer_epoch, - base_sequence, - batch_size, + magic: int, + compression_type: int, + is_transactional: int, + producer_id: int, + producer_epoch: int, + base_sequence: int, + batch_size: int, ): assert magic >= 2 self._magic = magic @@ -414,14 +409,14 @@ def __init__( self._producer_epoch = producer_epoch self._base_sequence = base_sequence - self._first_timestamp = None - self._max_timestamp = None + self._first_timestamp: Optional[int] = None + self._max_timestamp: Optional[int] = None self._last_offset = 0 self._num_records = 0 self._buffer = bytearray(self.HEADER_STRUCT.size) - def _get_attributes(self, include_compression_type=True): + def _get_attributes(self, include_compression_type: bool = True) -> int: attrs = 0 if include_compression_type: attrs |= self._compression_type @@ -433,22 +428,26 @@ def _get_attributes(self, include_compression_type=True): def append( self, - offset, - timestamp, - key, - value, - headers, + offset: int, + timestamp: Optional[int], + key: Optional[bytes], + value: Optional[bytes], + headers: List[Tuple[str, Optional[bytes]]], # Cache for LOAD_FAST opcodes - encode_varint=encode_varint, - size_of_varint=size_of_varint, - get_type=type, - type_int=int, - time_time=time.time, - byte_like=(bytes, bytearray, memoryview), - bytearray_type=bytearray, - len_func=len, - zero_len_varint=1, - ): + encode_varint: Callable[[int, Callable[[int], None]], int] = encode_varint, + size_of_varint: Callable[[int], int] = size_of_varint, + get_type: Callable[[Any], type] = type, + type_int: Type[int] = int, + time_time: Callable[[], float] = time.time, + byte_like: Tuple[Type[bytes], Type[bytearray], Type[memoryview]] = ( + bytes, + bytearray, + memoryview, + ), + bytearray_type: Type[bytearray] = bytearray, + len_func: Callable[[Sized], int] = len, + zero_len_varint: int = 1, + ) -> Optional["_DefaultRecordMetadataPy"]: """Write message to messageset buffer with MsgVersion 2""" # Check types if get_type(offset) != type_int: @@ -497,9 +496,9 @@ def append( encode_varint(len_func(headers), write_byte) for h_key, h_value in headers: - h_key = h_key.encode("utf-8") - encode_varint(len_func(h_key), write_byte) - write(h_key) + h_key_bytes = h_key.encode("utf-8") + encode_varint(len_func(h_key_bytes), write_byte) + write(h_key_bytes) if h_value is not None: encode_varint(len_func(h_value), write_byte) write(h_value) @@ -518,6 +517,7 @@ def append( return None # Those should be updated after the length check + assert self._max_timestamp is not None if self._max_timestamp < timestamp: self._max_timestamp = timestamp self._num_records += 1 @@ -528,7 +528,7 @@ def append( return _DefaultRecordMetadataPy(offset, required_size, timestamp) - def write_header(self, use_compression_type=True): + def _write_header(self, use_compression_type: bool = True) -> None: batch_len = len(self._buffer) self.HEADER_STRUCT.pack_into( self._buffer, @@ -550,9 +550,9 @@ def write_header(self, use_compression_type=True): crc = calc_crc32c(self._buffer[self.ATTRIBUTES_OFFSET :]) struct.pack_into(">I", self._buffer, self.CRC_OFFSET, crc) - def _maybe_compress(self): + def _maybe_compress(self) -> bool: if self._compression_type != self.CODEC_NONE: - self._assert_has_codec(self._compression_type) + assert self._assert_has_codec(self._compression_type) header_size = self.HEADER_STRUCT.size data = bytes(self._buffer[header_size:]) if self._compression_type == self.CODEC_GZIP: @@ -563,6 +563,8 @@ def _maybe_compress(self): compressed = lz4_encode(data) elif self._compression_type == self.CODEC_ZSTD: compressed = zstd_encode(data) + else: + assert_never(self._compression_type) compressed_size = len(compressed) if len(data) <= compressed_size: # We did not get any benefit from compression, lets send @@ -576,16 +578,23 @@ def _maybe_compress(self): return True return False - def build(self): + def build(self) -> bytearray: send_compressed = self._maybe_compress() - self.write_header(send_compressed) + self._write_header(send_compressed) return self._buffer - def size(self): + def size(self) -> int: """Return current size of data written to buffer""" return len(self._buffer) - def size_in_bytes(self, offset, timestamp, key, value, headers): + def size_in_bytes( + self, + offset: int, + timestamp: int, + key: Optional[bytes], + value: Optional[bytes], + headers: List[Tuple[str, Optional[bytes]]], + ) -> int: if self._first_timestamp is not None: timestamp_delta = timestamp - self._first_timestamp else: @@ -599,7 +608,12 @@ def size_in_bytes(self, offset, timestamp, key, value, headers): return size_of_body + size_of_varint(size_of_body) @classmethod - def size_of(cls, key, value, headers): + def size_of( + cls, + key: Optional[bytes], + value: Optional[bytes], + headers: List[Tuple[str, Optional[bytes]]], + ) -> int: size = 0 # Key size if key is None: @@ -627,7 +641,12 @@ def size_of(cls, key, value, headers): return size @classmethod - def estimate_size_in_bytes(cls, key, value, headers): + def estimate_size_in_bytes( + cls, + key: Optional[bytes], + value: Optional[bytes], + headers: List[Tuple[str, Optional[bytes]]], + ) -> int: """Get the upper bound estimate on the size of record""" return ( cls.HEADER_STRUCT.size @@ -635,55 +654,51 @@ def estimate_size_in_bytes(cls, key, value, headers): + cls.size_of(key, value, headers) ) - def set_producer_state(self, producer_id, producer_epoch, base_sequence): + def set_producer_state( + self, producer_id: int, producer_epoch: int, base_sequence: int + ) -> None: self._producer_id = producer_id self._producer_epoch = producer_epoch self._base_sequence = base_sequence @property - def producer_id(self): + def producer_id(self) -> int: return self._producer_id @property - def producer_epoch(self): + def producer_epoch(self) -> int: return self._producer_epoch @property - def base_sequence(self): + def base_sequence(self) -> int: return self._base_sequence -class _DefaultRecordMetadataPy: - __slots__ = ("_size", "_timestamp", "_offset") +@final +@dataclass(frozen=True) +class _DefaultRecordMetadataPy(DefaultRecordMetadataProtocol): + __slots__ = ("size", "timestamp", "offset") - def __init__(self, offset, size, timestamp): - self._offset = offset - self._size = size - self._timestamp = timestamp + offset: int + size: int + timestamp: int @property - def offset(self): - return self._offset - - @property - def crc(self): + def crc(self) -> None: return None - @property - def size(self): - return self._size - - @property - def timestamp(self): - return self._timestamp - - def __repr__(self): + def __repr__(self) -> str: return ( - f"DefaultRecordMetadata(offset={self._offset!r}," - f" size={self._size!r}, timestamp={self._timestamp!r})" + f"DefaultRecordMetadata(offset={self.offset!r}," + f" size={self.size!r}, timestamp={self.timestamp!r})" ) +DefaultRecordBatchBuilder: Type[DefaultRecordBatchBuilderProtocol] +DefaultRecordMetadata: Type[DefaultRecordMetadataProtocol] +DefaultRecordBatch: Type[DefaultRecordBatchProtocol] +DefaultRecord: Type[DefaultRecordProtocol] + if NO_EXTENSIONS: DefaultRecordBatchBuilder = _DefaultRecordBatchBuilderPy DefaultRecordMetadata = _DefaultRecordMetadataPy diff --git a/aiokafka/record/legacy_records.py b/aiokafka/record/legacy_records.py index e2d9190f..32878229 100644 --- a/aiokafka/record/legacy_records.py +++ b/aiokafka/record/legacy_records.py @@ -1,6 +1,12 @@ +from __future__ import annotations + import struct import time from binascii import crc32 +from dataclasses import dataclass +from typing import Any, Generator, List, Optional, Tuple, Type, Union, final + +from typing_extensions import Literal, Never, TypeIs, assert_never import aiokafka.codec as codecs from aiokafka.codec import ( @@ -14,6 +20,20 @@ from aiokafka.errors import CorruptRecordException, UnsupportedCodecError from aiokafka.util import NO_EXTENSIONS +from ._protocols import ( + LegacyRecordBatchBuilderProtocol, + LegacyRecordBatchProtocol, + LegacyRecordMetadataProtocol, + LegacyRecordProtocol, +) +from ._types import ( + CodecGzipT, + CodecLz4T, + CodecMaskT, + CodecSnappyT, + LegacyCompressionTypeT, +) + NoneType = type(None) @@ -40,9 +60,7 @@ class LegacyRecordBase: ">q" # Offset "i" # Size ) - MAGIC_OFFSET = LOG_OVERHEAD + struct.calcsize( - ">I" # CRC - ) + MAGIC_OFFSET = LOG_OVERHEAD + struct.calcsize(">I") # CRC # Those are used for fast size calculations RECORD_OVERHEAD_V0 = struct.calcsize( ">I" # CRC @@ -68,16 +86,18 @@ class LegacyRecordBase: KEY_OFFSET_V1 = HEADER_STRUCT_V1.size KEY_LENGTH = VALUE_LENGTH = struct.calcsize(">i") # Bytes length is Int32 - CODEC_MASK = 0x07 - CODEC_GZIP = 0x01 - CODEC_SNAPPY = 0x02 - CODEC_LZ4 = 0x03 + CODEC_MASK: CodecMaskT = 0x07 + CODEC_GZIP: CodecGzipT = 0x01 + CODEC_SNAPPY: CodecSnappyT = 0x02 + CODEC_LZ4: CodecLz4T = 0x03 TIMESTAMP_TYPE_MASK = 0x08 LOG_APPEND_TIME = 1 CREATE_TIME = 0 - def _assert_has_codec(self, compression_type): + def _assert_has_codec( + self, compression_type: int + ) -> TypeIs[Union[CodecGzipT, CodecSnappyT, CodecLz4T]]: if compression_type == self.CODEC_GZIP: checker, name = codecs.has_gzip, "gzip" elif compression_type == self.CODEC_SNAPPY: @@ -92,14 +112,16 @@ def _assert_has_codec(self, compression_type): raise UnsupportedCodecError( f"Libraries for {name} compression codec not found" ) + return True -class _LegacyRecordBatchPy(LegacyRecordBase): - is_control_batch = False - is_transactional = False - producer_id = None +@final +class _LegacyRecordBatchPy(LegacyRecordBase, LegacyRecordBatchProtocol): + is_control_batch: bool = False + is_transactional: bool = False + producer_id: Optional[int] = None - def __init__(self, buffer, magic): + def __init__(self, buffer: Union[bytes, bytearray, memoryview], magic: int): self._buffer = memoryview(buffer) self._magic = magic @@ -114,7 +136,7 @@ def __init__(self, buffer, magic): self._decompressed = False @property - def timestamp_type(self): + def _timestamp_type(self) -> Optional[Literal[0, 1]]: """0 for CreateTime; 1 for LogAppendTime; None if unsupported. Value is determined by broker; produced messages should always set to 0 @@ -128,18 +150,18 @@ def timestamp_type(self): return 0 @property - def compression_type(self): + def _compression_type(self) -> int: return self._attributes & self.CODEC_MASK @property - def next_offset(self): + def next_offset(self) -> int: return self._offset + 1 - def validate_crc(self): + def validate_crc(self) -> bool: crc = crc32(self._buffer[self.MAGIC_OFFSET :]) return self._crc == crc - def _decompress(self, key_offset): + def _decompress(self, key_offset: int) -> bytes: # Copy of `_read_key_value`, but uses memoryview pos = key_offset key_size = struct.unpack_from(">i", self._buffer, pos)[0] @@ -153,8 +175,8 @@ def _decompress(self, key_offset): else: data = self._buffer[pos : pos + value_size] - compression_type = self.compression_type - self._assert_has_codec(compression_type) + compression_type = self._compression_type + assert self._assert_has_codec(compression_type) if compression_type == self.CODEC_GZIP: uncompressed = gzip_decode(data) elif compression_type == self.CODEC_SNAPPY: @@ -167,9 +189,11 @@ def _decompress(self, key_offset): ) else: uncompressed = lz4_decode(data.tobytes()) + else: + assert_never(compression_type) return uncompressed - def _read_header(self, pos): + def _read_header(self, pos: int) -> Tuple[int, int, int, int, int, Optional[int]]: if self._magic == 0: offset, length, crc, magic_read, attrs = self.HEADER_STRUCT_V0.unpack_from( self._buffer, pos @@ -186,9 +210,11 @@ def _read_header(self, pos): ) = self.HEADER_STRUCT_V1.unpack_from(self._buffer, pos) return offset, length, crc, magic_read, attrs, timestamp - def _read_all_headers(self): + def _read_all_headers( + self, + ) -> List[Tuple[Tuple[int, int, int, int, int, Optional[int]], int]]: pos = 0 - msgs = [] + msgs: List[Tuple[Tuple[int, int, int, int, int, Optional[int]], int]] = [] buffer_len = len(self._buffer) while pos < buffer_len: header = self._read_header(pos) @@ -196,8 +222,8 @@ def _read_all_headers(self): pos += self.LOG_OVERHEAD + header[1] # length return msgs - def _read_key_value(self, pos): - key_size = struct.unpack_from(">i", self._buffer, pos)[0] + def _read_key_value(self, pos: int) -> Tuple[Optional[bytes], Optional[bytes]]: + key_size: int = struct.unpack_from(">i", self._buffer, pos)[0] pos += self.KEY_LENGTH if key_size == -1: key = None @@ -205,7 +231,7 @@ def _read_key_value(self, pos): key = self._buffer[pos : pos + key_size].tobytes() pos += key_size - value_size = struct.unpack_from(">i", self._buffer, pos)[0] + value_size: int = struct.unpack_from(">i", self._buffer, pos)[0] pos += self.VALUE_LENGTH if value_size == -1: value = None @@ -213,14 +239,14 @@ def _read_key_value(self, pos): value = self._buffer[pos : pos + value_size].tobytes() return key, value - def __iter__(self): + def __iter__(self) -> Generator[_LegacyRecordPy, None, None]: if self._magic == 1: key_offset = self.KEY_OFFSET_V1 else: key_offset = self.KEY_OFFSET_V0 - timestamp_type = self.timestamp_type + timestamp_type = self._timestamp_type - if self.compression_type: + if self._compression_type: # In case we will call iter again if not self._decompressed: self._buffer = memoryview(self._decompress(key_offset)) @@ -263,67 +289,59 @@ def __iter__(self): ) -class _LegacyRecordPy: - __slots__ = ("_offset", "_timestamp", "_timestamp_type", "_key", "_value", "_crc") +@final +@dataclass(frozen=True) +class _LegacyRecordPy(LegacyRecordProtocol): + __slots__ = ("offset", "timestamp", "timestamp_type", "key", "value", "crc") - def __init__(self, offset, timestamp, timestamp_type, key, value, crc): - self._offset = offset - self._timestamp = timestamp - self._timestamp_type = timestamp_type - self._key = key - self._value = value - self._crc = crc + offset: int + timestamp: Optional[int] + timestamp_type: Optional[Literal[0, 1]] + key: Optional[bytes] + value: Optional[bytes] + crc: int @property - def offset(self): - return self._offset - - @property - def timestamp(self): - """Epoch milliseconds""" - return self._timestamp - - @property - def timestamp_type(self): - """CREATE_TIME(0) or APPEND_TIME(1)""" - return self._timestamp_type - - @property - def key(self): - """Bytes key or None""" - return self._key - - @property - def value(self): - """Bytes value or None""" - return self._value - - @property - def headers(self): + def headers(self) -> List[Never]: return [] @property - def checksum(self): - return self._crc + def checksum(self) -> int: + return self.crc - def __repr__(self): + def __repr__(self) -> str: return ( - f"LegacyRecord(offset={self._offset!r}, timestamp={self._timestamp!r}," - f" timestamp_type={self._timestamp_type!r}," - f" key={self._key!r}, value={self._value!r}, crc={self._crc!r})" + f"LegacyRecord(offset={self.offset!r}, timestamp={self.timestamp!r}," + f" timestamp_type={self.timestamp_type!r}," + f" key={self.key!r}, value={self.value!r}, crc={self.crc!r})" ) -class _LegacyRecordBatchBuilderPy(LegacyRecordBase): - def __init__(self, magic, compression_type, batch_size): +@final +class _LegacyRecordBatchBuilderPy(LegacyRecordBase, LegacyRecordBatchBuilderProtocol): + _buffer: Optional[bytearray] = None + + def __init__( + self, + magic: Literal[0, 1], + compression_type: LegacyCompressionTypeT, + batch_size: int, + ) -> None: assert magic in [0, 1] self._magic = magic self._compression_type = compression_type self._batch_size = batch_size - self._msg_buffers = [] + self._msg_buffers: List[bytearray] = [] self._pos = 0 - def append(self, offset, timestamp, key, value, headers=None): + def append( + self, + offset: int, + timestamp: Optional[int], + key: Optional[bytes], + value: Optional[bytes], + headers: Any = None, + ) -> Optional[_LegacyRecordMetadataPy]: """Append message to batch.""" if self._magic == 0: timestamp = -1 @@ -364,8 +382,16 @@ def append(self, offset, timestamp, key, value, headers=None): raise def _encode_msg( - self, buf, offset, timestamp, key_size, key, value_size, value, attributes=0 - ): + self, + buf: bytearray, + offset: int, + timestamp: int, + key_size: int, + key: Optional[bytes], + value_size: int, + value: Optional[bytes], + attributes: int = 0, + ) -> int: """Encode msg data into the `msg_buffer`, which should be allocated to at least the size of this message. """ @@ -433,9 +459,10 @@ def _encode_msg( struct.pack_into(">I", buf, self.CRC_OFFSET, crc) return crc - def _maybe_compress(self): + def _maybe_compress(self) -> bool: if self._compression_type: - self._assert_has_codec(self._compression_type) + assert self._buffer is not None + assert self._assert_has_codec(self._compression_type) buf = self._buffer if self._compression_type == self.CODEC_GZIP: compressed = gzip_encode(buf) @@ -449,6 +476,9 @@ def _maybe_compress(self): ) else: compressed = lz4_encode(bytes(buf)) + + else: + assert_never(self._compression_type) compressed_size = len(compressed) size = self._size_in_bytes(key_size=0, value_size=compressed_size) if size > len(self._buffer): @@ -469,24 +499,29 @@ def _maybe_compress(self): return True return False - def build(self): + def build(self) -> bytearray: """Compress batch to be ready for send""" self._buffer = bytearray().join(self._msg_buffers) self._maybe_compress() return self._buffer - def size(self): + def size(self) -> int: """Return current size of data written to buffer""" return self._pos - def size_in_bytes(self, offset, timestamp, key, value, headers=None): + def size_in_bytes( + self, + offset: Any, + timestamp: Any, + key: Optional[bytes], + value: Optional[bytes], + ) -> int: """Actual size of message to add""" - assert not headers, "Headers not supported in v0/v1" key_size = len(key) if key is not None else 0 value_size = len(value) if value is not None else 0 return self._size_in_bytes(key_size, value_size) - def _size_in_bytes(self, key_size, value_size): + def _size_in_bytes(self, key_size: int, value_size: int) -> int: return ( self.LOG_OVERHEAD + self.RECORD_OVERHEAD[self._magic] @@ -495,39 +530,40 @@ def _size_in_bytes(self, key_size, value_size): ) @classmethod - def record_overhead(cls, magic): + def record_overhead(cls, magic: int) -> int: try: return cls.RECORD_OVERHEAD[magic] except KeyError: raise ValueError(f"Unsupported magic: {magic}") from None -class _LegacyRecordMetadataPy: +@final +class _LegacyRecordMetadataPy(LegacyRecordMetadataProtocol): __slots__ = ("_crc", "_size", "_timestamp", "_offset") - def __init__(self, offset, crc, size, timestamp): + def __init__(self, offset: int, crc: int, size: int, timestamp: int) -> None: self._offset = offset self._crc = crc self._size = size self._timestamp = timestamp @property - def offset(self): + def offset(self) -> int: return self._offset @property - def crc(self): + def crc(self) -> int: return self._crc @property - def size(self): + def size(self) -> int: return self._size @property - def timestamp(self): + def timestamp(self) -> int: return self._timestamp - def __repr__(self): + def __repr__(self) -> str: return ( f"LegacyRecordMetadata(offset={self._offset!r}," f" crc={self._crc!r}, size={self._size!r}," @@ -535,6 +571,11 @@ def __repr__(self): ) +LegacyRecordBatchBuilder: Type[LegacyRecordBatchBuilderProtocol] +LegacyRecordMetadata: Type[LegacyRecordMetadataProtocol] +LegacyRecordBatch: Type[LegacyRecordBatchProtocol] +LegacyRecord: Type[LegacyRecordProtocol] + if NO_EXTENSIONS: LegacyRecordBatchBuilder = _LegacyRecordBatchBuilderPy LegacyRecordMetadata = _LegacyRecordMetadataPy diff --git a/aiokafka/record/memory_records.py b/aiokafka/record/memory_records.py index 814190a1..b618d4a8 100644 --- a/aiokafka/record/memory_records.py +++ b/aiokafka/record/memory_records.py @@ -20,36 +20,45 @@ # used to construct the correct class for Batch itself. import struct +from typing import Optional, Type, Union, final from aiokafka.errors import CorruptRecordException from aiokafka.util import NO_EXTENSIONS +from ._protocols import ( + DefaultRecordBatchProtocol, + LegacyRecordBatchProtocol, + MemoryRecordsProtocol, +) from .default_records import DefaultRecordBatch -from .legacy_records import LegacyRecordBatch +from .legacy_records import LegacyRecordBatch, _LegacyRecordBatchPy -class _MemoryRecordsPy: +@final +class _MemoryRecordsPy(MemoryRecordsProtocol): LENGTH_OFFSET = struct.calcsize(">q") LOG_OVERHEAD = struct.calcsize(">qi") MAGIC_OFFSET = struct.calcsize(">qii") # Minimum space requirements for Record V0 - MIN_SLICE = LOG_OVERHEAD + LegacyRecordBatch.RECORD_OVERHEAD_V0 + MIN_SLICE = LOG_OVERHEAD + _LegacyRecordBatchPy.RECORD_OVERHEAD_V0 - def __init__(self, bytes_data): + def __init__(self, bytes_data: bytes) -> None: self._buffer = bytes_data - self._pos = 0 + self._pos: int = 0 # We keep one slice ahead so `has_next` will return very fast - self._next_slice = None + self._next_slice: Optional[memoryview] = None self._remaining_bytes = 0 self._cache_next() - def size_in_bytes(self): + def size_in_bytes(self) -> int: return len(self._buffer) # NOTE: we cache offsets here as kwargs for a bit more speed, as cPython # will use LOAD_FAST opcode in this case - def _cache_next(self, len_offset=LENGTH_OFFSET, log_overhead=LOG_OVERHEAD): + def _cache_next( + self, len_offset: int = LENGTH_OFFSET, log_overhead: int = LOG_OVERHEAD + ) -> None: buffer = self._buffer buffer_len = len(buffer) pos = self._pos @@ -60,7 +69,7 @@ def _cache_next(self, len_offset=LENGTH_OFFSET, log_overhead=LOG_OVERHEAD): self._next_slice = None return - (length,) = struct.unpack_from(">i", buffer, pos + len_offset) + length: int = struct.unpack_from(">i", buffer, pos + len_offset)[0] slice_end = pos + log_overhead + length if slice_end > buffer_len: @@ -72,11 +81,13 @@ def _cache_next(self, len_offset=LENGTH_OFFSET, log_overhead=LOG_OVERHEAD): self._next_slice = memoryview(buffer)[pos:slice_end] self._pos = slice_end - def has_next(self): + def has_next(self) -> bool: return self._next_slice is not None # NOTE: same cache for LOAD_FAST as above - def next_batch(self, _min_slice=MIN_SLICE, _magic_offset=MAGIC_OFFSET): + def next_batch( + self, _min_slice: int = MIN_SLICE, _magic_offset: int = MAGIC_OFFSET + ) -> Optional[Union[DefaultRecordBatchProtocol, LegacyRecordBatchProtocol]]: next_slice = self._next_slice if next_slice is None: return None @@ -93,6 +104,8 @@ def next_batch(self, _min_slice=MIN_SLICE, _magic_offset=MAGIC_OFFSET): return LegacyRecordBatch(next_slice, magic) +MemoryRecords: Type[MemoryRecordsProtocol] + if NO_EXTENSIONS: MemoryRecords = _MemoryRecordsPy else: diff --git a/aiokafka/record/util.py b/aiokafka/record/util.py index 56381eb0..5133bc8d 100644 --- a/aiokafka/record/util.py +++ b/aiokafka/record/util.py @@ -1,9 +1,11 @@ +from typing import Callable, Iterable, Tuple, Union + from aiokafka.util import NO_EXTENSIONS from ._crc32c import crc as crc32c_py -def encode_varint_py(value, write): +def encode_varint_py(value: int, write: Callable[[int], None]) -> int: """Encode an integer to a varint presentation. See https://developers.google.com/protocol-buffers/docs/encoding?csw=1#varints on how those can be produced. @@ -56,7 +58,7 @@ def encode_varint_py(value, write): return i -def size_of_varint_py(value): +def size_of_varint_py(value: int) -> int: """Number of bytes needed to encode an integer in variable-length format.""" value = (value << 1) ^ (value >> 63) if value <= 0x7F: @@ -80,7 +82,7 @@ def size_of_varint_py(value): return 10 -def decode_varint_py(buffer, pos=0): +def decode_varint_py(buffer: bytearray, pos: int = 0) -> Tuple[int, int]: """Decode an integer from a varint presentation. See https://developers.google.com/protocol-buffers/docs/encoding?csw=1#varints on how those can be produced. @@ -112,12 +114,17 @@ def decode_varint_py(buffer, pos=0): raise ValueError("Out of int64 range") -def calc_crc32c_py(memview): +def calc_crc32c_py(memview: Iterable[int]) -> int: """Calculate CRC-32C (Castagnoli) checksum over a memoryview of data""" crc = crc32c_py(memview) return crc +calc_crc32c: Callable[[Union[bytes, bytearray]], int] +decode_varint: Callable[[bytearray, int], Tuple[int, int]] +size_of_varint: Callable[[int], int] +encode_varint: Callable[[int, Callable[[int], None]], int] + if NO_EXTENSIONS: calc_crc32c = calc_crc32c_py decode_varint = decode_varint_py diff --git a/pyproject.toml b/pyproject.toml index 534987eb..8121a35b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -31,7 +31,7 @@ dynamic = ["version"] dependencies = [ "async-timeout", "packaging", - "typing_extensions >=4.6.0", + "typing_extensions >=4.10.0", ] [project.optional-dependencies] diff --git a/requirements-ci.txt b/requirements-ci.txt index 69136120..08592f03 100644 --- a/requirements-ci.txt +++ b/requirements-ci.txt @@ -1,6 +1,6 @@ -r requirements-cython.txt ruff==0.3.4 -mypy==1.9.0 +mypy==1.10.0 pytest==7.4.3 pytest-cov==4.1.0 pytest-asyncio==0.21.1 diff --git a/requirements-win-test.txt b/requirements-win-test.txt index fab457ce..6a781e65 100644 --- a/requirements-win-test.txt +++ b/requirements-win-test.txt @@ -1,6 +1,6 @@ -r requirements-cython.txt ruff==0.3.2 -mypy==1.9.0 +mypy==1.10.0 pytest==7.4.3 pytest-cov==4.1.0 pytest-asyncio==0.21.1 diff --git a/tests/record/test_control_record.py b/tests/record/test_control_record.py index 951deefd..25ee26b3 100644 --- a/tests/record/test_control_record.py +++ b/tests/record/test_control_record.py @@ -10,11 +10,11 @@ (b"\x00\x00\x00\x01", COMMIT_MARKER), ], ) -def test_control_record_serde(data, marker): +def test_control_record_serde(data: bytes, marker: ControlRecord) -> None: assert ControlRecord.parse(data) == marker -def test_control_record_parse(): +def test_control_record_parse() -> None: record = ControlRecord.parse(b"\x00\x01\x00\x01") assert record.version == 1 assert record.type_ == 1 @@ -28,7 +28,7 @@ def test_control_record_parse(): assert record.type_ == 0 -def test_control_record_other(): +def test_control_record_other() -> None: record = ControlRecord.parse(b"\x00\x00\x00\x01") assert record != 1 assert record != object() diff --git a/tests/record/test_default_records.py b/tests/record/test_default_records.py index 0e0dfda0..74d893d0 100644 --- a/tests/record/test_default_records.py +++ b/tests/record/test_default_records.py @@ -1,6 +1,8 @@ +from typing import List, Optional, Tuple from unittest import mock import pytest +from typing_extensions import Literal import aiokafka.codec from aiokafka.errors import UnsupportedCodecError @@ -9,6 +11,8 @@ DefaultRecordBatchBuilder, ) +HeadersT = List[Tuple[str, Optional[bytes]]] + @pytest.mark.parametrize( "compression_type,crc", @@ -24,7 +28,9 @@ pytest.param(DefaultRecordBatch.CODEC_ZSTD, 1714138923, id="zstd"), ], ) -def test_read_write_serde_v2(compression_type, crc): +def test_read_write_serde_v2( + compression_type: Literal[0x00, 0x01, 0x02, 0x03, 0x04], crc: int +) -> None: builder = DefaultRecordBatchBuilder( magic=2, compression_type=compression_type, @@ -34,7 +40,7 @@ def test_read_write_serde_v2(compression_type, crc): base_sequence=9999, batch_size=999999, ) - headers = [("header1", b"aaa"), ("header2", b"bbb")] + headers: HeadersT = [("header1", b"aaa"), ("header2", b"bbb")] for offset in range(10): builder.append( offset, @@ -68,10 +74,10 @@ def test_read_write_serde_v2(compression_type, crc): assert msg.headers == headers -def test_written_bytes_equals_size_in_bytes_v2(): +def test_written_bytes_equals_size_in_bytes_v2() -> None: key = b"test" value = b"Super" - headers = [("header1", b"aaa"), ("header2", b"bbb"), ("xx", None)] + headers: HeadersT = [("header1", b"aaa"), ("header2", b"bbb"), ("xx", None)] builder = DefaultRecordBatchBuilder( magic=2, compression_type=0, @@ -90,13 +96,14 @@ def test_written_bytes_equals_size_in_bytes_v2(): meta = builder.append(0, timestamp=9999999, key=key, value=value, headers=headers) assert builder.size() - pos == size_in_bytes + assert meta is not None assert meta.size == size_in_bytes -def test_estimate_size_in_bytes_bigger_than_batch_v2(): +def test_estimate_size_in_bytes_bigger_than_batch_v2() -> None: key = b"Super Key" value = b"1" * 100 - headers = [("header1", b"aaa"), ("header2", b"bbb")] + headers: HeadersT = [("header1", b"aaa"), ("header2", b"bbb")] estimate_size = DefaultRecordBatchBuilder.estimate_size_in_bytes( key, value, headers ) @@ -115,7 +122,7 @@ def test_estimate_size_in_bytes_bigger_than_batch_v2(): assert len(buf) <= estimate_size, "Estimate should always be upper bound" -def test_default_batch_builder_validates_arguments(): +def test_default_batch_builder_validates_arguments() -> None: builder = DefaultRecordBatchBuilder( magic=2, compression_type=0, @@ -128,22 +135,30 @@ def test_default_batch_builder_validates_arguments(): # Key should not be str with pytest.raises(TypeError): - builder.append(0, timestamp=9999999, key="some string", value=None, headers=[]) + builder.append(0, timestamp=9999999, key="some string", value=None, headers=[]) # type: ignore[arg-type] # Value should not be str with pytest.raises(TypeError): - builder.append(0, timestamp=9999999, key=None, value="some string", headers=[]) + builder.append(0, timestamp=9999999, key=None, value="some string", headers=[]) # type: ignore[arg-type] # Timestamp should be of proper type with pytest.raises(TypeError): builder.append( - 0, timestamp="1243812793", key=None, value=b"some string", headers=[] + 0, + timestamp="1243812793", # type: ignore[arg-type] + key=None, + value=b"some string", + headers=[], ) # Offset of invalid type with pytest.raises(TypeError): builder.append( - "0", timestamp=9999999, key=None, value=b"some string", headers=[] + "0", # type: ignore[arg-type] + timestamp=9999999, + key=None, + value=b"some string", + headers=[], ) # Ok to pass value as None @@ -159,7 +174,7 @@ def test_default_batch_builder_validates_arguments(): assert len(builder.build()) == 104 -def test_default_correct_metadata_response(): +def test_default_correct_metadata_response() -> None: builder = DefaultRecordBatchBuilder( magic=2, compression_type=0, @@ -171,6 +186,7 @@ def test_default_correct_metadata_response(): ) meta = builder.append(0, timestamp=9999999, key=b"test", value=b"Super", headers=[]) + assert meta is not None assert meta.offset == 0 assert meta.timestamp == 9999999 assert meta.crc is None @@ -180,7 +196,7 @@ def test_default_correct_metadata_response(): ) -def test_default_batch_size_limit(): +def test_default_batch_size_limit() -> None: # First message can be added even if it's too big builder = DefaultRecordBatchBuilder( magic=2, @@ -193,6 +209,7 @@ def test_default_batch_size_limit(): ) meta = builder.append(0, timestamp=None, key=None, value=b"M" * 2000, headers=[]) + assert meta is not None assert meta.size > 0 assert meta.crc is None assert meta.offset == 0 @@ -226,7 +243,9 @@ def test_default_batch_size_limit(): (DefaultRecordBatch.CODEC_ZSTD, "zstd", "has_zstd"), ], ) -def test_unavailable_codec(compression_type, name, checker_name): +def test_unavailable_codec( + compression_type: Literal[0x01, 0x02, 0x03, 0x04], name: str, checker_name: str +) -> None: builder = DefaultRecordBatchBuilder( magic=2, compression_type=compression_type, @@ -261,11 +280,11 @@ def test_unavailable_codec(compression_type, name, checker_name): list(batch) -def test_unsupported_yet_codec(): +def test_unsupported_yet_codec() -> None: compression_type = DefaultRecordBatch.CODEC_MASK # It doesn't exist builder = DefaultRecordBatchBuilder( magic=2, - compression_type=compression_type, + compression_type=compression_type, # type: ignore[arg-type] is_transactional=0, producer_id=-1, producer_epoch=-1, @@ -277,7 +296,7 @@ def test_unsupported_yet_codec(): builder.build() -def test_build_without_append(): +def test_build_without_append() -> None: builder = DefaultRecordBatchBuilder( magic=2, compression_type=0, @@ -294,7 +313,7 @@ def test_build_without_append(): assert not msgs -def test_set_producer_state(): +def test_set_producer_state() -> None: builder = DefaultRecordBatchBuilder( magic=2, compression_type=0, diff --git a/tests/record/test_legacy.py b/tests/record/test_legacy.py index 90d88b02..9faa3bc0 100644 --- a/tests/record/test_legacy.py +++ b/tests/record/test_legacy.py @@ -1,7 +1,9 @@ import struct +from typing import Optional, Tuple from unittest import mock import pytest +from typing_extensions import Literal import aiokafka.codec from aiokafka.errors import CorruptRecordException, UnsupportedCodecError @@ -19,7 +21,12 @@ (b"test", b"", [4230475139, 3614888862]), ], ) -def test_read_write_serde_v0_v1_no_compression(magic, key, value, checksum): +def test_read_write_serde_v0_v1_no_compression( + magic: Literal[0, 1], + key: Optional[bytes], + value: Optional[bytes], + checksum: Tuple[int, int], +) -> None: builder = LegacyRecordBatchBuilder( magic=magic, compression_type=0, batch_size=1024 * 1024 ) @@ -57,7 +64,9 @@ def test_read_write_serde_v0_v1_no_compression(magic, key, value, checksum): (LegacyRecordBatch.CODEC_LZ4, 1), ], ) -def test_read_write_serde_v0_v1_with_compression(compression_type, magic): +def test_read_write_serde_v0_v1_with_compression( + compression_type: Literal[0x01, 0x02, 0x03], magic: Literal[0, 1] +) -> None: builder = LegacyRecordBatchBuilder( magic=magic, compression_type=compression_type, batch_size=1024 * 1024 ) @@ -88,7 +97,7 @@ def test_read_write_serde_v0_v1_with_compression(compression_type, magic): @pytest.mark.parametrize("magic", [0, 1]) -def test_written_bytes_equals_size_in_bytes(magic): +def test_written_bytes_equals_size_in_bytes(magic: Literal[0, 1]) -> None: key = b"test" value = b"Super" builder = LegacyRecordBatchBuilder( @@ -104,27 +113,47 @@ def test_written_bytes_equals_size_in_bytes(magic): @pytest.mark.parametrize("magic", [0, 1]) -def test_legacy_batch_builder_validates_arguments(magic): +def test_legacy_batch_builder_validates_arguments(magic: Literal[0, 1]) -> None: builder = LegacyRecordBatchBuilder( magic=magic, compression_type=0, batch_size=1024 * 1024 ) # Key should not be str with pytest.raises(TypeError): - builder.append(0, timestamp=9999999, key="some string", value=None) + builder.append( + 0, + timestamp=9999999, + key="some string", # type: ignore[arg-type] + value=None, + ) # Value should not be str with pytest.raises(TypeError): - builder.append(0, timestamp=9999999, key=None, value="some string") + builder.append( + 0, + timestamp=9999999, + key=None, + value="some string", # type: ignore[arg-type] + ) # Timestamp should be of proper type (timestamp is ignored for magic == 0) if magic != 0: with pytest.raises(TypeError): - builder.append(0, timestamp="1243812793", key=None, value=b"some string") + builder.append( + 0, + timestamp="1243812793", # type: ignore[arg-type] + key=None, + value=b"some string", + ) # Offset of invalid type with pytest.raises(TypeError): - builder.append("0", timestamp=9999999, key=None, value=b"some string") + builder.append( + "0", # type: ignore[arg-type] + timestamp=9999999, + key=None, + value=b"some string", + ) # Unknown struct errors are passed through. These are theoretical and # indicate a bug in the implementation. The C implementation locates @@ -152,12 +181,13 @@ def test_legacy_batch_builder_validates_arguments(magic): @pytest.mark.parametrize("magic", [0, 1]) -def test_legacy_correct_metadata_response(magic): +def test_legacy_correct_metadata_response(magic: Literal[0, 1]) -> None: builder = LegacyRecordBatchBuilder( magic=magic, compression_type=0, batch_size=1024 * 1024 ) meta = builder.append(0, timestamp=9999999, key=b"test", value=b"Super") + assert meta is not None assert meta.offset == 0 assert meta.timestamp == (9999999 if magic else -1) assert meta.crc == (-2095076219 if magic else 278251978) & 0xFFFFFFFF @@ -168,10 +198,11 @@ def test_legacy_correct_metadata_response(magic): @pytest.mark.parametrize("magic", [0, 1]) -def test_legacy_batch_size_limit(magic): +def test_legacy_batch_size_limit(magic: Literal[0, 1]) -> None: # First message can be added even if it's too big builder = LegacyRecordBatchBuilder(magic=magic, compression_type=0, batch_size=1024) meta = builder.append(0, timestamp=None, key=None, value=b"M" * 2000) + assert meta is not None assert meta.size > 0 assert meta.crc is not None assert meta.offset == 0 @@ -195,7 +226,9 @@ def test_legacy_batch_size_limit(magic): (LegacyRecordBatch.CODEC_SNAPPY, "snappy", "has_snappy"), ], ) -def test_unavailable_codec(compression_type, name, checker_name): +def test_unavailable_codec( + compression_type: Literal[0x01, 0x02], name: str, checker_name: str +) -> None: builder = LegacyRecordBatchBuilder( magic=0, compression_type=compression_type, batch_size=1024 ) @@ -219,10 +252,12 @@ def test_unavailable_codec(compression_type, name, checker_name): list(batch) -def test_unsupported_yet_codec(): +def test_unsupported_yet_codec() -> None: compression_type = LegacyRecordBatch.CODEC_MASK # It doesn't exist builder = LegacyRecordBatchBuilder( - magic=0, compression_type=compression_type, batch_size=1024 + magic=0, + compression_type=compression_type, # type: ignore[arg-type] + batch_size=1024, ) with pytest.raises(UnsupportedCodecError): builder.append(0, timestamp=None, key=None, value=b"M") @@ -234,7 +269,7 @@ def test_unsupported_yet_codec(): TIMESTAMP_TYPE_MASK = 0x08 -def _make_compressed_batch(magic): +def _make_compressed_batch(magic: Literal[0, 1]) -> bytearray: builder = LegacyRecordBatchBuilder( magic=magic, compression_type=LegacyRecordBatch.CODEC_GZIP, @@ -245,7 +280,7 @@ def _make_compressed_batch(magic): return builder.build() -def test_read_log_append_time_v1(): +def test_read_log_append_time_v1() -> None: buffer = _make_compressed_batch(1) # As Builder does not support creating data with `timestamp_type==1` we @@ -265,7 +300,7 @@ def test_read_log_append_time_v1(): @pytest.mark.parametrize("magic", [0, 1]) -def test_reader_corrupt_record_v0_v1(magic): +def test_reader_corrupt_record_v0_v1(magic: Literal[0, 1]) -> None: buffer = _make_compressed_batch(magic) len_offset = 8 @@ -301,7 +336,7 @@ def test_reader_corrupt_record_v0_v1(magic): list(batch) -def test_record_overhead(): +def test_record_overhead() -> None: known = { 0: 14, 1: 22, diff --git a/tests/record/test_records.py b/tests/record/test_records.py index 756c43cf..3323d159 100644 --- a/tests/record/test_records.py +++ b/tests/record/test_records.py @@ -2,6 +2,7 @@ from aiokafka.errors import CorruptRecordException from aiokafka.record import MemoryRecords +from aiokafka.record._protocols import DefaultRecordProtocol, LegacyRecordProtocol # This is real live data from Kafka 11 broker record_batch_data_v2 = [ @@ -56,7 +57,7 @@ ] -def test_memory_records_v2(): +def test_memory_records_v2() -> None: data_bytes = b"".join(record_batch_data_v2) + b"\x00" * 4 records = MemoryRecords(data_bytes) @@ -64,8 +65,10 @@ def test_memory_records_v2(): assert records.has_next() is True batch = records.next_batch() - recs = list(batch) + assert batch is not None + recs = tuple(batch) assert len(recs) == 1 + assert isinstance(recs[0], DefaultRecordProtocol) assert recs[0].value == b"123" assert recs[0].key is None assert recs[0].timestamp == 1503229838908 @@ -81,7 +84,7 @@ def test_memory_records_v2(): assert records.next_batch() is None -def test_memory_records_v1(): +def test_memory_records_v1() -> None: data_bytes = b"".join(record_batch_data_v1) + b"\x00" * 4 records = MemoryRecords(data_bytes) @@ -89,8 +92,10 @@ def test_memory_records_v1(): assert records.has_next() is True batch = records.next_batch() - recs = list(batch) + assert batch is not None + recs = tuple(batch) assert len(recs) == 1 + assert isinstance(recs[0], LegacyRecordProtocol) assert recs[0].value == b"123" assert recs[0].key is None assert recs[0].timestamp == 1503648000942 @@ -106,7 +111,7 @@ def test_memory_records_v1(): assert records.next_batch() is None -def test_memory_records_v0(): +def test_memory_records_v0() -> None: data_bytes = b"".join(record_batch_data_v0) records = MemoryRecords(data_bytes + b"\x00" * 4) @@ -116,8 +121,10 @@ def test_memory_records_v0(): assert records.has_next() is True batch = records.next_batch() - recs = list(batch) + assert batch is not None + recs = tuple(batch) assert len(recs) == 1 + assert isinstance(recs[0], LegacyRecordProtocol) assert recs[0].value == b"123" assert recs[0].key is None assert recs[0].timestamp is None @@ -133,7 +140,7 @@ def test_memory_records_v0(): assert records.next_batch() is None -def test_memory_records_corrupt(): +def test_memory_records_corrupt() -> None: records = MemoryRecords(b"") assert records.size_in_bytes() == 0 assert records.has_next() is False diff --git a/tests/record/test_util.py b/tests/record/test_util.py index d630c5f8..139043fc 100644 --- a/tests/record/test_util.py +++ b/tests/record/test_util.py @@ -1,10 +1,11 @@ import struct +from typing import List, Tuple import pytest from aiokafka.record import util -varint_data = [ +varint_data: List[Tuple[bytes, int]] = [ (b"\x00", 0), (b"\x01", -1), (b"\x02", 1), @@ -48,14 +49,14 @@ @pytest.mark.parametrize("encoded, decoded", varint_data) -def test_encode_varint(encoded, decoded): +def test_encode_varint(encoded: bytes, decoded: int) -> None: res = bytearray() util.encode_varint(decoded, res.append) assert res == encoded @pytest.mark.parametrize("encoded, decoded", varint_data) -def test_decode_varint(encoded, decoded): +def test_decode_varint(encoded: bytes, decoded: int) -> None: # We add a bit of bytes around just to check position is calculated # correctly value, pos = util.decode_varint(bytearray(b"\x01\xf0" + encoded + b"\xff\x01"), 2) @@ -64,12 +65,12 @@ def test_decode_varint(encoded, decoded): @pytest.mark.parametrize("encoded, decoded", varint_data) -def test_size_of_varint(encoded, decoded): +def test_size_of_varint(encoded: bytes, decoded: int) -> None: assert util.size_of_varint(decoded) == len(encoded) -def test_crc32c(): - def make_crc(data): +def test_crc32c() -> None: + def make_crc(data: bytes) -> bytes: crc = util.calc_crc32c(data) return struct.pack(">I", crc)