Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Decouple broker and coordinator interface #6675

Draft
wants to merge 21 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ dependencies:
- importlib-metadata~=6.0
- numpy~=1.21
- paramiko~=3.0
- plumpy~=0.22.3
- plumpy
- pgsu~=0.3.0
- psutil~=5.6
- psycopg[binary]~=3.0
Expand Down
7 changes: 6 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ dependencies = [
'importlib-metadata~=6.0',
'numpy~=1.21',
'paramiko~=3.0',
'plumpy~=0.22.3',
'plumpy',
'pgsu~=0.3.0',
'psutil~=5.6',
'psycopg[binary]~=3.0',
Expand Down Expand Up @@ -246,6 +246,7 @@ tests = [
'pympler~=1.0',
'coverage~=7.0',
'sphinx~=7.2.0',
'watchdog~=6.0',
'docutils~=0.20'
]
tui = [
Expand Down Expand Up @@ -387,6 +388,7 @@ minversion = '7.0'
testpaths = [
'tests'
]
timeout = 30
xfail_strict = true

[tool.ruff]
Expand Down Expand Up @@ -509,3 +511,6 @@ passenv =
AIIDA_TEST_WORKERS
commands = molecule {posargs:test}
"""

[tool.uv.sources]
plumpy = {git = "https://github.com/unkcpz/plumpy", branch = "rmq-out"}
26 changes: 18 additions & 8 deletions src/aiida/brokers/broker.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,25 +3,35 @@
import abc
import typing as t

from plumpy.controller import ProcessController

if t.TYPE_CHECKING:
from aiida.manage.configuration.profile import Profile
from plumpy.coordinator import Coordinator


__all__ = ('Broker',)


# FIXME: make me a protocol
class Broker:
"""Interface for a message broker that facilitates communication with and between process runners."""

def __init__(self, profile: 'Profile') -> None:
"""Construct a new instance.
# def __init__(self, profile: 'Profile') -> None:
# """Construct a new instance.
#
# :param profile: The profile.
# """
# self._profile = profile

:param profile: The profile.
"""
self._profile = profile
@abc.abstractmethod
# FIXME: make me a property
def get_coordinator(self) -> 'Coordinator':
"""Return an instance of coordinator."""

@abc.abstractmethod
def get_communicator(self):
"""Return an instance of :class:`kiwipy.Communicator`."""
def get_controller(self) -> ProcessController:
"""Return the process controller"""
...

@abc.abstractmethod
def iterate_tasks(self):
Expand Down
35 changes: 27 additions & 8 deletions src/aiida/brokers/rabbitmq/broker.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,15 @@

from __future__ import annotations

import asyncio
import functools
import typing as t

from plumpy import ProcessController
from plumpy.rmq import RemoteProcessThreadController

from aiida.brokers.broker import Broker
from aiida.brokers.rabbitmq.coordinator import RmqLoopCoordinator
from aiida.common.log import AIIDA_LOGGER
from aiida.manage.configuration import get_config_option

Expand All @@ -24,14 +29,16 @@
class RabbitmqBroker(Broker):
"""Implementation of the message broker interface using RabbitMQ through ``kiwipy``."""

def __init__(self, profile: Profile) -> None:
def __init__(self, profile: Profile, loop=None) -> None:
"""Construct a new instance.

:param profile: The profile.
"""
self._profile = profile
self._communicator: 'RmqThreadCommunicator' | None = None
self._communicator: 'RmqThreadCommunicator | None' = None
self._prefix = f'aiida-{self._profile.uuid}'
self._coordinator = None
self._loop = loop or asyncio.get_event_loop()

def __str__(self):
try:
Expand All @@ -47,24 +54,36 @@ def close(self):

def iterate_tasks(self):
"""Return an iterator over the tasks in the launch queue."""
for task in self.get_communicator().task_queue(get_launch_queue_name(self._prefix)):
for task in self.get_coordinator().communicator.task_queue(get_launch_queue_name(self._prefix)):
yield task

def get_communicator(self) -> 'RmqThreadCommunicator':
def get_coordinator(self):
if self._coordinator is not None:
return self._coordinator

return self.create_coordinator()

def create_coordinator(self):
if self._communicator is None:
self._communicator = self._create_communicator()
# Check whether a compatible version of RabbitMQ is being used.
self.check_rabbitmq_version()

return self._communicator
coordinator = RmqLoopCoordinator(self._communicator, self._loop)

return coordinator

def get_controller(self) -> ProcessController:
coordinator = self.get_coordinator()
return RemoteProcessThreadController(coordinator)

def _create_communicator(self) -> 'RmqThreadCommunicator':
"""Return an instance of :class:`kiwipy.Communicator`."""
from kiwipy.rmq import RmqThreadCommunicator

from aiida.orm.utils import serialize

self._communicator = RmqThreadCommunicator.connect(
_communicator = RmqThreadCommunicator.connect(
connection_params={'url': self.get_url()},
message_exchange=get_message_exchange_name(self._prefix),
encoder=functools.partial(serialize.serialize, encoding='utf-8'),
Expand All @@ -78,7 +97,7 @@ def _create_communicator(self) -> 'RmqThreadCommunicator':
testing_mode=self._profile.is_test_profile,
)

return self._communicator
return _communicator

def check_rabbitmq_version(self):
"""Check the version of RabbitMQ that is being connected to and emit warning if it is not compatible."""
Expand Down Expand Up @@ -122,4 +141,4 @@ def get_rabbitmq_version(self):
"""
from packaging.version import parse

return parse(self.get_communicator().server_properties['version'])
return parse(self.get_coordinator().communicator.server_properties['version'])
89 changes: 89 additions & 0 deletions src/aiida/brokers/rabbitmq/coordinator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
import concurrent.futures
from asyncio import AbstractEventLoop
from typing import Generic, TypeVar, final

import kiwipy
from plumpy.exceptions import CoordinatorConnectionError
from plumpy.rmq.communications import convert_to_comm

__all__ = ['RmqCoordinator']

U = TypeVar('U', bound=kiwipy.Communicator)


@final
class RmqLoopCoordinator(Generic[U]):
def __init__(self, comm: U, loop: AbstractEventLoop):
self._comm = comm
self._loop = loop

@property
def communicator(self) -> U:
"""The inner communicator."""
return self._comm

def add_rpc_subscriber(self, subscriber, identifier=None):
subscriber = convert_to_comm(subscriber, self._loop)
return self._comm.add_rpc_subscriber(subscriber, identifier)

def add_broadcast_subscriber(
self,
subscriber,
subject_filters=None,
sender_filters=None,
identifier=None,
):
subscriber = kiwipy.BroadcastFilter(subscriber)

subject_filters = subject_filters or []
sender_filters = sender_filters or []

for filter in subject_filters:
subscriber.add_subject_filter(filter)
for filter in sender_filters:
subscriber.add_sender_filter(filter)

subscriber = convert_to_comm(subscriber, self._loop)
return self._comm.add_broadcast_subscriber(subscriber, identifier)

def add_task_subscriber(self, subscriber, identifier=None):
subscriber = convert_to_comm(subscriber, self._loop)
return self._comm.add_task_subscriber(subscriber, identifier)

def remove_rpc_subscriber(self, identifier):
return self._comm.remove_rpc_subscriber(identifier)

def remove_broadcast_subscriber(self, identifier):
return self._comm.remove_broadcast_subscriber(identifier)

def remove_task_subscriber(self, identifier):
return self._comm.remove_task_subscriber(identifier)

def rpc_send(self, recipient_id, msg):
return self._comm.rpc_send(recipient_id, msg)

def broadcast_send(
self,
body,
sender=None,
subject=None,
correlation_id=None,
):
from aio_pika.exceptions import AMQPConnectionError, ChannelInvalidStateError

try:
rsp = self._comm.broadcast_send(body, sender, subject, correlation_id)
except (ChannelInvalidStateError, AMQPConnectionError, concurrent.futures.TimeoutError) as exc:
raise CoordinatorConnectionError from exc
else:
return rsp

def task_send(self, task, no_reply=False):
return self._comm.task_send(task, no_reply)

def close(self):
self._comm.close()

def is_closed(self) -> bool:
"""Return `True` if the communicator was closed"""
return self._comm.is_closed()
19 changes: 10 additions & 9 deletions src/aiida/cmdline/commands/cmd_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

import click

from aiida.brokers.broker import Broker
from aiida.cmdline.commands.cmd_verdi import verdi
from aiida.cmdline.params import arguments, options, types
from aiida.cmdline.utils import decorators, echo
Expand Down Expand Up @@ -340,8 +341,8 @@ def process_kill(processes, all_entries, timeout, wait):

with capture_logging() as stream:
try:
message = 'Killed through `verdi process kill`'
control.kill_processes(processes, all_entries=all_entries, timeout=timeout, wait=wait, message=message)
msg_text = 'Killed through `verdi process kill`'
control.kill_processes(processes, all_entries=all_entries, timeout=timeout, wait=wait, msg_text=msg_text)
except control.ProcessTimeoutException as exception:
echo.echo_critical(f'{exception}\n{REPAIR_INSTRUCTIONS}')

Expand Down Expand Up @@ -371,8 +372,8 @@ def process_pause(processes, all_entries, timeout, wait):

with capture_logging() as stream:
try:
message = 'Paused through `verdi process pause`'
control.pause_processes(processes, all_entries=all_entries, timeout=timeout, wait=wait, message=message)
msg_text = 'Paused through `verdi process pause`'
control.pause_processes(processes, all_entries=all_entries, timeout=timeout, wait=wait, msg_text=msg_text)
except control.ProcessTimeoutException as exception:
echo.echo_critical(f'{exception}\n{REPAIR_INSTRUCTIONS}')

Expand Down Expand Up @@ -416,7 +417,7 @@ def process_play(processes, all_entries, timeout, wait):
@decorators.with_dbenv()
@decorators.with_broker
@decorators.only_if_daemon_running(echo.echo_warning, 'daemon is not running, so process may not be reachable')
def process_watch(broker, processes, most_recent_node):
def process_watch(broker: Broker, processes, most_recent_node):
"""Watch the state transitions of processes.

Watch the state transitions for one or multiple running processes."""
Expand All @@ -436,7 +437,7 @@ def process_watch(broker, processes, most_recent_node):

from kiwipy import BroadcastFilter

def _print(communicator, body, sender, subject, correlation_id):
def _print(coordinator, body, sender, subject, correlation_id):
"""Format the incoming broadcast data into a message and echo it to stdout."""
if body is None:
body = 'No message specified'
Expand All @@ -446,7 +447,7 @@ def _print(communicator, body, sender, subject, correlation_id):

echo.echo(f'Process<{sender}> [{subject}|{correlation_id}]: {body}')

communicator = broker.get_communicator()
coordinator = broker.get_coordinator()
echo.echo_report('watching for broadcasted messages, press CTRL+C to stop...')

if most_recent_node:
Expand All @@ -457,7 +458,7 @@ def _print(communicator, body, sender, subject, correlation_id):
echo.echo_error(f'Process<{process.pk}> is already terminated')
continue

communicator.add_broadcast_subscriber(BroadcastFilter(_print, sender=process.pk))
coordinator.add_broadcast_subscriber(BroadcastFilter(_print, sender=process.pk))

try:
# Block this thread indefinitely until interrupt
Expand All @@ -467,7 +468,7 @@ def _print(communicator, body, sender, subject, correlation_id):
echo.echo('') # add a new line after the interrupt character
echo.echo_report('received interrupt, exiting...')
try:
communicator.close()
coordinator.close()
except RuntimeError:
pass

Expand Down
6 changes: 4 additions & 2 deletions src/aiida/cmdline/commands/cmd_rabbitmq.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from aiida.cmdline.commands.cmd_devel import verdi_devel
from aiida.cmdline.params import arguments, options
from aiida.cmdline.utils import decorators, echo, echo_tabulate
from aiida.manage.manager import Manager

if t.TYPE_CHECKING:
import requests
Expand Down Expand Up @@ -131,12 +132,13 @@ def with_client(ctx, wrapped, _, args, kwargs):

@cmd_rabbitmq.command('server-properties')
@decorators.with_manager
def cmd_server_properties(manager):
def cmd_server_properties(manager: Manager):
"""List the server properties."""
import yaml

data = {}
for key, value in manager.get_communicator().server_properties.items():
# FIXME: server_properties as an common API for coordinator?
for key, value in manager.get_coordinator().communicator.server_properties.items():
data[key] = value.decode('utf-8') if isinstance(value, bytes) else value
click.echo(yaml.dump(data, indent=4))

Expand Down
2 changes: 1 addition & 1 deletion src/aiida/cmdline/commands/cmd_status.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ def verdi_status(print_traceback, no_rmq):

if broker:
try:
broker.get_communicator()
broker.get_coordinator()
except Exception as exc:
message = f'Unable to connect to broker: {broker}'
print_status(ServiceStatus.ERROR, 'broker', message, exception=exc, print_traceback=print_traceback)
Expand Down
Loading
Loading