diff --git a/aiokafka/protocol/abstract.py b/aiokafka/protocol/abstract.py index 117d058e..953a76c3 100644 --- a/aiokafka/protocol/abstract.py +++ b/aiokafka/protocol/abstract.py @@ -1,15 +1,22 @@ import abc +import io +from typing import Generic, Optional, TypeVar +from typing_extensions import TypeAlias -class AbstractType(metaclass=abc.ABCMeta): +T = TypeVar("T") +BytesIO: TypeAlias = io.BytesIO + + +class AbstractType(Generic[T], metaclass=abc.ABCMeta): @classmethod @abc.abstractmethod - def encode(cls, value): ... + def encode(self, value: Optional[T]) -> bytes: ... @classmethod @abc.abstractmethod - def decode(cls, data): ... + def decode(self, data: BytesIO) -> Optional[T]: ... @classmethod - def repr(cls, value): + def repr(self, value: T) -> str: return repr(value) diff --git a/aiokafka/protocol/message.py b/aiokafka/protocol/message.py index 31993fe6..d2e038e4 100644 --- a/aiokafka/protocol/message.py +++ b/aiokafka/protocol/message.py @@ -1,6 +1,7 @@ import io import time from binascii import crc32 +from typing import Optional from aiokafka.codec import ( gzip_decode, @@ -47,7 +48,15 @@ class Message(Struct): 22 # crc(4), magic(1), attributes(1), timestamp(8), key+value size(4*2) ) - def __init__(self, value, key=None, magic=0, attributes=0, crc=0, timestamp=None): + def __init__( + self, + value: Optional[bytes], + key: Optional[bytes] = None, + magic: int = 0, + attributes: int = 0, + crc: int = 0, + timestamp: Optional[int] = None, + ): assert value is None or isinstance(value, bytes), "value must be bytes" assert key is None or isinstance(key, bytes), "key must be bytes" assert magic > 0 or timestamp is None, "timestamp not supported in v0" @@ -64,7 +73,7 @@ def __init__(self, value, key=None, magic=0, attributes=0, crc=0, timestamp=None self.value = value @property - def timestamp_type(self): + def timestamp_type(self) -> Optional[int]: """0 for CreateTime; 1 for LogAppendTime; None if unsupported. Value is determined by broker; produced messages should always set to 0 @@ -77,7 +86,7 @@ def timestamp_type(self): else: return 0 - def encode(self, recalc_crc=True): + def encode(self, recalc_crc: bool = True): version = self.magic if version == 1: fields = ( @@ -125,7 +134,7 @@ def decode(cls, data): msg._validated_crc = _validated_crc return msg - def validate_crc(self): + def validate_crc(self) -> bool: if self._validated_crc is None: raw_msg = self.encode(recalc_crc=False) self._validated_crc = crc32(raw_msg[4:]) @@ -133,7 +142,7 @@ def validate_crc(self): return True return False - def is_compressed(self): + def is_compressed(self) -> bool: return self.attributes & self.CODEC_MASK != 0 def decompress(self): diff --git a/aiokafka/protocol/types.py b/aiokafka/protocol/types.py index 7eadf7fb..03aaf554 100644 --- a/aiokafka/protocol/types.py +++ b/aiokafka/protocol/types.py @@ -1,10 +1,17 @@ +from __future__ import annotations + import struct from struct import error +from typing import Callable, Optional, TypeVar + +from _typeshed import ReadableBuffer + +from .abstract import AbstractType, BytesIO -from .abstract import AbstractType +T = TypeVar("T") -def _pack(f, value): +def _pack(f: Callable[[T], bytes], value: T) -> bytes: try: return f(value) except error as e: @@ -14,7 +21,7 @@ def _pack(f, value): ) from e -def _unpack(f, data): +def _unpack(f: Callable[[ReadableBuffer], tuple[T, ...]], data: ReadableBuffer) -> T: try: (value,) = f(data) except error as e: @@ -26,95 +33,95 @@ def _unpack(f, data): return value -class Int8(AbstractType): +class Int8(AbstractType[int]): _pack = struct.Struct(">b").pack _unpack = struct.Struct(">b").unpack @classmethod - def encode(cls, value): + def encode(cls, value: int) -> bytes: return _pack(cls._pack, value) @classmethod - def decode(cls, data): + def decode(cls, data: BytesIO) -> int: return _unpack(cls._unpack, data.read(1)) -class Int16(AbstractType): +class Int16(AbstractType[int]): _pack = struct.Struct(">h").pack _unpack = struct.Struct(">h").unpack @classmethod - def encode(cls, value): + def encode(cls, value: int) -> bytes: return _pack(cls._pack, value) @classmethod - def decode(cls, data): + def decode(cls, data: BytesIO) -> int: return _unpack(cls._unpack, data.read(2)) -class Int32(AbstractType): +class Int32(AbstractType[int]): _pack = struct.Struct(">i").pack _unpack = struct.Struct(">i").unpack @classmethod - def encode(cls, value): + def encode(cls, value: int) -> bytes: return _pack(cls._pack, value) @classmethod - def decode(cls, data): + def decode(cls, data: BytesIO) -> int: return _unpack(cls._unpack, data.read(4)) -class UInt32(AbstractType): +class UInt32(AbstractType[int]): _pack = struct.Struct(">I").pack _unpack = struct.Struct(">I").unpack @classmethod - def encode(cls, value): + def encode(cls, value: int) -> bytes: return _pack(cls._pack, value) @classmethod - def decode(cls, data): + def decode(cls, data: BytesIO) -> int: return _unpack(cls._unpack, data.read(4)) -class Int64(AbstractType): +class Int64(AbstractType[int]): _pack = struct.Struct(">q").pack _unpack = struct.Struct(">q").unpack @classmethod - def encode(cls, value): + def encode(cls, value: int) -> bytes: return _pack(cls._pack, value) @classmethod - def decode(cls, data): + def decode(cls, data: BytesIO) -> int: return _unpack(cls._unpack, data.read(8)) -class Float64(AbstractType): +class Float64(AbstractType[float]): _pack = struct.Struct(">d").pack _unpack = struct.Struct(">d").unpack @classmethod - def encode(cls, value): + def encode(cls, value: float) -> bytes: return _pack(cls._pack, value) @classmethod - def decode(cls, data): + def decode(cls, data: BytesIO) -> float: return _unpack(cls._unpack, data.read(8)) -class String(AbstractType): - def __init__(self, encoding="utf-8"): +class String(AbstractType[str]): + def __init__(self, encoding="utf-8") -> None: self.encoding = encoding - def encode(self, value): + def encode(self, value: Optional[str]) -> bytes: if value is None: return Int16.encode(-1) value = str(value).encode(self.encoding) return Int16.encode(len(value)) + value - def decode(self, data): + def decode(self, data: BytesIO) -> Optional[str]: length = Int16.decode(data) if length < 0: return None @@ -124,16 +131,16 @@ def decode(self, data): return value.decode(self.encoding) -class Bytes(AbstractType): +class Bytes(AbstractType[bytes]): @classmethod - def encode(cls, value): + def encode(cls, value: Optional[bytes]) -> bytes: if value is None: return Int32.encode(-1) else: return Int32.encode(len(value)) + value @classmethod - def decode(cls, data): + def decode(cls, data: BytesIO) -> Optional[bytes]: length = Int32.decode(data) if length < 0: return None @@ -143,33 +150,34 @@ def decode(cls, data): return value @classmethod - def repr(cls, value): + def repr(cls, value: Optional[bytes]) -> str: return repr( value[:100] + b"..." if value is not None and len(value) > 100 else value ) -class Boolean(AbstractType): +class Boolean(AbstractType[bool]): _pack = struct.Struct(">?").pack _unpack = struct.Struct(">?").unpack @classmethod - def encode(cls, value): + def encode(cls, value: bool) -> bytes: return _pack(cls._pack, value) @classmethod - def decode(cls, data): + def decode(cls, data: BytesIO) -> bool: return _unpack(cls._unpack, data.read(1)) class Schema(AbstractType): + def __init__(self, *fields): if fields: self.names, self.fields = zip(*fields) else: self.names, self.fields = (), () - def encode(self, item): + def encode(self, item) -> bytes: if len(item) != len(self.fields): raise ValueError("Item field count does not match Schema") return b"".join(field.encode(item[i]) for i, field in enumerate(self.fields)) @@ -177,10 +185,10 @@ def encode(self, item): def decode(self, data): return tuple(field.decode(data) for field in self.fields) - def __len__(self): + def __len__(self) -> int: return len(self.fields) - def repr(self, value): + def repr(self, value) -> str: key_vals = [] try: for i in range(len(self)): @@ -195,6 +203,7 @@ def repr(self, value): class Array(AbstractType): + def __init__(self, *array_of): if len(array_of) > 1: self.array_of = Schema(*array_of) @@ -206,7 +215,7 @@ def __init__(self, *array_of): else: raise ValueError("Array instantiated with no array_of type") - def encode(self, items): + def encode(self, items) -> bytes: if items is None: return Int32.encode(-1) encoded_items = (self.array_of.encode(item) for item in items) @@ -214,13 +223,13 @@ def encode(self, items): (Int32.encode(len(items)), *encoded_items), ) - def decode(self, data): + def decode(self, data: BytesIO) -> Optional[list[AbstractType]]: length = Int32.decode(data) if length == -1: return None return [self.array_of.decode(data) for _ in range(length)] - def repr(self, list_of_items): + def repr(self, list_of_items: Optional[list[AbstractType]]) -> str: if list_of_items is None: return "NULL" return "[" + ", ".join(self.array_of.repr(item) for item in list_of_items) + "]" @@ -242,7 +251,7 @@ def decode(cls, data): return value @classmethod - def encode(cls, value): + def encode(cls, value) -> bytes: value &= 0xFFFFFFFF ret = b"" while (value & 0xFFFFFF80) != 0: @@ -260,7 +269,7 @@ def decode(cls, data): return (value >> 1) ^ -(value & 1) @classmethod - def encode(cls, value): + def encode(cls, value) -> bytes: # bring it in line with the java binary repr value &= 0xFFFFFFFF return UnsignedVarInt32.encode((value << 1) ^ (value >> 31)) @@ -282,7 +291,7 @@ def decode(cls, data): return (value >> 1) ^ -(value & 1) @classmethod - def encode(cls, value): + def encode(cls, value) -> bytes: # bring it in line with the java binary repr value &= 0xFFFFFFFFFFFFFFFF v = (value << 1) ^ (value >> 63) @@ -296,7 +305,7 @@ def encode(cls, value): class CompactString(String): - def decode(self, data): + def decode(self, data: BytesIO) -> Optional[bytes]: length = UnsignedVarInt32.decode(data) - 1 if length < 0: return None @@ -305,7 +314,7 @@ def decode(self, data): raise ValueError("Buffer underrun decoding string") return value.decode(self.encoding) - def encode(self, value): + def encode(self, value: Optional[str]) -> bytes: if value is None: return UnsignedVarInt32.encode(0) value = str(value).encode(self.encoding) @@ -314,7 +323,7 @@ def encode(self, value): class TaggedFields(AbstractType): @classmethod - def decode(cls, data): + def decode(cls, data: bytes): num_fields = UnsignedVarInt32.decode(data) ret = {} if not num_fields: @@ -331,7 +340,7 @@ def decode(cls, data): return ret @classmethod - def encode(cls, value): + def encode(cls, value) -> bytes: ret = UnsignedVarInt32.encode(len(value)) for k, v in value.items(): # do we allow for other data types ?? It could get complicated really fast @@ -344,7 +353,7 @@ def encode(cls, value): class CompactBytes(AbstractType): @classmethod - def decode(cls, data): + def decode(cls, data: BytesIO) -> Optional[bytes]: length = UnsignedVarInt32.decode(data) - 1 if length < 0: return None @@ -354,7 +363,7 @@ def decode(cls, data): return value @classmethod - def encode(cls, value): + def encode(cls, value: Optional[bytes]) -> bytes: if value is None: return UnsignedVarInt32.encode(0) else: @@ -362,7 +371,7 @@ def encode(cls, value): class CompactArray(Array): - def encode(self, items): + def encode(self, items) -> bytes: if items is None: return UnsignedVarInt32.encode(0) encoded_items = (self.array_of.encode(item) for item in items) @@ -370,7 +379,7 @@ def encode(self, items): (UnsignedVarInt32.encode(len(items) + 1), *encoded_items), ) - def decode(self, data): + def decode(self, data: BytesIO): length = UnsignedVarInt32.decode(data) - 1 if length == -1: return None