diff --git a/src/aiida/brokers/rabbitmq/broker.py b/src/aiida/brokers/rabbitmq/broker.py index 0ed8bcd0d..7bfcb2fec 100644 --- a/src/aiida/brokers/rabbitmq/broker.py +++ b/src/aiida/brokers/rabbitmq/broker.py @@ -2,6 +2,7 @@ from __future__ import annotations +import asyncio import functools import typing as t @@ -10,6 +11,7 @@ from plumpy.rmq.process_control import RemoteProcessController from aiida.brokers.broker import Broker +from aiida.brokers.rabbitmq.coordinator import RmqLoopCoordinator from aiida.common.log import AIIDA_LOGGER from aiida.manage.configuration import get_config_option @@ -36,6 +38,8 @@ def __init__(self, profile: Profile) -> None: self._profile = profile self._communicator: 'RmqThreadCommunicator | None' = None self._prefix = f'aiida-{self._profile.uuid}' + # FIXME: ??? should make the event loop setable?? + self._loop = asyncio.get_event_loop() def __str__(self): try: @@ -60,7 +64,7 @@ def get_coordinator(self): # Check whether a compatible version of RabbitMQ is being used. self.check_rabbitmq_version() - coordinator = RmqCoordinator(self._communicator) + coordinator = RmqLoopCoordinator(self._communicator, self._loop) return coordinator diff --git a/src/aiida/brokers/rabbitmq/coordinator.py b/src/aiida/brokers/rabbitmq/coordinator.py new file mode 100644 index 000000000..6c6a13c7e --- /dev/null +++ b/src/aiida/brokers/rabbitmq/coordinator.py @@ -0,0 +1,86 @@ +# -*- coding: utf-8 -*- +from asyncio import AbstractEventLoop +from typing import Generic, TypeVar, final +import kiwipy +import concurrent.futures + +from plumpy.exceptions import CoordinatorConnectionError +from plumpy.rmq.communications import convert_to_comm + +__all__ = ['RmqCoordinator'] + +U = TypeVar('U', bound=kiwipy.Communicator) + + +@final +class RmqLoopCoordinator(Generic[U]): + def __init__(self, comm: U, loop: AbstractEventLoop): + self._comm = comm + self._loop = loop + + @property + def communicator(self) -> U: + """The inner communicator.""" + return self._comm + + def add_rpc_subscriber(self, subscriber, identifier=None): + subscriber = convert_to_comm(subscriber, self._loop) + return self._comm.add_rpc_subscriber(subscriber, identifier) + + def add_broadcast_subscriber( + self, + subscriber, + subject_filters=None, + sender_filters=None, + identifier=None, + ): + subscriber = kiwipy.BroadcastFilter(subscriber) + + subject_filters = subject_filters or [] + sender_filters = sender_filters or [] + + for filter in subject_filters: + subscriber.add_subject_filter(filter) + for filter in sender_filters: + subscriber.add_sender_filter(filter) + + subscriber = convert_to_comm(subscriber, self._loop) + return self._comm.add_broadcast_subscriber(subscriber, identifier) + + def add_task_subscriber(self, subscriber, identifier=None): + subscriber = convert_to_comm(subscriber, self._loop) + return self._comm.add_task_subscriber(subscriber, identifier) + + def remove_rpc_subscriber(self, identifier): + return self._comm.remove_rpc_subscriber(identifier) + + def remove_broadcast_subscriber(self, identifier): + return self._comm.remove_broadcast_subscriber(identifier) + + def remove_task_subscriber(self, identifier): + return self._comm.remove_task_subscriber(identifier) + + def rpc_send(self, recipient_id, msg): + return self._comm.rpc_send(recipient_id, msg) + + def broadcast_send( + self, + body, + sender=None, + subject=None, + correlation_id=None, + ): + from aio_pika.exceptions import ChannelInvalidStateError, AMQPConnectionError + + try: + rsp = self._comm.broadcast_send(body, sender, subject, correlation_id) + except (ChannelInvalidStateError, AMQPConnectionError, concurrent.futures.TimeoutError) as exc: + raise CoordinatorConnectionError from exc + else: + return rsp + + def task_send(self, task, no_reply=False): + return self._comm.task_send(task, no_reply) + + def close(self): + self._comm.close() diff --git a/src/aiida/engine/runners.py b/src/aiida/engine/runners.py index 92a62c071..cb14be2b8 100644 --- a/src/aiida/engine/runners.py +++ b/src/aiida/engine/runners.py @@ -93,9 +93,7 @@ def __init__( # FIXME: broker and coordinator overlap the concept there for over-abstraction, remove the abstraction if broker is not None: - _coordinator = broker.get_coordinator() - # FIXME: the wrap should not be needed - self._coordinator = wrap_communicator(_coordinator.communicator, self._loop) + self._coordinator = broker.get_coordinator() self._controller = broker.get_controller() elif self._broker_submit: # FIXME: if broker then broker_submit else False diff --git a/tests/brokers/test_rabbitmq.py b/tests/brokers/test_rabbitmq.py index 58399a7e3..5fb5c2bc5 100644 --- a/tests/brokers/test_rabbitmq.py +++ b/tests/brokers/test_rabbitmq.py @@ -94,12 +94,12 @@ def test_communicator(url): def test_add_rpc_subscriber(coordinator): """Test ``add_rpc_subscriber``.""" - coordinator.add_rpc_subscriber(None) + coordinator.add_rpc_subscriber(lambda: None) def test_add_broadcast_subscriber(coordinator): """Test ``add_broadcast_subscriber``.""" - coordinator.add_broadcast_subscriber(None) + coordinator.add_broadcast_subscriber(lambda: None) @pytest.mark.usefixtures('aiida_profile_clean')