diff --git a/aiida/backends/tests/engine/test_calcfunctions.py b/aiida/backends/tests/engine/test_calcfunctions.py index 0cac84bd55..d39ab1a88f 100644 --- a/aiida/backends/tests/engine/test_calcfunctions.py +++ b/aiida/backends/tests/engine/test_calcfunctions.py @@ -20,6 +20,23 @@ EXECUTION_COUNTER = 0 +@calcfunction +def add_calcfunction(data): + return Int(data.value + 1) + + +@calcfunction +def return_stored_calcfunction(): + return Int(2).store() + + +@calcfunction +def execution_counter_calcfunction(data): + global EXECUTION_COUNTER # pylint: disable=global-statement + EXECUTION_COUNTER += 1 + return Int(data.value + 1) + + class TestCalcFunction(AiidaTestCase): """Tests for calcfunctions. @@ -31,11 +48,7 @@ def setUp(self): self.assertIsNone(Process.current()) self.default_int = Int(256) - @calcfunction - def test_calcfunction(data): - return Int(data.value + 1) - - self.test_calcfunction = test_calcfunction + self.test_calcfunction = add_calcfunction def tearDown(self): super().tearDown() @@ -56,12 +69,8 @@ def test_calcfunction_links(self): def test_calcfunction_return_stored(self): """Verify that a calcfunction will raise when a stored node is returned.""" - @calcfunction - def test_calcfunction(): - return Int(2).store() - with self.assertRaises(ValueError): - test_calcfunction.run_get_node() + return_stored_calcfunction.run_get_node() def test_calcfunction_default_linkname(self): """Verify that a calcfunction that returns a single Data node gets a default link label.""" @@ -74,20 +83,14 @@ def test_calcfunction_default_linkname(self): def test_calcfunction_caching(self): """Verify that a calcfunction can be cached.""" - @calcfunction - def test_calcfunction(data): - global EXECUTION_COUNTER # pylint: disable=global-statement - EXECUTION_COUNTER += 1 - return Int(data.value + 1) - self.assertEqual(EXECUTION_COUNTER, 0) - _, original = test_calcfunction.run_get_node(Int(5)) + _, original = execution_counter_calcfunction.run_get_node(Int(5)) self.assertEqual(EXECUTION_COUNTER, 1) # Caching a CalcFunctionNode should be possible with enable_caching(CalcFunctionNode): input_node = Int(5) - result, cached = test_calcfunction.run_get_node(input_node) + result, cached = execution_counter_calcfunction.run_get_node(input_node) self.assertEqual(EXECUTION_COUNTER, 1) # Calculation function body should not have been executed self.assertTrue(result.is_stored) @@ -99,16 +102,21 @@ def test_calcfunction_caching_change_code(self): """Verify that changing the source codde of a calcfunction invalidates any existing cached nodes.""" result_original = self.test_calcfunction(self.default_int) - with enable_caching(CalcFunctionNode): - - @calcfunction - def test_calcfunction(data): - """This calcfunction has a different source code from the one setup in the setUp method.""" - return Int(data.value + 2) + # Intentionally using the same name, to check that caching anyway + # distinguishes between the calcfunctions. + @calcfunction + def add_calcfunction(data): # pylint: disable=redefined-outer-name + """This calcfunction has a different source code from the one created at the module level.""" + return Int(data.value + 2) - result_cached, cached = test_calcfunction.run_get_node(self.default_int) + with enable_caching(CalcFunctionNode): + result_cached, cached = add_calcfunction.run_get_node(self.default_int) self.assertNotEqual(result_original, result_cached) self.assertFalse(cached.is_created_from_cache) + # Test that the locally-created calcfunction can be cached in principle + result2_cached, cached2 = add_calcfunction.run_get_node(self.default_int) + self.assertNotEqual(result_original, result2_cached) + self.assertTrue(cached2.is_created_from_cache) def test_calcfunction_do_not_store_provenance(self): """Run the function without storing the provenance.""" diff --git a/aiida/backends/tests/engine/test_process.py b/aiida/backends/tests/engine/test_process.py index f16d6be3b2..9346ca984f 100644 --- a/aiida/backends/tests/engine/test_process.py +++ b/aiida/backends/tests/engine/test_process.py @@ -19,6 +19,7 @@ from aiida.common.lang import override from aiida.engine import ExitCode, ExitCodesNamespace, Process, run, run_get_pk, run_get_node from aiida.engine.processes.ports import PortNamespace +from aiida.manage.caching import enable_caching from aiida.plugins import CalculationFactory @@ -183,6 +184,42 @@ def test_exit_codes(self): with self.assertRaises(AttributeError): ArithmeticAddCalculation.get_exit_statuses(['NON_EXISTING_EXIT_CODE_LABEL']) + def test_exit_codes_invalidate_cache(self): + """ + Test that returning an exit code with 'invalidates_cache' set to ``True`` + indeed means that the ProcessNode will not be cached from. + """ + # Sanity check that caching works when the exit code is not returned. + with enable_caching(): + _, node1 = run_get_node(test_processes.InvalidateCaching, return_exit_code=orm.Bool(False)) + _, node2 = run_get_node(test_processes.InvalidateCaching, return_exit_code=orm.Bool(False)) + self.assertEqual(node1.get_extra('_aiida_hash'), node2.get_extra('_aiida_hash')) + self.assertIn('_aiida_cached_from', node2.extras) + + with enable_caching(): + _, node3 = run_get_node(test_processes.InvalidateCaching, return_exit_code=orm.Bool(True)) + _, node4 = run_get_node(test_processes.InvalidateCaching, return_exit_code=orm.Bool(True)) + self.assertEqual(node3.get_extra('_aiida_hash'), node4.get_extra('_aiida_hash')) + self.assertNotIn('_aiida_cached_from', node4.extras) + + def test_valid_cache_hook(self): + """ + Test that the is_valid_cache behavior can be specified from + the method in the Process sub-class. + """ + # Sanity check that caching works when the hook returns True. + with enable_caching(): + _, node1 = run_get_node(test_processes.IsValidCacheHook) + _, node2 = run_get_node(test_processes.IsValidCacheHook) + self.assertEqual(node1.get_extra('_aiida_hash'), node2.get_extra('_aiida_hash')) + self.assertIn('_aiida_cached_from', node2.extras) + + with enable_caching(): + _, node3 = run_get_node(test_processes.IsValidCacheHook, not_valid_cache=orm.Bool(True)) + _, node4 = run_get_node(test_processes.IsValidCacheHook, not_valid_cache=orm.Bool(True)) + self.assertEqual(node3.get_extra('_aiida_hash'), node4.get_extra('_aiida_hash')) + self.assertNotIn('_aiida_cached_from', node4.extras) + def test_process_type_with_entry_point(self): """ For a process with a registered entry point, the process_type will be its formatted entry point string diff --git a/aiida/backends/tests/utils/processes.py b/aiida/backends/tests/utils/processes.py index 9a3a29be93..7115b4fdc5 100644 --- a/aiida/backends/tests/utils/processes.py +++ b/aiida/backends/tests/utils/processes.py @@ -11,7 +11,7 @@ import plumpy -from aiida.orm import Data, WorkflowNode +from aiida.orm import Data, WorkflowNode, CalcJobNode, Bool from aiida.engine import Process @@ -80,3 +80,39 @@ def run(self): def next_step(self): pass + + +class InvalidateCaching(Process): + """A process which invalidates cache for some exit codes.""" + + _node_class = CalcJobNode + + @classmethod + def define(cls, spec): + super().define(spec) + spec.input('return_exit_code', valid_type=Bool) + spec.exit_code( + 123, 'GENERIC_EXIT_CODE', message='This process should not be used as cache.', invalidates_cache=True + ) + + def run(self): # pylint: disable=inconsistent-return-statements + if self.inputs.return_exit_code: + return self.exit_codes.GENERIC_EXIT_CODE # pylint: disable=no-member + + +class IsValidCacheHook(Process): + """A process which overrides the hook for checking if it is valid cache.""" + + _node_class = CalcJobNode + + @classmethod + def define(cls, spec): + super().define(spec) + spec.input('not_valid_cache', valid_type=Bool, default=lambda: Bool(False)) + + def run(self): + pass + + @classmethod + def is_valid_cache(cls, node): + return super().is_valid_cache(node) and not node.inputs.not_valid_cache.value diff --git a/aiida/engine/processes/exit_code.py b/aiida/engine/processes/exit_code.py index cdc16237e0..2178518be6 100644 --- a/aiida/engine/processes/exit_code.py +++ b/aiida/engine/processes/exit_code.py @@ -15,8 +15,8 @@ __all__ = ('ExitCode', 'ExitCodesNamespace') -ExitCode = namedtuple('ExitCode', 'status message') -ExitCode.__new__.__defaults__ = (0, None) +ExitCode = namedtuple('ExitCode', ['status', 'message', 'invalidates_cache']) +ExitCode.__new__.__defaults__ = (0, None, False) """ A namedtuple to define an exit code for a :class:`~aiida.engine.processes.process.Process`. @@ -29,6 +29,9 @@ :param message: optional message with more details about the failure mode :type message: str + +:param invalidates_cache: optional flag, indicating that a process should not be used in caching +:type invalidates_cache: bool """ diff --git a/aiida/engine/processes/process.py b/aiida/engine/processes/process.py index 0c877bc850..2fd6736e73 100644 --- a/aiida/engine/processes/process.py +++ b/aiida/engine/processes/process.py @@ -894,6 +894,22 @@ def _get_namespace_list(namespace=None, agglomerate=True): namespace_list.extend(['.'.join(split_ns[:i]) for i in range(1, len(split_ns) + 1)]) return namespace_list + @classmethod + def is_valid_cache(cls, node): + """Check if the given node can be cached from. + + .. warning :: When overriding this method, make sure to call + super().is_valid_cache(node) and respect its output. Otherwise, + the 'invalidates_cache' keyword on exit codes will not work. + + This method allows extending the behavior of `ProcessNode.is_valid_cache` + from `Process` sub-classes, for example in plug-ins. + """ + try: + return not cls.spec().exit_codes(node.exit_status).invalidates_cache + except ValueError: + return True + def get_query_string_from_process_type_string(process_type_string): # pylint: disable=invalid-name """ diff --git a/aiida/engine/processes/process_spec.py b/aiida/engine/processes/process_spec.py index a3d7ba7a91..da9d562303 100644 --- a/aiida/engine/processes/process_spec.py +++ b/aiida/engine/processes/process_spec.py @@ -48,13 +48,15 @@ def exit_codes(self): """ return self._exit_codes - def exit_code(self, status, label, message): + def exit_code(self, status, label, message, invalidates_cache=False): """ Add an exit code to the ProcessSpec :param status: the exit status integer :param label: a label by which the exit code can be addressed :param message: a more detailed description of the exit code + :param invalidates_cache: when set to `True`, a process exiting + with this exit code will not be considered for caching """ if not isinstance(status, int): raise TypeError('status should be of integer type and not of {}'.format(type(status))) @@ -68,7 +70,10 @@ def exit_code(self, status, label, message): if not isinstance(message, str): raise TypeError('message should be of basestring type and not of {}'.format(type(message))) - self._exit_codes[label] = ExitCode(status, message) + if not isinstance(invalidates_cache, bool): + raise TypeError('invalidates_cache should be of type bool and not of {}'.format(type(invalidates_cache))) + + self._exit_codes[label] = ExitCode(status, message, invalidates_cache=invalidates_cache) class CalcJobProcessSpec(ProcessSpec): diff --git a/aiida/orm/nodes/process/process.py b/aiida/orm/nodes/process/process.py index f18df995e6..e8cc039093 100644 --- a/aiida/orm/nodes/process/process.py +++ b/aiida/orm/nodes/process/process.py @@ -479,7 +479,23 @@ def is_valid_cache(self): :returns: True if this process node is valid to be used for caching, False otherwise """ - return super().is_valid_cache and self.is_finished + if not (super().is_valid_cache and self.is_finished): + return False + try: + process_class = self.process_class + except ValueError as exc: + self.logger.warning( + "Not considering {} for caching, '{!r}' when accessing its process class.".format(self, exc) + ) + return False + # For process functions, the `process_class` does not have an + # is_valid_cache attribute + try: + is_valid_cache_func = process_class.is_valid_cache + except AttributeError: + return True + + return is_valid_cache_func(self) def _get_objects_to_hash(self): """ diff --git a/docs/source/developer_guide/core/caching.rst b/docs/source/developer_guide/core/caching.rst index 916d8b3391..b8ec01d423 100644 --- a/docs/source/developer_guide/core/caching.rst +++ b/docs/source/developer_guide/core/caching.rst @@ -23,11 +23,19 @@ Below are some methods you can use to control how the hashes of calculation and Controlling caching ------------------- -There are two methods you can use to disable caching for particular nodes: +There are several methods you can use to disable caching for particular nodes: + +On the level of generic :class:`aiida.orm.nodes.Node`: * The :meth:`~aiida.orm.nodes.Node.is_valid_cache` property determines whether a particular node can be used as a cache. This is used for example to disable caching from failed calculations. * Node classes have a ``_cachable`` attribute, which can be set to ``False`` to completely switch off caching for nodes of that class. This avoids performing queries for the hash altogether. +On the level of :class:`aiida.engine.processes.process.Process` and :class:`aiida.orm.nodes.process.ProcessNode`: + +* The :meth:`ProcessNode.is_valid_cache ` calls :meth:`Process.is_valid_cache `, passing the node itself. This can be used in :class:`~aiida.engine.processes.process.Process` subclasses (e.g. in calculation plugins) to implement custom ways of invalidating the cache. +* The ``spec.exit_code`` has a keyword argument ``invalidates_cache``. If this is set to ``True``, returning that exit code means the process is no longer considered a valid cache. This is implemented in :meth:`Process.is_valid_cache `. + + The ``WorkflowNode`` example ............................ @@ -48,6 +56,3 @@ When modifying the hashing/caching behaviour of your classes, keep in mind that * False positives, where two different nodes get the same hash by mistake False negatives are **highly preferrable** because they only increase the runtime of your calculations, while false positives can lead to wrong results. - - -