Skip to content

Commit

Permalink
Add the MsonableData data plugin
Browse files Browse the repository at this point in the history
The `MSONable` class, provided by the `monty` library, is designed to
make arbitrary classes easily JSON-serializable such that they can be
serialized and, for example, stored in a database. The class is used
extensively in `pymatgen` and many of their classes are MSONable. Since
the `pymatgen` classes are used often in the current AiiDA community, it
would be nice if these objects can be easily stored in the provenance
graph.

The new `MsonableData` data plugin, wraps any instance of a `MSONable`
class. When constructed, it calls `as_dict` which by definition returns
a JSON-serialized version of the object, which we can therefore store in
the nodes attributes and allow it to be stored. The `obj` property will
return the original MSONable instance as it deserializes it from the
serialized version stored in the attributes.
  • Loading branch information
sphuber committed Jul 8, 2021
1 parent 58df44c commit 2baa9db
Show file tree
Hide file tree
Showing 3 changed files with 231 additions and 0 deletions.
81 changes: 81 additions & 0 deletions aiida/orm/nodes/data/msonable.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
# -*- coding: utf-8 -*-
"""Data plugin for classes that implement the ``MSONable`` class of the ``monty`` library."""
import importlib

from monty.json import MSONable

from aiida.orm import Data


class MsonableData(Data):
"""Data plugin that allows to easily wrap objects that are MSONable.
To use this class, simply construct it passing an isntance of any ``MSONable`` class and store it, for example:
from pymatgen.core import Molecule
molecule = Molecule(['H']. [0, 0, 0])
node = MsonableData(molecule)
node.store()
After storing, the node can be loaded like any other node and the original MSONable instance can be retrieved:
loaded = load_node(node.pk)
molecule = loaded.obj
.. note:: As the ``MSONable`` mixin class requires, the wrapped object needs to implement the methods ``as_dict``
and ``from_dict``. A default implementation should be present on the ``MSONable`` base class itself, but it
might need to be overridden in a specific implementation.
"""

def __init__(self, obj, *args, **kwargs):
"""Construct the node from the pymatgen object."""
if obj is None:
raise TypeError('the `obj` argument cannot be `None`.')

if not isinstance(obj, MSONable):
raise TypeError('the `obj` argument needs to implement the ``MSONable`` class.')

for method in ['as_dict', 'from_dict']:
if not hasattr(obj, method) or not callable(getattr(obj, method)):
raise TypeError(f'the `obj` argument does not have the required `{method}` method.')

super().__init__(*args, **kwargs)

self._obj = obj
self.set_attribute_many(obj.as_dict())

def _get_object(self):
"""Return the cached wrapped MSONable object.
.. note:: If the object is not yet present in memory, for example if the node was loaded from the database,
the object will first be reconstructed from the state stored in the node attributes.
"""
try:
return self._obj
except AttributeError:
attributes = self.attributes
class_name = attributes['@class']
module_name = attributes['@module']

try:
module = importlib.import_module(module_name)
except ImportError as exc:
raise ImportError(f'the objects module `{module_name}` can not be imported.') from exc

try:
cls = getattr(module, class_name)
except AttributeError as exc:
raise ImportError(
f'the objects module `{module_name}` does not contain the class `{class_name}`.'
) from exc

self._obj = cls.from_dict(attributes)
return self._obj

@property
def obj(self):
"""Return the wrapped MSONable object."""
return self._get_object()
1 change: 1 addition & 0 deletions setup.json
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,7 @@
"folder = aiida.orm.nodes.data.folder:FolderData",
"int = aiida.orm.nodes.data.int:Int",
"list = aiida.orm.nodes.data.list:List",
"msonable = aiida.orm.nodes.data.msonable:MsonableData",
"numeric = aiida.orm.nodes.data.numeric:NumericType",
"orbital = aiida.orm.nodes.data.orbital:OrbitalData",
"remote = aiida.orm.nodes.data.remote.base:RemoteData",
Expand Down
149 changes: 149 additions & 0 deletions tests/orm/data/test_msonable.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
# -*- coding: utf-8 -*-
"""Tests for the :class:`aiida.orm.nodes.data.msonable.MsonableData` data type."""
from monty.json import MSONable
import pymatgen
import pytest

