Skip to content

Commit

Permalink
plumpy.ProcessListener made persistent
Browse files Browse the repository at this point in the history
solves #273

We implement the persistence of ProcessListener by deriving the class
ProcessListener and EventHelper from persistence.Savable.
The class EventHelper is moved to a new file because of a circular
import with utils and persistence
  • Loading branch information
rikigigi committed Aug 29, 2023
1 parent 44d27d1 commit 4fd678b
Show file tree
Hide file tree
Showing 4 changed files with 92 additions and 50 deletions.
52 changes: 52 additions & 0 deletions src/plumpy/event_helper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
# -*- coding: utf-8 -*-
import logging
from typing import TYPE_CHECKING, Any, Callable, Set, Type

from . import persistence

if TYPE_CHECKING:
from .process_listener import ProcessListener # pylint: disable=cyclic-import

_LOGGER = logging.getLogger(__name__)


@persistence.auto_persist('_listeners')
class EventHelper(persistence.Savable):

def __init__(self, listener_type: 'Type[ProcessListener]'):
assert listener_type is not None, 'Must provide valid listener type'

self._listener_type = listener_type
self._listeners: 'Set[ProcessListener]' = set()

def add_listener(self, listener: 'ProcessListener') -> None:
assert isinstance(listener, self._listener_type), 'Listener is not of right type'
self._listeners.add(listener)

def remove_listener(self, listener: 'ProcessListener') -> None:
self._listeners.discard(listener)

def remove_all_listeners(self) -> None:
self._listeners.clear()

@property
def listeners(self) -> 'Set[ProcessListener]':
return self._listeners

def fire_event(self, event_function: Callable[..., Any], *args: Any, **kwargs: Any) -> None:
"""Call an event method on all listeners.
:param event_function: the method of the ProcessListener
:param args: arguments to pass to the method
:param kwargs: keyword arguments to pass to the method
"""
if event_function is None:
raise ValueError('Must provide valid event method')

# Make a copy of the list for iteration just in case it changes in a callback
for listener in list(self.listeners):
try:
getattr(listener, event_function.__name__)(*args, **kwargs)
except Exception as exception: # pylint: disable=broad-except
_LOGGER.error("Listener '%s' produced an exception:\n%s", listener, exception)
26 changes: 24 additions & 2 deletions src/plumpy/process_listener.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,36 @@
# -*- coding: utf-8 -*-
import abc
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING, Any, Dict, Optional

from . import persistence
from .utils import SAVED_STATE_TYPE, protected

__all__ = ['ProcessListener']

if TYPE_CHECKING:
from .processes import Process # pylint: disable=cyclic-import


class ProcessListener(metaclass=abc.ABCMeta):
@persistence.auto_persist('_params')
class ProcessListener(persistence.Savable, metaclass=abc.ABCMeta):

# region Persistence methods

def __init__(self) -> None:
super().__init__()
self._params: Dict[str, Any] = {}

def init(self, **kwargs: Any) -> None:
self._params = kwargs

@protected
def load_instance_state(
self, saved_state: SAVED_STATE_TYPE, load_context: Optional[persistence.LoadSaveContext]
) -> None:
super().load_instance_state(saved_state, load_context)
self.init(**saved_state['_params'])

# endregion

