Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

store type in descriptor serialization data #3325

Merged
merged 6 commits into from
Feb 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions deepmd/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,3 +313,24 @@
object to hash
"""
return sha1(json.dumps(obj).encode("utf-8")).hexdigest()


def j_get_type(data: dict, class_name: str = "object") -> str:
"""Get the type from the data.

Parameters
----------
data : dict
the data
class_name : str, optional
the name of the class for error message, by default "object"

Returns
-------
str
the type
"""
try:
return data["type"]
except KeyError as e:
raise KeyError(f"the type of the {class_name} should be set by `type`") from e

Check warning on line 336 in deepmd/common.py

View check run for this annotation

Codecov / codecov/patch

deepmd/common.py#L335-L336

Added lines #L335 - L336 were not covered by tests
9 changes: 3 additions & 6 deletions deepmd/dpmodel/atomic_model/dp_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@

import numpy as np

from deepmd.dpmodel.descriptor import ( # noqa # TODO: should import all descriptors!
DescrptSeA,
from deepmd.dpmodel.descriptor.base_descriptor import (
BaseDescriptor,
)
from deepmd.dpmodel.fitting import ( # noqa # TODO: should import all fittings!
EnergyFittingNet,
Expand Down Expand Up @@ -135,16 +135,13 @@ def serialize(self) -> dict:
"type_map": self.type_map,
"descriptor": self.descriptor.serialize(),
"fitting": self.fitting.serialize(),
"descriptor_name": self.descriptor.__class__.__name__,
"fitting_name": self.fitting.__class__.__name__,
}

@classmethod
def deserialize(cls, data) -> "DPAtomicModel":
data = copy.deepcopy(data)
descriptor_obj = getattr(
sys.modules[__name__], data["descriptor_name"]
).deserialize(data["descriptor"])
descriptor_obj = BaseDescriptor.deserialize(data["descriptor"])
fitting_obj = getattr(sys.modules[__name__], data["fitting_name"]).deserialize(
data["fitting"]
)
Expand Down
1 change: 1 addition & 0 deletions deepmd/dpmodel/descriptor/base_descriptor.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# SPDX-License-Identifier: LGPL-3.0-or-later

import numpy as np

from .make_base_descriptor import (
Expand Down
66 changes: 61 additions & 5 deletions deepmd/dpmodel/descriptor/make_base_descriptor.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,24 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from abc import (
ABC,
abstractclassmethod,
abstractmethod,
)
from typing import (
Callable,
List,
Optional,
Type,
)

from deepmd.common import (
j_get_type,
)
from deepmd.utils.path import (
DPPath,
)
from deepmd.utils.plugin import (
Plugin,
)


def make_base_descriptor(
Expand All @@ -33,6 +40,42 @@
class BD(ABC):
"""Base descriptor provides the interfaces of descriptor."""

__plugins = Plugin()

@staticmethod
def register(key: str) -> Callable:
"""Register a descriptor plugin.

Parameters
----------
key : str
the key of a descriptor

Returns
-------
Descriptor
the registered descriptor

Examples
--------
>>> @Descriptor.register("some_descrpt")
class SomeDescript(Descriptor):
pass
"""
return BD.__plugins.register(key)

def __new__(cls, *args, **kwargs):
if cls is BD:
cls = cls.get_class_by_type(j_get_type(kwargs, cls.__name__))

Check warning on line 69 in deepmd/dpmodel/descriptor/make_base_descriptor.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/descriptor/make_base_descriptor.py#L69

Added line #L69 was not covered by tests
return super().__new__(cls)

@classmethod
def get_class_by_type(cls, descrpt_type: str) -> Type["BD"]:
if descrpt_type in BD.__plugins.plugins:
return BD.__plugins.plugins[descrpt_type]
else:
raise RuntimeError("Unknown descriptor type: " + descrpt_type)

Check warning on line 77 in deepmd/dpmodel/descriptor/make_base_descriptor.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/descriptor/make_base_descriptor.py#L77

Added line #L77 was not covered by tests

@abstractmethod
def get_rcut(self) -> float:
"""Returns the cut-off radius."""
Expand Down Expand Up @@ -95,10 +138,23 @@
"""Serialize the obj to dict."""
pass

@abstractclassmethod
def deserialize(cls):
"""Deserialize from a dict."""
pass
@classmethod
def deserialize(cls, data: dict) -> "BD":
"""Deserialize the model.

