Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add hook and exit_code keyword for more nuanced cache validation. #3637

Merged
merged 2 commits into from
Dec 12, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 33 additions & 25 deletions aiida/backends/tests/engine/test_calcfunctions.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,23 @@
EXECUTION_COUNTER = 0


@calcfunction
def add_calcfunction(data):
return Int(data.value + 1)


@calcfunction
def return_stored_calcfunction():
return Int(2).store()


@calcfunction
def execution_counter_calcfunction(data):
global EXECUTION_COUNTER # pylint: disable=global-statement
EXECUTION_COUNTER += 1
return Int(data.value + 1)


class TestCalcFunction(AiidaTestCase):
"""Tests for calcfunctions.

Expand All @@ -31,11 +48,7 @@ def setUp(self):
self.assertIsNone(Process.current())
self.default_int = Int(256)

@calcfunction
def test_calcfunction(data):
return Int(data.value + 1)

self.test_calcfunction = test_calcfunction
self.test_calcfunction = add_calcfunction

def tearDown(self):
super().tearDown()
Expand All @@ -56,12 +69,8 @@ def test_calcfunction_links(self):
def test_calcfunction_return_stored(self):
"""Verify that a calcfunction will raise when a stored node is returned."""

@calcfunction
def test_calcfunction():
return Int(2).store()

with self.assertRaises(ValueError):
test_calcfunction.run_get_node()
return_stored_calcfunction.run_get_node()

def test_calcfunction_default_linkname(self):
"""Verify that a calcfunction that returns a single Data node gets a default link label."""
Expand All @@ -74,20 +83,14 @@ def test_calcfunction_default_linkname(self):
def test_calcfunction_caching(self):
"""Verify that a calcfunction can be cached."""

@calcfunction
def test_calcfunction(data):
global EXECUTION_COUNTER # pylint: disable=global-statement
EXECUTION_COUNTER += 1
return Int(data.value + 1)

self.assertEqual(EXECUTION_COUNTER, 0)
_, original = test_calcfunction.run_get_node(Int(5))
_, original = execution_counter_calcfunction.run_get_node(Int(5))
self.assertEqual(EXECUTION_COUNTER, 1)

# Caching a CalcFunctionNode should be possible
with enable_caching(CalcFunctionNode):
input_node = Int(5)
result, cached = test_calcfunction.run_get_node(input_node)
result, cached = execution_counter_calcfunction.run_get_node(input_node)

self.assertEqual(EXECUTION_COUNTER, 1) # Calculation function body should not have been executed
self.assertTrue(result.is_stored)
Expand All @@ -99,16 +102,21 @@ def test_calcfunction_caching_change_code(self):
"""Verify that changing the source codde of a calcfunction invalidates any existing cached nodes."""
result_original = self.test_calcfunction(self.default_int)

with enable_caching(CalcFunctionNode):

@calcfunction
def test_calcfunction(data):
"""This calcfunction has a different source code from the one setup in the setUp method."""
return Int(data.value + 2)
# Intentionally using the same name, to check that caching anyway
# distinguishes between the calcfunctions.
@calcfunction
def add_calcfunction(data): # pylint: disable=redefined-outer-name
"""This calcfunction has a different source code from the one created at the module level."""
return Int(data.value + 2)

result_cached, cached = test_calcfunction.run_get_node(self.default_int)
with enable_caching(CalcFunctionNode):
result_cached, cached = add_calcfunction.run_get_node(self.default_int)
self.assertNotEqual(result_original, result_cached)
self.assertFalse(cached.is_created_from_cache)
# Test that the locally-created calcfunction can be cached in principle
result2_cached, cached2 = add_calcfunction.run_get_node(self.default_int)
self.assertNotEqual(result_original, result2_cached)
self.assertTrue(cached2.is_created_from_cache)

def test_calcfunction_do_not_store_provenance(self):
"""Run the function without storing the provenance."""
Expand Down
37 changes: 37 additions & 0 deletions aiida/backends/tests/engine/test_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from aiida.common.lang import override
from aiida.engine import ExitCode, ExitCodesNamespace, Process, run, run_get_pk, run_get_node
from aiida.engine.processes.ports import PortNamespace
from aiida.manage.caching import enable_caching
from aiida.plugins import CalculationFactory


Expand Down Expand Up @@ -183,6 +184,42 @@ def test_exit_codes(self):
with self.assertRaises(AttributeError):
ArithmeticAddCalculation.get_exit_statuses(['NON_EXISTING_EXIT_CODE_LABEL'])

def test_exit_codes_invalidate_cache(self):
"""
Test that returning an exit code with 'invalidates_cache' set to ``True``
indeed means that the ProcessNode will not be cached from.
"""
# Sanity check that caching works when the exit code is not returned.
with enable_caching():
_, node1 = run_get_node(test_processes.InvalidateCaching, return_exit_code=orm.Bool(False))
_, node2 = run_get_node(test_processes.InvalidateCaching, return_exit_code=orm.Bool(False))
self.assertEqual(node1.get_extra('_aiida_hash'), node2.get_extra('_aiida_hash'))
self.assertIn('_aiida_cached_from', node2.extras)

with enable_caching():
_, node3 = run_get_node(test_processes.InvalidateCaching, return_exit_code=orm.Bool(True))
_, node4 = run_get_node(test_processes.InvalidateCaching, return_exit_code=orm.Bool(True))
self.assertEqual(node3.get_extra('_aiida_hash'), node4.get_extra('_aiida_hash'))
self.assertNotIn('_aiida_cached_from', node4.extras)

def test_valid_cache_hook(self):
"""
Test that the is_valid_cache behavior can be specified from
the method in the Process sub-class.
"""
# Sanity check that caching works when the hook returns True.
with enable_caching():
_, node1 = run_get_node(test_processes.IsValidCacheHook)
_, node2 = run_get_node(test_processes.IsValidCacheHook)
self.assertEqual(node1.get_extra('_aiida_hash'), node2.get_extra('_aiida_hash'))
self.assertIn('_aiida_cached_from', node2.extras)

with enable_caching():
_, node3 = run_get_node(test_processes.IsValidCacheHook, not_valid_cache=orm.Bool(True))
_, node4 = run_get_node(test_processes.IsValidCacheHook, not_valid_cache=orm.Bool(True))
self.assertEqual(node3.get_extra('_aiida_hash'), node4.get_extra('_aiida_hash'))
self.assertNotIn('_aiida_cached_from', node4.extras)

def test_process_type_with_entry_point(self):
"""
For a process with a registered entry point, the process_type will be its formatted entry point string
Expand Down
38 changes: 37 additions & 1 deletion aiida/backends/tests/utils/processes.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

