Skip to content

Commit

Permalink
Merge pull request #972 from ods/fix-tagged-fields
Browse files Browse the repository at this point in the history
Fix tagged fields handling
  • Loading branch information
ods authored Jan 30, 2024
2 parents e8383ea + c54f257 commit 9166c96
Show file tree
Hide file tree
Showing 6 changed files with 49 additions and 32 deletions.
5 changes: 3 additions & 2 deletions aiokafka/admin/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -649,8 +649,9 @@ async def delete_records(
timeout_ms or self._request_timeout_ms,
)
response = await self._client.send(leader, request)
for topic, partitions in response.topics:
for partition_index, low_watermark, error_code in partitions:
# Starting with v2, DeleteRecordsResponse contains extra field with tags
for topic, partitions, *_ in response.topics:
for partition_index, low_watermark, error_code, *_ in partitions:
if error_code:
err = for_code(error_code)
raise err
Expand Down
26 changes: 13 additions & 13 deletions aiokafka/conn.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import functools
import hashlib
import hmac
import io
import logging
import random
import socket
Expand All @@ -24,7 +25,6 @@
SaslAuthenticateRequest,
SaslHandShakeRequest,
)
from aiokafka.protocol.api import RequestHeader
from aiokafka.protocol.commit import (
GroupCoordinatorResponse_v0 as GroupCoordinatorResponse,
)
Expand Down Expand Up @@ -459,10 +459,8 @@ def send(self, request, expect_response=True):
)

correlation_id = self._next_correlation_id()
header = RequestHeader(
request,
correlation_id=correlation_id,
client_id=self._client_id,
header = request.build_request_header(
correlation_id=correlation_id, client_id=self._client_id
)
message = header.encode() + request.encode()
size = struct.pack(">i", len(message))
Expand All @@ -480,7 +478,7 @@ def send(self, request, expect_response=True):
return self._writer.drain()
fut = self._loop.create_future()
self._requests.append(
(correlation_id, request.RESPONSE_TYPE, fut),
(correlation_id, request, fut),
)
return wait_for(fut, self._request_timeout)

Expand Down Expand Up @@ -569,39 +567,41 @@ async def _read(self_ref):
del self

def _handle_frame(self, resp):
correlation_id, resp_type, fut = self._requests[0]
correlation_id, request, fut = self._requests[0]

if correlation_id is None: # Is a SASL packet, just pass it though
if not fut.done():
fut.set_result(resp)
else:
(recv_correlation_id,) = struct.unpack_from(">i", resp, 0)
resp = io.BytesIO(resp)
response_header = request.parse_response_header(resp)
resp_type = request.RESPONSE_TYPE

if (
self._api_version == (0, 8, 2)
and resp_type is GroupCoordinatorResponse
and correlation_id != 0
and recv_correlation_id == 0
and response_header.correlation_id == 0
):
log.warning(
"Kafka 0.8.2 quirk -- GroupCoordinatorResponse"
" coorelation id does not match request. This"
" correlation id does not match request. This"
" should go away once at least one topic has been"
" initialized on the broker"
)

elif correlation_id != recv_correlation_id:
elif response_header.correlation_id != correlation_id:
error = Errors.CorrelationIdError(
f"Correlation ids do not match: sent {correlation_id},"
f" recv {recv_correlation_id}"
f" recv {response_header.correlation_id}"
)
if not fut.done():
fut.set_exception(error)
self.close(reason=CloseReason.OUT_OF_SYNC)
return

if not fut.done():
response = resp_type.decode(resp[4:])
response = resp_type.decode(resp)
log.debug("%s Response %d: %s", self, correlation_id, response)
fut.set_result(response)

Expand Down
23 changes: 19 additions & 4 deletions aiokafka/protocol/admin.py
Original file line number Diff line number Diff line change
Expand Up @@ -1385,17 +1385,32 @@ class DeleteRecordsRequest_v2(Request):
("tags", TaggedFields),
)

