Skip to content

Commit

Permalink
Merge branch 'AUTOML-63' into 'master'
Browse files Browse the repository at this point in the history
Add fail_tolerance to OptunaTuner

See merge request ai-lab-pmo/mltools/automl/LightAutoML!27
  • Loading branch information
dev-rinchin committed Dec 6, 2024
2 parents f21a34c + 882a9e5 commit a744d13
Showing 1 changed file with 22 additions and 4 deletions.
26 changes: 22 additions & 4 deletions lightautoml/ml_algo/tuning/optuna.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,13 +107,15 @@ def __init__(
direction: Optional[str] = "maximize",
fit_on_holdout: bool = True,
random_state: int = 42,
fail_tolerance: float = 0.5,
):
self.timeout = timeout
self.n_trials = n_trials
self.estimated_n_trials = n_trials
self.direction = direction
self._fit_on_holdout = fit_on_holdout
self.random_state = random_state
self.fail_tolerance = fail_tolerance

def _upd_timeout(self, timeout):
self.timeout = min(self.timeout, timeout)
Expand Down Expand Up @@ -173,6 +175,12 @@ def update_trial_time(study: optuna.study.Study, trial: optuna.trial.FrozenTrial
f"\x1b[1mTrial {len(study.trials)}\x1b[0m with hyperparameters {trial.params} scored {trial.value} in {trial.duration}"
)

def check_fail_tolerance(study: optuna.study.Study, trial: optuna.trial.FrozenTrial):
df = study.trials_dataframe()

if df[df["state"] == "FAIL"].shape[0] / self.estimated_n_trials > self.fail_tolerance:
raise Exception(f"Too much trials was failed ({df[df["state"] == "FAIL"].shape[0]} of {df.shape[0]}). Check the model or search space for it.")

try:
# Custom progress bar
def custom_progress_bar(study: optuna.study.Study, trial: optuna.trial.FrozenTrial):
Expand All @@ -196,10 +204,11 @@ def custom_progress_bar(study: optuna.study.Study, trial: optuna.trial.FrozenTri
n_trials=self.n_trials,
timeout=self.timeout,
callbacks=(
[update_trial_time, custom_progress_bar]
[update_trial_time, check_fail_tolerance, custom_progress_bar]
if get_stdout_level() in [logging.INFO, logging.INFO2]
else [update_trial_time]
else [update_trial_time, check_fail_tolerance]
),
catch=[Exception],
)

# Close the progress bar if it was initialized
Expand Down Expand Up @@ -333,13 +342,15 @@ def __init__(
direction: Optional[str] = "maximize",
fit_on_holdout: bool = True,
random_state: int = 42,
fail_tolerance = 0.5,
):
self.timeout = timeout
self.n_trials = n_trials
self.estimated_n_trials = n_trials
self.direction = direction
self._fit_on_holdout = fit_on_holdout
self.random_state = random_state
self.fail_tolerance = fail_tolerance

def _upd_timeout(self, timeout):
self.timeout = min(self.timeout, timeout)
Expand Down Expand Up @@ -401,6 +412,12 @@ def update_trial_time(study: optuna.study.Study, trial: optuna.trial.FrozenTrial
f"\x1b[1mTrial {len(study.trials)}\x1b[0m with hyperparameters {trial.params} scored {trial.value} in {trial.duration}"
)

def check_fail_tolerance(study: optuna.study.Study, trial: optuna.trial.FrozenTrial):
df = study.trials_dataframe()

if df[df["state"] == "FAIL"].shape[0] / self.estimated_n_trials > self.fail_tolerance:
raise Exception(f"Too much trials was failed ({df[df["state"] == "FAIL"].shape[0]} of {df.shape[0]}). Check the model or search space for it.")

try:
# Custom progress bar
def custom_progress_bar(study: optuna.study.Study, trial: optuna.trial.FrozenTrial):
Expand All @@ -424,10 +441,11 @@ def custom_progress_bar(study: optuna.study.Study, trial: optuna.trial.FrozenTri
n_trials=self.n_trials,
timeout=self.timeout,
callbacks=(
[update_trial_time, custom_progress_bar]
[update_trial_time, check_fail_tolerance, custom_progress_bar]
if get_stdout_level() in [logging.INFO, logging.INFO2]
else [update_trial_time]
else [update_trial_time, check_fail_tolerance]
),
catch=[Exception],
)

# Close the progress bar if it was initialized
Expand Down

0 comments on commit a744d13

Please sign in to comment.