diff --git a/pypots/cli/tuning.py b/pypots/cli/tuning.py index baf9d200..40419a8f 100644 --- a/pypots/cli/tuning.py +++ b/pypots/cli/tuning.py @@ -10,10 +10,11 @@ from argparse import ArgumentParser, Namespace from .base import BaseCommand +from .utils import load_package_from_path from ..classification import BRITS as BRITS_classification from ..classification import Raindrop, GRUD from ..clustering import CRLI, VaDER -from ..imputation import SAITS, Transformer, CSDI, USGAN, GPVAE, MRNN, BRITS +from ..imputation import SAITS, Transformer, CSDI, USGAN, GPVAE, MRNN, BRITS, TimesNet from ..optim import Adam from ..utils.logging import logger @@ -25,10 +26,12 @@ "but is missing in the current environment." ) + NN_MODELS = { # imputation models "pypots.imputation.SAITS": SAITS, "pypots.imputation.Transformer": Transformer, + "pypots.imputation.TimesNet": TimesNet, "pypots.imputation.CSDI": CSDI, "pypots.imputation.US_GAN": USGAN, "pypots.imputation.GP_VAE": GPVAE, @@ -47,6 +50,7 @@ def env_command_factory(args: Namespace): return TuningCommand( args.model, + args.model_package_path, args.train_set, args.val_set, ) @@ -73,23 +77,20 @@ def register_subcommand(parser: ArgumentParser): help="CLI tools helping run hyper-parameter tuning for specified models", allow_abbrev=True, ) - sub_parser.add_argument( "--model", dest="model", type=str, required=True, - choices=[ - "pypots.imputation.SAITS", - "pypots.imputation.Transformer", - "pypots.imputation.CSDI", - "pypots.imputation.US_GAN", - "pypots.imputation.GP_VAE", - "pypots.imputation.BRITS", - "pypots.imputation.MRNN", - ], help="Install specified dependencies in the current python environment", ) + sub_parser.add_argument( + "--model_package_path", + dest="model_package_path", + type=str, + required=False, + help="If the model is not in the pypots package, specify the path to the model package here.", + ) sub_parser.add_argument( "--train_set", dest="train_set", @@ -108,11 +109,13 @@ def register_subcommand(parser: ArgumentParser): def __init__( self, - model: bool, + model: str, + model_package_path: str, train_set: str, val_set: str, ): self._model = model + self._model_package_path = model_package_path self._train_set = train_set self._val_set = val_set @@ -126,7 +129,32 @@ def run(self): # fetch a new set of hyperparameters from NNI tuner tuner_params = nni.get_next_parameter() # get the specified model class - model_class = NN_MODELS[self._model] + if self._model not in NN_MODELS: + logger.info( + f"The specified model {self._model} is not in PyPOTS. " + f"Trying to fetch it from the given model package {self._model_package_path}." + ) + assert self._model_package_path is not None, ( + f"The given model {self._model} is not in PyPOTS. " + f"Please give the full import path of the model in PyPOTS like pypots.imputation.SAITS\n" + f"If you're trying to tune an outside model, " + f"please specify the path to the model package with argument `--model_package_path`." + ) + model_package = load_package_from_path(self._model_package_path) + assert self._model in model_package.__all__, ( + f"{self._model} is not in the given model package {self._model_package_path}." + f"Please ensure that the model class is in the __all__ list of the model package." + ) + model_class = getattr(model_package, self._model) + else: + if self._model_package_path is not None: + logger.warning( + f"‼️ Find the specified model {self._model} in PyPOTS, " + f"but also find the argument --model_package_path is not None." + f"Note that --model_package_path is ignored." + ) + + model_class = NN_MODELS[self._model] # pop out the learning rate lr = tuner_params.pop("lr") diff --git a/pypots/cli/utils.py b/pypots/cli/utils.py new file mode 100644 index 00000000..a1ae14ed --- /dev/null +++ b/pypots/cli/utils.py @@ -0,0 +1,25 @@ +""" +Adding CLI utilities here. +""" + +# Created by Wenjie Du +# License: BSD-3-Clause + + +import os +import sys +from importlib import util +from types import ModuleType + + +def load_package_from_path(pkg_path: str) -> ModuleType: + """Load a package from a given path. Please refer to https://stackoverflow.com/a/50395128""" + init_path = os.path.join(pkg_path, "__init__.py") + assert os.path.exists(init_path) + + name = os.path.basename(pkg_path) + spec = util.spec_from_file_location(name, init_path) + module = util.module_from_spec(spec) + sys.modules[spec.name] = module + spec.loader.exec_module(module) + return module