Skip to content

Commit

Permalink
Add support for float constants infinity and nan
Browse files Browse the repository at this point in the history
  • Loading branch information
sphuber committed Jul 21, 2021
1 parent 2baa9db commit 385b4a4
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 6 deletions.
54 changes: 51 additions & 3 deletions aiida/orm/nodes/data/msonable.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
# -*- coding: utf-8 -*-
"""Data plugin for classes that implement the ``MSONable`` class of the ``monty`` library."""
import importlib
import json

from monty.json import MSONable
from monty.json import MSONable, MontyDecoder, MontyEncoder

from aiida.orm import Data

Expand Down Expand Up @@ -44,7 +45,37 @@ def __init__(self, obj, *args, **kwargs):
super().__init__(*args, **kwargs)

self._obj = obj
self.set_attribute_many(obj.as_dict())

# Serialize the object by calling ``as_dict`` and performing a roundtrip through JSON encoding using the
# ``MontyEncode`` for the encoding to string part. This is necessary to recursively serialize objects that can
# not be written to JSON, such as ``datetime`` objects and ``numpy`` arrays.
serialized = json.loads(json.dumps(obj.as_dict(), cls=MontyEncoder), parse_constant=lambda x: x)

# Then we apply our own custom serializer that serializes the float constants infinity and nan to a string value
# which is necessary because the serializer of the ``json`` standard module deserializes to the Python values
# that can not be written to JSON.
self.set_attribute_many(serialized)

@classmethod
def _deserialize_float_constants(cls, data):
"""Deserialize the contents of a dictionary ``data`` deserializing infinity and NaN string constants.
The ``data`` dictionary is recursively checked for the ``Infinity``, ``-Infinity`` and ``NaN`` strings, which
are the Javascript string equivalents to the Python ``float('inf')``, ``-float('inf')`` and ``float('nan')``
float constants. If one of the strings is encountered, the Python float constant is returned and otherwise the
original value is returned.
"""
if isinstance(data, dict):
return {k: cls._deserialize_float_constants(v) for k, v in data.items()}
if isinstance(data, list):
return [cls._deserialize_float_constants(v) for v in data]
if data == 'Infinity':
return float('inf')
if data == '-Infinity':
return -float('inf')
if data == 'NaN':
return float('nan')
return data

def _get_object(self):
"""Return the cached wrapped MSONable object.
Expand Down Expand Up @@ -72,7 +103,24 @@ def _get_object(self):
f'the objects module `{module_name}` does not contain the class `{class_name}`.'
) from exc

self._obj = cls.from_dict(attributes)
# First we need to deserialize any infinity or nan float string markers that were serialized in the
# constructor of this node when it was created. There the decoding step in the JSON roundtrip defined a
# pass-through for the ``parse_constant`` argument, which means that the serialized versions of the float
# constants (i.e. the strings ``Infinity`` etc.) are not deserialized in the Python float constants. Here we
# need to first explicit deserialize them. One would think that we could simply let the ``json.loads`` in
# the following step take care of this, however, since the attributes would first be serialized by the
# ``json.dumps`` call, the string placeholders would be dumped again to an actual string, which would then
# no longer be recognized by ``json.loads`` as the Javascript notation of the float constants and so it will
# leave them as separate strings.
deserialized = self._deserialize_float_constants(attributes)

# As a second step, we perform a JSON-serialization roundtrip using the ``MontyDecoder`` for deserialization
# which is necessary to recursively deserialize objects such as ``datetime`` and ``numpy`` arrays.
deserialized = json.loads(json.dumps(deserialized), cls=MontyDecoder)

# Finally, reconstruct the original ``MSONable`` class from the fully deserialized data.
self._obj = cls.from_dict(deserialized)

return self._obj

@property
Expand Down
22 changes: 19 additions & 3 deletions tests/orm/data/test_msonable.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
# -*- coding: utf-8 -*-
"""Tests for the :class:`aiida.orm.nodes.data.msonable.MsonableData` data type."""
import datetime
import math

from monty.json import MSONable
import numpy
import pymatgen
import pytest

Expand Down Expand Up @@ -31,7 +35,7 @@ def as_dict(self):
@classmethod
def from_dict(cls, d):
"""Reconstruct an instance from a serialized version."""
return cls(d['data'])
return cls(**d)


def test_construct():
Expand Down Expand Up @@ -104,7 +108,7 @@ def test_load():
@pytest.mark.usefixtures('clear_database_before_test')
def test_obj():
"""Test the ``MsonableData.obj`` property."""
data = {'a': 1}
data = [1, float('inf'), float('-inf'), float('nan'), numpy.arange(10), datetime.datetime.now()]
obj = MsonableClass(data)
node = MsonableData(obj)
node.store()
Expand All @@ -114,7 +118,19 @@ def test_obj():

loaded = load_node(node.pk)
assert isinstance(node.obj, MsonableClass)
assert loaded.obj.data == data

for left, right in zip(loaded.obj.data, data):

# Need this explicit case to compare NaN because of the peculiarity in Python where ``float(nan) != float(nan)``
if isinstance(left, float) and math.isnan(left):
assert math.isnan(right)
continue

try:
# This is needed to match numpy arrays
assert (left == right).all()
except AttributeError:
assert left == right


@pytest.mark.usefixtures('clear_database_before_test')
Expand Down

0 comments on commit 385b4a4

Please sign in to comment.