Skip to content

Commit

Permalink
ORM: Add NodeCaching.CACHED_FROM_KEY for _aiida_cached_from constant
Browse files Browse the repository at this point in the history
The `_aiida_cached_from` key is used to store the UUID, of the node from
which a node was cached, into the extras. It appeared in a few places as
a string literal. It is now added as the `CACHED_FROM_KEY` class
variable of `NodeCaching`.
  • Loading branch information
sphuber committed Oct 21, 2023
1 parent b0546e8 commit 35fc3ae
Show file tree
Hide file tree
Showing 4 changed files with 11 additions and 8 deletions.
5 changes: 3 additions & 2 deletions .github/system_tests/test_daemon.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
from aiida.engine.processes import CalcJob, Process
from aiida.manage.caching import enable_caching
from aiida.orm import CalcJobNode, Dict, Int, List, Str, load_code, load_node
from aiida.orm.nodes.caching import NodeCaching
from aiida.plugins import CalculationFactory, WorkflowFactory
from aiida.workflows.arithmetic.add_multiply import add, add_multiply
from tests.utils.memory import get_instances # pylint: disable=import-error
Expand Down Expand Up @@ -207,14 +208,14 @@ def validate_cached(cached_calcs):
print_report(calc.pk)
valid = False

if '_aiida_cached_from' not in calc.base.extras or calc.base.caching.get_hash(
if NodeCaching.CACHED_FROM_KEY not in calc.base.extras or calc.base.caching.get_hash(
) != calc.base.extras.get('_aiida_hash'):
print(f'Cached calculation<{calc.pk}> has invalid hash')
print_report(calc.pk)
valid = False

if isinstance(calc, CalcJobNode):
original_calc = load_node(calc.base.extras.get('_aiida_cached_from'))
original_calc = load_node(calc.base.extras.get(NodeCaching.CACHED_FROM_KEY))
files_original = original_calc.base.repository.list_object_names()
files_cached = calc.base.repository.list_object_names()

Expand Down
3 changes: 2 additions & 1 deletion aiida/orm/nodes/caching.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ class NodeCaching:
# The keys in the extras that are used to store the hash of the node and whether it should be used in caching.
_HASH_EXTRA_KEY: str = '_aiida_hash'
_VALID_CACHE_KEY: str = '_aiida_valid_cache'
CACHED_FROM_KEY: str = '_aiida_cached_from'

def __init__(self, node: 'Node') -> None:
"""Initialize the caching interface."""
Expand Down Expand Up @@ -82,7 +83,7 @@ def get_cache_source(self) -> str | None:
:return: source node UUID or None
"""
return self._node.base.extras.get('_aiida_cached_from', None)
return self._node.base.extras.get(self.CACHED_FROM_KEY, None)

@property
def is_created_from_cache(self) -> bool:
Expand Down
2 changes: 1 addition & 1 deletion aiida/orm/nodes/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -539,7 +539,7 @@ def _store_from_cache(self, cache_node: 'Node') -> None:

self._store(clean=False)
self._add_outputs_from_cache(cache_node)
self.base.extras.set('_aiida_cached_from', cache_node.uuid)
self.base.extras.set(self.base.caching.CACHED_FROM_KEY, cache_node.uuid)

def _add_outputs_from_cache(self, cache_node: 'Node') -> None:
"""Replicate the output links and nodes from the cached node onto this node."""
Expand Down
9 changes: 5 additions & 4 deletions tests/engine/test_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from aiida.engine import ExitCode, ExitCodesNamespace, Process, run, run_get_node, run_get_pk
from aiida.engine.processes.ports import PortNamespace
from aiida.manage.caching import enable_caching
from aiida.orm.nodes.caching import NodeCaching
from aiida.plugins import CalculationFactory
from tests.utils import processes as test_processes

Expand Down Expand Up @@ -210,13 +211,13 @@ def test_exit_codes_invalidate_cache(self):
_, node1 = run_get_node(test_processes.InvalidateCaching, return_exit_code=orm.Bool(False))
_, node2 = run_get_node(test_processes.InvalidateCaching, return_exit_code=orm.Bool(False))
assert node1.base.extras.get('_aiida_hash') == node2.base.extras.get('_aiida_hash')
assert '_aiida_cached_from' in node2.base.extras
assert NodeCaching.CACHED_FROM_KEY in node2.base.extras

with enable_caching():
_, node3 = run_get_node(test_processes.InvalidateCaching, return_exit_code=orm.Bool(True))
_, node4 = run_get_node(test_processes.InvalidateCaching, return_exit_code=orm.Bool(True))
assert node3.base.extras.get('_aiida_hash') == node4.base.extras.get('_aiida_hash')
assert '_aiida_cached_from' not in node4.base.extras
assert NodeCaching.CACHED_FROM_KEY not in node4.base.extras

def test_valid_cache_hook(self):
"""
Expand All @@ -228,13 +229,13 @@ def test_valid_cache_hook(self):
_, node1 = run_get_node(test_processes.IsValidCacheHook)
_, node2 = run_get_node(test_processes.IsValidCacheHook)
assert node1.base.extras.get('_aiida_hash') == node2.base.extras.get('_aiida_hash')
assert '_aiida_cached_from' in node2.base.extras
assert NodeCaching.CACHED_FROM_KEY in node2.base.extras

with enable_caching():
_, node3 = run_get_node(test_processes.IsValidCacheHook, not_valid_cache=orm.Bool(True))
_, node4 = run_get_node(test_processes.IsValidCacheHook, not_valid_cache=orm.Bool(True))
assert node3.base.extras.get('_aiida_hash') == node4.base.extras.get('_aiida_hash')
assert '_aiida_cached_from' not in node4.base.extras
assert NodeCaching.CACHED_FROM_KEY not in node4.base.extras

def test_process_type_with_entry_point(self):
"""For a process with a registered entry point, the process_type will be its formatted entry point string."""
Expand Down

0 comments on commit 35fc3ae

Please sign in to comment.