Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add typing to tests/test_protocol* #1005

Merged
merged 4 commits into from
May 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ FORMATTED_AREAS=\
aiokafka/record/ \
tests/test_codec.py \
tests/test_helpers.py \
tests/test_protocol.py \
tests/test_protocol_object_conversion.py \
tests/record/

.PHONY: setup
Expand Down
6 changes: 6 additions & 0 deletions aiokafka/protocol/fetch.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import List, Optional, Tuple

from .api import Request, Response
from .types import Array, Bytes, Int8, Int16, Int32, Int64, Schema, String

Expand All @@ -23,6 +25,8 @@ class FetchResponse_v0(Response):
)
)

topics: Optional[List[Tuple[str, List[Tuple[int, int, int, bytes]]]]]


class FetchResponse_v1(Response):
API_KEY = 1
Expand Down Expand Up @@ -235,6 +239,8 @@ class FetchRequest_v0(Request):
),
)

min_bytes: Optional[int]


class FetchRequest_v1(Request):
API_KEY = 1
Expand Down
48 changes: 25 additions & 23 deletions tests/test_protocol.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
import io
import struct
from typing import Type

import pytest

from aiokafka.protocol.api import Request, RequestHeader_v0, Response
from aiokafka.protocol.commit import GroupCoordinatorRequest
from aiokafka.protocol.fetch import FetchRequest, FetchResponse
from aiokafka.protocol.commit import GroupCoordinatorRequest_v0
from aiokafka.protocol.fetch import FetchRequest_v0, FetchResponse_v0
from aiokafka.protocol.message import Message, MessageSet, PartialMessage
from aiokafka.protocol.metadata import MetadataRequest
from aiokafka.protocol.metadata import MetadataRequest_v0
from aiokafka.protocol.types import (
CompactArray,
CompactBytes,
Expand All @@ -20,7 +21,7 @@
)


def test_create_message():
def test_create_message() -> None:
payload = b"test"
key = b"key"
msg = Message(value=payload, key=key, magic=0, attributes=0, crc=0)
Expand All @@ -30,7 +31,7 @@ def test_create_message():
assert msg.value == payload


def test_encode_message_v0():
def test_encode_message_v0() -> None:
message = Message(value=b"test", key=b"key", magic=0, attributes=0, crc=0)
encoded = message.encode()
expect = b"".join(
Expand All @@ -46,7 +47,7 @@ def test_encode_message_v0():
assert encoded == expect


def test_encode_message_v1():
def test_encode_message_v1() -> None:
message = Message(
value=b"test", key=b"key", magic=1, attributes=0, crc=0, timestamp=1234
)
Expand All @@ -65,7 +66,7 @@ def test_encode_message_v1():
assert encoded == expect


def test_decode_message():
def test_decode_message() -> None:
encoded = b"".join(
[
struct.pack(">i", -1427009701), # CRC
Expand All @@ -82,7 +83,7 @@ def test_decode_message():
assert decoded_message == msg


def test_decode_message_validate_crc():
def test_decode_message_validate_crc() -> None:
encoded = b"".join(
[
struct.pack(">i", -1427009701), # CRC
Expand Down Expand Up @@ -110,7 +111,7 @@ def test_decode_message_validate_crc():
assert decoded_message.validate_crc() is False


def test_encode_message_set():
def test_encode_message_set() -> None:
messages = [
Message(value=b"v1", key=b"k1", magic=0, attributes=0, crc=0),
Message(value=b"v2", key=b"k2", magic=0, attributes=0, crc=0),
Expand Down Expand Up @@ -140,7 +141,7 @@ def test_encode_message_set():
assert encoded == expect


def test_decode_message_set():
def test_decode_message_set() -> None:
encoded = b"".join(
[
struct.pack(">q", 0), # MsgSet Offset
Expand Down Expand Up @@ -180,7 +181,7 @@ def test_decode_message_set():
assert decoded_message2 == message2


def test_encode_message_header():
def test_encode_message_header() -> None:
expect = b"".join(
[
struct.pack(">h", 10), # API Key
Expand All @@ -191,12 +192,12 @@ def test_encode_message_header():
]
)

req = GroupCoordinatorRequest[0]("foo")
req = GroupCoordinatorRequest_v0("foo")
header = RequestHeader_v0(req, correlation_id=4, client_id="client3")
assert header.encode() == expect


def test_decode_message_set_partial():
def test_decode_message_set_partial() -> None:
encoded = b"".join(
[
struct.pack(">q", 0), # Msg Offset
Expand Down Expand Up @@ -235,7 +236,7 @@ def test_decode_message_set_partial():
assert decoded_message2 == PartialMessage()


def test_decode_fetch_response_partial():
def test_decode_fetch_response_partial() -> None:
encoded = b"".join(
[
Int32.encode(1), # Num Topics (Array)
Expand Down Expand Up @@ -283,7 +284,8 @@ def test_decode_fetch_response_partial():
b"ar", # Value (truncated)
]
)
resp = FetchResponse[0].decode(io.BytesIO(encoded))
resp = FetchResponse_v0.decode(io.BytesIO(encoded))
assert resp.topics is not None
assert len(resp.topics) == 1
topic, partitions = resp.topics[0]
assert topic == "foobar"
Expand All @@ -294,18 +296,18 @@ def test_decode_fetch_response_partial():
assert m1[1] == (None, None, PartialMessage())


def test_struct_unrecognized_kwargs():
def test_struct_unrecognized_kwargs() -> None:
# Structs should not allow unrecognized kwargs
with pytest.raises(ValueError):
MetadataRequest[0](topicz="foo")
MetadataRequest_v0(topicz="foo")


def test_struct_missing_kwargs():
fr = FetchRequest[0](max_wait_time=100)
def test_struct_missing_kwargs() -> None:
fr = FetchRequest_v0(max_wait_time=100)
assert fr.min_bytes is None


def test_unsigned_varint_serde():
def test_unsigned_varint_serde() -> None:
pairs = {
0: [0],
-1: [0xFF, 0xFF, 0xFF, 0xFF, 0x0F],
Expand All @@ -326,7 +328,7 @@ def test_unsigned_varint_serde():
assert value == UnsignedVarInt32.decode(io.BytesIO(encoded))


def test_compact_data_structs():
def test_compact_data_structs() -> None:
cs = CompactString()
encoded = cs.encode(None)
assert encoded == struct.pack("B", 0)
Expand Down Expand Up @@ -366,7 +368,7 @@ def test_compact_data_structs():

@pytest.mark.parametrize("klass", Request.__subclasses__())
@pytest.mark.parametrize("attr_name", attr_names)
def test_request_type_conformance(klass, attr_name):
def test_request_type_conformance(klass: Type[Request], attr_name: str) -> None:
assert hasattr(klass, attr_name)


Expand All @@ -380,5 +382,5 @@ def test_request_type_conformance(klass, attr_name):

@pytest.mark.parametrize("klass", Response.__subclasses__())
@pytest.mark.parametrize("attr_name", attr_names)
def test_response_type_conformance(klass, attr_name):
def test_response_type_conformance(klass: Type[Response], attr_name: str) -> None:
assert hasattr(klass, attr_name)
Loading
Loading