Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

store type in fitting serialization data #3331

Merged
merged 3 commits into from
Feb 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 3 additions & 8 deletions deepmd/dpmodel/atomic_model/dp_atomic_model.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import copy
import sys
from typing import (
Dict,
List,
Expand All @@ -12,9 +11,8 @@
from deepmd.dpmodel.descriptor.base_descriptor import (
BaseDescriptor,
)
from deepmd.dpmodel.fitting import ( # noqa # TODO: should import all fittings!
EnergyFittingNet,
InvarFitting,
from deepmd.dpmodel.fitting.base_fitting import (
BaseFitting,
)
from deepmd.dpmodel.output_def import (
FittingOutputDef,
Expand Down Expand Up @@ -135,16 +133,13 @@ def serialize(self) -> dict:
"type_map": self.type_map,
"descriptor": self.descriptor.serialize(),
"fitting": self.fitting.serialize(),
"fitting_name": self.fitting.__class__.__name__,
}

@classmethod
def deserialize(cls, data) -> "DPAtomicModel":
data = copy.deepcopy(data)
descriptor_obj = BaseDescriptor.deserialize(data["descriptor"])
fitting_obj = getattr(sys.modules[__name__], data["fitting_name"]).deserialize(
data["fitting"]
)
fitting_obj = BaseFitting.deserialize(data["fitting"])
obj = cls(descriptor_obj, fitting_obj, type_map=data["type_map"])
return obj

Expand Down
5 changes: 5 additions & 0 deletions deepmd/dpmodel/fitting/dipole_fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@
from deepmd.dpmodel import (
DEFAULT_PRECISION,
)
from deepmd.dpmodel.fitting.base_fitting import (
BaseFitting,
)
from deepmd.dpmodel.output_def import (
FittingOutputDef,
OutputVariableDef,
Expand All @@ -22,6 +25,7 @@
)


@BaseFitting.register("dipole")
@fitting_check_output
class DipoleFitting(GeneralFitting):
r"""Fitting rotationally equivariant diploe of the system.
Expand Down Expand Up @@ -142,6 +146,7 @@ def _net_out_dim(self):

def serialize(self) -> dict:
data = super().serialize()
data["type"] = "dipole"
data["embedding_width"] = self.embedding_width
data["old_impl"] = self.old_impl
data["r_differentiable"] = self.r_differentiable
Expand Down
8 changes: 8 additions & 0 deletions deepmd/dpmodel/fitting/ener_fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
)


@InvarFitting.register("ener")
class EnergyFittingNet(InvarFitting):
def __init__(
self,
Expand Down Expand Up @@ -70,3 +71,10 @@ def deserialize(cls, data: dict) -> "GeneralFitting":
data.pop("var_name")
data.pop("dim_out")
return super().deserialize(data)

def serialize(self) -> dict:
"""Serialize the fitting to dict."""
return {
**super().serialize(),
"type": "ener",
}
3 changes: 3 additions & 0 deletions deepmd/dpmodel/fitting/general_fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,7 @@ def __getitem__(self, key):
def serialize(self) -> dict:
"""Serialize the fitting to dict."""
return {
"@class": "Fitting",
"var_name": self.var_name,
"ntypes": self.ntypes,
"dim_descrpt": self.dim_descrpt,
Expand Down Expand Up @@ -240,6 +241,8 @@ def serialize(self) -> dict:
@classmethod
def deserialize(cls, data: dict) -> "GeneralFitting":
data = copy.deepcopy(data)
data.pop("@class")
data.pop("type")
variables = data.pop("@variables")
nets = data.pop("nets")
obj = cls(**data)
Expand Down
2 changes: 2 additions & 0 deletions deepmd/dpmodel/fitting/invar_fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
)


@GeneralFitting.register("invar")
@fitting_check_output
class InvarFitting(GeneralFitting):
r"""Fitting the energy (or a rotationally invariant porperty of `dim_out`) of the system. The force and the virial can also be trained.
Expand Down Expand Up @@ -162,6 +163,7 @@ def __init__(

def serialize(self) -> dict:
data = super().serialize()
data["type"] = "invar"
data["dim_out"] = self.dim_out
data["atom_ener"] = self.atom_ener
return data
Expand Down
66 changes: 61 additions & 5 deletions deepmd/dpmodel/fitting/make_base_fitting.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,24 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from abc import (
ABC,
abstractclassmethod,
abstractmethod,
)
from typing import (
Callable,
Dict,
Optional,
Type,
)

from deepmd.common import (
j_get_type,
)
from deepmd.dpmodel.output_def import (
FittingOutputDef,
)
from deepmd.utils.plugin import (
Plugin,
)


def make_base_fitting(
Expand All @@ -33,6 +40,42 @@
class BF(ABC):
"""Base fitting provides the interfaces of fitting net."""

__plugins = Plugin()

@staticmethod
def register(key: str) -> Callable[[object], object]:
"""Register a descriptor plugin.

Parameters
----------
key : str
the key of a descriptor

Returns
-------
callable[[object], object]
the registered descriptor

Examples
--------
>>> @Fitting.register("some_fitting")
class SomeFitting(Fitting):
pass
"""
return BF.__plugins.register(key)

def __new__(cls, *args, **kwargs):
if cls is BF:
cls = cls.get_class_by_type(j_get_type(kwargs, cls.__name__))
return super().__new__(cls)

@classmethod
def get_class_by_type(cls, fitting_type: str) -> Type["BF"]:
if fitting_type in BF.__plugins.plugins:
return BF.__plugins.plugins[fitting_type]
else:
raise RuntimeError("Unknown fitting type: " + fitting_type)

Check warning on line 77 in deepmd/dpmodel/fitting/make_base_fitting.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/fitting/make_base_fitting.py#L77

Added line #L77 was not covered by tests

@abstractmethod
def output_def(self) -> FittingOutputDef:
"""Returns the output def of the fitting net."""
Expand Down Expand Up @@ -65,10 +108,23 @@
"""Serialize the obj to dict."""
pass

@abstractclassmethod
def deserialize(cls):
"""Deserialize from a dict."""
pass
@classmethod
def deserialize(cls, data: dict) -> "BF":
"""Deserialize the fitting.

Parameters
----------
data : dict
The serialized data

Returns
-------
BF
The deserialized fitting
"""
if cls is BF:
return BF.get_class_by_type(data["type"]).deserialize(data)
raise NotImplementedError("Not implemented in class %s" % cls.__name__)

Check warning on line 127 in deepmd/dpmodel/fitting/make_base_fitting.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/fitting/make_base_fitting.py#L127

Added line #L127 was not covered by tests

setattr(BF, fwd_method_name, BF.fwd)
delattr(BF, "fwd")
Expand Down
5 changes: 5 additions & 0 deletions deepmd/dpmodel/fitting/polarizability_fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@
from deepmd.dpmodel import (
DEFAULT_PRECISION,
)
from deepmd.dpmodel.fitting.base_fitting import (
BaseFitting,
)
from deepmd.dpmodel.output_def import (
FittingOutputDef,
OutputVariableDef,
Expand All @@ -25,6 +28,7 @@
)


@BaseFitting.register("polar")
@fitting_check_output
class PolarFitting(GeneralFitting):
r"""Fitting rotationally equivariant polarizability of the system.
Expand Down Expand Up @@ -166,6 +170,7 @@ def _net_out_dim(self):

def serialize(self) -> dict:
data = super().serialize()
data["type"] = "polar"
data["embedding_width"] = self.embedding_width
data["old_impl"] = self.old_impl
data["fit_diag"] = self.fit_diag
Expand Down
11 changes: 3 additions & 8 deletions deepmd/pt/model/atomic_model/dp_atomic_model.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import copy
import logging
import sys
from typing import (
Dict,
List,
Expand All @@ -16,9 +15,8 @@
from deepmd.pt.model.descriptor.descriptor import (
Descriptor,
)
from deepmd.pt.model.task.ener import ( # noqa # TODO: should import all fittings!
EnergyFittingNet,
InvarFitting,
from deepmd.pt.model.task.base_fitting import (
BaseFitting,
)
from deepmd.pt.utils.utils import (
dict_to_device,
Expand Down Expand Up @@ -98,16 +96,13 @@ def serialize(self) -> dict:
"type_map": self.type_map,
"descriptor": self.descriptor.serialize(),
"fitting": self.fitting_net.serialize(),
"fitting_name": self.fitting_net.__class__.__name__,
}

@classmethod
def deserialize(cls, data) -> "DPAtomicModel":
data = copy.deepcopy(data)
descriptor_obj = Descriptor.deserialize(data["descriptor"])
fitting_obj = getattr(sys.modules[__name__], data["fitting_name"]).deserialize(
data["fitting"]
)
fitting_obj = BaseFitting.deserialize(data["fitting"])
obj = cls(descriptor_obj, fitting_obj, type_map=data["type_map"])
return obj

Expand Down
4 changes: 4 additions & 0 deletions deepmd/pt/model/task/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@
from .fitting import (
Fitting,
)
from .polarizability import (
PolarFittingNet,
)
from .type_predict import (
TypePredictNet,
)
Expand All @@ -31,4 +34,5 @@
"Fitting",
"BaseFitting",
"TypePredictNet",
"PolarFittingNet",
]
2 changes: 2 additions & 0 deletions deepmd/pt/model/task/dipole.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
log = logging.getLogger(__name__)


@GeneralFitting.register("dipole")
class DipoleFittingNet(GeneralFitting):
"""Construct a dipole fitting net.

Expand Down Expand Up @@ -111,6 +112,7 @@ def _net_out_dim(self):

def serialize(self) -> dict:
data = super().serialize()
data["type"] = "dipole"
data["embedding_width"] = self.embedding_width
data["old_impl"] = self.old_impl
data["r_differentiable"] = self.r_differentiable
Expand Down
9 changes: 9 additions & 0 deletions deepmd/pt/model/task/ener.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
log = logging.getLogger(__name__)


@GeneralFitting.register("invar")
@fitting_check_output
class InvarFitting(GeneralFitting):
"""Construct a fitting net for energy.
Expand Down Expand Up @@ -129,6 +130,7 @@ def _net_out_dim(self):

def serialize(self) -> dict:
data = super().serialize()
data["type"] = "invar"
data["dim_out"] = self.dim_out
data["atom_ener"] = self.atom_ener
return data
Expand Down Expand Up @@ -238,6 +240,13 @@ def deserialize(cls, data: dict) -> "GeneralFitting":
data.pop("dim_out")
return super().deserialize(data)

def serialize(self) -> dict:
"""Serialize the fitting to dict."""
return {
**super().serialize(),
"type": "ener",
}


@Fitting.register("direct_force")
@Fitting.register("direct_force_ener")
Expand Down
Loading
Loading