From 4a0f6a8fcd6f73c9373873447c3ae9e70e2ad6d5 Mon Sep 17 00:00:00 2001 From: Anton Agestam Date: Wed, 3 Jul 2024 00:11:46 +0200 Subject: [PATCH] Add type hints to message_accumulator --- Makefile | 1 + aiokafka/producer/message_accumulator.py | 270 ++++++++++++------ aiokafka/producer/producer.py | 19 +- aiokafka/producer/sender.py | 2 +- aiokafka/protocol/types.py | 4 + aiokafka/record/_crecords/default_records.pyi | 4 +- aiokafka/record/_protocols.py | 5 +- aiokafka/record/default_records.py | 15 +- 8 files changed, 227 insertions(+), 93 deletions(-) diff --git a/Makefile b/Makefile index f0214347..26c2a6ee 100644 --- a/Makefile +++ b/Makefile @@ -12,6 +12,7 @@ FORMATTED_AREAS=\ aiokafka/helpers.py \ aiokafka/structs.py \ aiokafka/util.py \ + aiokafka/producer/message_accumulator.py \ aiokafka/protocol/ \ aiokafka/record/ \ tests/test_codec.py \ diff --git a/aiokafka/producer/message_accumulator.py b/aiokafka/producer/message_accumulator.py index 2b2e3785..f894fbfc 100644 --- a/aiokafka/producer/message_accumulator.py +++ b/aiokafka/producer/message_accumulator.py @@ -1,40 +1,92 @@ +from __future__ import annotations + import asyncio import collections import copy import time from collections.abc import Sequence +from typing import ( + TYPE_CHECKING, + AbstractSet, + Callable, + DefaultDict, + Deque, + Dict, + Generic, + List, + Protocol, + Tuple, + TypeVar, + Union, +) +from typing_extensions import Literal, TypeAlias + +from aiokafka.cluster import ClusterMetadata from aiokafka.errors import ( + BrokerResponseError, KafkaTimeoutError, LeaderNotAvailableError, NotLeaderForPartitionError, ProducerClosed, ) +from aiokafka.producer.transaction_manager import TransactionManager +from aiokafka.protocol.types import BrokerId +from aiokafka.record._protocols import ( + DefaultRecordBatchBuilderProtocol, + DefaultRecordMetadataProtocol, + LegacyRecordBatchBuilderProtocol, + LegacyRecordMetadataProtocol, +) from aiokafka.record.default_records import DefaultRecordBatchBuilder from aiokafka.record.legacy_records import LegacyRecordBatchBuilder -from aiokafka.structs import RecordMetadata +from aiokafka.structs import RecordMetadata, TopicPartition from aiokafka.util import create_future, get_running_loop +T = TypeVar("T") + +BytesSerializer: TypeAlias = Callable[ + [Union[T, None]], + Union[bytes, None], +] + + +if TYPE_CHECKING: + KT = TypeVar("KT", default=bytes) + VT = TypeVar("VT", default=bytes) +else: + KT = TypeVar("KT") + VT = TypeVar("VT") + -class BatchBuilder: +_Metadata: TypeAlias = "DefaultRecordMetadataProtocol | LegacyRecordMetadataProtocol" + + +class BatchBuilder(Generic[KT, VT]): def __init__( self, - magic, - batch_size, - compression_type, + magic: Literal[0, 1, 2], + batch_size: int, + compression_type: Literal[0, 1, 2, 3], *, - is_transactional, - key_serializer=None, - value_serializer=None, - ): + is_transactional: bool, + key_serializer: Union[BytesSerializer[KT], None] = None, + value_serializer: Union[BytesSerializer[VT], None] = None, + ) -> None: + self._builder: Union[ + LegacyRecordBatchBuilderProtocol, + DefaultRecordBatchBuilderProtocol, + ] if magic < 2: assert not is_transactional self._builder = LegacyRecordBatchBuilder( - magic, compression_type, batch_size + magic, # type: ignore[arg-type] + compression_type, + batch_size, ) else: self._builder = DefaultRecordBatchBuilder( - magic, + magic, # type: ignore[arg-type] compression_type, is_transactional=is_transactional, producer_id=-1, @@ -43,24 +95,38 @@ def __init__( batch_size=batch_size, ) self._relative_offset = 0 - self._buffer = None + self._buffer: Union[bytearray, None] = None self._closed = False self._key_serializer = key_serializer self._value_serializer = value_serializer - def _serialize(self, key, value): + def _serialize( + self, + key: KT | None, + value: VT | None, + ) -> tuple[bytes | None, bytes | None]: + serialized_key: bytes | None if self._key_serializer is None: - serialized_key = key + serialized_key = key # type: ignore[assignment] else: serialized_key = self._key_serializer(key) + + serialized_value: bytes | None if self._value_serializer is None: - serialized_value = value + serialized_value = value # type: ignore[assignment] else: serialized_value = self._value_serializer(value) return serialized_key, serialized_value - def append(self, *, timestamp, key, value, headers: Sequence = []): + def append( + self, + *, + timestamp: Union[int, None], + key: KT | None, + value: VT | None, + headers: Sequence[Tuple[str, bytes]] = [], + ) -> _Metadata | None: """Add a message to the batch. Arguments: @@ -98,7 +164,7 @@ def append(self, *, timestamp, key, value, headers: Sequence = []): self._relative_offset += 1 return metadata - def close(self): + def close(self) -> None: """Close the batch to further updates. Closing the batch before submitting to the producer ensures that no @@ -114,33 +180,45 @@ def close(self): return self._closed = True - def _set_producer_state(self, producer_id, producer_epoch, base_sequence): + def _set_producer_state( + self, + producer_id: int, + producer_epoch: int, + base_sequence: int, + ) -> None: assert type(self._builder) is DefaultRecordBatchBuilder self._builder.set_producer_state(producer_id, producer_epoch, base_sequence) - def _build(self): + def _build(self) -> bytearray: self.close() if self._buffer is None: self._buffer = self._builder.build() del self._builder # We may only call self._builder.build() once! return self._buffer - def size(self): + def size(self) -> int: """Get the size of batch in bytes.""" if self._buffer is not None: return len(self._buffer) else: return self._builder.size() - def record_count(self): + def record_count(self) -> int: """Get the number of records in the batch.""" return self._relative_offset +class _FutureCreator(Protocol): + def __call__( + self, + loop: Union[asyncio.AbstractEventLoop, None] = ..., + ) -> asyncio.Future[object]: ... + + class MessageBatch: """This class incapsulate operations with batch of produce messages""" - def __init__(self, tp, builder, ttl): + def __init__(self, tp: TopicPartition, builder: BatchBuilder, ttl: int) -> None: self._builder = builder self._tp = tp self._ttl = ttl @@ -148,28 +226,28 @@ def __init__(self, tp, builder, ttl): # Waiters # Set when messages are delivered to Kafka based on ACK setting - self.future = create_future() - self._msg_futures = [] + self.future: asyncio.Future[object] = create_future() + self._msg_futures: List[Tuple[asyncio.Future[object], _Metadata]] = [] # Set when sender takes this batch - self._drain_waiter = create_future() + self._drain_waiter: asyncio.Future[object] = create_future() self._retry_count = 0 @property - def tp(self): + def tp(self) -> TopicPartition: return self._tp @property - def record_count(self): + def record_count(self) -> int: return self._builder.record_count() def append( self, - key, - value, - timestamp_ms, - _create_future=create_future, - headers: Sequence = [], - ): + key: Union[bytes, None], + value: Union[bytes, None], + timestamp_ms: Union[int, None], + _create_future: _FutureCreator = create_future, + headers: Sequence[Tuple[str, bytes]] = [], + ) -> Union[asyncio.Future[object], None]: """Append message (key and value) to batch Returns: @@ -189,11 +267,11 @@ def append( def done( self, - base_offset, - timestamp=None, - log_start_offset=None, - _record_metadata_class=RecordMetadata, - ): + base_offset: int, + timestamp: Union[int, None] = None, + log_start_offset: int | None = None, + _record_metadata_class: type[RecordMetadata] = RecordMetadata, + ) -> None: """Resolve all pending futures""" tp = self._tp topic = tp.topic @@ -238,7 +316,7 @@ def done( ) ) - def done_noack(self): + def done_noack(self) -> None: """Resolve all pending futures to None""" # Faster resolve for base_offset=None case. if not self.future.done(): @@ -248,7 +326,7 @@ def done_noack(self): continue future.set_result(None) - def failure(self, exception): + def failure(self, exception: BaseException) -> None: if not self.future.done(): self.future.set_exception(exception) for future, _ in self._msg_futures: @@ -268,40 +346,45 @@ def failure(self, exception): if not self._drain_waiter.done(): self._drain_waiter.set_exception(exception) - async def wait_drain(self, timeout=None): + async def wait_drain(self, timeout: Union[float, None] = None) -> None: """Wait until all message from this batch is processed""" waiter = self._drain_waiter await asyncio.wait([waiter], timeout=timeout) if waiter.done(): waiter.result() # Check for exception - def expired(self): + def expired(self) -> bool: """Check that batch is expired or not""" return (time.monotonic() - self._ctime) > self._ttl - def drain_ready(self): + def drain_ready(self) -> None: """Compress batch to be ready for send""" if not self._drain_waiter.done(): self._drain_waiter.set_result(None) self._retry_count += 1 - def reset_drain(self): + def reset_drain(self) -> None: """Reset drain waiter, until we will do another retry""" assert self._drain_waiter.done() self._drain_waiter = create_future() - def set_producer_state(self, producer_id, producer_epoch, base_sequence): + def set_producer_state( + self, + producer_id: int, + producer_epoch: int, + base_sequence: int, + ) -> None: assert not self._drain_waiter.done() self._builder._set_producer_state(producer_id, producer_epoch, base_sequence) - def get_data_buffer(self): + def get_data_buffer(self) -> bytearray: return self._builder._build() - def is_empty(self): + def is_empty(self) -> bool: return self._builder.record_count() == 0 @property - def retry_count(self): + def retry_count(self) -> int: return self._retry_count @@ -314,34 +397,36 @@ class MessageAccumulator: def __init__( self, - cluster, - batch_size, - compression_type, - batch_ttl, + cluster: ClusterMetadata, + batch_size: int, + compression_type: Literal[0, 1, 2, 3], + batch_ttl: int, *, - txn_manager=None, - loop=None, - ): + txn_manager: Union[TransactionManager, None] = None, + loop: Union[asyncio.AbstractEventLoop, None] = None, + ) -> None: if loop is None: loop = get_running_loop() self._loop = loop - self._batches = collections.defaultdict(collections.deque) - self._pending_batches = set() + self._batches: DefaultDict[TopicPartition, Deque[MessageBatch]] = ( + collections.defaultdict(collections.deque) + ) + self._pending_batches: set[MessageBatch] = set() self._cluster = cluster self._batch_size = batch_size self._compression_type = compression_type self._batch_ttl = batch_ttl self._wait_data_future = loop.create_future() self._closed = False - self._api_version = (0, 9) + self._api_version: Union[Tuple[int, int, int], Tuple[int, int]] = (0, 9) self._txn_manager = txn_manager - self._exception = None # Critical exception + self._exception: BaseException | None = None # Critical exception - def set_api_version(self, api_version): + def set_api_version(self, api_version: Tuple[int, int, int]) -> None: self._api_version = api_version - async def flush(self): + async def flush(self) -> None: waiters = [ batch.future for batches in self._batches.values() for batch in batches ] @@ -349,7 +434,7 @@ async def flush(self): if waiters: await asyncio.wait(waiters) - async def flush_for_commit(self): + async def flush_for_commit(self) -> None: waiters = [] for batches in self._batches.values(): for batch in batches: @@ -364,7 +449,7 @@ async def flush_for_commit(self): if waiters: await asyncio.wait(waiters) - def fail_all(self, exception): + def fail_all(self, exception: BaseException) -> None: # Close all batches with this exception for batches in self._batches.values(): for batch in batches: @@ -373,19 +458,19 @@ def fail_all(self, exception): batch.failure(exception) self._exception = exception - async def close(self): + async def close(self) -> None: self._closed = True await self.flush() async def add_message( self, - tp, - key, - value, - timeout, - timestamp_ms=None, - headers: Sequence = [], - ): + tp: TopicPartition, + key: Union[bytes, None], + value: Union[bytes, None], + timeout: float, + timestamp_ms: Union[int, None] = None, + headers: Sequence[Tuple[str, bytes]] = [], + ) -> asyncio.Future[object] | None: """Add message to batch by topic-partition If batch is already full this method waits (`timeout` seconds maximum) until batch is drained by send task @@ -416,13 +501,13 @@ async def add_message( if timeout <= 0: raise KafkaTimeoutError() - def data_waiter(self): + def data_waiter(self) -> asyncio.Future[None]: """Return waiter future that will be resolved when accumulator contain some data for drain """ return self._wait_data_future - def _pop_batch(self, tp): + def _pop_batch(self, tp: TopicPartition) -> MessageBatch: batch = self._batches[tp].popleft() not_retry = batch.retry_count == 0 if self._txn_manager is not None and not_retry: @@ -443,22 +528,36 @@ def _pop_batch(self, tp): if not_retry: - def cb(fut, batch=batch, self=self): + def cb( + fut: asyncio.Future[object], + batch: MessageBatch = batch, + self: MessageAccumulator = self, + ) -> None: self._pending_batches.remove(batch) batch.future.add_done_callback(cb) return batch - def reenqueue(self, batch): + def reenqueue(self, batch: MessageBatch) -> None: tp = batch.tp self._batches[tp].appendleft(batch) self._pending_batches.remove(batch) batch.reset_drain() - def drain_by_nodes(self, ignore_nodes, muted_partitions=frozenset()): + def drain_by_nodes( + self, + ignore_nodes: AbstractSet[BrokerId], + muted_partitions: AbstractSet[TopicPartition] = frozenset(), + ) -> Tuple[ + Dict[BrokerId, Dict[TopicPartition, MessageBatch]], + bool, + ]: """Group batches by leader to partition nodes.""" - nodes = collections.defaultdict(dict) + nodes: DefaultDict[BrokerId, Dict[TopicPartition, MessageBatch]] = ( + collections.defaultdict(dict) + ) unknown_leaders_exist = False + err: BrokerResponseError for tp in list(self._batches.keys()): # Just ignoring by node is not enough, as leader can change during # the cycle @@ -499,7 +598,13 @@ def drain_by_nodes(self, ignore_nodes, muted_partitions=frozenset()): return nodes, unknown_leaders_exist - def create_builder(self, key_serializer=None, value_serializer=None): + def create_builder( + self, + key_serializer: BytesSerializer[KT] | None = None, + value_serializer: BytesSerializer[VT] | None = None, + ) -> BatchBuilder[KT, VT]: + magic: Literal[0, 1, 2] + if self._api_version >= (0, 11): magic = 2 elif self._api_version >= (0, 10): @@ -522,7 +627,7 @@ def create_builder(self, key_serializer=None, value_serializer=None): value_serializer=value_serializer, ) - def _append_batch(self, builder, tp): + def _append_batch(self, builder: BatchBuilder, tp: TopicPartition) -> MessageBatch: # We must do this before actual add takes place to check for errors. if self._txn_manager is not None: self._txn_manager.maybe_add_partition_to_txn(tp) @@ -533,7 +638,12 @@ def _append_batch(self, builder, tp): self._wait_data_future.set_result(None) return batch - async def add_batch(self, builder, tp, timeout): + async def add_batch( + self, + builder: BatchBuilder, + tp: TopicPartition, + timeout: Union[int, float], + ) -> asyncio.Future[object]: """Add BatchBuilder to queue by topic-partition. Arguments: diff --git a/aiokafka/producer/producer.py b/aiokafka/producer/producer.py index 5606040b..9b66c4e8 100644 --- a/aiokafka/producer/producer.py +++ b/aiokafka/producer/producer.py @@ -3,6 +3,10 @@ import sys import traceback import warnings +from typing import Generic, Union + +# Import TypeVar from typing extensions to get support for defaults (PEP 696). +from typing_extensions import TypeVar from aiokafka.client import AIOKafkaClient from aiokafka.codec import has_gzip, has_lz4, has_snappy, has_zstd @@ -22,7 +26,7 @@ get_running_loop, ) -from .message_accumulator import MessageAccumulator +from .message_accumulator import BytesSerializer, MessageAccumulator from .sender import Sender from .transaction_manager import TransactionManager @@ -33,8 +37,11 @@ _DEFAULT_PARTITIONER = DefaultPartitioner() +VT = TypeVar("VT", default=bytes) +KT = TypeVar("KT", default=bytes) + -class AIOKafkaProducer: +class AIOKafkaProducer(Generic[KT, VT]): """A Kafka client that publishes records to the Kafka cluster. The producer consists of a pool of buffer space that holds records that @@ -204,8 +211,8 @@ def __init__( request_timeout_ms=40000, api_version="auto", acks=_missing, - key_serializer=None, - value_serializer=None, + key_serializer: Union[BytesSerializer[KT], None] = None, + value_serializer: Union[BytesSerializer[VT], None] = None, compression_type=None, max_batch_size=16384, partitioner=_DEFAULT_PARTITIONER, @@ -441,8 +448,8 @@ def _partition( async def send( self, topic, - value=None, - key=None, + value: Union[KT, bytes, None] = None, + key: Union[KT, bytes, None] = None, partition=None, timestamp_ms=None, headers=None, diff --git a/aiokafka/producer/sender.py b/aiokafka/producer/sender.py index 519b5d34..05f06921 100644 --- a/aiokafka/producer/sender.py +++ b/aiokafka/producer/sender.py @@ -76,7 +76,7 @@ async def start(self): self._sender_task = create_task(self._sender_routine()) self._sender_task.add_done_callback(self._fail_all) - def _fail_all(self, task): + def _fail_all(self, task: asyncio.Task) -> None: """Called when sender fails. Will fail all pending batches, as they will never be delivered as well as fail transaction """ diff --git a/aiokafka/protocol/types.py b/aiokafka/protocol/types.py index 944783c0..8b2199b9 100644 --- a/aiokafka/protocol/types.py +++ b/aiokafka/protocol/types.py @@ -6,6 +6,7 @@ Callable, Dict, List, + NewType, Optional, Sequence, Tuple, @@ -423,3 +424,6 @@ def decode(self, data: BytesIO) -> Optional[List[Union[Any, Tuple[Any, ...]]]]: if length == -1: return None return [self.array_of.decode(data) for _ in range(length)] + + +BrokerId = NewType("BrokerId", int) diff --git a/aiokafka/record/_crecords/default_records.pyi b/aiokafka/record/_crecords/default_records.pyi index 0910f9de..243ceb34 100644 --- a/aiokafka/record/_crecords/default_records.pyi +++ b/aiokafka/record/_crecords/default_records.pyi @@ -1,4 +1,4 @@ -from typing import ClassVar, final +from typing import ClassVar, Sequence, final from typing_extensions import Literal, Self @@ -112,7 +112,7 @@ class DefaultRecordBatchBuilder(DefaultRecordBatchBuilderProtocol): timestamp: int | None, key: bytes | None, value: bytes | None, - headers: list[tuple[str, bytes | None]], + headers: Sequence[tuple[str, bytes | None]], ) -> DefaultRecordMetadata: ... def build(self) -> bytearray: ... def size(self) -> int: ... diff --git a/aiokafka/record/_protocols.py b/aiokafka/record/_protocols.py index 176932b1..0d71a02d 100644 --- a/aiokafka/record/_protocols.py +++ b/aiokafka/record/_protocols.py @@ -8,6 +8,7 @@ List, Optional, Protocol, + Sequence, Tuple, Union, runtime_checkable, @@ -30,7 +31,7 @@ class DefaultRecordBatchBuilderProtocol(Protocol): def __init__( self, - magic: int, + magic: Literal[2], compression_type: DefaultCompressionTypeT, is_transactional: int, producer_id: int, @@ -44,7 +45,7 @@ def append( timestamp: Optional[int], key: Optional[bytes], value: Optional[bytes], - headers: List[Tuple[str, Optional[bytes]]], + headers: Sequence[Tuple[str, Optional[bytes]]], ) -> Optional[DefaultRecordMetadataProtocol]: ... def build(self) -> bytearray: ... def size(self) -> int: ... diff --git a/aiokafka/record/default_records.py b/aiokafka/record/default_records.py index 8b0a596d..4092fc00 100644 --- a/aiokafka/record/default_records.py +++ b/aiokafka/record/default_records.py @@ -57,7 +57,18 @@ import struct import time from dataclasses import dataclass -from typing import Any, Callable, List, Optional, Sized, Tuple, Type, Union, final +from typing import ( + Any, + Callable, + List, + Optional, + Sequence, + Sized, + Tuple, + Type, + Union, + final, +) from typing_extensions import Self, TypeIs, assert_never @@ -432,7 +443,7 @@ def append( timestamp: Optional[int], key: Optional[bytes], value: Optional[bytes], - headers: List[Tuple[str, Optional[bytes]]], + headers: Sequence[Tuple[str, Optional[bytes]]], # Cache for LOAD_FAST opcodes encode_varint: Callable[[int, Callable[[int], None]], int] = encode_varint, size_of_varint: Callable[[int], int] = size_of_varint,