Skip to content

Commit

Permalink
BaseRestartWorkChain: require process handlers to be instance metho…
Browse files Browse the repository at this point in the history
…ds (#3782)

The original implementation of the `register_process_handler` allowed an
unbound method to be bound to a subclass of `BaseRestartWorkChain`
outside of its scope. Since this also makes it possible to attach these
process handlers from outside of the module of the work chain, this will
run the risk of the loss of provenance. Here we rename the decorator to
`process_handler` and it can only be applied to instance methods. This
forces the handlers to be defined in the same module as the work chain
to which they apply.
  • Loading branch information
sphuber authored Feb 21, 2020
1 parent 11cefed commit 617b2df
Show file tree
Hide file tree
Showing 4 changed files with 178 additions and 178 deletions.
39 changes: 18 additions & 21 deletions .ci/workchains.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
###########################################################################
from aiida.common import AttributeDict
from aiida.engine import calcfunction, workfunction, WorkChain, ToContext, append_, while_, ExitCode
from aiida.engine import BaseRestartWorkChain, register_process_handler, ProcessHandlerReport
from aiida.engine import BaseRestartWorkChain, process_handler, ProcessHandlerReport
from aiida.engine.persistence import ObjectLoader
from aiida.orm import Int, List, Str
from aiida.plugins import CalculationFactory
Expand Down Expand Up @@ -48,26 +48,23 @@ def setup(self):
super().setup()
self.ctx.inputs = AttributeDict(self.exposed_inputs(ArithmeticAddCalculation, 'add'))


@register_process_handler(ArithmeticAddBaseWorkChain, priority=500)
def sanity_check_not_too_big(self, node):
"""My puny brain cannot deal with numbers that I cannot count on my hand."""
if node.is_finished_ok and node.outputs.sum > 10:
return ProcessHandlerReport(True, self.exit_codes.ERROR_TOO_BIG)


@register_process_handler(ArithmeticAddBaseWorkChain, priority=450, exit_codes=ExitCode(1000, 'Unicorn encountered'))
def a_magic_unicorn_appeared(self, node):
"""As we all know unicorns do not exist so we should never have to deal with it."""
raise RuntimeError('this handler should never even have been called')


@register_process_handler(ArithmeticAddBaseWorkChain, priority=400, exit_codes=ArithmeticAddCalculation.exit_codes.ERROR_NEGATIVE_NUMBER)
def error_negative_sum(self, node):
"""What even is a negative number, how can I have minus three melons?!."""
self.ctx.inputs.x = Int(abs(node.inputs.x.value))
self.ctx.inputs.y = Int(abs(node.inputs.y.value))
return ProcessHandlerReport(True)
@process_handler(priority=500)
def sanity_check_not_too_big(self, node):
"""My puny brain cannot deal with numbers that I cannot count on my hand."""
if node.is_finished_ok and node.outputs.sum > 10:
return ProcessHandlerReport(True, self.exit_codes.ERROR_TOO_BIG)

@process_handler(priority=450, exit_codes=ExitCode(1000, 'Unicorn encountered'))
def a_magic_unicorn_appeared(self, node):
"""As we all know unicorns do not exist so we should never have to deal with it."""
raise RuntimeError('this handler should never even have been called')

@process_handler(priority=400, exit_codes=ArithmeticAddCalculation.exit_codes.ERROR_NEGATIVE_NUMBER)
def error_negative_sum(self, node):
"""What even is a negative number, how can I have minus three melons?!."""
self.ctx.inputs.x = Int(abs(node.inputs.x.value))
self.ctx.inputs.y = Int(abs(node.inputs.y.value))
return ProcessHandlerReport(True)


class NestedWorkChain(WorkChain):
Expand Down
73 changes: 38 additions & 35 deletions aiida/engine/processes/workchains/restart.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,10 @@
# For further information please visit http://www.aiida.net #
###########################################################################
"""Base implementation of `WorkChain` class that implements a simple automated restart mechanism for sub processes."""
import inspect

from aiida import orm
from aiida.common import AttributeDict
from aiida.common.lang import classproperty

from .context import ToContext, append_
from .workchain import WorkChain
Expand All @@ -28,8 +29,9 @@ class BaseRestartWorkChain(WorkChain):
are recoverable.
This work chain implements the most basic functionality to achieve this goal. It will launch the sub process,
restarting until it is completed successfully or the maximum number of iterations is reached. It can recover from
errors through error handlers that can be registered to the class through the `register_process_handler` decorator.
restarting until it is completed successfully or the maximum number of iterations is reached. After completion of
the sub process it will be inspected, and a list of process handlers are called successively. These process handlers
are defined as class methods that are decorated with :meth:`~aiida.engine.process_handler`.
The idea is to sub class this work chain and leverage the generic error handling that is implemented in the few
outline methods. The minimally required outline would look something like the following::
Expand Down Expand Up @@ -57,13 +59,21 @@ class BaseRestartWorkChain(WorkChain):
process will be run with those inputs.
The `_process_class` attribute should be set to the `Process` class that should be run in the loop.
Finally, to define handlers that will be called during the `inspect_process` simply define a class method with the
signature `(self, node)` and decorate it with the `process_handler` decorator, for example::
@process_handler
def handle_problem(self, node):
if some_problem:
self.ctx.inputs = improved_inputs
return ProcessHandlerReport()
The `process_handler` and `ProcessHandlerReport` support various arguments to control the flow of the logic of the
`inspect_process`. Refer to their respective documentation for details.
"""

_verbose = False
_process_class = None

_handler_entry_point = None
__handlers = tuple()
_considered_handlers_extra = 'considered_handlers'

@classmethod
def define(cls, spec):
Expand Down Expand Up @@ -113,12 +123,12 @@ def run_process(self):
inputs = self._wrap_bare_dict_inputs(self._process_class.spec().inputs, unwrapped_inputs)
node = self.submit(self._process_class, **inputs)

# Add a new empty list to the `called_process_handlers` extra. If any errors handled registered through the
# `register_process_handler` decorator return an `ProcessHandlerReport`, their name will be appended to that
# list.
called_process_handlers = self.node.get_extra('called_process_handlers', [])
called_process_handlers.append([])
self.node.set_extra('called_process_handlers', called_process_handlers)
# Add a new empty list to the `BaseRestartWorkChain._considered_handlers_extra` extra. This will contain the
# name and return value of all class methods, decorated with `process_handler`, that are called during
# the `inspect_process` outline step.
considered_handlers = self.node.get_extra(self._considered_handlers_extra, [])
considered_handlers.append([])
self.node.set_extra(self._considered_handlers_extra, considered_handlers)

self.report('launching {}<{}> iteration #{}'.format(self.ctx.process_name, node.pk, self.ctx.iteration))

Expand Down Expand Up @@ -153,18 +163,16 @@ def inspect_process(self): # pylint: disable=inconsistent-return-statements,too
if node.is_killed:
return self.exit_codes.ERROR_SUB_PROCESS_KILLED # pylint: disable=no-member

# Sort the handlers with a priority defined, based on their priority in reverse order
handlers = [handler for handler in self._handlers if handler.priority]
handlers = sorted(handlers, key=lambda x: x.priority, reverse=True)

last_report = None

for handler in handlers:
# Sort the handlers with a priority defined, based on their priority in reverse order
for handler in sorted(self._handlers(), key=lambda handler: handler.priority, reverse=True):

report = handler.method(self, node)
# Always pass the `node` as args because the `process_handler` decorator relies on this behavior
report = handler(node)

if report is not None and not isinstance(report, ProcessHandlerReport):
name = handler.method.__name__
name = handler.__name__
raise RuntimeError('handler `{}` returned a value that is not a ProcessHandlerReport'.format(name))

# If an actual report was returned, save it so it is not overridden by next handler returning `None`
Expand Down Expand Up @@ -234,8 +242,6 @@ def results(self): # pylint: disable=inconsistent-return-statements
name, self.ctx.process_name, node.pk))
else:
self.out(name, output)
if self._verbose:
self.report("attaching the node {}<{}> as '{}'".format(output.__class__.__name__, output.pk, name))

def __init__(self, *args, **kwargs):
"""Construct the instance."""
Expand All @@ -245,23 +251,20 @@ def __init__(self, *args, **kwargs):
if self._process_class is None or not issubclass(self._process_class, Process):
raise ValueError('no valid Process class defined for `_process_class` attribute')

@classproperty
def _handlers(cls): # pylint: disable=no-self-argument
"""Return the tuple of all registered handlers for this class and of any parent class.
def _handlers(self):
"""Return the list of all methods decorated with the `process_handler` decorator.
:return: tuple of handler methods
:return: list of process handler methods
"""
return getattr(super(), '__handlers', tuple()) + cls.__handlers
from .utils import process_handler

@classmethod
def register_handler(cls, name, handler):
"""Register a new handler to this class.
handlers = []

:param name: the name under which to register the handler
:param handler: a method with the signature `self, node`.
"""
setattr(cls, name, handler)
cls.__handlers = cls.__handlers + (handler,)
for method in inspect.getmembers(self, predicate=inspect.ismethod):
if hasattr(method[1], 'decorator') and method[1].decorator == process_handler: # pylint: disable=comparison-with-callable
handlers.append(method[1])

return handlers

def on_terminated(self):
"""Clean the working directories of all child calculation jobs if `clean_workdir=True` in the inputs."""
Expand Down
98 changes: 41 additions & 57 deletions aiida/engine/processes/workchains/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,27 +9,13 @@
###########################################################################
"""Utilities for `WorkChain` implementations."""
from collections import namedtuple
from functools import wraps
from functools import partial
from inspect import getfullargspec
from wrapt import decorator

from ..exit_code import ExitCode

__all__ = ('ProcessHandler', 'ProcessHandlerReport', 'register_process_handler')

ProcessHandler = namedtuple('ProcessHandler', 'method priority exit_codes')
ProcessHandler.__new__.__defaults__ = (None, 0, None)
"""A namedtuple to define a process handler for a :class:`aiida.engine.BaseRestartWorkChain`.
The `method` element refers to a function decorated by the `register_process_handler` that has turned it into a bound
method of the target `WorkChain` class. The method takes an instance of a :class:`~aiida.orm.ProcessNode` as its sole
argument. The method can return an optional `ProcessHandlerReport` to signal whether other handlers still need to be
considered or whether the work chain should be terminated immediately. The priority determines in which order the
handler methods are executed, with the higher priority being executed first.
:param method: the decorated process handling function turned into a bound work chain class method
:param priority: integer denoting the process handler's priority
:param exit_codes: single or list of `ExitCode` instances. A handler that defines this should only be called for a given
completed process if its exit status is a member of `exit_codes`.
"""
__all__ = ('ProcessHandlerReport', 'process_handler')

ProcessHandlerReport = namedtuple('ProcessHandlerReport', 'do_break exit_code')
ProcessHandlerReport.__new__.__defaults__ = (False, ExitCode())
Expand All @@ -49,16 +35,12 @@
"""


def register_process_handler(cls, *, priority=None, exit_codes=None):
"""Decorator to register a function as a handler for a :class:`~aiida.engine.BaseRestartWorkChain`.
def process_handler(wrapped=None, *, priority=None, exit_codes=None):
"""Decorator to register a :class:`~aiida.engine.BaseRestartWorkChain` instance method as a process handler.
The function expects two arguments, a work chain class and a priortity. The decorator will add the function as a
class method to the work chain class and add an :class:`~aiida.engine.ProcessHandler` tuple to the `__handlers`
private attribute of the work chain. During the `inspect_process` outline method, the work chain will retrieve all
the registered handlers through the :meth:`~aiida.engine.BaseRestartWorkChain._handlers` property and loop over them
sorted with respect to their priority in reverse. If the work chain class defines the
:attr:`~aiida.engine.BaseRestartWorkChain._verbose` attribute and is set to `True`, a report message will be fired
when the process handler is executed.
The decorator will validate the `priority` and `exit_codes` optional keyword arguments and then add itself as an
attribute to the `wrapped` instance method. This is used in the `inspect_process` to return all instance methods of
the class that have been decorated by this function and therefore are considered to be process handlers.
Requirements on the function signature of process handling functions. The function to which the decorator is applied
needs to take two arguments:
Expand All @@ -79,6 +61,9 @@ class method to the work chain class and add an :class:`~aiida.engine.ProcessHan
code set on the `node` does not appear in the `exit_codes`. This is useful to have a handler called only when
the process failed with a specific exit code.
"""
if wrapped is None:
return partial(process_handler, priority=priority, exit_codes=exit_codes)

if priority is None:
priority = 0

Expand All @@ -91,41 +76,40 @@ class method to the work chain class and add an :class:`~aiida.engine.ProcessHan
if exit_codes and any([not isinstance(exit_code, ExitCode) for exit_code in exit_codes]):
raise TypeError('`exit_codes` should be an instance of `ExitCode` or list thereof.')

def process_handler_decorator(handler):
"""Decorate a function to dynamically register a handler to a `WorkChain` class."""

@wraps(handler)
def process_handler(self, node):
"""Wrap handler to add a log to the report if the handler is called and verbosity is turned on."""
verbose = hasattr(cls, '_verbose') and cls._verbose # pylint: disable=protected-access
handler_args = getfullargspec(wrapped)[0]

if exit_codes and node.exit_status not in [exit_code.status for exit_code in exit_codes]:
if verbose:
self.report('skipped {} because of exit code filter'.format(handler.__name__))
return None
if len(handler_args) != 2:
raise TypeError('process handler `{}` has invalid signature: should be (self, node)'.format(wrapped.__name__))

if verbose:
self.report('({}){}'.format(priority, handler.__name__))
wrapped.decorator = process_handler
wrapped.priority = priority if priority else 0

result = handler(self, node)
@decorator
def wrapper(wrapped, instance, args, kwargs):

# If a handler report is returned, attach the handler's name to node's attributes
if isinstance(result, ProcessHandlerReport):
try:
called_process_handlers = self.node.get_extra('called_process_handlers', [])
current_process = called_process_handlers[-1]
except IndexError:
# The extra was never initialized, so we skip this functionality
pass
else:
# Append the name of the handler to the last list in `called_process_handlers` and save it
current_process.append(handler.__name__)
self.node.set_extra('called_process_handlers', called_process_handlers)
# When the handler will be called by the `BaseRestartWorkChain` it will pass the node as the only argument
node = args[0]

return result
if exit_codes and node.exit_status not in [exit_code.status for exit_code in exit_codes]:
result = None
else:
result = wrapped(*args, **kwargs)

cls.register_handler(handler.__name__, ProcessHandler(process_handler, priority, exit_codes))
# Append the name and return value of the current process handler to the `considered_handlers` extra.
try:
considered_handlers = instance.node.get_extra(instance._considered_handlers_extra, []) # pylint: disable=protected-access
current_process = considered_handlers[-1]
except IndexError:
# The extra was never initialized, so we skip this functionality
pass
else:
# Append the name of the handler to the last list in `considered_handlers` and save it
serialized = result
if isinstance(serialized, ProcessHandlerReport):
serialized = {'do_break': serialized.do_break, 'exit_status': serialized.exit_code.status}
current_process.append((wrapped.__name__, serialized))
instance.node.set_extra(instance._considered_handlers_extra, considered_handlers) # pylint: disable=protected-access

return process_handler
return result

return process_handler_decorator
return wrapper(wrapped) # pylint: disable=no-value-for-parameter
Loading

0 comments on commit 617b2df

Please sign in to comment.