From 35fc3ae5790023022d4d78cf2fe7274a72b590d2 Mon Sep 17 00:00:00 2001 From: Sebastiaan Huber Date: Wed, 18 Oct 2023 13:30:21 +0200 Subject: [PATCH] ORM: Add `NodeCaching.CACHED_FROM_KEY` for `_aiida_cached_from` constant The `_aiida_cached_from` key is used to store the UUID, of the node from which a node was cached, into the extras. It appeared in a few places as a string literal. It is now added as the `CACHED_FROM_KEY` class variable of `NodeCaching`. --- .github/system_tests/test_daemon.py | 5 +++-- aiida/orm/nodes/caching.py | 3 ++- aiida/orm/nodes/node.py | 2 +- tests/engine/test_process.py | 9 +++++---- 4 files changed, 11 insertions(+), 8 deletions(-) diff --git a/.github/system_tests/test_daemon.py b/.github/system_tests/test_daemon.py index fce3a827cc..c3f8c14ecb 100644 --- a/.github/system_tests/test_daemon.py +++ b/.github/system_tests/test_daemon.py @@ -37,6 +37,7 @@ from aiida.engine.processes import CalcJob, Process from aiida.manage.caching import enable_caching from aiida.orm import CalcJobNode, Dict, Int, List, Str, load_code, load_node +from aiida.orm.nodes.caching import NodeCaching from aiida.plugins import CalculationFactory, WorkflowFactory from aiida.workflows.arithmetic.add_multiply import add, add_multiply from tests.utils.memory import get_instances # pylint: disable=import-error @@ -207,14 +208,14 @@ def validate_cached(cached_calcs): print_report(calc.pk) valid = False - if '_aiida_cached_from' not in calc.base.extras or calc.base.caching.get_hash( + if NodeCaching.CACHED_FROM_KEY not in calc.base.extras or calc.base.caching.get_hash( ) != calc.base.extras.get('_aiida_hash'): print(f'Cached calculation<{calc.pk}> has invalid hash') print_report(calc.pk) valid = False if isinstance(calc, CalcJobNode): - original_calc = load_node(calc.base.extras.get('_aiida_cached_from')) + original_calc = load_node(calc.base.extras.get(NodeCaching.CACHED_FROM_KEY)) files_original = original_calc.base.repository.list_object_names() files_cached = calc.base.repository.list_object_names() diff --git a/aiida/orm/nodes/caching.py b/aiida/orm/nodes/caching.py index 9c3c4db07a..1fe342e5af 100644 --- a/aiida/orm/nodes/caching.py +++ b/aiida/orm/nodes/caching.py @@ -18,6 +18,7 @@ class NodeCaching: # The keys in the extras that are used to store the hash of the node and whether it should be used in caching. _HASH_EXTRA_KEY: str = '_aiida_hash' _VALID_CACHE_KEY: str = '_aiida_valid_cache' + CACHED_FROM_KEY: str = '_aiida_cached_from' def __init__(self, node: 'Node') -> None: """Initialize the caching interface.""" @@ -82,7 +83,7 @@ def get_cache_source(self) -> str | None: :return: source node UUID or None """ - return self._node.base.extras.get('_aiida_cached_from', None) + return self._node.base.extras.get(self.CACHED_FROM_KEY, None) @property def is_created_from_cache(self) -> bool: diff --git a/aiida/orm/nodes/node.py b/aiida/orm/nodes/node.py index a5d7400607..06a2680d9d 100644 --- a/aiida/orm/nodes/node.py +++ b/aiida/orm/nodes/node.py @@ -539,7 +539,7 @@ def _store_from_cache(self, cache_node: 'Node') -> None: self._store(clean=False) self._add_outputs_from_cache(cache_node) - self.base.extras.set('_aiida_cached_from', cache_node.uuid) + self.base.extras.set(self.base.caching.CACHED_FROM_KEY, cache_node.uuid) def _add_outputs_from_cache(self, cache_node: 'Node') -> None: """Replicate the output links and nodes from the cached node onto this node.""" diff --git a/tests/engine/test_process.py b/tests/engine/test_process.py index 95d40827a7..2604748646 100644 --- a/tests/engine/test_process.py +++ b/tests/engine/test_process.py @@ -20,6 +20,7 @@ from aiida.engine import ExitCode, ExitCodesNamespace, Process, run, run_get_node, run_get_pk from aiida.engine.processes.ports import PortNamespace from aiida.manage.caching import enable_caching +from aiida.orm.nodes.caching import NodeCaching from aiida.plugins import CalculationFactory from tests.utils import processes as test_processes @@ -210,13 +211,13 @@ def test_exit_codes_invalidate_cache(self): _, node1 = run_get_node(test_processes.InvalidateCaching, return_exit_code=orm.Bool(False)) _, node2 = run_get_node(test_processes.InvalidateCaching, return_exit_code=orm.Bool(False)) assert node1.base.extras.get('_aiida_hash') == node2.base.extras.get('_aiida_hash') - assert '_aiida_cached_from' in node2.base.extras + assert NodeCaching.CACHED_FROM_KEY in node2.base.extras with enable_caching(): _, node3 = run_get_node(test_processes.InvalidateCaching, return_exit_code=orm.Bool(True)) _, node4 = run_get_node(test_processes.InvalidateCaching, return_exit_code=orm.Bool(True)) assert node3.base.extras.get('_aiida_hash') == node4.base.extras.get('_aiida_hash') - assert '_aiida_cached_from' not in node4.base.extras + assert NodeCaching.CACHED_FROM_KEY not in node4.base.extras def test_valid_cache_hook(self): """ @@ -228,13 +229,13 @@ def test_valid_cache_hook(self): _, node1 = run_get_node(test_processes.IsValidCacheHook) _, node2 = run_get_node(test_processes.IsValidCacheHook) assert node1.base.extras.get('_aiida_hash') == node2.base.extras.get('_aiida_hash') - assert '_aiida_cached_from' in node2.base.extras + assert NodeCaching.CACHED_FROM_KEY in node2.base.extras with enable_caching(): _, node3 = run_get_node(test_processes.IsValidCacheHook, not_valid_cache=orm.Bool(True)) _, node4 = run_get_node(test_processes.IsValidCacheHook, not_valid_cache=orm.Bool(True)) assert node3.base.extras.get('_aiida_hash') == node4.base.extras.get('_aiida_hash') - assert '_aiida_cached_from' not in node4.base.extras + assert NodeCaching.CACHED_FROM_KEY not in node4.base.extras def test_process_type_with_entry_point(self): """For a process with a registered entry point, the process_type will be its formatted entry point string."""