Skip to content

Commit

Permalink
Separate create_broker and get_broker where get won't change state
Browse files Browse the repository at this point in the history
  • Loading branch information
unkcpz committed Dec 27, 2024
1 parent 03f7a5b commit 02a939e
Show file tree
Hide file tree
Showing 4 changed files with 30 additions and 14 deletions.
14 changes: 8 additions & 6 deletions src/aiida/brokers/broker.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,17 +13,19 @@
__all__ = ('Broker',)


# FIXME: make me a protocol
class Broker:
"""Interface for a message broker that facilitates communication with and between process runners."""

def __init__(self, profile: 'Profile') -> None:
"""Construct a new instance.
:param profile: The profile.
"""
self._profile = profile
# def __init__(self, profile: 'Profile') -> None:
# """Construct a new instance.
#
# :param profile: The profile.
# """
# self._profile = profile

@abc.abstractmethod
# FIXME: make me a property
def get_coordinator(self) -> 'Coordinator':
"""Return an instance of coordinator."""

Expand Down
12 changes: 9 additions & 3 deletions src/aiida/brokers/rabbitmq/broker.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,16 +30,16 @@
class RabbitmqBroker(Broker):
"""Implementation of the message broker interface using RabbitMQ through ``kiwipy``."""

def __init__(self, profile: Profile) -> None:
def __init__(self, profile: Profile, loop=None) -> None:
"""Construct a new instance.
:param profile: The profile.
"""
self._profile = profile
self._communicator: 'RmqThreadCommunicator | None' = None
self._prefix = f'aiida-{self._profile.uuid}'
# FIXME: ??? should make the event loop setable??
self._loop = asyncio.get_event_loop()
self._coordinator = None
self._loop = loop or asyncio.get_event_loop()

def __str__(self):
try:
Expand All @@ -59,6 +59,12 @@ def iterate_tasks(self):
yield task

def get_coordinator(self):
if self._coordinator is not None:
return self._coordinator

return self.create_coordinator()

def create_coordinator(self):
if self._communicator is None:
self._communicator = self._create_communicator()
# Check whether a compatible version of RabbitMQ is being used.
Expand Down
4 changes: 3 additions & 1 deletion src/aiida/cmdline/utils/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
"""

import asyncio
from contextlib import contextmanager

from click_spinner import spinner
Expand Down Expand Up @@ -325,7 +326,8 @@ def start_daemon():

assert profile is not None

if manager.get_broker() is None:
loop = asyncio.get_event_loop()
if manager.create_broker(loop) is None:
echo.echo_critical(
f'profile `{profile.name}` does not define a broker and so cannot use this functionality.'
f'See {URL_NO_BROKER} for more details.'
Expand Down
14 changes: 10 additions & 4 deletions src/aiida/manage/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,7 +285,10 @@ def get_profile_storage(self) -> 'StorageBackend':

return self._profile_storage

def get_broker(self) -> 'Broker' | None:
def get_broker(self) -> 'Broker | None':
return self._broker

def create_broker(self, loop) -> 'Broker | None':
"""Return an instance of :class:`aiida.brokers.broker.Broker` if the profile defines a broker.
:returns: The broker of the profile, or ``None`` if the profile doesn't define one.
Expand All @@ -307,7 +310,7 @@ def get_broker(self) -> 'Broker' | None:
entry_point = 'core.rabbitmq'

broker_cls = BrokerFactory(entry_point)
self._broker = broker_cls(self._profile)
self._broker = broker_cls(self._profile, loop)

return self._broker

Expand Down Expand Up @@ -421,11 +424,14 @@ def create_runner(
_default_poll_interval = 0.0 if profile.is_test_profile else self.get_option('runner.poll.interval')
_default_broker_submit = False
_default_persister = self.get_persister()
_default_broker = self.get_broker()
_default_loop = asyncio.get_event_loop()

loop = loop or _default_loop
_default_broker = self.create_broker(loop)

runner = runners.Runner(
poll_interval=poll_interval or _default_poll_interval,
loop=loop or asyncio.get_event_loop(),
loop=loop,
broker=broker or _default_broker,
broker_submit=broker_submit or _default_broker_submit,
persister=persister or _default_persister,
Expand Down

0 comments on commit 02a939e

Please sign in to comment.