diff --git a/lightautoml/ml_algo/tuning/optuna.py b/lightautoml/ml_algo/tuning/optuna.py index eade5d12..62e565df 100644 --- a/lightautoml/ml_algo/tuning/optuna.py +++ b/lightautoml/ml_algo/tuning/optuna.py @@ -11,6 +11,7 @@ from typing import Union import optuna +from tqdm import tqdm from ...dataset.base import LAMLDataset from ..base import MLAlgo @@ -19,6 +20,7 @@ from .base import Uniform from ...validation.base import HoldoutIterator from ...validation.base import TrainValidIterator +from ...utils.logging import get_stdout_level logger = logging.getLogger(__name__) @@ -172,6 +174,15 @@ def update_trial_time(study: optuna.study.Study, trial: optuna.trial.FrozenTrial ) try: + # Custom progress bar + def custom_progress_bar(study: optuna.study.Study, trial: optuna.trial.FrozenTrial): + best_trial = study.best_trial + progress_bar.set_postfix(best_trial=best_trial.number, best_value=best_trial.value) + progress_bar.update(1) + + # Initialize progress bar + if get_stdout_level() in [logging.INFO, logging.INFO2]: + progress_bar = tqdm(total=self.n_trials, desc="Optimization Progress") sampler = optuna.samplers.TPESampler(seed=self.random_state) self.study = optuna.create_study(direction=self.direction, sampler=sampler) @@ -184,10 +195,17 @@ def update_trial_time(study: optuna.study.Study, trial: optuna.trial.FrozenTrial ), n_trials=self.n_trials, timeout=self.timeout, - callbacks=[update_trial_time], - # show_progress_bar=True, + callbacks=( + [update_trial_time, custom_progress_bar] + if get_stdout_level() in [logging.INFO, logging.INFO2] + else [update_trial_time] + ), ) + # Close the progress bar if it was initialized + if get_stdout_level() in [logging.INFO, logging.INFO2]: + progress_bar.close() + # need to update best params here self._best_params = self.study.best_params ml_algo.params = self._best_params @@ -384,6 +402,16 @@ def update_trial_time(study: optuna.study.Study, trial: optuna.trial.FrozenTrial ) try: + # Custom progress bar + def custom_progress_bar(study: optuna.study.Study, trial: optuna.trial.FrozenTrial): + best_trial = study.best_trial + progress_bar.set_postfix(best_trial=best_trial.number, best_value=best_trial.value) + progress_bar.update(1) + + # Initialize progress bar + if get_stdout_level() in [logging.INFO, logging.INFO2]: + progress_bar = tqdm(total=self.n_trials, desc="Optimization Progress") + sampler = optuna.samplers.TPESampler(seed=self.random_state) self.study = optuna.create_study(direction=self.direction, sampler=sampler) @@ -395,10 +423,17 @@ def update_trial_time(study: optuna.study.Study, trial: optuna.trial.FrozenTrial ), n_trials=self.n_trials, timeout=self.timeout, - callbacks=[update_trial_time], - # show_progress_bar=True, + callbacks=( + [update_trial_time, custom_progress_bar] + if get_stdout_level() in [logging.INFO, logging.INFO2] + else [update_trial_time] + ), ) + # Close the progress bar if it was initialized + if get_stdout_level() in [logging.INFO, logging.INFO2]: + progress_bar.close() + # need to update best params here if self.direction == "maximize": self._best_params = max(self._params_scores, key=lambda x: x[1])[0]