diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 751bf2a231..46a627af87 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -68,6 +68,7 @@ repos: aiida/engine/.*py| aiida/manage/manager.py| aiida/manage/database/delete/nodes.py| + aiida/orm/nodes/data/jsonable.py| aiida/orm/nodes/node.py| aiida/orm/nodes/process/.*py| aiida/plugins/entry_point.py| diff --git a/aiida/orm/nodes/data/jsonable.py b/aiida/orm/nodes/data/jsonable.py new file mode 100644 index 0000000000..70d0487332 --- /dev/null +++ b/aiida/orm/nodes/data/jsonable.py @@ -0,0 +1,154 @@ +# -*- coding: utf-8 -*- +"""Data plugin that allows to easily wrap objects that are JSON-able.""" +import importlib +import json +import typing + +from aiida.orm import Data + +try: + from typing import Protocol +except ImportError: # Python <3.8 doesn't have `Protocol` in the stdlib + from typing_extensions import Protocol # type: ignore[misc] + + +class JsonSerializableProtocol(Protocol): + + def as_dict(self) -> typing.MutableMapping[typing.Any, typing.Any]: + ... + + +class JsonableData(Data): + """Data plugin that allows to easily wrap objects that are JSON-able. + + Any class that implements the ``as_dict`` method, returning a dictionary that is a JSON serializable representation + of the object, can be wrapped and stored by this data plugin. + + As an example, take the ``Molecule`` class of the ``pymatgen`` library, which respects the spec described above. To + store an instance as a ``JsonableData`` simply pass an instance as an argument to the constructor as follows:: + + from pymatgen.core import Molecule + molecule = Molecule(['H']. [0, 0, 0]) + node = JsonableData(molecule) + node.store() + + Since ``Molecule.as_dict`` returns a dictionary that is JSON-serializable, the data plugin will call it and store + the dictionary as the attributes of the ``JsonableData`` node in the database. + + .. note:: A JSON-serializable dictionary means a dictionary that when passed to ``json.dumps`` does not except but + produces a valid JSON string representation of the dictionary. + + If the wrapped class implements a class-method ``from_dict``, the wrapped instance can easily be recovered from a + previously stored node that was optionally loaded from the database. The ``from_dict`` method should simply accept + a single argument which is the dictionary that is returned by the ``as_dict`` method. If this criteria is satisfied, + an instance wrapped and stored in a ``JsonableData`` node can be recovered through the ``obj`` property:: + + loaded = load_node(node.pk) + molecule = loaded.obj + + Of course, this requires that the class of the originally wrapped instance can be imported in the current + environment, or an ``ImportError`` will be raised. + """ + + def __init__(self, obj: JsonSerializableProtocol, *args, **kwargs): + """Construct the node for the to be wrapped object.""" + if obj is None: + raise TypeError('the `obj` argument cannot be `None`.') + + if not hasattr(obj, 'as_dict') or not callable(getattr(obj, 'as_dict')): + raise TypeError('the `obj` argument does not have the required `as_dict` method.') + + super().__init__(*args, **kwargs) + + self._obj = obj + dictionary = obj.as_dict() + + if '@class' not in dictionary: + dictionary['@class'] = obj.__class__.__name__ + + if '@module' not in dictionary: + dictionary['@module'] = obj.__class__.__module__ + + # Even though the dictionary returned by ``as_dict`` should be JSON-serializable and therefore this should be + # sufficient to be able to generate a JSON representation and thus store it in the database, there is a + # difference in the JSON serializers used by Python's ``json`` module and those of the PostgreSQL database that + # is used for the database backend. Python's ``json`` module automatically serializes the ``inf`` and ``nan`` + # float constants to the Javascript equivalent strings, however, PostgreSQL does not. If we were to pass the + # dictionary from ``as_dict`` straight to the attributes and it were to contain any of these floats, the storing + # of the node would fail, even though technically it is JSON-serializable using the default Python module. To + # work around this asymmetry, we perform a serialization round-trip with the ``JsonEncoder`` and ``JsonDecoder`` + # where in the deserialization, the encoded float constants are not deserialized, but instead the string + # placeholders are kept. This now ensures that the full dictionary will be serializable by PostgreSQL. + try: + serialized = json.loads(json.dumps(dictionary), parse_constant=lambda x: x) + except TypeError as exc: + raise TypeError(f'the object `{obj}` is not JSON-serializable and therefore cannot be stored.') from exc + + self.set_attribute_many(serialized) + + @classmethod + def _deserialize_float_constants(cls, data: typing.Any): + """Deserialize the contents of a dictionary ``data`` deserializing infinity and NaN string constants. + + The ``data`` dictionary is recursively checked for the ``Infinity``, ``-Infinity`` and ``NaN`` strings, which + are the Javascript string equivalents to the Python ``float('inf')``, ``-float('inf')`` and ``float('nan')`` + float constants. If one of the strings is encountered, the Python float constant is returned and otherwise the + original value is returned. + """ + if isinstance(data, dict): + return {k: cls._deserialize_float_constants(v) for k, v in data.items()} + if isinstance(data, list): + return [cls._deserialize_float_constants(v) for v in data] + if data == 'Infinity': + return float('inf') + if data == '-Infinity': + return -float('inf') + if data == 'NaN': + return float('nan') + return data + + def _get_object(self) -> JsonSerializableProtocol: + """Return the cached wrapped object. + + .. note:: If the object is not yet present in memory, for example if the node was loaded from the database, + the object will first be reconstructed from the state stored in the node attributes. + + """ + try: + return self._obj + except AttributeError: + attributes = self.attributes + class_name = attributes.pop('@class') + module_name = attributes.pop('@module') + + try: + module = importlib.import_module(module_name) + except ImportError as exc: + raise ImportError(f'the objects module `{module_name}` can not be imported.') from exc + + try: + cls = getattr(module, class_name) + except AttributeError as exc: + raise ImportError( + f'the objects module `{module_name}` does not contain the class `{class_name}`.' + ) from exc + + deserialized = self._deserialize_float_constants(attributes) + self._obj = cls.from_dict(deserialized) + + return self._obj + + @property + def obj(self) -> JsonSerializableProtocol: + """Return the wrapped object. + + .. note:: This property caches the deserialized object, this means that when the node is loaded from the + database, the object is deserialized only once and stored in memory as an attribute. Subsequent calls will + simply return this cached object and not reload it from the database. This is fine, since nodes that are + loaded from the database are by definition stored and therefore immutable, making it safe to assume that the + object that is represented can not change. Note, however, that the caching also applies to unstored nodes. + That means that manually changing the attributes of an unstored ``JsonableData`` can lead to inconsistencies + with the object returned by this property. + + """ + return self._get_object() diff --git a/docs/source/howto/data.rst b/docs/source/howto/data.rst index a30a5104c6..d710aa0023 100644 --- a/docs/source/howto/data.rst +++ b/docs/source/howto/data.rst @@ -79,7 +79,6 @@ If none of the currently available data types, as listed by ``verdi plugin list` For details refer to the next section :ref:`"How to add support for custom data types"`. - .. _how-to:data:organize: Organizing data diff --git a/docs/source/nitpick-exceptions b/docs/source/nitpick-exceptions index 45e463df6a..76c37de928 100644 --- a/docs/source/nitpick-exceptions +++ b/docs/source/nitpick-exceptions @@ -26,6 +26,7 @@ py:class IO py:class traceback py:class _io.BufferedReader py:class BinaryIO +py:class typing_extensions.Protocol ### AiiDA diff --git a/environment.yml b/environment.yml index 01447410cd..5a7ff80be7 100644 --- a/environment.yml +++ b/environment.yml @@ -37,6 +37,7 @@ dependencies: - sqlalchemy-utils~=0.37.2 - sqlalchemy~=1.3.10 - tabulate~=0.8.5 +- typing-extensions - tqdm~=4.45 - tzlocal~=2.0 - upf_to_json~=0.9.2 diff --git a/setup.json b/setup.json index 2e5534c651..cbf8cad787 100644 --- a/setup.json +++ b/setup.json @@ -51,6 +51,7 @@ "sqlalchemy-utils~=0.37.2", "sqlalchemy~=1.3.10", "tabulate~=0.8.5", + "typing-extensions; python_version < '3.8'", "tqdm~=4.45", "tzlocal~=2.0", "upf_to_json~=0.9.2", @@ -164,6 +165,7 @@ "folder = aiida.orm.nodes.data.folder:FolderData", "int = aiida.orm.nodes.data.int:Int", "list = aiida.orm.nodes.data.list:List", + "jsonable = aiida.orm.nodes.data.jsonable:JsonableData", "numeric = aiida.orm.nodes.data.numeric:NumericType", "orbital = aiida.orm.nodes.data.orbital:OrbitalData", "remote = aiida.orm.nodes.data.remote.base:RemoteData", diff --git a/tests/orm/data/test_jsonable.py b/tests/orm/data/test_jsonable.py new file mode 100644 index 0000000000..fe185d5da8 --- /dev/null +++ b/tests/orm/data/test_jsonable.py @@ -0,0 +1,162 @@ +# -*- coding: utf-8 -*- +"""Tests for the :class:`aiida.orm.nodes.data.jsonable.JsonableData` data type.""" +import datetime +import math + +from pymatgen.core import Molecule +import pytest + +from aiida.orm import load_node +from aiida.orm.nodes.data.jsonable import JsonableData + + +class JsonableClass: + """Dummy class that implements the required interface.""" + + def __init__(self, data): + """Construct a new object.""" + self._data = data + + @property + def data(self): + """Return the data of this instance.""" + return self._data + + def as_dict(self): + """Represent the object as a JSON-serializable dictionary.""" + return { + 'data': self._data, + } + + @classmethod + def from_dict(cls, dictionary): + """Reconstruct an instance from a serialized version.""" + return cls(dictionary['data']) + + +def test_construct(): + """Test the ``JsonableData`` constructor.""" + data = {'a': 1} + obj = JsonableClass(data) + node = JsonableData(obj) + + assert isinstance(node, JsonableData) + assert not node.is_stored + + +def test_constructor_object_none(): + """Test the ``JsonableData`` constructor raises if object is ``None``.""" + with pytest.raises(TypeError, match=r'the `obj` argument cannot be `None`.'): + JsonableData(None) + + +def test_invalid_class_no_as_dict(): + """Test the ``JsonableData`` constructor raises if object does not implement ``as_dict``.""" + + class InvalidClass: + pass + + with pytest.raises(TypeError, match=r'the `obj` argument does not have the required `as_dict` method.'): + JsonableData(InvalidClass()) + + +def test_invalid_class_not_serializable(): + """Test the ``JsonableData`` constructor raises if object .""" + obj = JsonableClass({'datetime': datetime.datetime.now()}) + + with pytest.raises(TypeError, match=r'the object `.*` is not JSON-serializable and therefore cannot be stored.'): + JsonableData(obj) + + +@pytest.mark.usefixtures('clear_database_before_test') +def test_store(): + """Test storing a ``JsonableData`` instance.""" + data = {'a': 1} + obj = JsonableClass(data) + node = JsonableData(obj) + assert not node.is_stored + + node.store() + assert node.is_stored + + +@pytest.mark.usefixtures('clear_database_before_test') +def test_load(): + """Test loading a ``JsonableData`` instance.""" + data = {'a': 1} + obj = JsonableClass(data) + node = JsonableData(obj) + node.store() + + loaded = load_node(node.pk) + assert isinstance(node, JsonableData) + assert loaded == node + + +@pytest.mark.usefixtures('clear_database_before_test') +def test_obj(): + """Test the ``JsonableData.obj`` property.""" + data = [1, float('inf'), float('-inf'), float('nan')] + obj = JsonableClass(data) + node = JsonableData(obj) + node.store() + + assert isinstance(node.obj, JsonableClass) + assert node.obj.data == data + + loaded = load_node(node.pk) + assert isinstance(node.obj, JsonableClass) + + for left, right in zip(loaded.obj.data, data): + + # Need this explicit case to compare NaN because of the peculiarity in Python where ``float(nan) != float(nan)`` + if isinstance(left, float) and math.isnan(left): + assert math.isnan(right) + continue + + assert left == right + + +@pytest.mark.usefixtures('clear_database_before_test') +def test_unimportable_module(): + """Test the ``JsonableData.obj`` property if the associated module cannot be loaded.""" + obj = Molecule(['H'], [[0, 0, 0]]) + node = JsonableData(obj) + + # Artificially change the ``@module`` in the attributes so it becomes unloadable + node.set_attribute('@module', 'not.existing') + node.store() + + loaded = load_node(node.pk) + + with pytest.raises(ImportError, match='the objects module `not.existing` can not be imported.'): + _ = loaded.obj + + +@pytest.mark.usefixtures('clear_database_before_test') +def test_unimportable_class(): + """Test the ``JsonableData.obj`` property if the associated class cannot be loaded.""" + obj = Molecule(['H'], [[0, 0, 0]]) + node = JsonableData(obj) + + # Artificially change the ``@class`` in the attributes so it becomes unloadable + node.set_attribute('@class', 'NonExistingClass') + node.store() + + loaded = load_node(node.pk) + + with pytest.raises(ImportError, match=r'the objects module `.*` does not contain the class `NonExistingClass`.'): + _ = loaded.obj + + +@pytest.mark.usefixtures('clear_database_before_test') +def test_msonable(): + """Test that an ``MSONAble`` object can be wrapped, stored and loaded again.""" + obj = Molecule(['H'], [[0, 0, 0]]) + node = JsonableData(obj) + node.store() + assert node.is_stored + + loaded = load_node(node.pk) + assert loaded is not node + assert loaded.obj == obj