Skip to content

Commit

Permalink
added progress bar for optuna
Browse files Browse the repository at this point in the history
  • Loading branch information
screengreen committed Aug 20, 2024
1 parent 38fdfa9 commit 6103edb
Showing 1 changed file with 39 additions and 4 deletions.
43 changes: 39 additions & 4 deletions lightautoml/ml_algo/tuning/optuna.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from typing import Union

import optuna
from tqdm import tqdm

from ...dataset.base import LAMLDataset
from ..base import MLAlgo
Expand All @@ -19,6 +20,7 @@
from .base import Uniform
from ...validation.base import HoldoutIterator
from ...validation.base import TrainValidIterator
from ...utils.logging import get_stdout_level


logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -172,6 +174,15 @@ def update_trial_time(study: optuna.study.Study, trial: optuna.trial.FrozenTrial
)

try:
# Custom progress bar
def custom_progress_bar(study: optuna.study.Study, trial: optuna.trial.FrozenTrial):
best_trial = study.best_trial
progress_bar.set_postfix(best_trial=best_trial.number, best_value=best_trial.value)
progress_bar.update(1)

# Initialize progress bar
if get_stdout_level() in [logging.INFO, logging.INFO2]:
progress_bar = tqdm(total=self.n_trials, desc="Optimization Progress")

sampler = optuna.samplers.TPESampler(seed=self.random_state)
self.study = optuna.create_study(direction=self.direction, sampler=sampler)
Expand All @@ -184,10 +195,17 @@ def update_trial_time(study: optuna.study.Study, trial: optuna.trial.FrozenTrial
),
n_trials=self.n_trials,
timeout=self.timeout,
callbacks=[update_trial_time],
# show_progress_bar=True,
callbacks=(
[update_trial_time, custom_progress_bar]
if get_stdout_level() in [logging.INFO, logging.INFO2]
else [update_trial_time]
),
)

# Close the progress bar if it was initialized
if get_stdout_level() in [logging.INFO, logging.INFO2]:
progress_bar.close()

# need to update best params here
self._best_params = self.study.best_params
ml_algo.params = self._best_params
Expand Down Expand Up @@ -384,6 +402,16 @@ def update_trial_time(study: optuna.study.Study, trial: optuna.trial.FrozenTrial
)

try:
# Custom progress bar
def custom_progress_bar(study: optuna.study.Study, trial: optuna.trial.FrozenTrial):
best_trial = study.best_trial
progress_bar.set_postfix(best_trial=best_trial.number, best_value=best_trial.value)
progress_bar.update(1)

# Initialize progress bar
if get_stdout_level() in [logging.INFO, logging.INFO2]:
progress_bar = tqdm(total=self.n_trials, desc="Optimization Progress")

sampler = optuna.samplers.TPESampler(seed=self.random_state)
self.study = optuna.create_study(direction=self.direction, sampler=sampler)

Expand All @@ -395,10 +423,17 @@ def update_trial_time(study: optuna.study.Study, trial: optuna.trial.FrozenTrial
),
n_trials=self.n_trials,
timeout=self.timeout,
callbacks=[update_trial_time],
# show_progress_bar=True,
callbacks=(
[update_trial_time, custom_progress_bar]
if get_stdout_level() in [logging.INFO, logging.INFO2]
else [update_trial_time]
),
)

# Close the progress bar if it was initialized
if get_stdout_level() in [logging.INFO, logging.INFO2]:
progress_bar.close()

# need to update best params here
if self.direction == "maximize":
self._best_params = max(self._params_scores, key=lambda x: x[1])[0]
Expand Down

0 comments on commit 6103edb

Please sign in to comment.