diff --git a/.ci/test_daemon.py b/.ci/test_daemon.py index eba74ccee6..4e9dbff2a7 100644 --- a/.ci/test_daemon.py +++ b/.ci/test_daemon.py @@ -304,6 +304,14 @@ def run_base_restart_workchain(): assert node.exit_status == ArithmeticAddBaseWorkChain.exit_codes.ERROR_TOO_BIG.status, node.exit_status # pylint: disable=no-member assert len(node.called) == 1 + # Check that overriding default handler enabled status works + inputs['add']['y'] = Int(1) + inputs['handler_overrides'] = Dict(dict={'disabled_handler': True}) + results, node = run.get_node(ArithmeticAddBaseWorkChain, **inputs) + assert not node.is_finished_ok, node.process_state + assert node.exit_status == ArithmeticAddBaseWorkChain.exit_codes.ERROR_ENABLED_DOOM.status, node.exit_status # pylint: disable=no-member + assert len(node.called) == 1 + def main(): """Launch a bunch of calculation jobs and workchains.""" diff --git a/.ci/workchains.py b/.ci/workchains.py index a3aa5960fa..110334f0ae 100644 --- a/.ci/workchains.py +++ b/.ci/workchains.py @@ -38,6 +38,7 @@ def define(cls, spec): cls.results, ) spec.exit_code(100, 'ERROR_TOO_BIG', message='The sum was too big.') + spec.exit_code(110, 'ERROR_ENABLED_DOOM', message='You should not have done that.') def setup(self): """Call the `setup` of the `BaseRestartWorkChain` and then create the inputs dictionary in `self.ctx.inputs`. @@ -54,6 +55,11 @@ def sanity_check_not_too_big(self, node): if node.is_finished_ok and node.outputs.sum > 10: return ProcessHandlerReport(True, self.exit_codes.ERROR_TOO_BIG) + @process_handler(priority=460, enabled=False) + def disabled_handler(self, node): + """By default this is not enabled and so should never be called, irrespective of exit codes of sub process.""" + return ProcessHandlerReport(True, self.exit_codes.ERROR_ENABLED_DOOM) + @process_handler(priority=450, exit_codes=ExitCode(1000, 'Unicorn encountered')) def a_magic_unicorn_appeared(self, node): """As we all know unicorns do not exist so we should never have to deal with it.""" diff --git a/aiida/engine/processes/workchains/restart.py b/aiida/engine/processes/workchains/restart.py index 919d1a42f5..481f056755 100644 --- a/aiida/engine/processes/workchains/restart.py +++ b/aiida/engine/processes/workchains/restart.py @@ -8,18 +8,47 @@ # For further information please visit http://www.aiida.net # ########################################################################### """Base implementation of `WorkChain` class that implements a simple automated restart mechanism for sub processes.""" -import inspect +import functools from aiida import orm from aiida.common import AttributeDict from .context import ToContext, append_ from .workchain import WorkChain -from .utils import ProcessHandlerReport +from .utils import ProcessHandlerReport, process_handler __all__ = ('BaseRestartWorkChain',) +def validate_handler_overrides(process_class, handler_overrides, ctx): # pylint: disable=inconsistent-return-statements,unused-argument + """Validator for the `handler_overrides` input port of the `BaseRestartWorkChain. + + The `handler_overrides` should be a dictionary where keys are strings that are the name of a process handler, i.e. a + instance method of the `process_class` that has been decorated with the `process_handler` decorator. The values + should be boolean. + + .. note:: the normal signature of a port validator is `(value, ctx)` but since for the validation here we need a + reference to the process class, we add it and the class is bound to the method in the port declaration in the + `define` method. + + :param process_class: the `BaseRestartWorkChain` (sub) class + :param handler_overrides: the input `Dict` node + :param ctx: the `PortNamespace` in which the port is embedded + """ + if not handler_overrides: + return + + for handler, override in handler_overrides.get_dict().items(): + if not isinstance(handler, str): + return 'The key `{}` is not a string.'.format(handler) + + if not process_class.is_process_handler(handler): + return 'The key `{}` is not a process handler of {}'.format(handler, process_class) + + if not isinstance(override, bool): + return 'The value of key `{}` is not a boolean.'.format(handler) + + class BaseRestartWorkChain(WorkChain): """Base restart work chain. @@ -84,6 +113,11 @@ def define(cls, spec): help='Maximum number of iterations the work chain will restart the process to finish successfully.') spec.input('clean_workdir', valid_type=orm.Bool, default=lambda: orm.Bool(False), help='If `True`, work directories of all called calculation jobs will be cleaned at the end of execution.') + spec.input('handler_overrides', + valid_type=orm.Dict, required=False, validator=functools.partial(validate_handler_overrides, cls), + help='Mapping where keys are process handler names and the values are a boolean, where `True` will enable ' + 'the corresponding handler and `False` will disable it. This overrides the default value set by the ' + '`enabled` keyword of the `process_handler` decorator with which the method is decorated.') spec.exit_code(301, 'ERROR_SUB_PROCESS_EXCEPTED', message='The sub process excepted.') spec.exit_code(302, 'ERROR_SUB_PROCESS_KILLED', @@ -95,6 +129,8 @@ def define(cls, spec): def setup(self): """Initialize context variables that are used during the logical flow of the `BaseRestartWorkChain`.""" + overrides = self.inputs.handler_overrides.get_dict() if 'handler_overrides' in self.inputs else {} + self.ctx.handler_overrides = overrides self.ctx.process_name = self._process_class.__name__ self.ctx.unhandled_failure = False self.ctx.is_finished = False @@ -166,10 +202,16 @@ def inspect_process(self): # pylint: disable=inconsistent-return-statements,too last_report = None # Sort the handlers with a priority defined, based on their priority in reverse order - for handler in sorted(self._handlers(), key=lambda handler: handler.priority, reverse=True): + for handler in sorted(self.get_process_handlers(), key=lambda handler: handler.priority, reverse=True): + + # Skip if the handler is enabled, either explicitly through `handler_overrides` or by default + if not self.ctx.handler_overrides.get(handler.__name__, handler.enabled): + continue - # Always pass the `node` as args because the `process_handler` decorator relies on this behavior - report = handler(node) + # Even though the `handler` is an instance method, the `get_process_handlers` method returns unbound methods + # so we have to pass in `self` manually. Also, always pass the `node` as an argument because the + # `process_handler` decorator with which the handler is decorated relies on this behavior. + report = handler(self, node) if report is not None and not isinstance(report, ProcessHandlerReport): name = handler.__name__ @@ -251,20 +293,25 @@ def __init__(self, *args, **kwargs): if self._process_class is None or not issubclass(self._process_class, Process): raise ValueError('no valid Process class defined for `_process_class` attribute') - def _handlers(self): - """Return the list of all methods decorated with the `process_handler` decorator. + @classmethod + def is_process_handler(cls, process_handler_name): + """Return whether the given method name corresponds to a process handler of this class. - :return: list of process handler methods + :param process_handler_name: string name of the instance method + :return: boolean, True if corresponds to process handler, False otherwise """ - from .utils import process_handler - - handlers = [] + # pylint: disable=comparison-with-callable + if isinstance(process_handler_name, str): + handler = getattr(cls, process_handler_name, {}) + else: + handler = process_handler_name - for method in inspect.getmembers(self, predicate=inspect.ismethod): - if hasattr(method[1], 'decorator') and method[1].decorator == process_handler: # pylint: disable=comparison-with-callable - handlers.append(method[1]) + return getattr(handler, 'decorator', None) == process_handler - return handlers + @classmethod + def get_process_handlers(cls): + from inspect import getmembers + return [method[1] for method in getmembers(cls) if cls.is_process_handler(method[1])] def on_terminated(self): """Clean the working directories of all child calculation jobs if `clean_workdir=True` in the inputs.""" diff --git a/aiida/engine/processes/workchains/utils.py b/aiida/engine/processes/workchains/utils.py index 99c131dbf0..9869aa3a36 100644 --- a/aiida/engine/processes/workchains/utils.py +++ b/aiida/engine/processes/workchains/utils.py @@ -35,7 +35,7 @@ """ -def process_handler(wrapped=None, *, priority=None, exit_codes=None): +def process_handler(wrapped=None, *, priority=0, exit_codes=None, enabled=True): """Decorator to register a :class:`~aiida.engine.BaseRestartWorkChain` instance method as a process handler. The decorator will validate the `priority` and `exit_codes` optional keyword arguments and then add itself as an @@ -56,25 +56,29 @@ def process_handler(wrapped=None, *, priority=None, exit_codes=None): :param cls: the work chain class to register the process handler with :param priority: optional integer that defines the order in which registered handlers will be called during the - handling of a finished process. Higher priorities will be handled first. + handling of a finished process. Higher priorities will be handled first. Default value is `0`. Multiple handlers + with the same priority is allowed, but the order of those is not well defined. :param exit_codes: single or list of `ExitCode` instances. If defined, the handler will return `None` if the exit code set on the `node` does not appear in the `exit_codes`. This is useful to have a handler called only when the process failed with a specific exit code. + :param enabled: boolean, by default True, which will cause the handler to be called during `inspect_process`. When + set to `False`, the handler will be skipped. This static value can be overridden on a per work chain instance + basis through the input `handler_overrides`. """ if wrapped is None: - return partial(process_handler, priority=priority, exit_codes=exit_codes) - - if priority is None: - priority = 0 + return partial(process_handler, priority=priority, exit_codes=exit_codes, enabled=enabled) if not isinstance(priority, int): - raise TypeError('the `priority` should be an integer.') + raise TypeError('the `priority` keyword should be an integer.') if exit_codes is not None and not isinstance(exit_codes, list): exit_codes = [exit_codes] if exit_codes and any([not isinstance(exit_code, ExitCode) for exit_code in exit_codes]): - raise TypeError('`exit_codes` should be an instance of `ExitCode` or list thereof.') + raise TypeError('`exit_codes` keyword should be an instance of `ExitCode` or list thereof.') + + if not isinstance(enabled, bool): + raise TypeError('the `enabled` keyword should be a boolean.') handler_args = getfullargspec(wrapped)[0] @@ -82,7 +86,8 @@ def process_handler(wrapped=None, *, priority=None, exit_codes=None): raise TypeError('process handler `{}` has invalid signature: should be (self, node)'.format(wrapped.__name__)) wrapped.decorator = process_handler - wrapped.priority = priority if priority else 0 + wrapped.priority = priority + wrapped.enabled = enabled @decorator def wrapper(wrapped, instance, args, kwargs): diff --git a/tests/engine/processes/workchains/test_restart.py b/tests/engine/processes/workchains/test_restart.py new file mode 100644 index 0000000000..5cbabaa73b --- /dev/null +++ b/tests/engine/processes/workchains/test_restart.py @@ -0,0 +1,51 @@ +# -*- coding: utf-8 -*- +########################################################################### +# Copyright (c), The AiiDA team. All rights reserved. # +# This file is part of the AiiDA code. # +# # +# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # +# For further information on the license, see the LICENSE.txt file # +# For further information please visit http://www.aiida.net # +########################################################################### +"""Tests for `aiida.engine.processes.workchains.restart` module.""" +from aiida.backends.testbase import AiidaTestCase +from aiida.engine.processes.workchains.restart import BaseRestartWorkChain +from aiida.engine.processes.workchains.utils import process_handler + + +class TestBaseRestartWorkChain(AiidaTestCase): + """Tests for the `BaseRestartWorkChain` class.""" + + @staticmethod + def test_is_process_handler(): + """Test the `BaseRestartWorkChain.is_process_handler` class method.""" + + class SomeWorkChain(BaseRestartWorkChain): + """Dummy class.""" + + @process_handler() + def handler_a(self, node): + pass + + def not_a_handler(self, node): + pass + + assert SomeWorkChain.is_process_handler('handler_a') + assert not SomeWorkChain.is_process_handler('not_a_handler') + assert not SomeWorkChain.is_process_handler('unexisting_method') + + @staticmethod + def test_get_process_handler(): + """Test the `BaseRestartWorkChain.get_process_handlers` class method.""" + + class SomeWorkChain(BaseRestartWorkChain): + """Dummy class.""" + + @process_handler + def handler_a(self, node): + pass + + def not_a_handler(self, node): + pass + + assert [handler.__name__ for handler in SomeWorkChain.get_process_handlers()] == ['handler_a'] diff --git a/tests/engine/processes/workchains/test_utils.py b/tests/engine/processes/workchains/test_utils.py index a047e5a754..00f1e127a3 100644 --- a/tests/engine/processes/workchains/test_utils.py +++ b/tests/engine/processes/workchains/test_utils.py @@ -170,9 +170,44 @@ def _(self, node): process = ArithmeticAddBaseWorkChain() # Loop over all handlers, which should be just the one, and call it with the two different nodes - for handler in process._handlers(): # pylint: disable=protected-access + for handler in process.get_process_handlers(): # The `node_match` should match the `exit_codes` filter and so return a report instance - assert isinstance(handler(node_match), ProcessHandlerReport) + assert isinstance(handler(process, node_match), ProcessHandlerReport) # The `node_skip` has a wrong exit status and so should get skipped, returning `None` - assert handler(node_skip) is None + assert handler(process, node_skip) is None + + def test_enabled_keyword_only(self): + """The `enabled` should be keyword only.""" + with self.assertRaises(TypeError): + + class SomeWorkChain(BaseRestartWorkChain): + + @process_handler(True) # pylint: disable=too-many-function-args + def _(self, node): + pass + + class SomeWorkChain(BaseRestartWorkChain): + + @process_handler(enabled=False) + def _(self, node): + pass + + def test_enabled(self): + """The `enabled` should be keyword only.""" + + class SomeWorkChain(BaseRestartWorkChain): + + @process_handler + def enabled_handler(self, node): + pass + + assert SomeWorkChain.enabled_handler.enabled # pylint: disable=no-member + + class SomeWorkChain(BaseRestartWorkChain): + + @process_handler(enabled=False) + def disabled_handler(self, node): + pass + + assert not SomeWorkChain.disabled_handler.enabled # pylint: disable=no-member