Skip to content

Commit

Permalink
adding an event to catch connections issues
Browse files Browse the repository at this point in the history
  • Loading branch information
DanielePalaia committed Jul 10, 2023
1 parent 3076c86 commit 6f58020
Show file tree
Hide file tree
Showing 7 changed files with 135 additions and 15 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import asyncio
import signal

from rstream import (
AMQPMessage,
Consumer,
MessageContext,
amqp_decoder,
)

STREAM = "my-test-stream"


def on_connection_closed(reason: Exception) -> None:
print("connection has been closed for reason: " + str(reason))


async def consume():
consumer = Consumer(
host="localhost",
port=5552,
vhost="/",
username="guest",
password="guest",
connection_closed_handler=on_connection_closed,
)

loop = asyncio.get_event_loop()
loop.add_signal_handler(signal.SIGINT, lambda: asyncio.create_task(consumer.close()))

async def on_message(msg: AMQPMessage, message_context: MessageContext):
stream = message_context.consumer.get_stream(message_context.subscriber_name)
offset = message_context.offset
print("Got message: {} from stream {}, offset {}".format(msg, stream, offset))

await consumer.start()
await consumer.subscribe(stream=STREAM, callback=on_message, decoder=amqp_decoder)
await consumer.run()


asyncio.run(consume())
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import asyncio
import time

from rstream import AMQPMessage, Producer

STREAM = "my-test-stream"
MESSAGES = 1000000


def on_connection_closed(reason: Exception) -> None:
print("connection has been closed for reason: " + str(reason))


async def publish():

async with Producer("localhost", username="guest", password="guest") as producer:
# create a stream if it doesn't already exist
await producer.create_stream(STREAM, exists_ok=True)

# sending a million of messages in AMQP format
start_time = time.perf_counter()

for i in range(MESSAGES):
amqp_message = AMQPMessage(
body="hello: {}".format(i),
)
# send is asynchronous
await producer.send(stream=STREAM, message=amqp_message)

end_time = time.perf_counter()
print(f"Sent {MESSAGES} messages in {end_time - start_time:0.4f} seconds")


asyncio.run(publish())
46 changes: 38 additions & 8 deletions rstream/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import asyncio
import logging
import socket
import ssl
import time
from collections import defaultdict
Expand Down Expand Up @@ -40,6 +41,8 @@

DEFAULT_REQUEST_TIMEOUT = 10

MT = TypeVar("MT")
CB = Annotated[Callable[[MT], Union[None, Awaitable[None]]], "Message callback type"]

logger = logging.getLogger(__name__)

