Skip to content

Commit

Permalink
improving disconnection Information report
Browse files Browse the repository at this point in the history
  • Loading branch information
DanielePalaia authored and Daniele Palaia committed Oct 11, 2023
1 parent 11da180 commit 78f2b22
Show file tree
Hide file tree
Showing 11 changed files with 185 additions and 22 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,21 @@
from rstream import (
AMQPMessage,
Consumer,
DisconnectionErrorInfo,
MessageContext,
amqp_decoder,
)

STREAM = "my-test-stream"


async def on_connection_closed(reason: Exception) -> None:
print("connection has been closed for reason: " + str(reason))
async def on_connection_closed(disconnection_info: DisconnectionErrorInfo) -> None:
print(
"connection has been closed from stream: "
+ str(disconnection_info.streams)
+ " for reason: "
+ str(disconnection_info.reason)
)


async def consume():
Expand Down
Original file line number Diff line number Diff line change
@@ -1,14 +1,23 @@
import asyncio
import time

from rstream import AMQPMessage, Producer
from rstream import (
AMQPMessage,
DisconnectionErrorInfo,
Producer,
)

STREAM = "my-test-stream"
MESSAGES = 1000000


async def on_connection_closed(reason: Exception) -> None:
print("connection has been closed for reason: " + str(reason))
async def on_connection_closed(disconnection_info: DisconnectionErrorInfo) -> None:
print(
"connection has been closed from stream: "
+ str(disconnection_info.streams)
+ " for reason: "
+ str(disconnection_info.reason)
)


async def publish():
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
import asyncio
import signal

from rstream import (
AMQPMessage,
ConsumerOffsetSpecification,
DisconnectionErrorInfo,
MessageContext,
OffsetType,
SuperStreamConsumer,
amqp_decoder,
)

cont = 0


async def on_message(msg: AMQPMessage, message_context: MessageContext):
stream = await message_context.consumer.stream(message_context.subscriber_name)
offset = message_context.offset
print("Received message: {} from stream: {} - message offset: {}".format(msg, stream, offset))


async def on_connection_closed(disconnection_info: DisconnectionErrorInfo) -> None:
print(
"connection has been closed from stream: "
+ str(disconnection_info.streams)
+ " for reason: "
+ str(disconnection_info.reason)
)


async def consume():
consumer = SuperStreamConsumer(
host="localhost",
port=5552,
vhost="/",
username="guest",
password="guest",
super_stream="test_super_stream",
connection_closed_handler=on_connection_closed,
)

loop = asyncio.get_event_loop()
loop.add_signal_handler(signal.SIGINT, lambda: asyncio.create_task(consumer.close()))
offset_specification = ConsumerOffsetSpecification(OffsetType.FIRST, None)
await consumer.start()
await consumer.subscribe(
callback=on_message, decoder=amqp_decoder, offset_specification=offset_specification
)
await consumer.run()


asyncio.run(consume())
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import asyncio
import time

from rstream import (
AMQPMessage,
DisconnectionErrorInfo,
RouteType,
SuperStreamProducer,
)

SUPER_STREAM = "test_super_stream"
MESSAGES = 10000000


async def publish():
# this value will be hashed using mumh3 hashing algorithm to decide the partition resolution for the message
async def routing_extractor(message: AMQPMessage) -> str:
return message.application_properties["id"]

async def on_connection_closed(disconnection_info: DisconnectionErrorInfo) -> None:

print(
"connection has been closed from stream: "
+ str(disconnection_info.streams)
+ " for reason: "
+ str(disconnection_info.reason)
)

async with SuperStreamProducer(
"localhost",
username="guest",
password="guest",
routing_extractor=routing_extractor,
routing=RouteType.Hash,
connection_closed_handler=on_connection_closed,
super_stream=SUPER_STREAM,
) as super_stream_producer:

# sending a million of messages in AMQP format
start_time = time.perf_counter()

for i in range(MESSAGES):
amqp_message = AMQPMessage(
body="hello: {}".format(i),
application_properties={"id": "{}".format(i)},
)
# send is asynchronous
await super_stream_producer.send(message=amqp_message)

end_time = time.perf_counter()
print(f"Sent {MESSAGES} messages in {end_time - start_time:0.4f} seconds")


