From a16381736c0ed7c4399beebac8e1c557a1a88c79 Mon Sep 17 00:00:00 2001 From: Sam Daulton Date: Thu, 22 Aug 2024 10:22:18 -0700 Subject: [PATCH] allow more candidate trials than max_trials in Scheduler (#2689) Summary: Pull Request resolved: https://github.com/facebook/Ax/pull/2689 see title. This is useful for queueing trials manually Reviewed By: bernardbeckerman Differential Revision: D61508135 fbshipit-source-id: 76b0c8f0671e7b367706edadf22314b781074458 --- ax/service/scheduler.py | 19 ++++++---- ax/service/tests/scheduler_test_utils.py | 46 +++++++++++++++++++++--- 2 files changed, 54 insertions(+), 11 deletions(-) diff --git a/ax/service/scheduler.py b/ax/service/scheduler.py index 115e4c2e8bb..4faf63b3f21 100644 --- a/ax/service/scheduler.py +++ b/ax/service/scheduler.py @@ -920,13 +920,6 @@ def run_trials_and_yield_results( n_initial_candidate_trials = len(self.candidate_trials) if n_initial_candidate_trials == 0 and max_trials < 0: raise UserInputError(f"Expected `max_trials` >= 0, got {max_trials}.") - elif max_trials < n_initial_candidate_trials: - raise UserInputError( - "The number of pre-attached candidate trials " - f"({n_initial_candidate_trials}) is greater than `max_trials = " - f"{max_trials}`. Increase `max_trials` or reduce the number of " - "pre-attached candidate trials." - ) # trials are pre-existing only if they do not still require running n_existing = len(self.experiment.trials) - n_initial_candidate_trials @@ -1570,6 +1563,7 @@ def _complete_optimization( num_preexisting_trials=num_preexisting_trials, status=RunTrialsStatus.SUCCESS, ) + self.warn_if_non_terminal_trials() return res def _validate_options(self, options: SchedulerOptions) -> None: @@ -2151,6 +2145,17 @@ def _get_failure_rate_exceeded_error( ) ) + def warn_if_non_terminal_trials(self) -> None: + """Warns if there are any non-terminal trials on the experiment.""" + non_terminal_trials = [ + t.index for t in self.experiment.trials.values() if not t.status.is_terminal + ] + if len(non_terminal_trials) > 0: + self.logger.warning( + f"Found {len(non_terminal_trials)} non-terminal trials on " + f"{self.experiment.name}: {non_terminal_trials}." + ) + def get_fitted_model_bridge( scheduler: Scheduler, force_refit: bool = False diff --git a/ax/service/tests/scheduler_test_utils.py b/ax/service/tests/scheduler_test_utils.py index 49548b1046e..0f723046d93 100644 --- a/ax/service/tests/scheduler_test_utils.py +++ b/ax/service/tests/scheduler_test_utils.py @@ -798,10 +798,11 @@ def test_run_preattached_trials_only(self) -> None: # pyre-fixme[6]: For 1st param expected `Dict[str, Union[None, bool, float, # int, str]]` but got `Dict[str, int]`. trial.add_arm(Arm(parameters=parameter_dict)) - with self.assertRaisesRegex( - UserInputError, "number of pre-attached candidate trials .* is greater than" - ): - scheduler.run_n_trials(max_trials=0) + + # check no new trials are run, when max_trials = 0 + scheduler.run_n_trials(max_trials=0) + self.assertEqual(trial.status, TrialStatus.CANDIDATE) + # check that candidate trial is run, when max_trials = 1 scheduler.run_n_trials(max_trials=1) self.assertEqual(len(scheduler.experiment.trials), 1) self.assertDictEqual( @@ -813,6 +814,43 @@ def test_run_preattached_trials_only(self) -> None: all(t.completed_successfully for t in scheduler.experiment.trials.values()) ) + def test_run_multiple_preattached_trials_only(self) -> None: + gs = self._get_generation_strategy_strategy_for_test( + experiment=self.branin_experiment, + generation_strategy=self.two_sobol_steps_GS, + ) + # assert that pre-attached trials run when max_trials = number of + # pre-attached trials + scheduler = Scheduler( + experiment=self.branin_experiment, # Has runner and metrics. + generation_strategy=gs, + options=SchedulerOptions( + init_seconds_between_polls=0, # Short between polls so test is fast. + trial_type=TrialType.BATCH_TRIAL, + ), + db_settings=self.db_settings_if_always_needed, + ) + trial1 = scheduler.experiment.new_trial() + trial1.add_arm(Arm(parameters={"x1": 5, "x2": 5})) + trial2 = scheduler.experiment.new_trial() + trial2.add_arm(Arm(parameters={"x1": 6, "x2": 3})) + + # check that first candidate trial is run when called with max_trials = 1 + with self.assertLogs(logger="ax.service.scheduler") as lg: + scheduler.run_n_trials(max_trials=1) + self.assertIn( + "Found 1 non-terminal trials on branin_test_experiment: [1]", + lg.output[-1], + ) + self.assertIn(trial1.status, [TrialStatus.RUNNING, TrialStatus.COMPLETED]) + self.assertEqual(trial2.status, TrialStatus.CANDIDATE) + # check that next candidate trial is run, when max_trials = 1 + scheduler.run_n_trials(max_trials=1) + self.assertEqual(len(scheduler.experiment.trials), 2) + self.assertTrue( # Make sure all trials got to complete. + all(t.completed_successfully for t in scheduler.experiment.trials.values()) + ) + def test_global_stopping(self) -> None: gs = self._get_generation_strategy_strategy_for_test( experiment=self.branin_experiment,