Skip to content

Commit

Permalink
store type in descriptor serialization data (#3325)
Browse files Browse the repository at this point in the history
Signed-off-by: Jinzhe Zeng <jinzhe.zeng@rutgers.edu>
  • Loading branch information
njzjz authored Feb 23, 2024
1 parent 543276a commit 649fdca
Show file tree
Hide file tree
Showing 12 changed files with 170 additions and 56 deletions.
21 changes: 21 additions & 0 deletions deepmd/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,3 +313,24 @@ def get_hash(obj) -> str:
object to hash
"""
return sha1(json.dumps(obj).encode("utf-8")).hexdigest()


def j_get_type(data: dict, class_name: str = "object") -> str:
"""Get the type from the data.
Parameters
----------
data : dict
the data
class_name : str, optional
the name of the class for error message, by default "object"
Returns
-------
str
the type
"""
try:
return data["type"]
except KeyError as e:
raise KeyError(f"the type of the {class_name} should be set by `type`") from e
9 changes: 3 additions & 6 deletions deepmd/dpmodel/atomic_model/dp_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@

import numpy as np

from deepmd.dpmodel.descriptor import ( # noqa # TODO: should import all descriptors!
DescrptSeA,
from deepmd.dpmodel.descriptor.base_descriptor import (
BaseDescriptor,
)
from deepmd.dpmodel.fitting import ( # noqa # TODO: should import all fittings!
EnergyFittingNet,
Expand Down Expand Up @@ -135,16 +135,13 @@ def serialize(self) -> dict:
"type_map": self.type_map,
"descriptor": self.descriptor.serialize(),
"fitting": self.fitting.serialize(),
"descriptor_name": self.descriptor.__class__.__name__,
"fitting_name": self.fitting.__class__.__name__,
}

@classmethod
def deserialize(cls, data) -> "DPAtomicModel":
data = copy.deepcopy(data)
descriptor_obj = getattr(
sys.modules[__name__], data["descriptor_name"]
).deserialize(data["descriptor"])
descriptor_obj = BaseDescriptor.deserialize(data["descriptor"])
fitting_obj = getattr(sys.modules[__name__], data["fitting_name"]).deserialize(
data["fitting"]
)
Expand Down
1 change: 1 addition & 0 deletions deepmd/dpmodel/descriptor/base_descriptor.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# SPDX-License-Identifier: LGPL-3.0-or-later

import numpy as np

from .make_base_descriptor import (
Expand Down
66 changes: 61 additions & 5 deletions deepmd/dpmodel/descriptor/make_base_descriptor.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,24 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from abc import (
ABC,
abstractclassmethod,
abstractmethod,
)
from typing import (
Callable,
List,
Optional,
Type,
)

from deepmd.common import (
j_get_type,
)
from deepmd.utils.path import (
DPPath,
)
from deepmd.utils.plugin import (
Plugin,
)


def make_base_descriptor(
Expand All @@ -33,6 +40,42 @@ def make_base_descriptor(
class BD(ABC):
"""Base descriptor provides the interfaces of descriptor."""

__plugins = Plugin()

@staticmethod
def register(key: str) -> Callable:
"""Register a descriptor plugin.
Parameters
----------
key : str
the key of a descriptor
Returns
-------
Descriptor
the registered descriptor
Examples
--------
>>> @Descriptor.register("some_descrpt")
class SomeDescript(Descriptor):
pass
"""
return BD.__plugins.register(key)

def __new__(cls, *args, **kwargs):
if cls is BD:
cls = cls.get_class_by_type(j_get_type(kwargs, cls.__name__))
return super().__new__(cls)

@classmethod
def get_class_by_type(cls, descrpt_type: str) -> Type["BD"]:
if descrpt_type in BD.__plugins.plugins:
return BD.__plugins.plugins[descrpt_type]
else:
raise RuntimeError("Unknown descriptor type: " + descrpt_type)

@abstractmethod
def get_rcut(self) -> float:
"""Returns the cut-off radius."""
Expand Down Expand Up @@ -95,10 +138,23 @@ def serialize(self) -> dict:
"""Serialize the obj to dict."""
pass

@abstractclassmethod
def deserialize(cls):
"""Deserialize from a dict."""
pass
@classmethod
def deserialize(cls, data: dict) -> "BD":
"""Deserialize the model.
Parameters
----------
data : dict
The serialized data
Returns
-------
BD
The deserialized descriptor
"""
if cls is BD:
return BD.get_class_by_type(data["type"]).deserialize(data)
raise NotImplementedError("Not implemented in class %s" % cls.__name__)

setattr(BD, fwd_method_name, BD.fwd)
delattr(BD, "fwd")
Expand Down
5 changes: 5 additions & 0 deletions deepmd/dpmodel/descriptor/se_e2_a.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
)


@BaseDescriptor.register("se_e2_a")
class DescrptSeA(NativeOP, BaseDescriptor):
r"""DeepPot-SE constructed from all information (both angular and radial) of
atomic configurations. The embedding takes the distance between atoms as input.
Expand Down Expand Up @@ -313,6 +314,8 @@ def call(
def serialize(self) -> dict:
"""Serialize the descriptor to dict."""
return {
"@class": "Descriptor",
"type": "se_e2_a",
"rcut": self.rcut,
"rcut_smth": self.rcut_smth,
"sel": self.sel,
Expand All @@ -339,6 +342,8 @@ def serialize(self) -> dict:
def deserialize(cls, data: dict) -> "DescrptSeA":
"""Deserialize from dict."""
data = copy.deepcopy(data)
data.pop("@class", None)
data.pop("type", None)
variables = data.pop("@variables")
embeddings = data.pop("embeddings")
env_mat = data.pop("env_mat")
Expand Down
9 changes: 3 additions & 6 deletions deepmd/pt/model/atomic_model/dp_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@
from deepmd.dpmodel import (
FittingOutputDef,
)
from deepmd.pt.model.descriptor.se_a import ( # noqa # TODO: should import all descriptors!!!
DescrptSeA,
from deepmd.pt.model.descriptor.descriptor import (
Descriptor,
)
from deepmd.pt.model.task.ener import ( # noqa # TODO: should import all fittings!
EnergyFittingNet,
Expand Down Expand Up @@ -98,16 +98,13 @@ def serialize(self) -> dict:
"type_map": self.type_map,
"descriptor": self.descriptor.serialize(),
"fitting": self.fitting_net.serialize(),
"descriptor_name": self.descriptor.__class__.__name__,
"fitting_name": self.fitting_net.__class__.__name__,
}

@classmethod
def deserialize(cls, data) -> "DPAtomicModel":
data = copy.deepcopy(data)
descriptor_obj = getattr(
sys.modules[__name__], data["descriptor_name"]
).deserialize(data["descriptor"])
descriptor_obj = Descriptor.deserialize(data["descriptor"])
fitting_obj = getattr(sys.modules[__name__], data["fitting_name"]).deserialize(
data["fitting"]
)
Expand Down
41 changes: 33 additions & 8 deletions deepmd/pt/model/descriptor/descriptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,14 @@
Dict,
List,
Optional,
Type,
)

import torch

from deepmd.common import (
j_get_type,
)
from deepmd.pt.model.network.network import (
TypeEmbedNet,
)
Expand Down Expand Up @@ -92,16 +96,37 @@ def data_stat_key(self):

def __new__(cls, *args, **kwargs):
if cls is Descriptor:
try:
descrpt_type = kwargs["type"]
except KeyError:
raise KeyError("the type of descriptor should be set by `type`")
if descrpt_type in Descriptor.__plugins.plugins:
cls = Descriptor.__plugins.plugins[descrpt_type]
else:
raise RuntimeError("Unknown descriptor type: " + descrpt_type)
cls = cls.get_class_by_type(j_get_type(kwargs, cls.__name__))
return super().__new__(cls)

@classmethod
def get_class_by_type(cls, descrpt_type: str) -> Type["Descriptor"]:
if descrpt_type in Descriptor.__plugins.plugins:
return Descriptor.__plugins.plugins[descrpt_type]
else:
raise RuntimeError("Unknown descriptor type: " + descrpt_type)

@classmethod
def deserialize(cls, data: dict) -> "Descriptor":
"""Deserialize the model.
There is no suffix in a native DP model, but it is important
for the TF backend.
Parameters
----------
data : dict
The serialized data
Returns
-------
Descriptor
The deserialized descriptor
"""
if cls is Descriptor:
return Descriptor.get_class_by_type(data["type"]).deserialize(data)
raise NotImplementedError("Not implemented in class %s" % cls.__name__)


class DescriptorBlock(torch.nn.Module, ABC):
"""The building block of descriptor.
Expand Down
4 changes: 4 additions & 0 deletions deepmd/pt/model/descriptor/se_a.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,8 @@ def set_stat_mean_and_stddev(
def serialize(self) -> dict:
obj = self.sea
return {
"@class": "Descriptor",
"type": "se_e2_a",
"rcut": obj.rcut,
"rcut_smth": obj.rcut_smth,
"sel": obj.sel,
Expand All @@ -219,6 +221,8 @@ def serialize(self) -> dict:
@classmethod
def deserialize(cls, data: dict) -> "DescrptSeA":
data = data.copy()
data.pop("@class", None)
data.pop("type", None)
variables = data.pop("@variables")
embeddings = data.pop("embeddings")
env_mat = data.pop("env_mat")
Expand Down
17 changes: 9 additions & 8 deletions deepmd/tf/descriptor/descriptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@

import numpy as np

from deepmd.common import (
j_get_type,
)
from deepmd.tf.env import (
GLOBAL_TF_FLOAT_PRECISION,
tf,
Expand Down Expand Up @@ -67,19 +70,15 @@ class SomeDescript(Descriptor):
return Descriptor.__plugins.register(key)

@classmethod
def get_class_by_input(cls, input: dict):
try:
descrpt_type = input["type"]
except KeyError:
raise KeyError("the type of descriptor should be set by `type`")
def get_class_by_type(cls, descrpt_type: str):
if descrpt_type in Descriptor.__plugins.plugins:
return Descriptor.__plugins.plugins[descrpt_type]
else:
raise RuntimeError("Unknown descriptor type: " + descrpt_type)

def __new__(cls, *args, **kwargs):
if cls is Descriptor:
cls = cls.get_class_by_input(kwargs)
cls = cls.get_class_by_type(j_get_type(kwargs, cls.__name__))
return super().__new__(cls)

@abstractmethod
Expand Down Expand Up @@ -507,7 +506,7 @@ def update_sel(cls, global_jdata: dict, local_jdata: dict):
The local data refer to the current class
"""
# call subprocess
cls = cls.get_class_by_input(local_jdata)
cls = cls.get_class_by_type(j_get_type(local_jdata, cls.__name__))
return cls.update_sel(global_jdata, local_jdata)

@classmethod
Expand All @@ -530,7 +529,9 @@ def deserialize(cls, data: dict, suffix: str = "") -> "Descriptor":
The deserialized descriptor
"""
if cls is Descriptor:
return Descriptor.get_class_by_input(data).deserialize(data, suffix=suffix)
return Descriptor.get_class_by_type(
j_get_type(data, cls.__name__)
).deserialize(data, suffix=suffix)
raise NotImplementedError("Not implemented in class %s" % cls.__name__)

def serialize(self, suffix: str = "") -> dict:
Expand Down
4 changes: 4 additions & 0 deletions deepmd/tf/descriptor/se_a.py
Original file line number Diff line number Diff line change
Expand Up @@ -1368,6 +1368,8 @@ def deserialize(cls, data: dict, suffix: str = ""):
if cls is not DescrptSeA:
raise NotImplementedError("Not implemented in class %s" % cls.__name__)
data = data.copy()
data.pop("@class", None)
data.pop("type", None)
embedding_net_variables = cls.deserialize_network(
data.pop("embeddings"), suffix=suffix
)
Expand Down Expand Up @@ -1418,6 +1420,8 @@ def serialize(self, suffix: str = "") -> dict:
# but instead a part of the input data. Maybe the interface should be refactored...

return {
"@class": "Descriptor",
"type": "se_e2_a",
"rcut": self.rcut_r,
"rcut_smth": self.rcut_r_smth,
"sel": self.sel_a,
Expand Down
Loading

0 comments on commit 649fdca

Please sign in to comment.