def on_process_created(self, process: 'Process') -> None:
"""
Expand Down
23 changes: 16 additions & 7 deletions src/plumpy/processes.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
from .base import state_machine
from .base.state_machine import StateEntryFailed, StateMachine, TransitionFailed, event
from .base.utils import call_with_super_check, super_check
from .event_helper import EventHelper
from .process_listener import ProcessListener
from .process_spec import ProcessSpec
from .utils import PID_TYPE, SAVED_STATE_TYPE, protected
Expand Down Expand Up @@ -91,7 +92,9 @@ def func_wrapper(self: Any, *args: Any, **kwargs: Any) -> Any:
return func_wrapper


@persistence.auto_persist('_pid', '_creation_time', '_future', '_paused', '_status', '_pre_paused_status')
@persistence.auto_persist(
'_pid', '_creation_time', '_future', '_paused', '_status', '_pre_paused_status', '_event_helper'
)
class Process(StateMachine, persistence.Savable, metaclass=ProcessStateMachineMeta):
"""
The Process class is the base for any unit of work in plumpy.
Expand Down Expand Up @@ -289,7 +292,7 @@ def __init__(

# Runtime variables
self._future = persistence.SavableFuture(loop=self._loop)
self.__event_helper = utils.EventHelper(ProcessListener)
self._event_helper = EventHelper(ProcessListener)
self._logger = logger
self._communicator = communicator

Expand Down Expand Up @@ -597,6 +600,9 @@ def save_instance_state(
if self.outputs:
out_state[BundleKeys.OUTPUTS] = self.encode_input_args(self.outputs)

print(type(out_state))
print(out_state)

@protected
def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext) -> None:
"""Load the process from its saved instance state.
Expand All @@ -612,7 +618,7 @@ def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persi

# Runtime variables, set initial states
self._future = persistence.SavableFuture()
self.__event_helper = utils.EventHelper(ProcessListener)
self._event_helper = EventHelper(ProcessListener)
self._logger = None
self._communicator = None

Expand All @@ -621,6 +627,9 @@ def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persi
else:
self._loop = asyncio.get_event_loop()

print('saved_state')
print(saved_state)

self._state: process_states.State = self.recreate_state(saved_state['_state'])

if 'communicator' in load_context:
Expand Down Expand Up @@ -661,11 +670,11 @@ def add_process_listener(self, listener: ProcessListener) -> None:
"""
assert (listener != self), 'Cannot listen to yourself!' # type: ignore
self.__event_helper.add_listener(listener)
self._event_helper.add_listener(listener)

def remove_process_listener(self, listener: ProcessListener) -> None:
"""Remove a process listener from the process."""
self.__event_helper.remove_listener(listener)
self._event_helper.remove_listener(listener)

@protected
def set_logger(self, logger: logging.Logger) -> None:
Expand Down Expand Up @@ -778,7 +787,7 @@ def on_output_emitting(self, output_port: str, value: Any) -> None:
"""Output is about to be emitted."""

def on_output_emitted(self, output_port: str, value: Any, dynamic: bool) -> None:
self.__event_helper.fire_event(ProcessListener.on_output_emitted, self, output_port, value, dynamic)
self._event_helper.fire_event(ProcessListener.on_output_emitted, self, output_port, value, dynamic)

@super_check
def on_wait(self, awaitables: Sequence[Awaitable]) -> None:
Expand Down Expand Up @@ -891,7 +900,7 @@ def on_close(self) -> None:
self._closed = True

def _fire_event(self, evt: Callable[..., Any], *args: Any, **kwargs: Any) -> None:
self.__event_helper.fire_event(evt, self, *args, **kwargs)
self._event_helper.fire_event(evt, self, *args, **kwargs)

# endregion

Expand Down
41 changes: 0 additions & 41 deletions src/plumpy/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,47 +27,6 @@
PID_TYPE = Hashable # pylint: disable=invalid-name


class EventHelper:

def __init__(self, listener_type: 'Type[ProcessListener]'):
assert listener_type is not None, 'Must provide valid listener type'

self._listener_type = listener_type
self._listeners: 'Set[ProcessListener]' = set()

def add_listener(self, listener: 'ProcessListener') -> None:
assert isinstance(listener, self._listener_type), 'Listener is not of right type'
self._listeners.add(listener)

def remove_listener(self, listener: 'ProcessListener') -> None:
self._listeners.discard(listener)

def remove_all_listeners(self) -> None:
self._listeners.clear()

@property
def listeners(self) -> 'Set[ProcessListener]':
return self._listeners

def fire_event(self, event_function: Callable[..., Any], *args: Any, **kwargs: Any) -> None:
"""Call an event method on all listeners.
:param event_function: the method of the ProcessListener
:param args: arguments to pass to the method
:param kwargs: keyword arguments to pass to the method
"""
if event_function is None:
raise ValueError('Must provide valid event method')

# Make a copy of the list for iteration just in case it changes in a callback
for listener in list(self.listeners):
try:
getattr(listener, event_function.__name__)(*args, **kwargs)
except Exception as exception: # pylint: disable=broad-except
_LOGGER.error("Listener '%s' produced an exception:\n%s", listener, exception)


class Frozendict(Mapping):
"""
An immutable wrapper around dictionaries that implements the complete :py:class:`collections.abc.Mapping`
Expand Down

0 comments on commit 4fd678b

Please sign in to comment.