Skip to content

Commit

Permalink
Chore: refactor dpmodel (#3663)
Browse files Browse the repository at this point in the history
To decouple Model from DPAtomicModel, allowing make_model on different
atomic model.

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
anyangml and pre-commit-ci[bot] authored Apr 15, 2024
1 parent b3d6c70 commit 25435c0
Show file tree
Hide file tree
Showing 29 changed files with 339 additions and 215 deletions.
4 changes: 2 additions & 2 deletions deepmd/dpmodel/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
NativeOP,
)
from .model import (
DPModel,
DPModelCommon,
)
from .output_def import (
FittingOutputDef,
Expand All @@ -19,7 +19,7 @@
)

__all__ = [
"DPModel",
"DPModelCommon",
"PRECISION_DICT",
"DEFAULT_PRECISION",
"NativeOP",
Expand Down
4 changes: 2 additions & 2 deletions deepmd/dpmodel/model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
"""

from .dp_model import (
DPModel,
DPModelCommon,
)
from .make_model import (
make_model,
Expand All @@ -23,7 +23,7 @@
)

__all__ = [
"DPModel",
"DPModelCommon",
"SpinModel",
"make_model",
]
17 changes: 14 additions & 3 deletions deepmd/dpmodel/model/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,11 @@ class BaseBaseModel(ABC, PluginVariant, make_plugin_registry("model")):

def __new__(cls, *args, **kwargs):
if inspect.isabstract(cls):
cls = cls.get_class_by_type(kwargs.get("type", "standard"))
# getting model type based on fitting type
model_type = kwargs.get("type", "standard")
if model_type == "standard":
model_type = kwargs.get("fitting", {}).get("type", "ener")
cls = cls.get_class_by_type(model_type)
return super().__new__(cls)

@abstractmethod
Expand Down Expand Up @@ -118,7 +122,10 @@ def deserialize(cls, data: dict) -> "BaseBaseModel":
The deserialized model
"""
if inspect.isabstract(cls):
return cls.get_class_by_type(data["type"]).deserialize(data)
model_type = data.get("type", "standard")
if model_type == "standard":
model_type = data.get("fitting", {}).get("type", "ener")
return cls.get_class_by_type(model_type).deserialize(data)
raise NotImplementedError("Not implemented in class %s" % cls.__name__)

model_def_script: str
Expand Down Expand Up @@ -151,7 +158,11 @@ def update_sel(cls, global_jdata: dict, local_jdata: dict):
local_jdata : dict
The local data refer to the current class
"""
cls = cls.get_class_by_type(local_jdata.get("type", "standard"))
# getting model type based on fitting type
model_type = local_jdata.get("type", "standard")
if model_type == "standard":
model_type = local_jdata.get("fitting", {}).get("type", "ener")
cls = cls.get_class_by_type(model_type)
return cls.update_sel(global_jdata, local_jdata)

return BaseBaseModel
Expand Down
13 changes: 1 addition & 12 deletions deepmd/dpmodel/model/dp_model.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,13 @@
# SPDX-License-Identifier: LGPL-3.0-or-later


from deepmd.dpmodel.atomic_model import (
DPAtomicModel,
)
from deepmd.dpmodel.descriptor.base_descriptor import (
BaseDescriptor,
)
from deepmd.dpmodel.model.base_model import (
BaseModel,
)

from .make_model import (
make_model,
)


# use "class" to resolve "Variable not allowed in type expression"
@BaseModel.register("standard")
class DPModel(make_model(DPAtomicModel)):
class DPModelCommon:
@classmethod
def update_sel(cls, global_jdata: dict, local_jdata: dict):
"""Update the selection and perform neighbor statistics.
Expand Down
27 changes: 27 additions & 0 deletions deepmd/dpmodel/model/ener_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from deepmd.dpmodel.atomic_model.dp_atomic_model import (
DPAtomicModel,
)
from deepmd.dpmodel.model.base_model import (
BaseModel,
)

from .dp_model import (
DPModelCommon,
)
from .make_model import (
make_model,
)

DPEnergyModel_ = make_model(DPAtomicModel)


@BaseModel.register("ener")
class EnergyModel(DPModelCommon, DPEnergyModel_):
def __init__(
self,
*args,
**kwargs,
):
DPModelCommon.__init__(self)
DPEnergyModel_.__init__(self, *args, **kwargs)
10 changes: 5 additions & 5 deletions deepmd/dpmodel/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
from deepmd.dpmodel.fitting.ener_fitting import (
EnergyFittingNet,
)
from deepmd.dpmodel.model.dp_model import (
DPModel,
from deepmd.dpmodel.model.ener_model import (
EnergyModel,
)
from deepmd.dpmodel.model.spin_model import (
SpinModel,
Expand All @@ -16,8 +16,8 @@
)


def get_standard_model(data: dict) -> DPModel:
"""Get a standard DPModel from a dictionary.
def get_standard_model(data: dict) -> EnergyModel:
"""Get a EnergyModel from a dictionary.
Parameters
----------
Expand All @@ -41,7 +41,7 @@ def get_standard_model(data: dict) -> DPModel:
)
else:
raise ValueError(f"Unknown fitting type {fitting_type}")
return DPModel(
return EnergyModel(
descriptor=descriptor,
fitting=fitting,
type_map=data["type_map"],
Expand Down
11 changes: 8 additions & 3 deletions deepmd/dpmodel/model/spin_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,11 @@

import numpy as np

from deepmd.dpmodel.model.dp_model import (
DPModel,
from deepmd.dpmodel.atomic_model.dp_atomic_model import (
DPAtomicModel,
)
from deepmd.dpmodel.model.make_model import (
make_model,
)
from deepmd.utils.spin import (
Spin,
Expand Down Expand Up @@ -259,7 +262,9 @@ def serialize(self) -> dict:

@classmethod
def deserialize(cls, data) -> "SpinModel":
backbone_model_obj = DPModel.deserialize(data["backbone_model"])
backbone_model_obj = make_model(DPAtomicModel).deserialize(
data["backbone_model"]
)
spin = Spin.deserialize(data["spin"])
return cls(
backbone_model=backbone_model_obj,
Expand Down
16 changes: 16 additions & 0 deletions deepmd/pt/model/atomic_model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,21 +17,37 @@
from .base_atomic_model import (
BaseAtomicModel,
)
from .dipole_atomic_model import (
DPDipoleAtomicModel,
)
from .dos_atomic_model import (
DPDOSAtomicModel,
)
from .dp_atomic_model import (
DPAtomicModel,
)
from .energy_atomic_model import (
DPEnergyAtomicModel,
)
from .linear_atomic_model import (
DPZBLLinearEnergyAtomicModel,
LinearEnergyAtomicModel,
)
from .pairtab_atomic_model import (
PairTabAtomicModel,
)
from .polar_atomic_model import (
DPPolarAtomicModel,
)

__all__ = [
"BaseAtomicModel",
"DPAtomicModel",
"DPDOSAtomicModel",
"DPEnergyAtomicModel",
"PairTabAtomicModel",
"LinearEnergyAtomicModel",
"DPPolarAtomicModel",
"DPDipoleAtomicModel",
"DPZBLLinearEnergyAtomicModel",
]
28 changes: 28 additions & 0 deletions deepmd/pt/model/atomic_model/dipole_atomic_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from typing import (
Dict,
)

import torch

from deepmd.pt.model.task.dipole import (
DipoleFittingNet,
)

from .dp_atomic_model import (
DPAtomicModel,
)


class DPDipoleAtomicModel(DPAtomicModel):
def __init__(self, descriptor, fitting, type_map, **kwargs):
assert isinstance(fitting, DipoleFittingNet)
super().__init__(descriptor, fitting, type_map, **kwargs)

def apply_out_stat(
self,
ret: Dict[str, torch.Tensor],
atype: torch.Tensor,
):
# dipole not applying bias
return ret
14 changes: 14 additions & 0 deletions deepmd/pt/model/atomic_model/dos_atomic_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from deepmd.pt.model.task.dos import (
DOSFittingNet,
)

from .dp_atomic_model import (
DPAtomicModel,
)


class DPDOSAtomicModel(DPAtomicModel):
def __init__(self, descriptor, fitting, type_map, **kwargs):
assert isinstance(fitting, DOSFittingNet)
super().__init__(descriptor, fitting, type_map, **kwargs)
20 changes: 20 additions & 0 deletions deepmd/pt/model/atomic_model/energy_atomic_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from deepmd.pt.model.task.ener import (
EnergyFittingNet,
EnergyFittingNetDirect,
InvarFitting,
)

from .dp_atomic_model import (
DPAtomicModel,
)


class DPEnergyAtomicModel(DPAtomicModel):
def __init__(self, descriptor, fitting, type_map, **kwargs):
assert (
isinstance(fitting, EnergyFittingNet)
or isinstance(fitting, EnergyFittingNetDirect)
or isinstance(fitting, InvarFitting)
)
super().__init__(descriptor, fitting, type_map, **kwargs)
28 changes: 28 additions & 0 deletions deepmd/pt/model/atomic_model/polar_atomic_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from typing import (
Dict,
)

import torch

from deepmd.pt.model.task.polarizability import (
PolarFittingNet,
)

from .dp_atomic_model import (
DPAtomicModel,
)


class DPPolarAtomicModel(DPAtomicModel):
def __init__(self, descriptor, fitting, type_map, **kwargs):
assert isinstance(fitting, PolarFittingNet)
super().__init__(descriptor, fitting, type_map, **kwargs)

def apply_out_stat(
self,
ret: Dict[str, torch.Tensor],
atype: torch.Tensor,
):
# TODO: migrate bias
return ret
26 changes: 23 additions & 3 deletions deepmd/pt/model/model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,14 @@
Spin,
)

from .dipole_model import (
DipoleModel,
)
from .dos_model import (
DOSModel,
)
from .dp_model import (
DPModel,
DPModelCommon,
)
from .dp_zbl_model import (
DPZBLModel,
Expand All @@ -51,6 +57,9 @@
from .model import (
BaseModel,
)
from .polar_model import (
PolarModel,
)
from .spin_model import (
SpinEnergyModel,
SpinModel,
Expand Down Expand Up @@ -161,7 +170,18 @@ def get_standard_model(model_params):
atom_exclude_types = model_params.get("atom_exclude_types", [])
pair_exclude_types = model_params.get("pair_exclude_types", [])

model = DPModel(
if fitting_net["type"] == "dipole":
modelcls = DipoleModel
elif fitting_net["type"] == "polar":
modelcls = PolarModel
elif fitting_net["type"] == "dos":
modelcls = DOSModel
elif fitting_net["type"] in ["ener", "direct_force_ener"]:
modelcls = EnergyModel
else:
raise RuntimeError(f"Unknown fitting type: {fitting_net['type']}")

model = modelcls(
descriptor=descriptor,
fitting=fitting,
type_map=model_params["type_map"],
Expand All @@ -184,7 +204,7 @@ def get_model(model_params):
__all__ = [
"BaseModel",
"get_model",
"DPModel",
"DPModelCommon",
"EnergyModel",
"FrozenModel",
"SpinModel",
Expand Down
Loading

0 comments on commit 25435c0

Please sign in to comment.