diff --git a/aiida/orm/nodes/data/msonable.py b/aiida/orm/nodes/data/msonable.py index cb2853d7c0..359e7ea1a5 100644 --- a/aiida/orm/nodes/data/msonable.py +++ b/aiida/orm/nodes/data/msonable.py @@ -1,8 +1,9 @@ # -*- coding: utf-8 -*- """Data plugin for classes that implement the ``MSONable`` class of the ``monty`` library.""" import importlib +import json -from monty.json import MSONable +from monty.json import MSONable, MontyDecoder, MontyEncoder from aiida.orm import Data @@ -44,7 +45,37 @@ def __init__(self, obj, *args, **kwargs): super().__init__(*args, **kwargs) self._obj = obj - self.set_attribute_many(obj.as_dict()) + + # Serialize the object by calling ``as_dict`` and performing a roundtrip through JSON encoding using the + # ``MontyEncode`` for the encoding to string part. This is necessary to recursively serialize objects that can + # not be written to JSON, such as ``datetime`` objects and ``numpy`` arrays. + serialized = json.loads(json.dumps(obj.as_dict(), cls=MontyEncoder), parse_constant=lambda x: x) + + # Then we apply our own custom serializer that serializes the float constants infinity and nan to a string value + # which is necessary because the serializer of the ``json`` standard module deserializes to the Python values + # that can not be written to JSON. + self.set_attribute_many(serialized) + + @classmethod + def _deserialize_float_constants(cls, data): + """Deserialize the contents of a dictionary ``data`` deserializing infinity and NaN string constants. + + The ``data`` dictionary is recursively checked for the ``Infinity``, ``-Infinity`` and ``NaN`` strings, which + are the Javascript string equivalents to the Python ``float('inf')``, ``-float('inf')`` and ``float('nan')`` + float constants. If one of the strings is encountered, the Python float constant is returned and otherwise the + original value is returned. + """ + if isinstance(data, dict): + return {k: cls._deserialize_float_constants(v) for k, v in data.items()} + if isinstance(data, list): + return [cls._deserialize_float_constants(v) for v in data] + if data == 'Infinity': + return float('inf') + if data == '-Infinity': + return -float('inf') + if data == 'NaN': + return float('nan') + return data def _get_object(self): """Return the cached wrapped MSONable object. @@ -72,7 +103,24 @@ def _get_object(self): f'the objects module `{module_name}` does not contain the class `{class_name}`.' ) from exc - self._obj = cls.from_dict(attributes) + # First we need to deserialize any infinity or nan float string markers that were serialized in the + # constructor of this node when it was created. There the decoding step in the JSON roundtrip defined a + # pass-through for the ``parse_constant`` argument, which means that the serialized versions of the float + # constants (i.e. the strings ``Infinity`` etc.) are not deserialized in the Python float constants. Here we + # need to first explicit deserialize them. One would think that we could simply let the ``json.loads`` in + # the following step take care of this, however, since the attributes would first be serialized by the + # ``json.dumps`` call, the string placeholders would be dumped again to an actual string, which would then + # no longer be recognized by ``json.loads`` as the Javascript notation of the float constants and so it will + # leave them as separate strings. + deserialized = self._deserialize_float_constants(attributes) + + # As a second step, we perform a JSON-serialization roundtrip using the ``MontyDecoder`` for deserialization + # which is necessary to recursively deserialize objects such as ``datetime`` and ``numpy`` arrays. + deserialized = json.loads(json.dumps(deserialized), cls=MontyDecoder) + + # Finally, reconstruct the original ``MSONable`` class from the fully deserialized data. + self._obj = cls.from_dict(deserialized) + return self._obj @property diff --git a/tests/orm/data/test_msonable.py b/tests/orm/data/test_msonable.py index 180de57ce3..8b3d508a13 100644 --- a/tests/orm/data/test_msonable.py +++ b/tests/orm/data/test_msonable.py @@ -1,6 +1,10 @@ # -*- coding: utf-8 -*- """Tests for the :class:`aiida.orm.nodes.data.msonable.MsonableData` data type.""" +import datetime +import math + from monty.json import MSONable +import numpy import pymatgen import pytest @@ -31,7 +35,7 @@ def as_dict(self): @classmethod def from_dict(cls, d): """Reconstruct an instance from a serialized version.""" - return cls(d['data']) + return cls(**d) def test_construct(): @@ -104,7 +108,7 @@ def test_load(): @pytest.mark.usefixtures('clear_database_before_test') def test_obj(): """Test the ``MsonableData.obj`` property.""" - data = {'a': 1} + data = [1, float('inf'), float('-inf'), float('nan'), numpy.arange(10), datetime.datetime.now()] obj = MsonableClass(data) node = MsonableData(obj) node.store() @@ -114,7 +118,19 @@ def test_obj(): loaded = load_node(node.pk) assert isinstance(node.obj, MsonableClass) - assert loaded.obj.data == data + + for left, right in zip(loaded.obj.data, data): + + # Need this explicit case to compare NaN because of the peculiarity in Python where ``float(nan) != float(nan)`` + if isinstance(left, float) and math.isnan(left): + assert math.isnan(right) + continue + + try: + # This is needed to match numpy arrays + assert (left == right).all() + except AttributeError: + assert left == right @pytest.mark.usefixtures('clear_database_before_test')