Skip to content

Commit

Permalink
✨ NEW: Expose aio_pika.Connection.add_close_callback (#104)
Browse files Browse the repository at this point in the history
Add `add_close_callback` methods to `RmqCommunicator` and `RmqThreadCommunicator`
  • Loading branch information
chrisjsewell authored Mar 3, 2021
1 parent cf2e140 commit cfc0498
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 7 deletions.
9 changes: 8 additions & 1 deletion kiwipy/rmq/communicator.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,6 +350,13 @@ def loop(self):
"""Get the event loop instance driving this communicator connection."""
return self._connection.loop

def add_close_callback(self, callback: aio_pika.types.CloseCallbackType, weak: bool = False) -> None:
"""Add a callable to be called each time (after) the connection is closed.
:param weak: If True, the callback will be added to a `WeakSet`
"""
self._connection.add_close_callback(callback, weak)

async def get_default_task_queue(self) -> tasks.RmqTaskQueue:
"""Get a default task queue.
Expand Down Expand Up @@ -541,7 +548,7 @@ async def async_connect(
task_prefetch_count=defaults.TASK_PREFETCH_COUNT,
encoder=defaults.ENCODER,
decoder=defaults.DECODER,
testing_mode=False
testing_mode=False,
) -> RmqCommunicator:
# pylint: disable=too-many-arguments
"""Convenience method that returns a connected communicator.
Expand Down
20 changes: 14 additions & 6 deletions kiwipy/rmq/threadcomms.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def connect(
encoder=defaults.ENCODER,
decoder=defaults.DECODER,
testing_mode=False,
async_task_timeout=TASK_TIMEOUT
async_task_timeout=TASK_TIMEOUT,
):
# pylint: disable=too-many-arguments
comm = cls(
Expand Down Expand Up @@ -78,7 +78,7 @@ def __init__(
encoder=defaults.ENCODER,
decoder=defaults.DECODER,
testing_mode=False,
async_task_timeout=TASK_TIMEOUT
async_task_timeout=TASK_TIMEOUT,
):
# pylint: disable=too-many-arguments
"""
Expand All @@ -105,11 +105,10 @@ def __init__(
self._loop_scheduler.start() # Star the loop scheduler (i.e. the event loop thread)

# Establish the connection and get a communicator running on our thread
self._communicator = self._loop_scheduler.await_(
self._communicator: communicator.RmqCommunicator = self._loop_scheduler.await_(
communicator.async_connect(
connection_params=connection_params,
connection_factory=connection_factory,

# Messages
message_exchange=message_exchange,
queue_expires=queue_expires,
Expand Down Expand Up @@ -179,6 +178,14 @@ def close(self):
del self._loop
self._closed = True

def add_close_callback(self, callback: aio_pika.types.CloseCallbackType, weak: bool = False) -> None:
"""Add a callable to be called each time (after) the connection is closed.
:param weak: If True, the callback will be added to a `WeakSet`
"""
self._ensure_open()
self._communicator.add_close_callback(callback, weak)

def add_rpc_subscriber(self, subscriber, identifier=None):
self._ensure_open()
return self._loop_scheduler.await_(
Expand Down Expand Up @@ -339,13 +346,13 @@ def connect(
task_prefetch_count=defaults.TASK_PREFETCH_COUNT,
encoder=defaults.ENCODER,
decoder=defaults.DECODER,
testing_mode=False
testing_mode=False,
) -> RmqThreadCommunicator:
"""
Establish a RabbitMQ communicator connection
"""
# pylint: disable=too-many-arguments
return RmqThreadCommunicator.connect(
_communicator = RmqThreadCommunicator.connect(
connection_params=connection_params,
connection_factory=connection_factory,
message_exchange=message_exchange,
Expand All @@ -357,3 +364,4 @@ def connect(
decoder=decoder,
testing_mode=testing_mode
)
return _communicator
19 changes: 19 additions & 0 deletions test/rmq/test_rmq_thread_communicator.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,25 @@ def test_task_processing_exception(thread_task_queue: rmq.RmqThreadTaskQueue):
pass


def test_connection_close_callback():
"""Test that a callback set with `add_close_callback` is correctly called."""
result = []

def close_callback(sender, exc): # pylint: disable=unused-argument
result.append('called')

communicator = rmq.connect(
connection_params={'url': 'amqp://guest:guest@localhost:5672/'},
message_exchange=f'{__file__}.{shortuuid.uuid()}',
task_exchange=f'{__file__}.{shortuuid.uuid()}',
task_queue=f'{__file__}.{shortuuid.uuid()}',
testing_mode=True
)
communicator.add_close_callback(close_callback)
communicator.close()
assert result == ['called']


@pytest.mark.skipif(sys.version_info < (3, 6), reason='`pytest-notebook` plugin requires Python >= 3.6')
def test_jupyter_notebook():
"""Test that the `RmqThreadCommunicator` can be used in a Jupyter notebook."""
Expand Down

0 comments on commit cfc0498

Please sign in to comment.