Expand All @@ -62,6 +65,7 @@ def __init__(
ssl_context: Optional[ssl.SSLContext] = None,
frame_max: int,
heartbeat: int,
connection_closed_handler: Optional[CB[Exception]] = None,
):
self.host = host
self.port = port
Expand Down Expand Up @@ -89,6 +93,7 @@ def __init__(
self._handlers: dict[Type[schema.Frame], dict[str, HT[Any]]] = defaultdict(dict)

self._last_heartbeat: float = 0
self._connection_closed_handler = connection_closed_handler

def start_task(self, name: str, coro: Awaitable[None]) -> None:
assert name not in self._tasks
Expand Down Expand Up @@ -128,7 +133,13 @@ def remove_handler(self, frame_cls: Type[FT], name: Optional[str] = None) -> Non
async def send_frame(self, frame: schema.Frame) -> None:
logger.debug("Sending frame: %s", frame)
assert self._conn
await self._conn.write_frame(frame)
try:
await self._conn.write_frame(frame)
except socket.error as e:
if self._connection_closed_handler is None:
print("TCP connection closed")
else:
self._connection_closed_handler(e)

def wait_frame(
self,
Expand Down Expand Up @@ -166,7 +177,17 @@ async def _listener(self) -> None:
while True:
try:
frame = await self._conn.read_frame()
except ConnectionClosed:
except ConnectionClosed as e:
if self._connection_closed_handler is not None:
self._connection_closed_handler(e)
else:
print("TCP connection closed")
break
except socket.error as e:
if self._connection_closed_handler is not None:
self._connection_closed_handler(e)
else:
print("TCP connection closed")
break

logger.debug("Received frame: %s", frame)
Expand Down Expand Up @@ -523,7 +544,9 @@ def __init__(
self._heartbeat = heartbeat
self._clients: dict[Addr, Client] = {}

async def get(self, addr: Optional[Addr] = None) -> Client:
async def get(
self, addr: Optional[Addr] = None, connection_closed_handler: Optional[CB[Exception]] = None
) -> Client:
"""Get a client according to `addr` parameter
If class param `load_balancer_mode` is True, we create a connection via the LB
Expand All @@ -536,20 +559,26 @@ async def get(self, addr: Optional[Addr] = None) -> Client:

if desired_addr not in self._clients:
if addr and self.load_balancer_mode:
self._clients[desired_addr] = await self._resolve_broker(desired_addr)
self._clients[desired_addr] = await self._resolve_broker(
desired_addr, connection_closed_handler
)
else:
self._clients[desired_addr] = await self.new(desired_addr)
self._clients[desired_addr] = await self.new(
addr=desired_addr, connection_closed_handler=connection_closed_handler
)

assert self._clients[desired_addr].is_started
return self._clients[desired_addr]

async def _resolve_broker(self, addr: Addr) -> Client:
async def _resolve_broker(
self, addr: Addr, connection_closed_handler: Optional[CB[Exception]] = None
) -> Client:
desired_host, desired_port = addr.host, str(addr.port)

connection_attempts = 0

while connection_attempts < self.max_retries:
client = await self.new(self.addr)
client = await self.new(addr=self.addr, connection_closed_handler=connection_closed_handler)

assert client.server_properties is not None

Expand All @@ -568,14 +597,15 @@ async def _resolve_broker(self, addr: Addr) -> Client:
f"Failed to connect to {desired_host}:{desired_port} after {self.max_retries} tries"
)

async def new(self, addr: Addr) -> Client:
async def new(self, addr: Addr, connection_closed_handler: Optional[CB[Exception]] = None) -> Client:
host, port = addr
client = Client(
host=host,
port=port,
ssl_context=self.ssl_context,
frame_max=self._frame_max,
heartbeat=self._heartbeat,
connection_closed_handler=connection_closed_handler,
)
await client.start()
await client.authenticate(
Expand Down
9 changes: 7 additions & 2 deletions rstream/consumer.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@

MT = TypeVar("MT")
CB = Annotated[Callable[[MT, Any], Union[None, Awaitable[None]]], "Message callback type"]
CB_CONN = Annotated[Callable[[MT], Union[None, Awaitable[None]]], "Message callback type"]


@dataclass
Expand Down Expand Up @@ -70,6 +71,7 @@ def __init__(
heartbeat: int = 60,
load_balancer_mode: bool = False,
max_retries: int = 20,
connection_closed_handler: Optional[CB_CONN[Exception]] = None,
):
self._pool = ClientPool(
host,
Expand All @@ -89,6 +91,7 @@ def __init__(
self._subscribers: dict[str, _Subscriber] = {}
self._stop_event = asyncio.Event()
self._lock = asyncio.Lock()
self._connection_closed_handler = connection_closed_handler

@property
def default_client(self) -> Client:
Expand All @@ -104,7 +107,7 @@ async def __aexit__(self, *_: Any) -> None:
await self.close()

async def start(self) -> None:
self._default_client = await self._pool.get()
self._default_client = await self._pool.get(connection_closed_handler=self._connection_closed_handler)

def stop(self) -> None:
self._stop_event.set()
Expand All @@ -129,7 +132,9 @@ async def _get_or_create_client(self, stream: str) -> Client:
if stream not in self._clients:
leader, replicas = await self.default_client.query_leader_and_replicas(stream)
broker = random.choice(replicas) if replicas else leader
self._clients[stream] = await self._pool.get(Addr(broker.host, broker.port))
self._clients[stream] = await self._pool.get(
addr=Addr(broker.host, broker.port), connection_closed_handler=self._connection_closed_handler
)

return self._clients[stream]

Expand Down
8 changes: 6 additions & 2 deletions rstream/producer.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ def __init__(
max_retries: int = 20,
default_batch_publishing_delay: float = 0.2,
default_context_switch_value: int = 1000,
connection_closed_handler: Optional[CB[Exception]] = None,
):
self._pool = ClientPool(
host,
Expand Down Expand Up @@ -102,6 +103,7 @@ def __init__(
self._default_batch_publishing_delay = default_batch_publishing_delay
self._default_context_switch_counter = 0
self._default_context_switch_value = default_context_switch_value
self._connection_closed_handler = connection_closed_handler

@property
def default_client(self) -> Client:
Expand All @@ -117,7 +119,7 @@ async def __aexit__(self, *_: Any) -> None:
await self.close()

async def start(self) -> None:
self._default_client = await self._pool.get()
self._default_client = await self._pool.get(connection_closed_handler=self._connection_closed_handler)

async def close(self) -> None:
# flush messages still in buffer
Expand All @@ -142,7 +144,9 @@ async def close(self) -> None:
async def _get_or_create_client(self, stream: str) -> Client:
if stream not in self._clients:
leader, _ = await self.default_client.query_leader_and_replicas(stream)
self._clients[stream] = await self._pool.get(Addr(leader.host, leader.port))
self._clients[stream] = await self._pool.get(
Addr(leader.host, leader.port), self._connection_closed_handler
)

return self._clients[stream]

Expand Down
8 changes: 6 additions & 2 deletions rstream/superstream_consumer.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ def __init__(
load_balancer_mode: bool = False,
max_retries: int = 20,
super_stream: str,
connection_closed_handler: Optional[CB[Exception]] = None,
):
self._pool = ClientPool(
host,
Expand Down Expand Up @@ -75,6 +76,7 @@ def __init__(
self._consumers: dict[str, Consumer] = {}
self._stop_event = asyncio.Event()
self._subscribers: dict[str, str] = defaultdict(str)
self._connection_closed_handler = connection_closed_handler

@property
def default_client(self) -> Client:
Expand All @@ -90,7 +92,7 @@ async def __aexit__(self, *_: Any) -> None:
await self.close()

async def start(self) -> None:
self._default_client = await self._pool.get()
self._default_client = await self._pool.get(connection_closed_handler=self._connection_closed_handler)

def stop(self) -> None:
self._stop_event.set()
Expand All @@ -112,7 +114,9 @@ async def _get_or_create_client(self, stream: str) -> Client:
if stream not in self._clients:
leader, replicas = await self.default_client.query_leader_and_replicas(stream)
broker = random.choice(replicas) if replicas else leader
self._clients[stream] = await self._pool.get(Addr(broker.host, broker.port))
self._clients[stream] = await self._pool.get(
Addr(broker.host, broker.port), connection_closed_handler=self._connection_closed_handler
)

return self._clients[stream]

Expand Down
4 changes: 3 additions & 1 deletion rstream/superstream_producer.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ def __init__(
load_balancer_mode: bool = False,
max_retries: int = 20,
default_batch_publishing_delay: float = 0.2,
connection_closed_handler: Optional[CB[Exception]] = None
):
self._pool = ClientPool(
host,
Expand Down Expand Up @@ -81,6 +82,7 @@ def __init__(
self._default_client: Optional[Client] = None
self._producer: Producer | None = None
self._routing_strategy: RoutingStrategy
self._connection_closed_handler = connection_closed_handler

async def _get_producer(self) -> Producer:
if self._producer is None:
Expand Down Expand Up @@ -126,7 +128,7 @@ async def __aexit__(self, *_: Any) -> None:
await self.close()

async def start(self) -> None:
self._default_client = await self._pool.get()
self._default_client = await self._pool.get(connection_closed_handler=self._connection_closed_handler)
self.super_stream_metadata = DefaultSuperstreamMetadata(self.super_stream, self._default_client)
if self.routing == RouteType.Hash:
self._routing_strategy = HashRoutingMurmurStrategy(self.routing_extractor)
Expand Down

0 comments on commit 6f58020

Please sign in to comment.