from aiida.orm import load_node
from aiida.orm.nodes.data.msonable import MsonableData


class MsonableClass(MSONable):
"""Dummy class that implements the ``MSONable interface``."""

def __init__(self, data):
"""Construct a new object."""
self._data = data

@property
def data(self):
"""Return the data of this instance."""
return self._data

def as_dict(self):
"""Represent the object as a JSON-serializable dictionary."""
return {
'@module': self.__class__.__module__,
'@class': self.__class__.__name__,
'data': self._data,
}

@classmethod
def from_dict(cls, d):
"""Reconstruct an instance from a serialized version."""
return cls(d['data'])


def test_construct():
"""Test the ``MsonableData`` constructor."""
data = {'a': 1}
obj = MsonableClass(data)
node = MsonableData(obj)

assert isinstance(node, MsonableData)
assert not node.is_stored


def test_constructor_object_none():
"""Test the ``MsonableData`` constructor raises if object is ``None``."""
with pytest.raises(TypeError, match=r'the `obj` argument cannot be `None`.'):
MsonableData(None)


def test_invalid_class_not_msonable():
"""Test the ``MsonableData`` constructor raises if object does not sublass ``MSONable``."""

class InvalidClass:
pass

with pytest.raises(TypeError, match=r'the `obj` argument needs to implement the ``MSONable`` class.'):
MsonableData(InvalidClass())


def test_invalid_class_no_as_dict():
"""Test the ``MsonableData`` constructor raises if object does not sublass ``MSONable``."""

class InvalidClass(MSONable):

@classmethod
def from_dict(cls, d):
pass

# Remove the ``as_dict`` method from the ``MSONable`` base class because that is currently implemented by default.
del MSONable.as_dict

with pytest.raises(TypeError, match=r'the `obj` argument does not have the required `as_dict` method.'):
MsonableData(InvalidClass())


@pytest.mark.usefixtures('clear_database_before_test')
def test_store():
"""Test storing a ``MsonableData`` instance."""
data = {'a': 1}
obj = MsonableClass(data)
node = MsonableData(obj)
assert not node.is_stored

node.store()
assert node.is_stored


@pytest.mark.usefixtures('clear_database_before_test')
def test_load():
"""Test loading a ``MsonableData`` instance."""
data = {'a': 1}
obj = MsonableClass(data)
node = MsonableData(obj)
node.store()

loaded = load_node(node.pk)
assert isinstance(node, MsonableData)
assert loaded == node


@pytest.mark.usefixtures('clear_database_before_test')
def test_obj():
"""Test the ``MsonableData.obj`` property."""
data = {'a': 1}
obj = MsonableClass(data)
node = MsonableData(obj)
node.store()

assert isinstance(node.obj, MsonableClass)
assert node.obj.data == data

loaded = load_node(node.pk)
assert isinstance(node.obj, MsonableClass)
assert loaded.obj.data == data


@pytest.mark.usefixtures('clear_database_before_test')
def test_unimportable_module():
"""Test the ``MsonableData.obj`` property if the associated module cannot be loaded."""
obj = pymatgen.core.Molecule(['H'], [[0, 0, 0]])
node = MsonableData(obj)

# Artificially change the ``@module`` in the attributes so it becomes unloadable
node.set_attribute('@module', 'not.existing')
node.store()

loaded = load_node(node.pk)

with pytest.raises(ImportError, match='the objects module `not.existing` can not be imported.'):
_ = loaded.obj


@pytest.mark.usefixtures('clear_database_before_test')
def test_unimportable_class():
"""Test the ``MsonableData.obj`` property if the associated class cannot be loaded."""
obj = pymatgen.core.Molecule(['H'], [[0, 0, 0]])
node = MsonableData(obj)

# Artificially change the ``@class`` in the attributes so it becomes unloadable
node.set_attribute('@class', 'NonExistingClass')
node.store()

loaded = load_node(node.pk)

with pytest.raises(ImportError, match=r'the objects module `.*` does not contain the class `NonExistingClass`.'):
_ = loaded.obj

0 comments on commit 2baa9db

Please sign in to comment.