Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

allow more candidate trials than max_trials in Scheduler #2689

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 12 additions & 7 deletions ax/service/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -920,13 +920,6 @@ def run_trials_and_yield_results(
n_initial_candidate_trials = len(self.candidate_trials)
if n_initial_candidate_trials == 0 and max_trials < 0:
raise UserInputError(f"Expected `max_trials` >= 0, got {max_trials}.")
elif max_trials < n_initial_candidate_trials:
raise UserInputError(
"The number of pre-attached candidate trials "
f"({n_initial_candidate_trials}) is greater than `max_trials = "
f"{max_trials}`. Increase `max_trials` or reduce the number of "
"pre-attached candidate trials."
)

# trials are pre-existing only if they do not still require running
n_existing = len(self.experiment.trials) - n_initial_candidate_trials
Expand Down Expand Up @@ -1570,6 +1563,7 @@ def _complete_optimization(
num_preexisting_trials=num_preexisting_trials,
status=RunTrialsStatus.SUCCESS,
)
self.warn_if_non_terminal_trials()
return res

def _validate_options(self, options: SchedulerOptions) -> None:
Expand Down Expand Up @@ -2151,6 +2145,17 @@ def _get_failure_rate_exceeded_error(
)
)

def warn_if_non_terminal_trials(self) -> None:
"""Warns if there are any non-terminal trials on the experiment."""
non_terminal_trials = [
t.index for t in self.experiment.trials.values() if not t.status.is_terminal
]
if len(non_terminal_trials) > 0:
self.logger.warning(
f"Found {len(non_terminal_trials)} non-terminal trials on "
f"{self.experiment.name}: {non_terminal_trials}."
)


def get_fitted_model_bridge(
scheduler: Scheduler, force_refit: bool = False
Expand Down
46 changes: 42 additions & 4 deletions ax/service/tests/scheduler_test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -798,10 +798,11 @@ def test_run_preattached_trials_only(self) -> None:
# pyre-fixme[6]: For 1st param expected `Dict[str, Union[None, bool, float,
# int, str]]` but got `Dict[str, int]`.
trial.add_arm(Arm(parameters=parameter_dict))
with self.assertRaisesRegex(
UserInputError, "number of pre-attached candidate trials .* is greater than"
):
scheduler.run_n_trials(max_trials=0)

# check no new trials are run, when max_trials = 0
scheduler.run_n_trials(max_trials=0)
self.assertEqual(trial.status, TrialStatus.CANDIDATE)
# check that candidate trial is run, when max_trials = 1
scheduler.run_n_trials(max_trials=1)
self.assertEqual(len(scheduler.experiment.trials), 1)
self.assertDictEqual(
Expand All @@ -813,6 +814,43 @@ def test_run_preattached_trials_only(self) -> None:
all(t.completed_successfully for t in scheduler.experiment.trials.values())
)

def test_run_multiple_preattached_trials_only(self) -> None:
gs = self._get_generation_strategy_strategy_for_test(
experiment=self.branin_experiment,
generation_strategy=self.two_sobol_steps_GS,
)
# assert that pre-attached trials run when max_trials = number of
# pre-attached trials
scheduler = Scheduler(
experiment=self.branin_experiment, # Has runner and metrics.
generation_strategy=gs,
options=SchedulerOptions(
init_seconds_between_polls=0, # Short between polls so test is fast.
trial_type=TrialType.BATCH_TRIAL,
),
db_settings=self.db_settings_if_always_needed,
)
trial1 = scheduler.experiment.new_trial()
trial1.add_arm(Arm(parameters={"x1": 5, "x2": 5}))
trial2 = scheduler.experiment.new_trial()
trial2.add_arm(Arm(parameters={"x1": 6, "x2": 3}))

# check that first candidate trial is run when called with max_trials = 1
with self.assertLogs(logger="ax.service.scheduler") as lg:
scheduler.run_n_trials(max_trials=1)
self.assertIn(
"Found 1 non-terminal trials on branin_test_experiment: [1]",
lg.output[-1],
)
self.assertIn(trial1.status, [TrialStatus.RUNNING, TrialStatus.COMPLETED])
self.assertEqual(trial2.status, TrialStatus.CANDIDATE)
# check that next candidate trial is run, when max_trials = 1
scheduler.run_n_trials(max_trials=1)
self.assertEqual(len(scheduler.experiment.trials), 2)
self.assertTrue( # Make sure all trials got to complete.
all(t.completed_successfully for t in scheduler.experiment.trials.values())
)

def test_global_stopping(self) -> None:
gs = self._get_generation_strategy_strategy_for_test(
experiment=self.branin_experiment,
Expand Down