Skip to content

Commit

Permalink
BaseRestartWorkChain: add method to enable/disable process handlers
Browse files Browse the repository at this point in the history
The `process_handler` decorator is updated with a new keyword argument
`enabled` which is by default `True`. By setting it to `False` the
process handler is disabled and will always be skipped during the
`inspect_process` outline step. This default can be overridden on a per
instance basis through a new input called `handler_overrides`.

The base spec of `BaseRestartWorkChain` defines this new base input
called `handler_overrides` which takes a mapping of process handler
names to a boolean. For `True` the process handler is enabled and for
`False` it is disabled, where disabled means that during the
`inspect_process` call it is not called but skipped. The validator on
the port ensures that the keys correspond to actual instance methods of
the work chain that are decorated with `process_handler`. The value
specified in `handler_overrides`, as the name suggests, override the
default value specified in the decorator.
  • Loading branch information
sphuber committed Feb 24, 2020
1 parent 061f6ae commit 2789ea9
Show file tree
Hide file tree
Showing 6 changed files with 179 additions and 27 deletions.
8 changes: 8 additions & 0 deletions .ci/test_daemon.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,6 +304,14 @@ def run_base_restart_workchain():
assert node.exit_status == ArithmeticAddBaseWorkChain.exit_codes.ERROR_TOO_BIG.status, node.exit_status # pylint: disable=no-member
assert len(node.called) == 1

# Check that overriding default handler enabled status works
inputs['add']['y'] = Int(1)
inputs['handler_overrides'] = Dict(dict={'disabled_handler': True})
results, node = run.get_node(ArithmeticAddBaseWorkChain, **inputs)
assert not node.is_finished_ok, node.process_state
assert node.exit_status == ArithmeticAddBaseWorkChain.exit_codes.ERROR_ENABLED_DOOM.status, node.exit_status # pylint: disable=no-member
assert len(node.called) == 1


def main():
"""Launch a bunch of calculation jobs and workchains."""
Expand Down
6 changes: 6 additions & 0 deletions .ci/workchains.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ def define(cls, spec):
cls.results,
)
spec.exit_code(100, 'ERROR_TOO_BIG', message='The sum was too big.')
spec.exit_code(110, 'ERROR_ENABLED_DOOM', message='You should not have done that.')

