Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Serialize the context of a WorkChain before persisting #1354

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 37 additions & 1 deletion aiida/orm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from aiida.common.pluginloader import BaseFactory
from aiida.common.utils import abstractclassmethod

__all__ = ['CalculationFactory', 'DataFactory', 'WorkflowFactory', 'load_node', 'load_workflow']
__all__ = ['CalculationFactory', 'DataFactory', 'WorkflowFactory', 'load_group', 'load_node', 'load_workflow']


def CalculationFactory(module, from_abstract=False):
Expand Down Expand Up @@ -138,6 +138,42 @@ def create_node_id_qb(node_id=None, pk=None, uuid=None,
return qb


def load_group(group_id=None, pk=None, uuid=None, query_with_dashes=True):
"""
Load a group by its pk or uuid

:param group_id: pk (integer) or uuid (string) of a group
:param pk: pk of a group
:param uuid: uuid of a group, or the beginning of the uuid
:param bool query_with_dashes: allow to query for a uuid with dashes (default=True)
:returns: the requested group if existing and unique
:raise InputValidationError: if none or more than one of the arguments are supplied
:raise TypeError: if the wrong types are provided
:raise NotExistent: if no matching Node is found.
:raise MultipleObjectsError: if more than one Node was found
"""
from aiida.orm import Group

kwargs = {
'node_id': group_id,
'pk': pk,
'uuid': uuid,
'parent_class': Group,
'query_with_dashes': query_with_dashes
}

qb = create_node_id_qb(**kwargs)
qb.add_projection('node', '*')
qb.limit(2)

try:
return qb.one()[0]
except MultipleObjectsError:
raise MultipleObjectsError('More than one group found. Provide longer starting pattern for uuid.')
except NotExistent:
raise NotExistent('No group was found')


def load_node(node_id=None, pk=None, uuid=None, parent_class=None, query_with_dashes=True):
"""
Returns an AiiDA node given its PK or UUID.
Expand Down
87 changes: 87 additions & 0 deletions aiida/utils/serialize.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
# -*- coding: utf-8 -*-
import collections
from ast import literal_eval
from aiida.common.extendeddicts import AttributeDict
from aiida.orm import Group, Node, load_group, load_node


_PREFIX_KEY_TUPLE = 'tuple():'
_PREFIX_VALUE_NODE = 'aiida_node:'
_PREFIX_VALUE_GROUP = 'aiida_group:'


def encode_key(key):
"""
Helper function for the serialize_data function which may need to serialize a
dictionary that uses tuples as keys. This function will encode the tuple into
a string such that it is JSON serializable

:param key: the key to encode
:return: the encoded key
"""
if isinstance(key, tuple):
return '{}{}'.format(_PREFIX_KEY_TUPLE, key)
else:
return key


def decode_key(key):
"""
Helper function for the deserialize_data function which can undo the key encoding
of tuple keys done by the encode_key function

:param key: the key to decode
:return: the decoded key
"""
if key.startswith(_PREFIX_KEY_TUPLE):
return literal_eval(key[len(_PREFIX_KEY_TUPLE):])
else:
return key


def serialize_data(data):
"""
Serialize a value or collection that may potentially contain AiiDA nodes, which
will be serialized to their UUID. Keys encountered in any mappings, such as a dictionary,
will also be encoded if necessary. An example is where tuples are used as keys in the
pseudo potential input dictionaries. These operations will ensure that the returned data is
JSON serializable.

:param data: a single value or collection
:return: the serialized data with the same internal structure
"""
if isinstance(data, Node):
return '{}{}'.format(_PREFIX_VALUE_NODE, data.uuid)
elif isinstance(data, Group):
return '{}{}'.format(_PREFIX_VALUE_GROUP, data.uuid)
elif isinstance(data, AttributeDict):
return AttributeDict({encode_key(key): serialize_data(value) for key, value in data.iteritems()})
elif isinstance(data, collections.Mapping):
return {encode_key(key): serialize_data(value) for key, value in data.iteritems()}
elif isinstance(data, collections.Sequence) and not isinstance(data, (str, unicode)):
return [serialize_data(value) for value in data]
else:
return data


def deserialize_data(data):
"""
Deserialize a single value or a collection that may contain serialized AiiDA nodes. This is
essentially the inverse operation of serialize_data which will reload node instances from
the serialized UUID data. Encoded tuples that are used as dictionary keys will be decoded.

:param data: serialized data
:return: the deserialized data with keys decoded and node instances loaded from UUID's
"""
if isinstance(data, AttributeDict):
return AttributeDict({decode_key(key): deserialize_data(value) for key, value in data.iteritems()})
elif isinstance(data, collections.Mapping):
return {decode_key(key): deserialize_data(value) for key, value in data.iteritems()}
elif isinstance(data, collections.Sequence) and not isinstance(data, (str, unicode)):
return [deserialize_data(value) for value in data]
elif isinstance(data, (str, unicode)) and data.startswith(_PREFIX_VALUE_NODE):
return load_node(uuid=data[len(_PREFIX_VALUE_NODE):])
elif isinstance(data, (str, unicode)) and data.startswith(_PREFIX_VALUE_GROUP):
return load_group(uuid=data[len(_PREFIX_VALUE_GROUP):])
else:
return data
56 changes: 13 additions & 43 deletions aiida/work/persistence.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import plum.persistence.pickle_persistence
from plum.process import Process
from aiida.common.lang import override
from aiida.utils.serialize import serialize_data, deserialize_data
from aiida.work.defaults import class_loader

import glob
Expand Down Expand Up @@ -397,9 +398,13 @@ def _load_checkpoint(self, pid):
def load_checkpoint_from_file_object(self, file_object):
cp = pickle.load(file_object)

inputs = cp[Process.BundleKeys.INPUTS.value]
inputs = cp[Process.BundleKeys.INPUTS_RAW.value]
if inputs:
cp[Process.BundleKeys.INPUTS.value] = self._load_nodes_from(inputs)
cp[Process.BundleKeys.INPUTS_RAW.value] = deserialize_data(inputs)

inputs = cp[Process.BundleKeys.INPUTS_PARSED.value]
if inputs:
cp[Process.BundleKeys.INPUTS_PARSED.value] = deserialize_data(inputs)

cp.set_class_loader(class_loader)
return cp
Expand All @@ -412,50 +417,15 @@ def get_checkpoint_state(self, pid):
def create_bundle(self, process):
bundle = Bundle()
process.save_instance_state(bundle)
inputs = bundle[Process.BundleKeys.INPUTS.value]
inputs = bundle[Process.BundleKeys.INPUTS_RAW.value]
if inputs:
bundle[Process.BundleKeys.INPUTS.value] = self._convert_to_ids(inputs)

return bundle

def _convert_to_ids(self, nodes):
from aiida.orm import Node

input_ids = {}
for label, node in nodes.iteritems():
if node is None:
continue
elif isinstance(node, Node):
if node.is_stored:
input_ids[label] = node.pk
else:
# Try using the UUID, but there's probably no chance of
# being abel to recover the node from this if not stored
# (for the time being)
input_ids[label] = node.uuid
elif isinstance(node, collections.Mapping):
input_ids[label] = self._convert_to_ids(node)

return input_ids

def _load_nodes_from(self, pks_mapping):
"""
Take a dictionary of of {label: pk} or nested dictionary i.e.
{label: {label: pk}} and convert to the equivalent dictionary but
with nodes instead of the ids.
bundle[Process.BundleKeys.INPUTS_RAW.value] = serialize_data(inputs)

:param pks_mapping: The dictionary of node pks.
:return: A dictionary with the loaded nodes.
"""
from aiida.orm import load_node
inputs = bundle[Process.BundleKeys.INPUTS_PARSED.value]
if inputs:
bundle[Process.BundleKeys.INPUTS_PARSED.value] = serialize_data(inputs)

nodes = {}
for label, pk in pks_mapping.iteritems():
if isinstance(pk, collections.Mapping):
nodes[label] = self._load_nodes_from(pk)
else:
nodes[label] = load_node(pk=pk)
return nodes
return bundle

def _clear(self, fileobj):
"""
Expand Down
9 changes: 6 additions & 3 deletions aiida/work/workchain.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@
from aiida.common.lang import override
from aiida.common.utils import get_class_string, get_object_string, \
get_object_from_string
from aiida.orm import load_node, load_workflow
from aiida.orm import load_node, load_workflow, Node
from aiida.utils.serialize import serialize_data, deserialize_data
from plum.wait_ons import Checkpoint, WaitOnAll, WaitOnProcess
from plum.wait import WaitOn
from plum.persistence.bundle import Bundle
Expand Down Expand Up @@ -124,7 +125,9 @@ def setdefault(self, key, default=None):

def save_instance_state(self, out_state):
for k, v in self._content.iteritems():
out_state[k] = v
if isinstance(v, Node) and not v.is_stored:
v.store()
out_state[k] = serialize_data(v)

def __init__(self):
super(WorkChain, self).__init__()
Expand Down Expand Up @@ -283,7 +286,7 @@ def on_create(self, pid, inputs, saved_state):
self._context = self.Context()
else:
# Recreate the context
self._context = self.Context(saved_state[self._CONTEXT])
self._context = self.Context(deserialize_data(saved_state[self._CONTEXT]))

# Recreate the stepper
if self._STEPPER_STATE in saved_state:
Expand Down
2 changes: 1 addition & 1 deletion docs/requirements_for_rtd.txt
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ passlib==1.7.1
pathlib2==2.3.0
pgtest==1.1.0
pip==9.0.1
plumpy==0.7.11
plumpy==0.7.12
portalocker==1.1.0
psutil==5.4.0
pycrypto==2.6.1
Expand Down
2 changes: 1 addition & 1 deletion setup_requirements.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
'psutil==5.4.0',
'meld3==1.0.0',
'numpy==1.12.0',
'plumpy==0.7.11',
'plumpy==0.7.12',
'portalocker==1.1.0',
'SQLAlchemy==1.0.19', # upgrade to SQLalchemy 1.1.5 does break tests, see #465
'SQLAlchemy-Utils==0.33.0',
Expand Down