Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor hyperparameter search backends #24384

Merged
merged 7 commits into from
Jun 22, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@
"file_utils": [],
"generation": ["GenerationConfig", "TextIteratorStreamer", "TextStreamer"],
"hf_argparser": ["HfArgumentParser"],
"hyperparameter_search": [],
"image_transforms": [],
"integrations": [
"is_clearml_available",
Expand Down
128 changes: 128 additions & 0 deletions src/transformers/hyperparameter_search.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
from abc import ABC, abstractmethod
from typing import Dict

from .integrations import (
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can just add a copyright here similar to all other files in the lib (potentially switching the year to 2023 if it's not)?

is_optuna_available,
is_ray_available,
is_sigopt_available,
is_wandb_available,
run_hp_search_optuna,
run_hp_search_ray,
run_hp_search_sigopt,
run_hp_search_wandb,
)
from .trainer_utils import (
HPSearchBackend,
default_hp_space_optuna,
default_hp_space_ray,
default_hp_space_sigopt,
default_hp_space_wandb,
)


class HyperParamSearchBackendBase(ABC):
@abstractmethod
def name(self) -> str:
raise NotImplementedError

@abstractmethod
def is_available(self):
raise NotImplementedError

@abstractmethod
def run(self, trainer, n_trials: int, direction: str, **kwargs):
raise NotImplementedError

@abstractmethod
def default_hp_space(self, trial):
raise NotImplementedError

def ensure_available(self):
if not self.is_available():
raise RuntimeError(
f"You picked the {self.name()} backend, but it is not installed. "
f"Use `pip install {self.pip_install}`."
)

def pip_install(self):
return self.name()


class OptunaBackend(HyperParamSearchBackendBase):
def name(self) -> str:
return "optuna"

def is_available(self):
return is_optuna_available()

def run(self, trainer, n_trials: int, direction: str, **kwargs):
return run_hp_search_optuna(trainer, n_trials, direction, **kwargs)

def default_hp_space(self, trial):
return default_hp_space_optuna(trial)


class RayTuneBackend(HyperParamSearchBackendBase):
def name(self) -> str:
return "ray"

def pip_install(self):
return "'ray[tune]'"

def is_available(self):
return is_ray_available()

def run(self, trainer, n_trials: int, direction: str, **kwargs):
return run_hp_search_ray(trainer, n_trials, direction, **kwargs)

def default_hp_space(self, trial):
return default_hp_space_ray(trial)


class SigOptBackend(HyperParamSearchBackendBase):
def name(self) -> str:
return "sigopt"

def is_available(self):
return is_sigopt_available()

def run(self, trainer, n_trials: int, direction: str, **kwargs):
return run_hp_search_sigopt(trainer, n_trials, direction, **kwargs)

def default_hp_space(self, trial):
return default_hp_space_sigopt(trial)


class WandbBackend(HyperParamSearchBackendBase):
def name(self) -> str:
return "wandb"

def is_available(self):
return is_wandb_available()

def run(self, trainer, n_trials: int, direction: str, **kwargs):
return run_hp_search_wandb(trainer, n_trials, direction, **kwargs)

def default_hp_space(self, trial):
return default_hp_space_wandb(trial)


all_backends: Dict[str, HyperParamSearchBackendBase] = {
backend.name(): backend for backend in [OptunaBackend(), RayTuneBackend(), SigOptBackend(), WandbBackend()]
}

assert list(all_backends) == list(HPSearchBackend)


def default_hp_search_backend() -> str:
available_backends = [backend for backend in all_backends.values() if backend.is_available()]
if available_backends:
# TODO warn if len(available_backends) > 1 ?
return available_backends[0].name()
raise RuntimeError(
"No hyperparameter search backend available.\n"
+ "\n".join(
f" - To install {backend.name()} run `pip install {backend.pip_install()}`"
for backend in all_backends.values()
)
)
9 changes: 0 additions & 9 deletions src/transformers/integrations.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,15 +177,6 @@ def hp_params(trial):
raise RuntimeError(f"Unknown type for trial {trial.__class__}")


def default_hp_search_backend():
if is_optuna_available():
return "optuna"
elif is_ray_tune_available():
return "ray"
elif is_sigopt_available():
return "sigopt"


def run_hp_search_optuna(trainer, n_trials: int, direction: str, **kwargs) -> BestRun:
import optuna

Expand Down
40 changes: 5 additions & 35 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,18 +36,9 @@
# Integrations must be imported before ML frameworks:
# isort: off
from .integrations import (
default_hp_search_backend,
get_reporting_integration_callbacks,
hp_params,
is_fairscale_available,
is_optuna_available,
is_ray_tune_available,
is_sigopt_available,
is_wandb_available,
run_hp_search_optuna,
run_hp_search_ray,
run_hp_search_sigopt,
run_hp_search_wandb,
)

# isort: on
Expand All @@ -66,6 +57,7 @@
from .debug_utils import DebugOption, DebugUnderflowOverflow
from .deepspeed import deepspeed_init, deepspeed_load_checkpoint, is_deepspeed_zero3_enabled
from .dependency_versions_check import dep_version_check
from .hyperparameter_search import all_backends, default_hp_search_backend
from .modelcard import TrainingSummary
from .modeling_utils import PreTrainedModel, load_sharded_checkpoint, unwrap_model
from .models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES, MODEL_MAPPING_NAMES
Expand Down Expand Up @@ -114,7 +106,6 @@
TrainerMemoryTracker,
TrainOutput,
default_compute_objective,
default_hp_space,
denumpify_detensorize,
enable_full_determinism,
find_executable_batch_size,
Expand Down Expand Up @@ -2516,41 +2507,20 @@ def hyperparameter_search(
"""
if backend is None:
backend = default_hp_search_backend()
if backend is None:
raise RuntimeError(
"At least one of optuna or ray should be installed. "
"To install optuna run `pip install optuna`. "
"To install ray run `pip install ray[tune]`. "
"To install sigopt run `pip install sigopt`."
)
backend = HPSearchBackend(backend)
if backend == HPSearchBackend.OPTUNA and not is_optuna_available():
raise RuntimeError("You picked the optuna backend, but it is not installed. Use `pip install optuna`.")
if backend == HPSearchBackend.RAY and not is_ray_tune_available():
raise RuntimeError(
"You picked the Ray Tune backend, but it is not installed. Use `pip install 'ray[tune]'`."
)
if backend == HPSearchBackend.SIGOPT and not is_sigopt_available():
raise RuntimeError("You picked the sigopt backend, but it is not installed. Use `pip install sigopt`.")
if backend == HPSearchBackend.WANDB and not is_wandb_available():
raise RuntimeError("You picked the wandb backend, but it is not installed. Use `pip install wandb`.")
backend_obj = all_backends[backend]
backend_obj.ensure_available()
self.hp_search_backend = backend
if self.model_init is None:
raise RuntimeError(
"To use hyperparameter search, you need to pass your model through a model_init function."
)

self.hp_space = default_hp_space[backend] if hp_space is None else hp_space
self.hp_space = backend_obj.default_hp_space if hp_space is None else hp_space
self.hp_name = hp_name
self.compute_objective = default_compute_objective if compute_objective is None else compute_objective

backend_dict = {
HPSearchBackend.OPTUNA: run_hp_search_optuna,
HPSearchBackend.RAY: run_hp_search_ray,
HPSearchBackend.SIGOPT: run_hp_search_sigopt,
HPSearchBackend.WANDB: run_hp_search_wandb,
}
best_run = backend_dict[backend](self, n_trials, direction, **kwargs)
best_run = backend_obj.run(self, n_trials, direction, **kwargs)

self.hp_search_backend = None
return best_run
Expand Down
8 changes: 0 additions & 8 deletions src/transformers/trainer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,14 +301,6 @@ class HPSearchBackend(ExplicitEnum):
WANDB = "wandb"


default_hp_space = {
HPSearchBackend.OPTUNA: default_hp_space_optuna,
HPSearchBackend.RAY: default_hp_space_ray,
HPSearchBackend.SIGOPT: default_hp_space_sigopt,
HPSearchBackend.WANDB: default_hp_space_wandb,
}


def is_main_process(local_rank):
"""
Whether or not the current process is the local process, based on `xm.get_ordinal()` (for TPUs) first, then on
Expand Down