def setup(self):
"""Call the `setup` of the `BaseRestartWorkChain` and then create the inputs dictionary in `self.ctx.inputs`.
Expand All @@ -54,6 +55,11 @@ def sanity_check_not_too_big(self, node):
if node.is_finished_ok and node.outputs.sum > 10:
return ProcessHandlerReport(True, self.exit_codes.ERROR_TOO_BIG)

@process_handler(priority=460, enabled=False)
def disabled_handler(self, node):
"""By default this is not enabled and so should never be called, irrespective of exit codes of sub process."""
return ProcessHandlerReport(True, self.exit_codes.ERROR_ENABLED_DOOM)

@process_handler(priority=450, exit_codes=ExitCode(1000, 'Unicorn encountered'))
def a_magic_unicorn_appeared(self, node):
"""As we all know unicorns do not exist so we should never have to deal with it."""
Expand Down
77 changes: 62 additions & 15 deletions aiida/engine/processes/workchains/restart.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,18 +8,47 @@
# For further information please visit http://www.aiida.net #
###########################################################################
"""Base implementation of `WorkChain` class that implements a simple automated restart mechanism for sub processes."""
import inspect
import functools

from aiida import orm
from aiida.common import AttributeDict

from .context import ToContext, append_
from .workchain import WorkChain
from .utils import ProcessHandlerReport
from .utils import ProcessHandlerReport, process_handler

__all__ = ('BaseRestartWorkChain',)


def validate_handler_overrides(process_class, handler_overrides, ctx): # pylint: disable=inconsistent-return-statements,unused-argument
"""Validator for the `handler_overrides` input port of the `BaseRestartWorkChain.
The `handler_overrides` should be a dictionary where keys are strings that are the name of a process handler, i.e. a
instance method of the `process_class` that has been decorated with the `process_handler` decorator. The values
should be boolean.
.. note:: the normal signature of a port validator is `(value, ctx)` but since for the validation here we need a
reference to the process class, we add it and the class is bound to the method in the port declaration in the
`define` method.
:param process_class: the `BaseRestartWorkChain` (sub) class
:param handler_overrides: the input `Dict` node
:param ctx: the `PortNamespace` in which the port is embedded
"""
if not handler_overrides:
return

for handler, override in handler_overrides.get_dict().items():
if not isinstance(handler, str):
return 'The key `{}` is not a string.'.format(handler)

if not process_class.is_process_handler(handler):
return 'The key `{}` is not a process handler of {}'.format(handler, process_class)

if not isinstance(override, bool):
return 'The value of key `{}` is not a boolean.'.format(handler)


class BaseRestartWorkChain(WorkChain):
"""Base restart work chain.
Expand Down Expand Up @@ -84,6 +113,11 @@ def define(cls, spec):
help='Maximum number of iterations the work chain will restart the process to finish successfully.')
spec.input('clean_workdir', valid_type=orm.Bool, default=lambda: orm.Bool(False),
help='If `True`, work directories of all called calculation jobs will be cleaned at the end of execution.')
spec.input('handler_overrides',
valid_type=orm.Dict, required=False, validator=functools.partial(validate_handler_overrides, cls),
help='Mapping where keys are process handler names and the values are a boolean, where `True` will enable '
'the corresponding handler and `False` will disable it. This overrides the default value set by the '
'`enabled` keyword of the `process_handler` decorator with which the method is decorated.')
spec.exit_code(301, 'ERROR_SUB_PROCESS_EXCEPTED',
message='The sub process excepted.')
spec.exit_code(302, 'ERROR_SUB_PROCESS_KILLED',
Expand All @@ -95,6 +129,8 @@ def define(cls, spec):

def setup(self):
"""Initialize context variables that are used during the logical flow of the `BaseRestartWorkChain`."""
overrides = self.inputs.handler_overrides.get_dict() if 'handler_overrides' in self.inputs else {}
self.ctx.handler_overrides = overrides
self.ctx.process_name = self._process_class.__name__
self.ctx.unhandled_failure = False
self.ctx.is_finished = False
Expand Down Expand Up @@ -166,10 +202,16 @@ def inspect_process(self): # pylint: disable=inconsistent-return-statements,too
last_report = None

# Sort the handlers with a priority defined, based on their priority in reverse order
for handler in sorted(self._handlers(), key=lambda handler: handler.priority, reverse=True):
for handler in sorted(self.get_process_handlers(), key=lambda handler: handler.priority, reverse=True):

# Skip if the handler is enabled, either explicitly through `handler_overrides` or by default
if not self.ctx.handler_overrides.get(handler.__name__, handler.enabled):
continue

# Always pass the `node` as args because the `process_handler` decorator relies on this behavior
report = handler(node)
# Even though the `handler` is an instance method, the `get_process_handlers` method returns unbound methods
# so we have to pass in `self` manually. Also, always pass the `node` as an argument because the
# `process_handler` decorator with which the handler is decorated relies on this behavior.
report = handler(self, node)

if report is not None and not isinstance(report, ProcessHandlerReport):
name = handler.__name__
Expand Down Expand Up @@ -251,20 +293,25 @@ def __init__(self, *args, **kwargs):
if self._process_class is None or not issubclass(self._process_class, Process):
raise ValueError('no valid Process class defined for `_process_class` attribute')

def _handlers(self):
"""Return the list of all methods decorated with the `process_handler` decorator.
@classmethod
def is_process_handler(cls, process_handler_name):
"""Return whether the given method name corresponds to a process handler of this class.
:return: list of process handler methods
:param process_handler_name: string name of the instance method
:return: boolean, True if corresponds to process handler, False otherwise
"""
from .utils import process_handler

handlers = []
# pylint: disable=comparison-with-callable
if isinstance(process_handler_name, str):
handler = getattr(cls, process_handler_name, {})
else:
handler = process_handler_name

for method in inspect.getmembers(self, predicate=inspect.ismethod):
if hasattr(method[1], 'decorator') and method[1].decorator == process_handler: # pylint: disable=comparison-with-callable
handlers.append(method[1])
return getattr(handler, 'decorator', None) == process_handler

return handlers
@classmethod
def get_process_handlers(cls):
from inspect import getmembers
return [method[1] for method in getmembers(cls) if cls.is_process_handler(method[1])]

def on_terminated(self):
"""Clean the working directories of all child calculation jobs if `clean_workdir=True` in the inputs."""
Expand Down
23 changes: 14 additions & 9 deletions aiida/engine/processes/workchains/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
"""


def process_handler(wrapped=None, *, priority=None, exit_codes=None):
def process_handler(wrapped=None, *, priority=0, exit_codes=None, enabled=True):
"""Decorator to register a :class:`~aiida.engine.BaseRestartWorkChain` instance method as a process handler.
The decorator will validate the `priority` and `exit_codes` optional keyword arguments and then add itself as an
Expand All @@ -56,33 +56,38 @@ def process_handler(wrapped=None, *, priority=None, exit_codes=None):
:param cls: the work chain class to register the process handler with
:param priority: optional integer that defines the order in which registered handlers will be called during the
handling of a finished process. Higher priorities will be handled first.
handling of a finished process. Higher priorities will be handled first. Default value is `0`. Multiple handlers
with the same priority is allowed, but the order of those is not well defined.
:param exit_codes: single or list of `ExitCode` instances. If defined, the handler will return `None` if the exit
code set on the `node` does not appear in the `exit_codes`. This is useful to have a handler called only when
the process failed with a specific exit code.
:param enabled: boolean, by default True, which will cause the handler to be called during `inspect_process`. When
set to `False`, the handler will be skipped. This static value can be overridden on a per work chain instance
basis through the input `handler_overrides`.
"""
if wrapped is None:
return partial(process_handler, priority=priority, exit_codes=exit_codes)

if priority is None:
priority = 0
return partial(process_handler, priority=priority, exit_codes=exit_codes, enabled=enabled)

if not isinstance(priority, int):
raise TypeError('the `priority` should be an integer.')
raise TypeError('the `priority` keyword should be an integer.')

if exit_codes is not None and not isinstance(exit_codes, list):
exit_codes = [exit_codes]

if exit_codes and any([not isinstance(exit_code, ExitCode) for exit_code in exit_codes]):
raise TypeError('`exit_codes` should be an instance of `ExitCode` or list thereof.')
raise TypeError('`exit_codes` keyword should be an instance of `ExitCode` or list thereof.')

if not isinstance(enabled, bool):
raise TypeError('the `enabled` keyword should be a boolean.')

handler_args = getfullargspec(wrapped)[0]

if len(handler_args) != 2:
raise TypeError('process handler `{}` has invalid signature: should be (self, node)'.format(wrapped.__name__))

wrapped.decorator = process_handler
wrapped.priority = priority if priority else 0
wrapped.priority = priority
wrapped.enabled = enabled

@decorator
def wrapper(wrapped, instance, args, kwargs):
Expand Down
51 changes: 51 additions & 0 deletions tests/engine/processes/workchains/test_restart.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
# -*- coding: utf-8 -*-
###########################################################################
# Copyright (c), The AiiDA team. All rights reserved. #
# This file is part of the AiiDA code. #
# #
# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core #
# For further information on the license, see the LICENSE.txt file #
# For further information please visit http://www.aiida.net #
###########################################################################
"""Tests for `aiida.engine.processes.workchains.restart` module."""
from aiida.backends.testbase import AiidaTestCase
from aiida.engine.processes.workchains.restart import BaseRestartWorkChain
from aiida.engine.processes.workchains.utils import process_handler


class TestBaseRestartWorkChain(AiidaTestCase):
"""Tests for the `BaseRestartWorkChain` class."""

@staticmethod
def test_is_process_handler():
"""Test the `BaseRestartWorkChain.is_process_handler` class method."""

class SomeWorkChain(BaseRestartWorkChain):
"""Dummy class."""

@process_handler()
def handler_a(self, node):
pass

def not_a_handler(self, node):
pass

assert SomeWorkChain.is_process_handler('handler_a')
assert not SomeWorkChain.is_process_handler('not_a_handler')
assert not SomeWorkChain.is_process_handler('unexisting_method')

@staticmethod
def test_get_process_handler():
"""Test the `BaseRestartWorkChain.get_process_handlers` class method."""

class SomeWorkChain(BaseRestartWorkChain):
"""Dummy class."""

@process_handler
def handler_a(self, node):
pass

def not_a_handler(self, node):
pass

assert [handler.__name__ for handler in SomeWorkChain.get_process_handlers()] == ['handler_a']
41 changes: 38 additions & 3 deletions tests/engine/processes/workchains/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,9 +170,44 @@ def _(self, node):
process = ArithmeticAddBaseWorkChain()

# Loop over all handlers, which should be just the one, and call it with the two different nodes
for handler in process._handlers(): # pylint: disable=protected-access
for handler in process.get_process_handlers():
# The `node_match` should match the `exit_codes` filter and so return a report instance
assert isinstance(handler(node_match), ProcessHandlerReport)
assert isinstance(handler(process, node_match), ProcessHandlerReport)

# The `node_skip` has a wrong exit status and so should get skipped, returning `None`
assert handler(node_skip) is None
assert handler(process, node_skip) is None

def test_enabled_keyword_only(self):
"""The `enabled` should be keyword only."""
with self.assertRaises(TypeError):

class SomeWorkChain(BaseRestartWorkChain):

@process_handler(True) # pylint: disable=too-many-function-args
def _(self, node):
pass

class SomeWorkChain(BaseRestartWorkChain):

@process_handler(enabled=False)
def _(self, node):
pass

def test_enabled(self):
"""The `enabled` should be keyword only."""

class SomeWorkChain(BaseRestartWorkChain):

@process_handler
def enabled_handler(self, node):
pass

assert SomeWorkChain.enabled_handler.enabled # pylint: disable=no-member

class SomeWorkChain(BaseRestartWorkChain):

@process_handler(enabled=False)
def disabled_handler(self, node):
pass

assert not SomeWorkChain.disabled_handler.enabled # pylint: disable=no-member

0 comments on commit 2789ea9

Please sign in to comment.