Skip to content

Commit

Permalink
apply PluginVariant and make_plugin_registry to classes (#3346)
Browse files Browse the repository at this point in the history
Signed-off-by: Jinzhe Zeng <jinzhe.zeng@rutgers.edu>
  • Loading branch information
njzjz authored Feb 28, 2024
1 parent b1de9e6 commit 004ebd6
Show file tree
Hide file tree
Showing 15 changed files with 59 additions and 255 deletions.
33 changes: 5 additions & 28 deletions deepmd/backend/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@
)

from deepmd.utils.plugin import (
Plugin,
PluginVariant,
make_plugin_registry,
)

if TYPE_CHECKING:
Expand All @@ -33,7 +33,7 @@
)


class Backend(PluginVariant):
class Backend(PluginVariant, make_plugin_registry("backend")):
r"""General backend class.
Examples
Expand All @@ -44,24 +44,6 @@ class Backend(PluginVariant):
... pass
"""

__plugins = Plugin()

@staticmethod
def register(key: str) -> Callable[[object], object]:
"""Register a backend plugin.
Parameters
----------
key : str
the key of a backend
Returns
-------
Callable[[object], object]
the decorator to register backend
"""
return Backend.__plugins.register(key.lower())

@staticmethod
def get_backend(key: str) -> Type["Backend"]:
"""Get the backend by key.
Expand All @@ -76,12 +58,7 @@ def get_backend(key: str) -> Type["Backend"]:
Backend
the backend
"""
try:
backend = Backend.__plugins.get_plugin(key.lower())
except KeyError:
raise KeyError(f"Backend {key} is not registered.")
assert isinstance(backend, type)
return backend
return Backend.get_class_by_type(key)

@staticmethod
def get_backends() -> Dict[str, Type["Backend"]]:
Expand All @@ -92,7 +69,7 @@ def get_backends() -> Dict[str, Type["Backend"]]:
list
all the registered backends
"""
return Backend.__plugins.plugins
return Backend.get_plugins()

@staticmethod
def get_backends_by_feature(
Expand All @@ -112,7 +89,7 @@ def get_backends_by_feature(
"""
return {
key: backend
for key, backend in Backend.__plugins.plugins.items()
for key, backend in Backend.get_backends().items()
if backend.features & feature
}

Expand Down
38 changes: 3 additions & 35 deletions deepmd/dpmodel/descriptor/make_base_descriptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,8 @@
abstractmethod,
)
from typing import (
Callable,
List,
Optional,
Type,
)

from deepmd.common import (
Expand All @@ -17,7 +15,8 @@
DPPath,
)
from deepmd.utils.plugin import (
Plugin,
PluginVariant,
make_plugin_registry,
)


Expand All @@ -37,45 +36,14 @@ def make_base_descriptor(
"""

class BD(ABC):
class BD(ABC, PluginVariant, make_plugin_registry("descriptor")):
"""Base descriptor provides the interfaces of descriptor."""

__plugins = Plugin()

@staticmethod
def register(key: str) -> Callable:
"""Register a descriptor plugin.
Parameters
----------
key : str
the key of a descriptor
Returns
-------
Descriptor
the registered descriptor
Examples
--------
>>> @Descriptor.register("some_descrpt")
class SomeDescript(Descriptor):
pass
"""
return BD.__plugins.register(key)

def __new__(cls, *args, **kwargs):
if cls is BD:
cls = cls.get_class_by_type(j_get_type(kwargs, cls.__name__))
return super().__new__(cls)

@classmethod
def get_class_by_type(cls, descrpt_type: str) -> Type["BD"]:
if descrpt_type in BD.__plugins.plugins:
return BD.__plugins.plugins[descrpt_type]
else:
raise RuntimeError("Unknown descriptor type: " + descrpt_type)

@abstractmethod
def get_rcut(self) -> float:
"""Returns the cut-off radius."""
Expand Down
38 changes: 3 additions & 35 deletions deepmd/dpmodel/fitting/make_base_fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,8 @@
abstractmethod,
)
from typing import (
Callable,
Dict,
Optional,
Type,
)

from deepmd.common import (
Expand All @@ -17,7 +15,8 @@
FittingOutputDef,
)
from deepmd.utils.plugin import (
Plugin,
PluginVariant,
make_plugin_registry,
)


Expand All @@ -37,45 +36,14 @@ def make_base_fitting(
"""

class BF(ABC):
class BF(ABC, PluginVariant, make_plugin_registry("fitting")):
"""Base fitting provides the interfaces of fitting net."""

__plugins = Plugin()

@staticmethod
def register(key: str) -> Callable[[object], object]:
"""Register a descriptor plugin.
Parameters
----------
key : str
the key of a descriptor
Returns
-------
callable[[object], object]
the registered descriptor
Examples
--------
>>> @Fitting.register("some_fitting")
class SomeFitting(Fitting):
pass
"""
return BF.__plugins.register(key)

def __new__(cls, *args, **kwargs):
if cls is BF:
cls = cls.get_class_by_type(j_get_type(kwargs, cls.__name__))
return super().__new__(cls)

@classmethod
def get_class_by_type(cls, fitting_type: str) -> Type["BF"]:
if fitting_type in BF.__plugins.plugins:
return BF.__plugins.plugins[fitting_type]
else:
raise RuntimeError("Unknown fitting type: " + fitting_type)

@abstractmethod
def output_def(self) -> FittingOutputDef:
"""Returns the output def of the fitting net."""
Expand Down
3 changes: 2 additions & 1 deletion deepmd/dpmodel/model/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,13 @@
)