import plumpy

from aiida.orm import Data, WorkflowNode
from aiida.orm import Data, WorkflowNode, CalcJobNode, Bool
from aiida.engine import Process


Expand Down Expand Up @@ -80,3 +80,39 @@ def run(self):

def next_step(self):
pass


class InvalidateCaching(Process):
"""A process which invalidates cache for some exit codes."""

_node_class = CalcJobNode

@classmethod
def define(cls, spec):
super().define(spec)
spec.input('return_exit_code', valid_type=Bool)
spec.exit_code(
123, 'GENERIC_EXIT_CODE', message='This process should not be used as cache.', invalidates_cache=True
)

def run(self): # pylint: disable=inconsistent-return-statements
if self.inputs.return_exit_code:
return self.exit_codes.GENERIC_EXIT_CODE # pylint: disable=no-member


class IsValidCacheHook(Process):
"""A process which overrides the hook for checking if it is valid cache."""

_node_class = CalcJobNode

@classmethod
def define(cls, spec):
super().define(spec)
spec.input('not_valid_cache', valid_type=Bool, default=lambda: Bool(False))

def run(self):
pass

@classmethod
def is_valid_cache(cls, node):
return super().is_valid_cache(node) and not node.inputs.not_valid_cache.value
7 changes: 5 additions & 2 deletions aiida/engine/processes/exit_code.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@

__all__ = ('ExitCode', 'ExitCodesNamespace')

ExitCode = namedtuple('ExitCode', 'status message')
ExitCode.__new__.__defaults__ = (0, None)
ExitCode = namedtuple('ExitCode', ['status', 'message', 'invalidates_cache'])
sphuber marked this conversation as resolved.
Show resolved Hide resolved
ExitCode.__new__.__defaults__ = (0, None, False)
sphuber marked this conversation as resolved.
Show resolved Hide resolved
"""
A namedtuple to define an exit code for a :class:`~aiida.engine.processes.process.Process`.

Expand All @@ -29,6 +29,9 @@

:param message: optional message with more details about the failure mode
:type message: str

:param invalidates_cache: optional flag, indicating that a process should not be used in caching
:type invalidates_cache: bool
"""


Expand Down
16 changes: 16 additions & 0 deletions aiida/engine/processes/process.py
Original file line number Diff line number Diff line change
Expand Up @@ -894,6 +894,22 @@ def _get_namespace_list(namespace=None, agglomerate=True):
namespace_list.extend(['.'.join(split_ns[:i]) for i in range(1, len(split_ns) + 1)])
return namespace_list