def __init__(self, topics, timeout_ms, tags=None):
super().__init__(
[
(
topic,
[
(partition, before_offset, {})
for partition, before_offset in partitions
],
{},
)
for (topic, partitions) in topics
],
timeout_ms,
tags or {},
)


DeleteRecordsRequest = [
DeleteRecordsRequest_v0,
DeleteRecordsRequest_v1,
# FIXME: We have some problems with `TaggedFields`
# DeleteRecordsRequest_v2,
DeleteRecordsRequest_v2,
]

DeleteRecordsResponse = [
DeleteRecordsResponse_v0,
DeleteRecordsResponse_v1,
# FIXME: We have some problems with `TaggedFields`
# DeleteRecordsResponse_v2,
DeleteRecordsResponse_v2,
]
22 changes: 12 additions & 10 deletions aiokafka/protocol/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from .types import Array, Int16, Int32, Schema, String, TaggedFields


class RequestHeader(Struct):
class RequestHeader_v0(Struct):
SCHEMA = Schema(
("api_key", Int16),
("api_version", Int16),
Expand All @@ -13,12 +13,12 @@ class RequestHeader(Struct):
)

def __init__(self, request, correlation_id=0, client_id="aiokafka"):
super(RequestHeader, self).__init__(
super().__init__(
request.API_KEY, request.API_VERSION, correlation_id, client_id
)


class RequestHeaderV2(Struct):
class RequestHeader_v1(Struct):
# Flexible response / request headers end in field buffer
SCHEMA = Schema(
("api_key", Int16),
Expand All @@ -29,18 +29,18 @@ class RequestHeaderV2(Struct):
)

def __init__(self, request, correlation_id=0, client_id="aiokafka", tags=None):
super(RequestHeaderV2, self).__init__(
super().__init__(
request.API_KEY, request.API_VERSION, correlation_id, client_id, tags or {}
)


class ResponseHeader(Struct):
class ResponseHeader_v0(Struct):
SCHEMA = Schema(
("correlation_id", Int32),
)


class ResponseHeaderV2(Struct):
class ResponseHeader_v1(Struct):
SCHEMA = Schema(
("correlation_id", Int32),
("tags", TaggedFields),
Expand Down Expand Up @@ -81,15 +81,17 @@ def to_object(self):

def build_request_header(self, correlation_id, client_id):
if self.FLEXIBLE_VERSION:
return RequestHeaderV2(
return RequestHeader_v1(
self, correlation_id=correlation_id, client_id=client_id
)
return RequestHeader(self, correlation_id=correlation_id, client_id=client_id)
return RequestHeader_v0(
self, correlation_id=correlation_id, client_id=client_id
)

def parse_response_header(self, read_buffer):
if self.FLEXIBLE_VERSION:
return ResponseHeaderV2.decode(read_buffer)
return ResponseHeader.decode(read_buffer)
return ResponseHeader_v1.decode(read_buffer)
return ResponseHeader_v0.decode(read_buffer)


class Response(Struct):
Expand Down
1 change: 0 additions & 1 deletion aiokafka/protocol/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,7 +313,6 @@ def encode(self, value):
return UnsignedVarInt32.encode(len(value) + 1) + value


# FIXME: TaggedFields doesn't seem to work properly so they should be avoided
class TaggedFields(AbstractType):
@classmethod
def decode(cls, data):
Expand Down
4 changes: 2 additions & 2 deletions tests/test_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import pytest

from aiokafka.protocol.api import Request, RequestHeader, Response
from aiokafka.protocol.api import Request, RequestHeader_v0, Response
from aiokafka.protocol.commit import GroupCoordinatorRequest
from aiokafka.protocol.fetch import FetchRequest, FetchResponse
from aiokafka.protocol.message import Message, MessageSet, PartialMessage
Expand Down Expand Up @@ -188,7 +188,7 @@ def test_encode_message_header():
)

req = GroupCoordinatorRequest[0]("foo")
header = RequestHeader(req, correlation_id=4, client_id="client3")
header = RequestHeader_v0(req, correlation_id=4, client_id="client3")
assert header.encode() == expect


Expand Down

0 comments on commit 9166c96

Please sign in to comment.