from deepmd.utils.plugin import (
PluginVariant,
make_plugin_registry,
)


def make_base_model() -> Type[object]:
class BaseBaseModel(ABC, make_plugin_registry("model")):
class BaseBaseModel(ABC, PluginVariant, make_plugin_registry("model")):
"""Base class for final exported model that will be directly used for inference.
The class defines some abstractmethods that will be directly called by the
Expand Down
37 changes: 5 additions & 32 deletions deepmd/pt/model/descriptor/descriptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
abstractmethod,
)
from typing import (
Callable,
Dict,
List,
Optional,
Expand All @@ -22,60 +21,34 @@
from deepmd.pt.utils.env_mat_stat import (
EnvMatStatSe,
)
from deepmd.pt.utils.plugin import (
Plugin,
)
from deepmd.utils.env_mat_stat import (
StatItem,
)
from deepmd.utils.path import (
DPPath,
)
from deepmd.utils.plugin import (
make_plugin_registry,
)

log = logging.getLogger(__name__)


class DescriptorBlock(torch.nn.Module, ABC):
class DescriptorBlock(torch.nn.Module, ABC, make_plugin_registry("DescriptorBlock")):
"""The building block of descriptor.
Given the input descriptor, provide with the atomic coordinates,
atomic types and neighbor list, calculate the new descriptor.
"""

__plugins = Plugin()
local_cluster = False

@staticmethod
def register(key: str) -> Callable:
"""Register a DescriptorBlock plugin.
Parameters
----------
key : str
the key of a DescriptorBlock
Returns
-------
DescriptorBlock
the registered DescriptorBlock
Examples
--------
>>> @DescriptorBlock.register("some_descrpt")
class SomeDescript(DescriptorBlock):
pass
"""
return DescriptorBlock.__plugins.register(key)

def __new__(cls, *args, **kwargs):
if cls is DescriptorBlock:
try:
descrpt_type = kwargs["type"]
except KeyError:
raise KeyError("the type of DescriptorBlock should be set by `type`")
if descrpt_type in DescriptorBlock.__plugins.plugins:
cls = DescriptorBlock.__plugins.plugins[descrpt_type]
else:
raise RuntimeError("Unknown DescriptorBlock type: " + descrpt_type)
cls = cls.get_class_by_type(descrpt_type)
return super().__new__(cls)

@abstractmethod
Expand Down
38 changes: 4 additions & 34 deletions deepmd/tf/descriptor/descriptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
)
from typing import (
Any,
Callable,
Dict,
List,
Optional,
Expand All @@ -21,12 +20,14 @@
tf,
)
from deepmd.tf.utils import (
Plugin,
PluginVariant,
)
from deepmd.utils.plugin import (
make_plugin_registry,
)


class Descriptor(PluginVariant):
class Descriptor(PluginVariant, make_plugin_registry("descriptor")):
r"""The abstract class for descriptors. All specific descriptors should
be based on this class.
Expand All @@ -45,37 +46,6 @@ class Descriptor(PluginVariant):
that can be called by other classes.
"""

__plugins = Plugin()

@staticmethod
def register(key: str) -> Callable:
"""Register a descriptor plugin.
Parameters
----------
key : str
the key of a descriptor
Returns
-------
Descriptor
the registered descriptor
Examples
--------
>>> @Descriptor.register("some_descrpt")
class SomeDescript(Descriptor):
pass
"""
return Descriptor.__plugins.register(key)

@classmethod
def get_class_by_type(cls, descrpt_type: str):
if descrpt_type in Descriptor.__plugins.plugins:
return Descriptor.__plugins.plugins[descrpt_type]
else:
raise RuntimeError("Unknown descriptor type: " + descrpt_type)

def __new__(cls, *args, **kwargs):
if cls is Descriptor:
cls = cls.get_class_by_type(j_get_type(kwargs, cls.__name__))
Expand Down
Loading

0 comments on commit 004ebd6

Please sign in to comment.