From 004ebd6ea540a31536b0b2893768ed9caef622c5 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Tue, 27 Feb 2024 19:39:45 -0500 Subject: [PATCH] apply PluginVariant and make_plugin_registry to classes (#3346) Signed-off-by: Jinzhe Zeng --- deepmd/backend/backend.py | 33 ++---------- .../descriptor/make_base_descriptor.py | 38 ++------------ deepmd/dpmodel/fitting/make_base_fitting.py | 38 ++------------ deepmd/dpmodel/model/base_model.py | 3 +- deepmd/pt/model/descriptor/descriptor.py | 37 ++----------- deepmd/tf/descriptor/descriptor.py | 38 ++------------ deepmd/tf/fit/fitting.py | 52 ++----------------- deepmd/tf/model/__init__.py | 17 ++++++ deepmd/tf/model/frozen.py | 1 + deepmd/tf/model/linear.py | 1 + deepmd/tf/model/model.py | 47 ++--------------- deepmd/tf/model/multi.py | 1 + deepmd/tf/model/pairtab.py | 1 + deepmd/tf/model/pairwise_dprc.py | 1 + deepmd/utils/plugin.py | 6 +++ 15 files changed, 59 insertions(+), 255 deletions(-) diff --git a/deepmd/backend/backend.py b/deepmd/backend/backend.py index f1ef4cb52a..8f7bca319e 100644 --- a/deepmd/backend/backend.py +++ b/deepmd/backend/backend.py @@ -16,8 +16,8 @@ ) from deepmd.utils.plugin import ( - Plugin, PluginVariant, + make_plugin_registry, ) if TYPE_CHECKING: @@ -33,7 +33,7 @@ ) -class Backend(PluginVariant): +class Backend(PluginVariant, make_plugin_registry("backend")): r"""General backend class. Examples @@ -44,24 +44,6 @@ class Backend(PluginVariant): ... pass """ - __plugins = Plugin() - - @staticmethod - def register(key: str) -> Callable[[object], object]: - """Register a backend plugin. - - Parameters - ---------- - key : str - the key of a backend - - Returns - ------- - Callable[[object], object] - the decorator to register backend - """ - return Backend.__plugins.register(key.lower()) - @staticmethod def get_backend(key: str) -> Type["Backend"]: """Get the backend by key. @@ -76,12 +58,7 @@ def get_backend(key: str) -> Type["Backend"]: Backend the backend """ - try: - backend = Backend.__plugins.get_plugin(key.lower()) - except KeyError: - raise KeyError(f"Backend {key} is not registered.") - assert isinstance(backend, type) - return backend + return Backend.get_class_by_type(key) @staticmethod def get_backends() -> Dict[str, Type["Backend"]]: @@ -92,7 +69,7 @@ def get_backends() -> Dict[str, Type["Backend"]]: list all the registered backends """ - return Backend.__plugins.plugins + return Backend.get_plugins() @staticmethod def get_backends_by_feature( @@ -112,7 +89,7 @@ def get_backends_by_feature( """ return { key: backend - for key, backend in Backend.__plugins.plugins.items() + for key, backend in Backend.get_backends().items() if backend.features & feature } diff --git a/deepmd/dpmodel/descriptor/make_base_descriptor.py b/deepmd/dpmodel/descriptor/make_base_descriptor.py index 2cdb5abd52..18416ff16b 100644 --- a/deepmd/dpmodel/descriptor/make_base_descriptor.py +++ b/deepmd/dpmodel/descriptor/make_base_descriptor.py @@ -4,10 +4,8 @@ abstractmethod, ) from typing import ( - Callable, List, Optional, - Type, ) from deepmd.common import ( @@ -17,7 +15,8 @@ DPPath, ) from deepmd.utils.plugin import ( - Plugin, + PluginVariant, + make_plugin_registry, ) @@ -37,45 +36,14 @@ def make_base_descriptor( """ - class BD(ABC): + class BD(ABC, PluginVariant, make_plugin_registry("descriptor")): """Base descriptor provides the interfaces of descriptor.""" - __plugins = Plugin() - - @staticmethod - def register(key: str) -> Callable: - """Register a descriptor plugin. - - Parameters - ---------- - key : str - the key of a descriptor - - Returns - ------- - Descriptor - the registered descriptor - - Examples - -------- - >>> @Descriptor.register("some_descrpt") - class SomeDescript(Descriptor): - pass - """ - return BD.__plugins.register(key) - def __new__(cls, *args, **kwargs): if cls is BD: cls = cls.get_class_by_type(j_get_type(kwargs, cls.__name__)) return super().__new__(cls) - @classmethod - def get_class_by_type(cls, descrpt_type: str) -> Type["BD"]: - if descrpt_type in BD.__plugins.plugins: - return BD.__plugins.plugins[descrpt_type] - else: - raise RuntimeError("Unknown descriptor type: " + descrpt_type) - @abstractmethod def get_rcut(self) -> float: """Returns the cut-off radius.""" diff --git a/deepmd/dpmodel/fitting/make_base_fitting.py b/deepmd/dpmodel/fitting/make_base_fitting.py index d206f8e39e..041076ba89 100644 --- a/deepmd/dpmodel/fitting/make_base_fitting.py +++ b/deepmd/dpmodel/fitting/make_base_fitting.py @@ -4,10 +4,8 @@ abstractmethod, ) from typing import ( - Callable, Dict, Optional, - Type, ) from deepmd.common import ( @@ -17,7 +15,8 @@ FittingOutputDef, ) from deepmd.utils.plugin import ( - Plugin, + PluginVariant, + make_plugin_registry, ) @@ -37,45 +36,14 @@ def make_base_fitting( """ - class BF(ABC): + class BF(ABC, PluginVariant, make_plugin_registry("fitting")): """Base fitting provides the interfaces of fitting net.""" - __plugins = Plugin() - - @staticmethod - def register(key: str) -> Callable[[object], object]: - """Register a descriptor plugin. - - Parameters - ---------- - key : str - the key of a descriptor - - Returns - ------- - callable[[object], object] - the registered descriptor - - Examples - -------- - >>> @Fitting.register("some_fitting") - class SomeFitting(Fitting): - pass - """ - return BF.__plugins.register(key) - def __new__(cls, *args, **kwargs): if cls is BF: cls = cls.get_class_by_type(j_get_type(kwargs, cls.__name__)) return super().__new__(cls) - @classmethod - def get_class_by_type(cls, fitting_type: str) -> Type["BF"]: - if fitting_type in BF.__plugins.plugins: - return BF.__plugins.plugins[fitting_type] - else: - raise RuntimeError("Unknown fitting type: " + fitting_type) - @abstractmethod def output_def(self) -> FittingOutputDef: """Returns the output def of the fitting net.""" diff --git a/deepmd/dpmodel/model/base_model.py b/deepmd/dpmodel/model/base_model.py index df9c926d6c..faf3e7cfff 100644 --- a/deepmd/dpmodel/model/base_model.py +++ b/deepmd/dpmodel/model/base_model.py @@ -11,12 +11,13 @@ ) from deepmd.utils.plugin import ( + PluginVariant, make_plugin_registry, ) def make_base_model() -> Type[object]: - class BaseBaseModel(ABC, make_plugin_registry("model")): + class BaseBaseModel(ABC, PluginVariant, make_plugin_registry("model")): """Base class for final exported model that will be directly used for inference. The class defines some abstractmethods that will be directly called by the diff --git a/deepmd/pt/model/descriptor/descriptor.py b/deepmd/pt/model/descriptor/descriptor.py index 91e0a2527a..964cdb01eb 100644 --- a/deepmd/pt/model/descriptor/descriptor.py +++ b/deepmd/pt/model/descriptor/descriptor.py @@ -5,7 +5,6 @@ abstractmethod, ) from typing import ( - Callable, Dict, List, Optional, @@ -22,60 +21,34 @@ from deepmd.pt.utils.env_mat_stat import ( EnvMatStatSe, ) -from deepmd.pt.utils.plugin import ( - Plugin, -) from deepmd.utils.env_mat_stat import ( StatItem, ) from deepmd.utils.path import ( DPPath, ) +from deepmd.utils.plugin import ( + make_plugin_registry, +) log = logging.getLogger(__name__) -class DescriptorBlock(torch.nn.Module, ABC): +class DescriptorBlock(torch.nn.Module, ABC, make_plugin_registry("DescriptorBlock")): """The building block of descriptor. Given the input descriptor, provide with the atomic coordinates, atomic types and neighbor list, calculate the new descriptor. """ - __plugins = Plugin() local_cluster = False - @staticmethod - def register(key: str) -> Callable: - """Register a DescriptorBlock plugin. - - Parameters - ---------- - key : str - the key of a DescriptorBlock - - Returns - ------- - DescriptorBlock - the registered DescriptorBlock - - Examples - -------- - >>> @DescriptorBlock.register("some_descrpt") - class SomeDescript(DescriptorBlock): - pass - """ - return DescriptorBlock.__plugins.register(key) - def __new__(cls, *args, **kwargs): if cls is DescriptorBlock: try: descrpt_type = kwargs["type"] except KeyError: raise KeyError("the type of DescriptorBlock should be set by `type`") - if descrpt_type in DescriptorBlock.__plugins.plugins: - cls = DescriptorBlock.__plugins.plugins[descrpt_type] - else: - raise RuntimeError("Unknown DescriptorBlock type: " + descrpt_type) + cls = cls.get_class_by_type(descrpt_type) return super().__new__(cls) @abstractmethod diff --git a/deepmd/tf/descriptor/descriptor.py b/deepmd/tf/descriptor/descriptor.py index 48329ceb48..dbf260bfe8 100644 --- a/deepmd/tf/descriptor/descriptor.py +++ b/deepmd/tf/descriptor/descriptor.py @@ -4,7 +4,6 @@ ) from typing import ( Any, - Callable, Dict, List, Optional, @@ -21,12 +20,14 @@ tf, ) from deepmd.tf.utils import ( - Plugin, PluginVariant, ) +from deepmd.utils.plugin import ( + make_plugin_registry, +) -class Descriptor(PluginVariant): +class Descriptor(PluginVariant, make_plugin_registry("descriptor")): r"""The abstract class for descriptors. All specific descriptors should be based on this class. @@ -45,37 +46,6 @@ class Descriptor(PluginVariant): that can be called by other classes. """ - __plugins = Plugin() - - @staticmethod - def register(key: str) -> Callable: - """Register a descriptor plugin. - - Parameters - ---------- - key : str - the key of a descriptor - - Returns - ------- - Descriptor - the registered descriptor - - Examples - -------- - >>> @Descriptor.register("some_descrpt") - class SomeDescript(Descriptor): - pass - """ - return Descriptor.__plugins.register(key) - - @classmethod - def get_class_by_type(cls, descrpt_type: str): - if descrpt_type in Descriptor.__plugins.plugins: - return Descriptor.__plugins.plugins[descrpt_type] - else: - raise RuntimeError("Unknown descriptor type: " + descrpt_type) - def __new__(cls, *args, **kwargs): if cls is Descriptor: cls = cls.get_class_by_type(j_get_type(kwargs, cls.__name__)) diff --git a/deepmd/tf/fit/fitting.py b/deepmd/tf/fit/fitting.py index a24efcfdcd..6a7398daac 100644 --- a/deepmd/tf/fit/fitting.py +++ b/deepmd/tf/fit/fitting.py @@ -4,10 +4,8 @@ abstractmethod, ) from typing import ( - Callable, List, Optional, - Type, ) from deepmd.common import ( @@ -25,56 +23,14 @@ Loss, ) from deepmd.tf.utils import ( - Plugin, PluginVariant, ) +from deepmd.utils.plugin import ( + make_plugin_registry, +) -class Fitting(PluginVariant): - __plugins = Plugin() - - @staticmethod - def register(key: str) -> Callable: - """Register a Fitting plugin. - - Parameters - ---------- - key : str - the key of a Fitting - - Returns - ------- - Fitting - the registered Fitting - - Examples - -------- - >>> @Fitting.register("some_fitting") - class SomeFitting(Fitting): - pass - """ - return Fitting.__plugins.register(key) - - @classmethod - def get_class_by_type(cls, fitting_type: str) -> Type["Fitting"]: - """Get the fitting class by the input type. - - Parameters - ---------- - fitting_type : str - The input type - - Returns - ------- - Fitting - The fitting class - """ - if fitting_type in Fitting.__plugins.plugins: - cls = Fitting.__plugins.plugins[fitting_type] - else: - raise RuntimeError("Unknown descriptor type: " + fitting_type) - return cls - +class Fitting(PluginVariant, make_plugin_registry("fitting")): def __new__(cls, *args, **kwargs): if cls is Fitting: cls = cls.get_class_by_type(j_get_type(kwargs, cls.__name__)) diff --git a/deepmd/tf/model/__init__.py b/deepmd/tf/model/__init__.py index d366ca1441..1d100f2b09 100644 --- a/deepmd/tf/model/__init__.py +++ b/deepmd/tf/model/__init__.py @@ -1,4 +1,17 @@ # SPDX-License-Identifier: LGPL-3.0-or-later +from deepmd.tf.model.frozen import ( + FrozenModel, +) +from deepmd.tf.model.linear import ( + LinearEnergyModel, +) +from deepmd.tf.model.pairtab import ( + PairTabModel, +) +from deepmd.tf.model.pairwise_dprc import ( + PairwiseDPRc, +) + from .dos import ( DOSModel, ) @@ -23,4 +36,8 @@ "GlobalPolarModel", "PolarModel", "WFCModel", + "FrozenModel", + "LinearEnergyModel", + "PairTabModel", + "PairwiseDPRc", ] diff --git a/deepmd/tf/model/frozen.py b/deepmd/tf/model/frozen.py index f06ae954d1..1933690ca7 100644 --- a/deepmd/tf/model/frozen.py +++ b/deepmd/tf/model/frozen.py @@ -30,6 +30,7 @@ ) +@Model.register("frozen") class FrozenModel(Model): """Load model from a frozen model, which cannot be trained. diff --git a/deepmd/tf/model/linear.py b/deepmd/tf/model/linear.py index 7563e36b3f..da866ccc5f 100644 --- a/deepmd/tf/model/linear.py +++ b/deepmd/tf/model/linear.py @@ -147,6 +147,7 @@ def update_sel(cls, global_jdata: dict, local_jdata: dict): return local_jdata_cpy +@Model.register("linear_ener") class LinearEnergyModel(LinearModel): """Linear energy model make linear combinations of several existing energy models.""" diff --git a/deepmd/tf/model/model.py b/deepmd/tf/model/model.py index 76310834a7..2ae2879226 100644 --- a/deepmd/tf/model/model.py +++ b/deepmd/tf/model/model.py @@ -60,9 +60,12 @@ from deepmd.tf.utils.type_embed import ( TypeEmbedNet, ) +from deepmd.utils.plugin import ( + make_plugin_registry, +) -class Model(ABC): +class Model(ABC, make_plugin_registry("model")): """Abstract base model. Parameters @@ -94,47 +97,6 @@ class Model(ABC): Compression information for internal use """ - @classmethod - def get_class_by_type(cls, model_type: str): - """Get the class by input type. - - Parameters - ---------- - model_type : str - The input type - """ - # infer model type by fitting_type - from deepmd.tf.model.frozen import ( - FrozenModel, - ) - from deepmd.tf.model.linear import ( - LinearEnergyModel, - ) - from deepmd.tf.model.multi import ( - MultiModel, - ) - from deepmd.tf.model.pairtab import ( - PairTabModel, - ) - from deepmd.tf.model.pairwise_dprc import ( - PairwiseDPRc, - ) - - if model_type == "standard": - return StandardModel - elif model_type == "multi": - return MultiModel - elif model_type == "pairwise_dprc": - return PairwiseDPRc - elif model_type == "frozen": - return FrozenModel - elif model_type == "linear_ener": - return LinearEnergyModel - elif model_type == "pairtab": - return PairTabModel - else: - raise ValueError(f"unknown model type: {model_type}") - def __new__(cls, *args, **kwargs): if cls is Model: # init model @@ -621,6 +583,7 @@ def serialize(self, suffix: str = "") -> dict: raise NotImplementedError("Not implemented in class %s" % self.__name__) +@Model.register("standard") class StandardModel(Model): """Standard model, which must contain a descriptor and a fitting. diff --git a/deepmd/tf/model/multi.py b/deepmd/tf/model/multi.py index 2acf00fd52..6280fcd2f6 100644 --- a/deepmd/tf/model/multi.py +++ b/deepmd/tf/model/multi.py @@ -55,6 +55,7 @@ ) +@Model.register("multi") class MultiModel(Model): """Multi-task model. diff --git a/deepmd/tf/model/pairtab.py b/deepmd/tf/model/pairtab.py index fe94c43f64..2cb0dc6e52 100644 --- a/deepmd/tf/model/pairtab.py +++ b/deepmd/tf/model/pairtab.py @@ -31,6 +31,7 @@ ) +@Model.register("pairtab") class PairTabModel(Model): """Pairwise tabulation energy model. diff --git a/deepmd/tf/model/pairwise_dprc.py b/deepmd/tf/model/pairwise_dprc.py index 51296a0df9..5a377cdfa4 100644 --- a/deepmd/tf/model/pairwise_dprc.py +++ b/deepmd/tf/model/pairwise_dprc.py @@ -33,6 +33,7 @@ ) +@Model.register("pairwise_dprc") class PairwiseDPRc(Model): """Pairwise Deep Potential - Range Correction.""" diff --git a/deepmd/utils/plugin.py b/deepmd/utils/plugin.py index e6433ee681..22f315f63d 100644 --- a/deepmd/utils/plugin.py +++ b/deepmd/utils/plugin.py @@ -8,6 +8,7 @@ ) from typing import ( Callable, + Dict, Optional, Type, ) @@ -152,4 +153,9 @@ def get_class_by_type(cls, class_type: str) -> Type[object]: dym_message = f"Did you mean: {matches[0]}?" if matches else "" raise RuntimeError(f"Unknown {name} type: {class_type}. {dym_message}") + @classmethod + def get_plugins(cls) -> Dict[str, Type[object]]: + """Get all the registered plugins.""" + return PR.__plugins.plugins + return PR