Skip to content

Commit

Permalink
👌 IMPROVE: Add dataclass serialisation to context (#5833)
Browse files Browse the repository at this point in the history
  • Loading branch information
chrisjsewell authored Dec 14, 2022
1 parent 476d4b8 commit 40857b6
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 0 deletions.
25 changes: 25 additions & 0 deletions aiida/orm/utils/serialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,10 @@
"""
from __future__ import annotations

from dataclasses import asdict, is_dataclass
from enum import Enum
from functools import partial
import inspect
from typing import Any, Protocol, Type, overload

from plumpy import Bundle, get_object_loader # type: ignore[attr-defined]
Expand All @@ -28,6 +30,7 @@
from aiida.common import AttributeDict

_ENUM_TAG = '!enum'
_DATACLASS_TAG = '!dataclass'
_NODE_TAG = '!aiida_node'
_GROUP_TAG = '!aiida_group'
_COMPUTER_TAG = '!aiida_computer'
Expand All @@ -51,6 +54,25 @@ def enum_constructor(loader: yaml.Loader, serialized: yaml.Node) -> Enum:
return enum


def represent_dataclass(dumper: yaml.Dumper, obj: Any) -> yaml.MappingNode:
"""Represent an arbitrary dataclass in yaml."""
loader = get_object_loader()
data = {
'__type__': loader.identify_object(obj.__class__),
'__fields__': asdict(obj),
}
return dumper.represent_mapping(_DATACLASS_TAG, data)


def dataclass_constructor(loader: yaml.Loader, serialized: yaml.Node) -> Any:
"""Construct a dataclass from the serialized representation."""
deserialized = loader.construct_mapping(serialized, deep=True) # type: ignore[arg-type]
identifier = deserialized['__type__']
cls = get_object_loader().load_object(identifier)
data = deserialized['__fields__']
return cls(**data)


def represent_node(dumper: yaml.Dumper, node: orm.Node) -> yaml.ScalarNode:
"""Represent a node in yaml."""
if not node.is_stored:
Expand Down Expand Up @@ -136,6 +158,8 @@ def represent_data(self, data):
return represent_computer(self, data)
if isinstance(data, orm.Group):
return represent_group(self, data)
if is_dataclass(data) and not inspect.isclass(data):
return represent_dataclass(self, data)

return super().represent_data(data)

Expand Down Expand Up @@ -163,6 +187,7 @@ class AiiDALoader(yaml.Loader):
yaml.add_constructor(_GROUP_TAG, group_constructor, Loader=AiiDALoader)
yaml.add_constructor(_COMPUTER_TAG, computer_constructor, Loader=AiiDALoader)
yaml.add_constructor(_ENUM_TAG, enum_constructor, Loader=AiiDALoader)
yaml.add_constructor(_DATACLASS_TAG, dataclass_constructor, Loader=AiiDALoader)


@overload
Expand Down
17 changes: 17 additions & 0 deletions tests/orm/utils/test_serialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
# For further information please visit http://www.aiida.net #
###########################################################################
"""Tests for the :mod:`aiida.orm.utils.serialize` module."""
from dataclasses import dataclass
import types
import uuid

Expand Down Expand Up @@ -162,3 +163,19 @@ def test_enum():

deserialized = serialize.deserialize_unsafe(serialized)
assert deserialized == enum


@dataclass
class DataClass:
"""A dataclass for testing."""
my_value: int


def test_dataclass():
"""Test serialization and deserialization of a ``dataclass``."""
obj = DataClass(1)
serialized = serialize.serialize(obj)
assert isinstance(serialized, str)

deserialized = serialize.deserialize_unsafe(serialized)
assert deserialized == obj

0 comments on commit 40857b6

Please sign in to comment.