Skip to content

Commit

Permalink
Merge pull request #269 from WenjieDu/dev
Browse files Browse the repository at this point in the history
Enabling to tune hyperparameters for outside models implemented with PyPOTS framework
  • Loading branch information
WenjieDu authored Dec 18, 2023
2 parents d457629 + 6a4714d commit a796dc2
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 13 deletions.
54 changes: 41 additions & 13 deletions pypots/cli/tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,11 @@
from argparse import ArgumentParser, Namespace

from .base import BaseCommand
from .utils import load_package_from_path
from ..classification import BRITS as BRITS_classification
from ..classification import Raindrop, GRUD
from ..clustering import CRLI, VaDER
from ..imputation import SAITS, Transformer, CSDI, USGAN, GPVAE, MRNN, BRITS
from ..imputation import SAITS, Transformer, CSDI, USGAN, GPVAE, MRNN, BRITS, TimesNet
from ..optim import Adam
from ..utils.logging import logger

Expand All @@ -25,10 +26,12 @@
"but is missing in the current environment."
)


NN_MODELS = {
# imputation models
"pypots.imputation.SAITS": SAITS,
"pypots.imputation.Transformer": Transformer,
"pypots.imputation.TimesNet": TimesNet,
"pypots.imputation.CSDI": CSDI,
"pypots.imputation.US_GAN": USGAN,
"pypots.imputation.GP_VAE": GPVAE,
Expand All @@ -47,6 +50,7 @@
def env_command_factory(args: Namespace):
return TuningCommand(
args.model,
args.model_package_path,
args.train_set,
args.val_set,
)
Expand All @@ -73,23 +77,20 @@ def register_subcommand(parser: ArgumentParser):
help="CLI tools helping run hyper-parameter tuning for specified models",
allow_abbrev=True,
)

sub_parser.add_argument(
"--model",
dest="model",
type=str,
required=True,
choices=[
"pypots.imputation.SAITS",
"pypots.imputation.Transformer",
"pypots.imputation.CSDI",
"pypots.imputation.US_GAN",
"pypots.imputation.GP_VAE",
"pypots.imputation.BRITS",
"pypots.imputation.MRNN",
],
help="Install specified dependencies in the current python environment",
)
sub_parser.add_argument(
"--model_package_path",
dest="model_package_path",
type=str,
required=False,
help="If the model is not in the pypots package, specify the path to the model package here.",
)
sub_parser.add_argument(
"--train_set",
dest="train_set",
Expand All @@ -108,11 +109,13 @@ def register_subcommand(parser: ArgumentParser):

def __init__(
self,
model: bool,
model: str,
model_package_path: str,
train_set: str,
val_set: str,
):
self._model = model
self._model_package_path = model_package_path
self._train_set = train_set
self._val_set = val_set

Expand All @@ -126,7 +129,32 @@ def run(self):
# fetch a new set of hyperparameters from NNI tuner
tuner_params = nni.get_next_parameter()
# get the specified model class
model_class = NN_MODELS[self._model]
if self._model not in NN_MODELS:
logger.info(
f"The specified model {self._model} is not in PyPOTS. "
f"Trying to fetch it from the given model package {self._model_package_path}."
)
assert self._model_package_path is not None, (
f"The given model {self._model} is not in PyPOTS. "
f"Please give the full import path of the model in PyPOTS like pypots.imputation.SAITS\n"
f"If you're trying to tune an outside model, "
f"please specify the path to the model package with argument `--model_package_path`."
)
model_package = load_package_from_path(self._model_package_path)
assert self._model in model_package.__all__, (
f"{self._model} is not in the given model package {self._model_package_path}."
f"Please ensure that the model class is in the __all__ list of the model package."
)
model_class = getattr(model_package, self._model)
else:
if self._model_package_path is not None:
logger.warning(
f"‼️ Find the specified model {self._model} in PyPOTS, "
f"but also find the argument --model_package_path is not None."
f"Note that --model_package_path is ignored."
)

model_class = NN_MODELS[self._model]
# pop out the learning rate
lr = tuner_params.pop("lr")

Expand Down
25 changes: 25 additions & 0 deletions pypots/cli/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
"""
Adding CLI utilities here.
"""

# Created by Wenjie Du <wenjay.du@gmail.com>
# License: BSD-3-Clause


import os
import sys
from importlib import util
from types import ModuleType


def load_package_from_path(pkg_path: str) -> ModuleType:
"""Load a package from a given path. Please refer to https://stackoverflow.com/a/50395128"""
init_path = os.path.join(pkg_path, "__init__.py")
assert os.path.exists(init_path)

name = os.path.basename(pkg_path)
spec = util.spec_from_file_location(name, init_path)
module = util.module_from_spec(spec)
sys.modules[spec.name] = module
spec.loader.exec_module(module)
return module

0 comments on commit a796dc2

Please sign in to comment.