Parameters
----------
data : dict
The serialized data

Returns
-------
BD
The deserialized descriptor
"""
if cls is BD:
return BD.get_class_by_type(data["type"]).deserialize(data)
raise NotImplementedError("Not implemented in class %s" % cls.__name__)

Check warning on line 157 in deepmd/dpmodel/descriptor/make_base_descriptor.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/descriptor/make_base_descriptor.py#L157

Added line #L157 was not covered by tests

setattr(BD, fwd_method_name, BD.fwd)
delattr(BD, "fwd")
Expand Down
5 changes: 5 additions & 0 deletions deepmd/dpmodel/descriptor/se_e2_a.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
)


@BaseDescriptor.register("se_e2_a")
class DescrptSeA(NativeOP, BaseDescriptor):
r"""DeepPot-SE constructed from all information (both angular and radial) of
atomic configurations. The embedding takes the distance between atoms as input.
Expand Down Expand Up @@ -313,6 +314,8 @@ def call(
def serialize(self) -> dict:
"""Serialize the descriptor to dict."""
return {
"@class": "Descriptor",
wanghan-iapcm marked this conversation as resolved.
Show resolved Hide resolved
"type": "se_e2_a",
"rcut": self.rcut,
"rcut_smth": self.rcut_smth,
"sel": self.sel,
Expand All @@ -339,6 +342,8 @@ def serialize(self) -> dict:
def deserialize(cls, data: dict) -> "DescrptSeA":
"""Deserialize from dict."""
data = copy.deepcopy(data)
data.pop("@class", None)
data.pop("type", None)
variables = data.pop("@variables")
embeddings = data.pop("embeddings")
env_mat = data.pop("env_mat")
Expand Down
9 changes: 3 additions & 6 deletions deepmd/pt/model/atomic_model/dp_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@
from deepmd.dpmodel import (
FittingOutputDef,
)
from deepmd.pt.model.descriptor.se_a import ( # noqa # TODO: should import all descriptors!!!
DescrptSeA,
from deepmd.pt.model.descriptor.descriptor import (
Descriptor,
)
from deepmd.pt.model.task.ener import ( # noqa # TODO: should import all fittings!
EnergyFittingNet,
Expand Down Expand Up @@ -98,16 +98,13 @@ def serialize(self) -> dict:
"type_map": self.type_map,
"descriptor": self.descriptor.serialize(),
"fitting": self.fitting_net.serialize(),
"descriptor_name": self.descriptor.__class__.__name__,
"fitting_name": self.fitting_net.__class__.__name__,
}

@classmethod
def deserialize(cls, data) -> "DPAtomicModel":
data = copy.deepcopy(data)
descriptor_obj = getattr(
sys.modules[__name__], data["descriptor_name"]
).deserialize(data["descriptor"])
descriptor_obj = Descriptor.deserialize(data["descriptor"])
fitting_obj = getattr(sys.modules[__name__], data["fitting_name"]).deserialize(
data["fitting"]
)
Expand Down
41 changes: 33 additions & 8 deletions deepmd/pt/model/descriptor/descriptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,14 @@
Dict,
List,
Optional,
Type,
)

import torch

from deepmd.common import (
j_get_type,
)
from deepmd.pt.model.network.network import (
TypeEmbedNet,
)
Expand Down Expand Up @@ -92,16 +96,37 @@

def __new__(cls, *args, **kwargs):
if cls is Descriptor:
try:
descrpt_type = kwargs["type"]
except KeyError:
raise KeyError("the type of descriptor should be set by `type`")
if descrpt_type in Descriptor.__plugins.plugins:
cls = Descriptor.__plugins.plugins[descrpt_type]
else:
raise RuntimeError("Unknown descriptor type: " + descrpt_type)
cls = cls.get_class_by_type(j_get_type(kwargs, cls.__name__))
return super().__new__(cls)

@classmethod
def get_class_by_type(cls, descrpt_type: str) -> Type["Descriptor"]:
if descrpt_type in Descriptor.__plugins.plugins:
return Descriptor.__plugins.plugins[descrpt_type]
else:
raise RuntimeError("Unknown descriptor type: " + descrpt_type)

Check warning on line 107 in deepmd/pt/model/descriptor/descriptor.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/descriptor/descriptor.py#L107

Added line #L107 was not covered by tests

@classmethod
def deserialize(cls, data: dict) -> "Descriptor":
"""Deserialize the model.

There is no suffix in a native DP model, but it is important
for the TF backend.

Parameters
----------
data : dict
The serialized data

Returns
-------
Descriptor
The deserialized descriptor
"""
if cls is Descriptor:
return Descriptor.get_class_by_type(data["type"]).deserialize(data)
raise NotImplementedError("Not implemented in class %s" % cls.__name__)

Check warning on line 128 in deepmd/pt/model/descriptor/descriptor.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/descriptor/descriptor.py#L128

Added line #L128 was not covered by tests


class DescriptorBlock(torch.nn.Module, ABC):
"""The building block of descriptor.
Expand Down
4 changes: 4 additions & 0 deletions deepmd/pt/model/descriptor/se_a.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,8 @@ def set_stat_mean_and_stddev(
def serialize(self) -> dict:
obj = self.sea
return {
"@class": "Descriptor",
"type": "se_e2_a",
"rcut": obj.rcut,
"rcut_smth": obj.rcut_smth,
"sel": obj.sel,
Expand All @@ -219,6 +221,8 @@ def serialize(self) -> dict:
@classmethod
def deserialize(cls, data: dict) -> "DescrptSeA":
data = data.copy()
data.pop("@class", None)
data.pop("type", None)
variables = data.pop("@variables")
embeddings = data.pop("embeddings")
env_mat = data.pop("env_mat")
Expand Down
17 changes: 9 additions & 8 deletions deepmd/tf/descriptor/descriptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@

import numpy as np

from deepmd.common import (
j_get_type,
)
from deepmd.tf.env import (
GLOBAL_TF_FLOAT_PRECISION,
tf,
Expand Down Expand Up @@ -67,19 +70,15 @@ class SomeDescript(Descriptor):
return Descriptor.__plugins.register(key)

@classmethod
def get_class_by_input(cls, input: dict):
try:
descrpt_type = input["type"]
except KeyError:
raise KeyError("the type of descriptor should be set by `type`")
def get_class_by_type(cls, descrpt_type: str):
if descrpt_type in Descriptor.__plugins.plugins:
return Descriptor.__plugins.plugins[descrpt_type]
else:
raise RuntimeError("Unknown descriptor type: " + descrpt_type)

def __new__(cls, *args, **kwargs):
if cls is Descriptor:
cls = cls.get_class_by_input(kwargs)
cls = cls.get_class_by_type(j_get_type(kwargs, cls.__name__))
return super().__new__(cls)

@abstractmethod
Expand Down Expand Up @@ -507,7 +506,7 @@ def update_sel(cls, global_jdata: dict, local_jdata: dict):
The local data refer to the current class
"""
# call subprocess
cls = cls.get_class_by_input(local_jdata)
cls = cls.get_class_by_type(j_get_type(local_jdata, cls.__name__))
return cls.update_sel(global_jdata, local_jdata)

@classmethod
Expand All @@ -530,7 +529,9 @@ def deserialize(cls, data: dict, suffix: str = "") -> "Descriptor":
The deserialized descriptor
"""
if cls is Descriptor:
return Descriptor.get_class_by_input(data).deserialize(data, suffix=suffix)
return Descriptor.get_class_by_type(
j_get_type(data, cls.__name__)
).deserialize(data, suffix=suffix)
raise NotImplementedError("Not implemented in class %s" % cls.__name__)

def serialize(self, suffix: str = "") -> dict:
Expand Down
4 changes: 4 additions & 0 deletions deepmd/tf/descriptor/se_a.py
Original file line number Diff line number Diff line change
Expand Up @@ -1368,6 +1368,8 @@ def deserialize(cls, data: dict, suffix: str = ""):
if cls is not DescrptSeA:
raise NotImplementedError("Not implemented in class %s" % cls.__name__)
data = data.copy()
data.pop("@class", None)
data.pop("type", None)
embedding_net_variables = cls.deserialize_network(
data.pop("embeddings"), suffix=suffix
)
Expand Down Expand Up @@ -1418,6 +1420,8 @@ def serialize(self, suffix: str = "") -> dict:
# but instead a part of the input data. Maybe the interface should be refactored...

return {
"@class": "Descriptor",
"type": "se_e2_a",
"rcut": self.rcut_r,
"rcut_smth": self.rcut_r_smth,
"sel": self.sel_a,
Expand Down
Loading