Skip to content

Commit

Permalink
BaseRestartWorkChain: allow to override priority in `handler_overri…
Browse files Browse the repository at this point in the history
…des` (#5546)

The `handler_overrides` could so far be used to override the `enabled`
keyword of the corresponding handler. This would allow to disable or
enable a handler on a per process instance basis.

A user may want to do the same for the `priority`. To make this possible
the type of the `handler_overrides` is changed where the values should
now be dictionaries where the keys `enabled` and `priority` are
supported. These can be used to override the original values declared
in the source code of the work chain.

To provide backwards-compatibility, the old syntax is still supported
and automatically converted, with a deprecation warning being displayed.
  • Loading branch information
sphuber authored May 30, 2022
1 parent 408efa7 commit 7b8c61d
Show file tree
Hide file tree
Showing 3 changed files with 104 additions and 26 deletions.
83 changes: 58 additions & 25 deletions aiida/engine/processes/workchains/restart.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,11 @@
import functools
from inspect import getmembers
from types import FunctionType
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Type, Union
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type, Union

from aiida import orm
from aiida.common import AttributeDict
from aiida.common.warnings import warn_deprecation

from .context import ToContext, append_
from .utils import ProcessHandlerReport, process_handler # pylint: disable=no-name-in-module
Expand All @@ -31,32 +32,44 @@ def validate_handler_overrides(
handler_overrides: Optional[orm.Dict],
ctx: 'PortNamespace' # pylint: disable=unused-argument
) -> Optional[str]:
"""Validator for the `handler_overrides` input port of the `BaseRestartWorkChain`.
"""Validator for the ``handler_overrides`` input port of the ``BaseRestartWorkChain``.
The `handler_overrides` should be a dictionary where keys are strings that are the name of a process handler, i.e. a
instance method of the `process_class` that has been decorated with the `process_handler` decorator. The values
should be boolean.
The ``handler_overrides`` should be a dictionary where keys are strings that are the name of a process handler, i.e.
an instance method of the ``process_class`` that has been decorated with the ``process_handler`` decorator. The
values should be a dictionary that can specify the keys ``enabled`` and ``priority``.
.. note:: the normal signature of a port validator is `(value, ctx)` but since for the validation here we need a
.. note:: the normal signature of a port validator is ``(value, ctx)`` but since for the validation here we need a
reference to the process class, we add it and the class is bound to the method in the port declaration in the
`define` method.
``define`` method.
:param process_class: the `BaseRestartWorkChain` (sub) class
:param handler_overrides: the input `Dict` node
:param ctx: the `PortNamespace` in which the port is embedded
:param process_class: the ``BaseRestartWorkChain`` (sub) class
:param handler_overrides: the input ``Dict`` node
:param ctx: the ``PortNamespace`` in which the port is embedded
"""
if not handler_overrides:
return None

for handler, override in handler_overrides.get_dict().items():
for handler, overrides in handler_overrides.get_dict().items():
if not isinstance(handler, str):
return f'The key `{handler}` is not a string.'

if not process_class.is_process_handler(handler):
return f'The key `{handler}` is not a process handler of {process_class}'

if not isinstance(override, bool):
return f'The value of key `{handler}` is not a boolean.'
if not isinstance(overrides, (bool, dict)):
return f'The value of key `{handler}` is not a boolean or dictionary.'

if isinstance(overrides, bool):
warn_deprecation(
'Setting a boolean as value for `handler_overrides` is deprecated. Use '
"`{'handler_name': {'enabled': " + f'{overrides}' + '}` instead.',
version=3
)

if isinstance(overrides, dict):
for key in overrides.keys():
if key not in ['enabled', 'priority']:
return f'The value of key `{handler}` contain keys `{key}` which is not supported.'

return None

Expand Down Expand Up @@ -135,9 +148,10 @@ def define(cls, spec: 'ProcessSpec') -> None: # type: ignore[override]
help='If `True`, work directories of all called calculation jobs will be cleaned at the end of execution.')
spec.input('handler_overrides',
valid_type=orm.Dict, required=False, validator=functools.partial(validate_handler_overrides, cls),
help='Mapping where keys are process handler names and the values are a boolean, where `True` will enable '
'the corresponding handler and `False` will disable it. This overrides the default value set by the '
'`enabled` keyword of the `process_handler` decorator with which the method is decorated.')
serializer=orm.to_aiida_type,
help='Mapping where keys are process handler names and the values are a dictionary, where each dictionary '
'can define the ``enabled`` and ``priority`` key, which can be used to toggle the values set on '
'the original process handler declaration.')
spec.exit_code(301, 'ERROR_SUB_PROCESS_EXCEPTED',
message='The sub process excepted.')
spec.exit_code(302, 'ERROR_SUB_PROCESS_KILLED',
Expand Down Expand Up @@ -225,16 +239,11 @@ def inspect_process(self) -> Optional['ExitCode']: # pylint: disable=too-many-b
last_report = None

# Sort the handlers with a priority defined, based on their priority in reverse order
get_priority = lambda handler: handler.priority
for handler in sorted(self.get_process_handlers(), key=get_priority, reverse=True):

# Skip if the handler is enabled, either explicitly through `handler_overrides` or by default
if not self.ctx.handler_overrides.get(handler.__name__, handler.enabled): # type: ignore[attr-defined]
continue
for _, handler in sorted(self.get_process_handlers_by_priority(), key=lambda e: e[0], reverse=True):

# Even though the `handler` is an instance method, the `get_process_handlers` method returns unbound methods
# so we have to pass in `self` manually. Also, always pass the `node` as an argument because the
# `process_handler` decorator with which the handler is decorated relies on this behavior.
# Even though the ``handler`` is an instance method, the ``get_process_handlers_by_priority`` method returns
# unbound methods so we have to pass in ``self`` manually. Also, always pass the ``node`` as an argument
# because the ``process_handler`` decorator with which the handler is decorated relies on this behavior.
report = handler(self, node)

if report is not None and not isinstance(report, ProcessHandlerReport):
Expand Down Expand Up @@ -346,6 +355,30 @@ def is_process_handler(cls, process_handler_name: Union[str, FunctionType]) -> b
def get_process_handlers(cls) -> List[FunctionType]:
return [method[1] for method in getmembers(cls) if cls.is_process_handler(method[1])]

def get_process_handlers_by_priority(self) -> List[Tuple[int, FunctionType]]:
"""Return list of process handlers where overrides from ``inputs.handler_overrides`` are taken into account."""
handlers = []

for handler in self.get_process_handlers():

overrides = self.ctx.handler_overrides.get(handler.__name__, {})

enabled = None
priority = None

if isinstance(overrides, bool):
enabled = overrides
else:
enabled = overrides.pop('enabled', None)
priority = overrides.pop('priority', None)

if enabled is False or not handler.enabled: # type: ignore[attr-defined]
continue

handlers.append((priority or handler.priority, handler)) # type: ignore[attr-defined]

return handlers

def on_terminated(self):
"""Clean the working directories of all child calculation jobs if `clean_workdir=True` in the inputs."""
super().on_terminated()
Expand Down
23 changes: 23 additions & 0 deletions docs/source/howto/workchains_restart.rst
Original file line number Diff line number Diff line change
Expand Up @@ -400,3 +400,26 @@ The base restart work chain will detect this exit code and abort the work chain,
└── ArithmeticAddCalculation<1952> Finished [410]
With these basic tools, a broad range of use-cases can be addressed while preventing a lot of boilerplate code.


Handler overrides
=================

It is possible to change the priority of handlers and enable/disable them without changing the source code of the work chain.
These properties of the handlers can be controlled through the ``handler_overrides`` input of the work chain.
This input takes a ``Dict`` node, that has the following form:

.. code-block:: python
handler_overrides = Dict({
'handler_negative_sum': {
'enabled': True,
'priority': 10000
}
})
As you can see, the keys are the name of the handler to affect and the value is a dictionary that can take two keys: ``enabled`` and ``priority``.
To enable or disable a handler, set ``enabled`` to ``True`` or ``False``, respectively.
The ``priority`` key takes an integer and determines the priority of the handler.
Note that the values of the ``handler_overrides`` are fully optional and will override the values configured by the process handler decorator in the source code of the work chain.
The changes also only affect the work chain instance that receives the ``handler_overrides`` input, all other instances of the work chain that will be launched will be unaffected.
24 changes: 23 additions & 1 deletion tests/engine/processes/workchains/test_restart.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,11 +44,33 @@ def test_is_process_handler():
assert not SomeWorkChain.is_process_handler('unexisting_method')


def test_get_process_handler():
def test_get_process_handlers():
"""Test the `BaseRestartWorkChain.get_process_handlers` class method."""
assert [handler.__name__ for handler in SomeWorkChain.get_process_handlers()] == ['handler_a', 'handler_b']


# yapf: disable
@pytest.mark.parametrize('inputs, priorities', (
({}, [100, 200]),
({'handler_overrides': {'handler_a': {'priority': 50}}}, [50, 100]),
({'handler_overrides': {'handler_a': {'enabled': False}}}, [100]),
({'handler_overrides': {'handler_a': False}}, [100]), # This notation is deprecated
))
# yapf: enable
@pytest.mark.usefixtures('aiida_profile_clean')
def test_get_process_handlers_by_priority(generate_work_chain, inputs, priorities):
"""Test the `BaseRestartWorkChain.get_process_handlers_by_priority` method."""
process = generate_work_chain(SomeWorkChain, inputs)
process.setup()
assert sorted([priority for priority, handler in process.get_process_handlers_by_priority()]) == priorities

# Verify the actual handlers on the class haven't been modified
assert getattr(SomeWorkChain, 'handler_a').priority == 200
assert getattr(SomeWorkChain, 'handler_b').priority == 100
assert getattr(SomeWorkChain, 'handler_a').enabled
assert getattr(SomeWorkChain, 'handler_b').enabled


@pytest.mark.requires_rmq
@pytest.mark.usefixtures('aiida_profile_clean')
def test_excepted_process(generate_work_chain, generate_calculation_node):
Expand Down

0 comments on commit 7b8c61d

Please sign in to comment.