Skip to content

Commit

Permalink
Refactor hyperparameter search backends (#24384)
Browse files Browse the repository at this point in the history
* Refactor hyperparameter search backends

* Simpler refactoring without abstract base class

* black

* review comments:
specify name in class
use methods instead of callable class attributes
name constant better

* review comments: safer bool checking, log multiple available backends

* test ALL_HYPERPARAMETER_SEARCH_BACKENDS vs HPSearchBackend in unit test, not module. format with black.

* copyright
  • Loading branch information
alexmojaki authored Jun 22, 2023
1 parent a1c4b63 commit b6295b2
Show file tree
Hide file tree
Showing 6 changed files with 152 additions and 53 deletions.
1 change: 1 addition & 0 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@
"file_utils": [],
"generation": ["GenerationConfig", "TextIteratorStreamer", "TextStreamer"],
"hf_argparser": ["HfArgumentParser"],
"hyperparameter_search": [],
"image_transforms": [],
"integrations": [
"is_clearml_available",
Expand Down
136 changes: 136 additions & 0 deletions src/transformers/hyperparameter_search.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
# coding=utf-8
# Copyright 2023-present the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from .integrations import (
is_optuna_available,
is_ray_available,
is_sigopt_available,
is_wandb_available,
run_hp_search_optuna,
run_hp_search_ray,
run_hp_search_sigopt,
run_hp_search_wandb,
)
from .trainer_utils import (
HPSearchBackend,
default_hp_space_optuna,
default_hp_space_ray,
default_hp_space_sigopt,
default_hp_space_wandb,
)
from .utils import logging


logger = logging.get_logger(__name__)


class HyperParamSearchBackendBase:
name: str
pip_package: str = None

def is_available(self):
raise NotImplementedError

def run(self, trainer, n_trials: int, direction: str, **kwargs):
raise NotImplementedError

def default_hp_space(self, trial):
raise NotImplementedError

def ensure_available(self):
if not self.is_available():
raise RuntimeError(
f"You picked the {self.name} backend, but it is not installed. Run {self.pip_install()}."
)

@classmethod
def pip_install(cls):
return f"`pip install {cls.pip_package or cls.name}`"


class OptunaBackend(HyperParamSearchBackendBase):
name = "optuna"

def is_available(self):
return is_optuna_available()

def run(self, trainer, n_trials: int, direction: str, **kwargs):
return run_hp_search_optuna(trainer, n_trials, direction, **kwargs)

def default_hp_space(self, trial):
return default_hp_space_optuna(trial)


class RayTuneBackend(HyperParamSearchBackendBase):
name = "ray"
pip_package = "'ray[tune]'"

def is_available(self):
return is_ray_available()

def run(self, trainer, n_trials: int, direction: str, **kwargs):
return run_hp_search_ray(trainer, n_trials, direction, **kwargs)

def default_hp_space(self, trial):
return default_hp_space_ray(trial)


class SigOptBackend(HyperParamSearchBackendBase):
name = "sigopt"

def is_available(self):
return is_sigopt_available()

def run(self, trainer, n_trials: int, direction: str, **kwargs):
return run_hp_search_sigopt(trainer, n_trials, direction, **kwargs)

def default_hp_space(self, trial):
return default_hp_space_sigopt(trial)


class WandbBackend(HyperParamSearchBackendBase):
name = "wandb"

def is_available(self):
return is_wandb_available()

def run(self, trainer, n_trials: int, direction: str, **kwargs):
return run_hp_search_wandb(trainer, n_trials, direction, **kwargs)

def default_hp_space(self, trial):
return default_hp_space_wandb(trial)


ALL_HYPERPARAMETER_SEARCH_BACKENDS = {
HPSearchBackend(backend.name): backend for backend in [OptunaBackend, RayTuneBackend, SigOptBackend, WandbBackend]
}


def default_hp_search_backend() -> str:
available_backends = [backend for backend in ALL_HYPERPARAMETER_SEARCH_BACKENDS.values() if backend.is_available()]
if len(available_backends) > 0:
name = available_backends[0].name
if len(available_backends) > 1:
logger.info(
f"{len(available_backends)} hyperparameter search backends available. Using {name} as the default."
)
return name
raise RuntimeError(
"No hyperparameter search backend available.\n"
+ "\n".join(
f" - To install {backend.name} run {backend.pip_install()}"
for backend in ALL_HYPERPARAMETER_SEARCH_BACKENDS.values()
)
)
9 changes: 0 additions & 9 deletions src/transformers/integrations.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,15 +177,6 @@ def hp_params(trial):
raise RuntimeError(f"Unknown type for trial {trial.__class__}")


def default_hp_search_backend():
if is_optuna_available():
return "optuna"
elif is_ray_tune_available():
return "ray"
elif is_sigopt_available():
return "sigopt"


def run_hp_search_optuna(trainer, n_trials: int, direction: str, **kwargs) -> BestRun:
import optuna

Expand Down
40 changes: 5 additions & 35 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,18 +36,9 @@
# Integrations must be imported before ML frameworks:
# isort: off
from .integrations import (
default_hp_search_backend,
get_reporting_integration_callbacks,
hp_params,
is_fairscale_available,
is_optuna_available,
is_ray_tune_available,
is_sigopt_available,
is_wandb_available,
run_hp_search_optuna,
run_hp_search_ray,
run_hp_search_sigopt,
run_hp_search_wandb,
)

# isort: on
Expand All @@ -66,6 +57,7 @@
from .debug_utils import DebugOption, DebugUnderflowOverflow
from .deepspeed import deepspeed_init, deepspeed_load_checkpoint, is_deepspeed_zero3_enabled
from .dependency_versions_check import dep_version_check
from .hyperparameter_search import ALL_HYPERPARAMETER_SEARCH_BACKENDS, default_hp_search_backend
from .modelcard import TrainingSummary
from .modeling_utils import PreTrainedModel, load_sharded_checkpoint, unwrap_model
from .models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES, MODEL_MAPPING_NAMES
Expand Down Expand Up @@ -114,7 +106,6 @@
TrainerMemoryTracker,
TrainOutput,
default_compute_objective,
default_hp_space,
denumpify_detensorize,
enable_full_determinism,
find_executable_batch_size,
Expand Down Expand Up @@ -2517,41 +2508,20 @@ def hyperparameter_search(
"""
if backend is None:
backend = default_hp_search_backend()
if backend is None:
raise RuntimeError(
"At least one of optuna or ray should be installed. "
"To install optuna run `pip install optuna`. "
"To install ray run `pip install ray[tune]`. "
"To install sigopt run `pip install sigopt`."
)
backend = HPSearchBackend(backend)
if backend == HPSearchBackend.OPTUNA and not is_optuna_available():
raise RuntimeError("You picked the optuna backend, but it is not installed. Use `pip install optuna`.")
if backend == HPSearchBackend.RAY and not is_ray_tune_available():
raise RuntimeError(
"You picked the Ray Tune backend, but it is not installed. Use `pip install 'ray[tune]'`."
)
if backend == HPSearchBackend.SIGOPT and not is_sigopt_available():
raise RuntimeError("You picked the sigopt backend, but it is not installed. Use `pip install sigopt`.")
if backend == HPSearchBackend.WANDB and not is_wandb_available():
raise RuntimeError("You picked the wandb backend, but it is not installed. Use `pip install wandb`.")
backend_obj = ALL_HYPERPARAMETER_SEARCH_BACKENDS[backend]()
backend_obj.ensure_available()
self.hp_search_backend = backend
if self.model_init is None:
raise RuntimeError(
"To use hyperparameter search, you need to pass your model through a model_init function."
)

self.hp_space = default_hp_space[backend] if hp_space is None else hp_space
self.hp_space = backend_obj.default_hp_space if hp_space is None else hp_space
self.hp_name = hp_name
self.compute_objective = default_compute_objective if compute_objective is None else compute_objective

backend_dict = {
HPSearchBackend.OPTUNA: run_hp_search_optuna,
HPSearchBackend.RAY: run_hp_search_ray,
HPSearchBackend.SIGOPT: run_hp_search_sigopt,
HPSearchBackend.WANDB: run_hp_search_wandb,
}
best_run = backend_dict[backend](self, n_trials, direction, **kwargs)
best_run = backend_obj.run(self, n_trials, direction, **kwargs)

self.hp_search_backend = None
return best_run
Expand Down
8 changes: 0 additions & 8 deletions src/transformers/trainer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,14 +301,6 @@ class HPSearchBackend(ExplicitEnum):
WANDB = "wandb"


default_hp_space = {
HPSearchBackend.OPTUNA: default_hp_space_optuna,
HPSearchBackend.RAY: default_hp_space_ray,
HPSearchBackend.SIGOPT: default_hp_space_sigopt,
HPSearchBackend.WANDB: default_hp_space_wandb,
}


def is_main_process(local_rank):
"""
Whether or not the current process is the local process, based on `xm.get_ordinal()` (for TPUs) first, then on
Expand Down
11 changes: 10 additions & 1 deletion tests/trainer/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
is_torch_available,
logging,
)
from transformers.hyperparameter_search import ALL_HYPERPARAMETER_SEARCH_BACKENDS
from transformers.testing_utils import (
ENDPOINT_STAGING,
TOKEN,
Expand Down Expand Up @@ -72,7 +73,7 @@
require_wandb,
slow,
)
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, HPSearchBackend
from transformers.training_args import OptimizerNames
from transformers.utils import (
SAFE_WEIGHTS_INDEX_NAME,
Expand Down Expand Up @@ -2803,3 +2804,11 @@ def hp_name(params):
trainer.hyperparameter_search(
direction="minimize", hp_space=hp_space, hp_name=hp_name, backend="wandb", n_trials=4, anonymous="must"
)


class HyperParameterSearchBackendsTest(unittest.TestCase):
def test_hyperparameter_search_backends(self):
self.assertEqual(
list(ALL_HYPERPARAMETER_SEARCH_BACKENDS.keys()),
list(HPSearchBackend),
)

0 comments on commit b6295b2

Please sign in to comment.