@classmethod
def is_valid_cache(cls, node):
"""Check if the given node can be cached from.

.. warning :: When overriding this method, make sure to call
super().is_valid_cache(node) and respect its output. Otherwise,
the 'invalidates_cache' keyword on exit codes will not work.

This method allows extending the behavior of `ProcessNode.is_valid_cache`
from `Process` sub-classes, for example in plug-ins.
"""
try:
return not cls.spec().exit_codes(node.exit_status).invalidates_cache
except ValueError:
return True


def get_query_string_from_process_type_string(process_type_string): # pylint: disable=invalid-name
"""
Expand Down
9 changes: 7 additions & 2 deletions aiida/engine/processes/process_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,13 +48,15 @@ def exit_codes(self):
"""
return self._exit_codes

def exit_code(self, status, label, message):
def exit_code(self, status, label, message, invalidates_cache=False):
"""
Add an exit code to the ProcessSpec

:param status: the exit status integer
:param label: a label by which the exit code can be addressed
:param message: a more detailed description of the exit code
:param invalidates_cache: when set to `True`, a process exiting
with this exit code will not be considered for caching
"""
if not isinstance(status, int):
raise TypeError('status should be of integer type and not of {}'.format(type(status)))
Expand All @@ -68,7 +70,10 @@ def exit_code(self, status, label, message):
if not isinstance(message, str):
raise TypeError('message should be of basestring type and not of {}'.format(type(message)))

self._exit_codes[label] = ExitCode(status, message)
if not isinstance(invalidates_cache, bool):
raise TypeError('invalidates_cache should be of type bool and not of {}'.format(type(invalidates_cache)))

self._exit_codes[label] = ExitCode(status, message, invalidates_cache=invalidates_cache)


class CalcJobProcessSpec(ProcessSpec):
Expand Down
18 changes: 17 additions & 1 deletion aiida/orm/nodes/process/process.py
Original file line number Diff line number Diff line change
Expand Up @@ -479,7 +479,23 @@ def is_valid_cache(self):

:returns: True if this process node is valid to be used for caching, False otherwise
"""
return super().is_valid_cache and self.is_finished
if not (super().is_valid_cache and self.is_finished):
return False
sphuber marked this conversation as resolved.
Show resolved Hide resolved
try:
process_class = self.process_class
except ValueError as exc:
self.logger.warning(
"Not considering {} for caching, '{!r}' when accessing its process class.".format(self, exc)
)
return False
# For process functions, the `process_class` does not have an
# is_valid_cache attribute
try:
is_valid_cache_func = process_class.is_valid_cache
except AttributeError:
return True

return is_valid_cache_func(self)

def _get_objects_to_hash(self):
"""
Expand Down
13 changes: 9 additions & 4 deletions docs/source/developer_guide/core/caching.rst
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,19 @@ Below are some methods you can use to control how the hashes of calculation and
Controlling caching
-------------------

There are two methods you can use to disable caching for particular nodes:
There are several methods you can use to disable caching for particular nodes:

On the level of generic :class:`aiida.orm.nodes.Node`:

* The :meth:`~aiida.orm.nodes.Node.is_valid_cache` property determines whether a particular node can be used as a cache. This is used for example to disable caching from failed calculations.
* Node classes have a ``_cachable`` attribute, which can be set to ``False`` to completely switch off caching for nodes of that class. This avoids performing queries for the hash altogether.

On the level of :class:`aiida.engine.processes.process.Process` and :class:`aiida.orm.nodes.process.ProcessNode`:

* The :meth:`ProcessNode.is_valid_cache <aiida.orm.nodes.process.ProcessNode.is_valid_cache>` calls :meth:`Process.is_valid_cache <aiida.engine.processes.process.Process.is_valid_cache>`, passing the node itself. This can be used in :class:`~aiida.engine.processes.process.Process` subclasses (e.g. in calculation plugins) to implement custom ways of invalidating the cache.
sphuber marked this conversation as resolved.
Show resolved Hide resolved
* The ``spec.exit_code`` has a keyword argument ``invalidates_cache``. If this is set to ``True``, returning that exit code means the process is no longer considered a valid cache. This is implemented in :meth:`Process.is_valid_cache <aiida.engine.processes.process.Process.is_valid_cache>`.


The ``WorkflowNode`` example
............................

Expand All @@ -48,6 +56,3 @@ When modifying the hashing/caching behaviour of your classes, keep in mind that
* False positives, where two different nodes get the same hash by mistake

False negatives are **highly preferrable** because they only increase the runtime of your calculations, while false positives can lead to wrong results.