asyncio.run(publish())
3 changes: 3 additions & 0 deletions rstream/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@

from importlib import metadata

from .utils import DisconnectionErrorInfo

try:
__version__ = metadata.version(__package__)
__license__ = metadata.metadata(__package__)["license"]
Expand Down Expand Up @@ -59,4 +61,5 @@
"StreamDoesNotExist",
"OffsetSpecification",
"EventContext",
"DisconnectionErrorInfo",
]
43 changes: 33 additions & 10 deletions rstream/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
)
from .connection import Connection, ConnectionClosed
from .schema import OffsetSpecification
from .utils import DisconnectionErrorInfo

FT = TypeVar("FT", bound=schema.Frame)
HT = Annotated[
Expand Down Expand Up @@ -66,7 +67,7 @@ def __init__(
ssl_context: Optional[ssl.SSLContext] = None,
frame_max: int,
heartbeat: int,
connection_closed_handler: Optional[CB[Exception]] = None,
connection_closed_handler: Optional[CB[DisconnectionErrorInfo]] = None,
):
self.host = host
self.port = port
Expand All @@ -93,9 +94,13 @@ def __init__(

self._last_heartbeat: float = 0
self._connection_closed_handler = connection_closed_handler

self._frames: dict[str, asyncio.Queue] = defaultdict(asyncio.Queue)
self._is_not_closed: bool = True

self._streams: list[str] = []
self._conn_is_closed: bool = False

def start_task(self, name: str, coro: Awaitable[None]) -> None:
assert name not in self._tasks
task = self._tasks[name] = asyncio.create_task(coro)
Expand Down Expand Up @@ -131,16 +136,21 @@ def remove_handler(self, frame_cls: Type[FT], name: Optional[str] = None) -> Non
else:
self._handlers[frame_cls].clear()

def get_connection(self) -> Optional[Connection]:
return self._conn

async def send_frame(self, frame: schema.Frame) -> None:
logger.debug("Sending frame: %s", frame)
assert self._conn
try:
await self._conn.write_frame(frame)
except socket.error as e:
self._conn_is_closed = True
if self._connection_closed_handler is None:
print("TCP connection closed")
else:
result = self._connection_closed_handler(e)
connection_error_info = DisconnectionErrorInfo(e, self._streams)
result = self._connection_closed_handler(connection_error_info)
if result is not None and inspect.isawaitable(result):
await result

Expand Down Expand Up @@ -203,16 +213,20 @@ async def _listener(self) -> None:
try:
frame = await self._conn.read_frame()
except ConnectionClosed as e:
self._conn_is_closed = True

if self._connection_closed_handler is not None:
result = self._connection_closed_handler(e)
connection_error_info = DisconnectionErrorInfo(e, self._streams)
result = self._connection_closed_handler(connection_error_info)
if result is not None and inspect.isawaitable(result):
await result
else:
print("TCP connection closed")
break
except socket.error as e:
if self._connection_closed_handler is not None:
result = self._connection_closed_handler(e)
connection_error_info = DisconnectionErrorInfo(e, self._streams)
result = self._connection_closed_handler(connection_error_info)
if result is not None and inspect.isawaitable(result):
await result
else:
Expand All @@ -235,6 +249,7 @@ async def _listener(self) -> None:
maybe_coro = handler(frame)
if maybe_coro is not None:
await maybe_coro

except Exception:
logger.exception("Error while running handler %s of frame %s", handler, frame)

Expand Down Expand Up @@ -270,7 +285,7 @@ def is_started(self) -> bool:
async def close(self) -> None:
logger.info("Stopping client %s:%s", self.host, self.port)

if self._conn is None:
if self._conn_is_closed is True:
return

if self.is_started:
Expand All @@ -290,8 +305,9 @@ async def close(self) -> None:
for subscriber_name in self._frames:
await self.stop_task(f"run_delivery_handlers_{subscriber_name}")

await self._conn.close()
self._conn = None
if self._conn is not None:
await self._conn.close()
self._conn = None

self.server_properties = None
self._tasks.clear()
Expand Down Expand Up @@ -417,6 +433,7 @@ async def query_leader_and_replicas(
assert len(metadata_resp.metadata) == 1
metadata = metadata_resp.metadata[0]
assert metadata.name == stream
self._streams.append(stream)

brokers = {broker.reference: broker for broker in metadata_resp.brokers}
leader = brokers[metadata.leader_ref]
Expand Down Expand Up @@ -494,6 +511,8 @@ async def declare_publisher(self, stream: str, reference: str, publisher_id: int
)

async def delete_publisher(self, publisher_id: int) -> None:
if self._conn is None:
return
await self.sync_request(
schema.DeletePublisher(
self._corr_id_seq.next(),
Expand Down Expand Up @@ -584,7 +603,9 @@ def __init__(
self._clients: dict[Addr, Client] = {}

async def get(
self, addr: Optional[Addr] = None, connection_closed_handler: Optional[CB[Exception]] = None
self,
addr: Optional[Addr] = None,
connection_closed_handler: Optional[CB[DisconnectionErrorInfo]] = None,
) -> Client:
"""Get a client according to `addr` parameter
Expand All @@ -610,7 +631,7 @@ async def get(
return self._clients[desired_addr]

async def _resolve_broker(
self, addr: Addr, connection_closed_handler: Optional[CB[Exception]] = None
self, addr: Addr, connection_closed_handler: Optional[CB[DisconnectionErrorInfo]] = None
) -> Client:
desired_host, desired_port = addr.host, str(addr.port)

Expand All @@ -636,7 +657,9 @@ async def _resolve_broker(
f"Failed to connect to {desired_host}:{desired_port} after {self.max_retries} tries"
)

async def new(self, addr: Addr, connection_closed_handler: Optional[CB[Exception]] = None) -> Client:
async def new(
self, addr: Addr, connection_closed_handler: Optional[CB[DisconnectionErrorInfo]] = None
) -> Client:
host, port = addr
client = Client(
host=host,
Expand Down
6 changes: 4 additions & 2 deletions rstream/consumer.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
OffsetType,
)
from .schema import OffsetSpecification
from .utils import DisconnectionErrorInfo

MT = TypeVar("MT")
CB = Annotated[Callable[[MT, Any], Union[None, Awaitable[None]]], "Message callback type"]
Expand Down Expand Up @@ -71,7 +72,7 @@ def __init__(
heartbeat: int = 60,
load_balancer_mode: bool = False,
max_retries: int = 20,
connection_closed_handler: Optional[CB_CONN[Exception]] = None,
connection_closed_handler: Optional[CB_CONN[DisconnectionErrorInfo]] = None,
):
self._pool = ClientPool(
host,
Expand Down Expand Up @@ -331,7 +332,8 @@ async def stream_exists(self, stream: str) -> bool:
return await self.default_client.stream_exists(stream)

async def stream(self, subscriber_name) -> str:

if subscriber_name not in self._subscribers:
return ""
return self._subscribers[subscriber_name].stream

def get_stream(self, subscriber_name) -> str:
Expand Down
9 changes: 6 additions & 3 deletions rstream/producer.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
CompressionType,
ICompressionCodec,
)
from .utils import RawMessage
from .utils import DisconnectionErrorInfo, RawMessage

MessageT = TypeVar("MessageT", _MessageProtocol, bytes)
MT = TypeVar("MT")
Expand Down Expand Up @@ -75,7 +75,7 @@ def __init__(
max_retries: int = 20,
default_batch_publishing_delay: float = 0.2,
default_context_switch_value: int = 1000,
connection_closed_handler: Optional[CB[Exception]] = None,
connection_closed_handler: Optional[CB[DisconnectionErrorInfo]] = None,
):
self._pool = ClientPool(
host,
Expand Down Expand Up @@ -125,7 +125,6 @@ async def start(self) -> None:
async def close(self) -> None:
# flush messages still in buffer
if self.task is not None:

for stream in self._buffered_messages:
await self._publish_buffered_messages(stream)
self.task.cancel()
Expand Down Expand Up @@ -393,6 +392,10 @@ async def _timer(self):

async def _publish_buffered_messages(self, stream: str) -> None:

if stream in self._clients:
if self._clients[stream].get_connection() is None:
return

async with self._buffered_messages_lock:
if len(self._buffered_messages[stream]):
await self._send_batch(stream, self._buffered_messages[stream], sync=False)
Expand Down
Loading

0 comments on commit 78f2b22

Please sign in to comment.