Skip to content

Commit

Permalink
Construct and use RmqLooCoordinator directly
Browse files Browse the repository at this point in the history
  • Loading branch information
unkcpz committed Dec 27, 2024
1 parent c769906 commit 03f7a5b
Show file tree
Hide file tree
Showing 4 changed files with 94 additions and 6 deletions.
6 changes: 5 additions & 1 deletion src/aiida/brokers/rabbitmq/broker.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from __future__ import annotations

import asyncio
import functools
import typing as t

Expand All @@ -10,6 +11,7 @@
from plumpy.rmq.process_control import RemoteProcessController

from aiida.brokers.broker import Broker
from aiida.brokers.rabbitmq.coordinator import RmqLoopCoordinator
from aiida.common.log import AIIDA_LOGGER
from aiida.manage.configuration import get_config_option

Expand All @@ -36,6 +38,8 @@ def __init__(self, profile: Profile) -> None:
self._profile = profile
self._communicator: 'RmqThreadCommunicator | None' = None
self._prefix = f'aiida-{self._profile.uuid}'
# FIXME: ??? should make the event loop setable??
self._loop = asyncio.get_event_loop()

def __str__(self):
try:
Expand All @@ -60,7 +64,7 @@ def get_coordinator(self):
# Check whether a compatible version of RabbitMQ is being used.
self.check_rabbitmq_version()

coordinator = RmqCoordinator(self._communicator)
coordinator = RmqLoopCoordinator(self._communicator, self._loop)

return coordinator

Expand Down
86 changes: 86 additions & 0 deletions src/aiida/brokers/rabbitmq/coordinator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
# -*- coding: utf-8 -*-
from asyncio import AbstractEventLoop
from typing import Generic, TypeVar, final
import kiwipy
import concurrent.futures

from plumpy.exceptions import CoordinatorConnectionError
from plumpy.rmq.communications import convert_to_comm

__all__ = ['RmqCoordinator']

U = TypeVar('U', bound=kiwipy.Communicator)


@final
class RmqLoopCoordinator(Generic[U]):
def __init__(self, comm: U, loop: AbstractEventLoop):
self._comm = comm
self._loop = loop

@property
def communicator(self) -> U:
"""The inner communicator."""
return self._comm

def add_rpc_subscriber(self, subscriber, identifier=None):
subscriber = convert_to_comm(subscriber, self._loop)
return self._comm.add_rpc_subscriber(subscriber, identifier)

def add_broadcast_subscriber(
self,
subscriber,
subject_filters=None,
sender_filters=None,
identifier=None,
):
subscriber = kiwipy.BroadcastFilter(subscriber)

subject_filters = subject_filters or []
sender_filters = sender_filters or []

for filter in subject_filters:
subscriber.add_subject_filter(filter)
for filter in sender_filters:
subscriber.add_sender_filter(filter)

subscriber = convert_to_comm(subscriber, self._loop)
return self._comm.add_broadcast_subscriber(subscriber, identifier)

def add_task_subscriber(self, subscriber, identifier=None):
subscriber = convert_to_comm(subscriber, self._loop)
return self._comm.add_task_subscriber(subscriber, identifier)

def remove_rpc_subscriber(self, identifier):
return self._comm.remove_rpc_subscriber(identifier)

def remove_broadcast_subscriber(self, identifier):
return self._comm.remove_broadcast_subscriber(identifier)

def remove_task_subscriber(self, identifier):
return self._comm.remove_task_subscriber(identifier)

def rpc_send(self, recipient_id, msg):
return self._comm.rpc_send(recipient_id, msg)

def broadcast_send(
self,
body,
sender=None,
subject=None,
correlation_id=None,
):
from aio_pika.exceptions import ChannelInvalidStateError, AMQPConnectionError

try:
rsp = self._comm.broadcast_send(body, sender, subject, correlation_id)
except (ChannelInvalidStateError, AMQPConnectionError, concurrent.futures.TimeoutError) as exc:
raise CoordinatorConnectionError from exc
else:
return rsp

def task_send(self, task, no_reply=False):
return self._comm.task_send(task, no_reply)

def close(self):
self._comm.close()
4 changes: 1 addition & 3 deletions src/aiida/engine/runners.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,9 +93,7 @@ def __init__(

# FIXME: broker and coordinator overlap the concept there for over-abstraction, remove the abstraction
if broker is not None:
_coordinator = broker.get_coordinator()
# FIXME: the wrap should not be needed
self._coordinator = wrap_communicator(_coordinator.communicator, self._loop)
self._coordinator = broker.get_coordinator()
self._controller = broker.get_controller()
elif self._broker_submit:
# FIXME: if broker then broker_submit else False
Expand Down
4 changes: 2 additions & 2 deletions tests/brokers/test_rabbitmq.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,12 +94,12 @@ def test_communicator(url):

def test_add_rpc_subscriber(coordinator):
"""Test ``add_rpc_subscriber``."""
coordinator.add_rpc_subscriber(None)
coordinator.add_rpc_subscriber(lambda: None)


def test_add_broadcast_subscriber(coordinator):
"""Test ``add_broadcast_subscriber``."""
coordinator.add_broadcast_subscriber(None)
coordinator.add_broadcast_subscriber(lambda: None)


@pytest.mark.usefixtures('aiida_profile_clean')
Expand Down

0 comments on commit 03f7a5b

Please sign in to comment.