diff --git a/ax/benchmark/benchmark.py b/ax/benchmark/benchmark.py index 0503c009e0b..2412f98d19f 100644 --- a/ax/benchmark/benchmark.py +++ b/ax/benchmark/benchmark.py @@ -32,6 +32,7 @@ from ax.core.experiment import Experiment from ax.core.utils import get_model_times from ax.service.scheduler import Scheduler +from ax.service.utils.best_point_mixin import BestPointMixin from ax.utils.common.logger import get_logger from ax.utils.common.random import with_rng_seed @@ -116,7 +117,15 @@ def benchmark_replication( with with_rng_seed(seed=seed): scheduler.run_n_trials(max_trials=problem.num_trials) - optimization_trace = problem.get_opt_trace(experiment=experiment) + oracle_experiment = problem.get_oracle_experiment_from_experiment( + experiment=experiment + ) + optimization_trace = np.array( + BestPointMixin._get_trace( + experiment=oracle_experiment, + optimization_config=problem.optimization_config, + ) + ) try: # Catch any errors that may occur during score computation, such as errors diff --git a/ax/benchmark/benchmark_problem.py b/ax/benchmark/benchmark_problem.py index 7ef84519c43..ea8e21313bb 100644 --- a/ax/benchmark/benchmark_problem.py +++ b/ax/benchmark/benchmark_problem.py @@ -5,10 +5,10 @@ # pyre-strict +from collections.abc import Mapping from dataclasses import dataclass, field from typing import Any, Optional, Union -import numpy as np import pandas as pd from ax.benchmark.benchmark_metric import BenchmarkMetric @@ -25,9 +25,8 @@ from ax.core.outcome_constraint import OutcomeConstraint from ax.core.parameter import ParameterType, RangeParameter from ax.core.search_space import SearchSpace -from ax.core.types import ComparisonOp +from ax.core.types import ComparisonOp, TParamValue from ax.modelbridge.modelbridge_utils import extract_search_space_digest -from ax.service.utils.best_point_mixin import BestPointMixin from ax.utils.common.base import Base from botorch.test_functions.base import ( BaseTestProblem, @@ -86,55 +85,85 @@ class BenchmarkProblem(Base): search_space: SearchSpace = field(repr=False) runner: BenchmarkRunner = field(repr=False) - def get_oracle_experiment(self, experiment: Experiment) -> Experiment: + def get_oracle_experiment_from_params( + self, + dict_of_dict_of_params: Mapping[int, Mapping[str, [Mapping[str, TParamValue]]]], + ) -> Experiment: + """ + Get a new experiment with the same search space and optimization config + as those belonging to this problem, but with parameterizations evaluated + at oracle values. + + Args: + dict_of_dict_of_params: Keys are trial indices, values are Mappings + (e.g. dicts) that map arm names to parameterizations. + + Example: + >>> problem.get_oracle_experiment_from_params( + ... { + ... 0: { + ... "0_0": {"x0": 0.0, "x1": 0.0}, + ... "0_1": {"x0": 0.3, "x1": 0.4}, + ... }, + ... 1: {"1_0": {"x0": 0.0, "x1": 0.0}}, + ... } + ... ) + """ records = [] - new_experiment = Experiment( + experiment = Experiment( search_space=self.search_space, optimization_config=self.optimization_config ) - for trial_index, trial in experiment.trials.items(): - for arm in trial.arms: + if len(dict_of_dict_of_params) == 0: + return experiment + + for trial_index, dict_of_params in dict_of_dict_of_params.items(): + if len(dict_of_params) == 0: + raise ValueError( + "Can't create a trial with no arms. Each sublist in " + "list_of_list_of_params must have at least one element." + ) + for arm_name, params in dict_of_params.items(): for metric_name, metric_value in zip( self.runner.outcome_names, - self.runner.evaluate_oracle(parameters=arm.parameters), + self.runner.evaluate_oracle(parameters=params), ): records.append( { - "arm_name": arm.name, + "arm_name": arm_name, "metric_name": metric_name, - "mean": metric_value.item(), + "mean": metric_value, "sem": 0.0, "trial_index": trial_index, } ) - new_experiment.attach_trial( - parameterizations=[arm.parameters for arm in trial.arms], - arm_names=[arm.name for arm in trial.arms], + experiment.attach_trial( + parameterizations=list(dict_of_params.values()), + arm_names=list(dict_of_params.keys()), ) - for trial in new_experiment.trials.values(): + for trial in experiment.trials.values(): trial.mark_completed() data = Data(df=pd.DataFrame.from_records(records)) - new_experiment.attach_data(data=data, overwrite_existing_data=True) - return new_experiment + experiment.attach_data(data=data, overwrite_existing_data=True) + return experiment + + def get_oracle_experiment_from_experiment( + self, experiment: Experiment + ) -> Experiment: + return self.get_oracle_experiment_from_params( + dict_of_dict_of_params={ + trial.index: {arm.name: arm.parameters for arm in trial.arms} + for trial in experiment.trials.values() + } + ) @property def is_moo(self) -> bool: """Whether the problem is multi-objective.""" return isinstance(self.optimization_config, MultiObjectiveOptimizationConfig) - def get_opt_trace(self, experiment: Experiment) -> np.ndarray: - """Evaluate the optimization trace of a list of Trials.""" - oracle_experiment = self.get_oracle_experiment(experiment=experiment) - - return np.array( - BestPointMixin._get_trace( - experiment=oracle_experiment, - optimization_config=self.optimization_config, - ) - ) - def _get_constraints( num_constraints: int, observe_noise_sd: bool diff --git a/ax/benchmark/tests/test_benchmark_problem.py b/ax/benchmark/tests/test_benchmark_problem.py index 8d0bf0d94ab..2ebf123812e 100644 --- a/ax/benchmark/tests/test_benchmark_problem.py +++ b/ax/benchmark/tests/test_benchmark_problem.py @@ -6,6 +6,7 @@ # pyre-strict import math +from math import pi from typing import Optional, Union import torch @@ -22,6 +23,7 @@ from ax.core.types import ComparisonOp from ax.utils.common.testutils import TestCase from ax.utils.common.typeutils import checked_cast +from ax.utils.testing.core_stubs import get_branin_experiment from botorch.test_functions.base import ConstrainedBaseTestProblem from botorch.test_functions.multi_fidelity import AugmentedBranin from botorch.test_functions.multi_objective import BraninCurrin, ConstrainedBraninCurrin @@ -322,3 +324,72 @@ def test_maximization_problem(self) -> None: test_problem_kwargs={}, ) self.assertFalse(test_problem.optimization_config.objective.minimize) + + def test_get_oracle_experiment_from_params(self) -> None: + problem = create_problem_from_botorch( + test_problem_class=Branin, + test_problem_kwargs={}, + num_trials=5, + ) + # first is near optimum + near_opt_params = {"x0": -pi, "x1": 12.275} + other_params = {"x0": 0.5, "x1": 0.5} + unbatched_experiment = problem.get_oracle_experiment_from_params( + {0: {"0": near_opt_params}, 1: {"1": other_params}} + ) + self.assertEqual(len(unbatched_experiment.trials), 2) + self.assertTrue( + all(t.status.is_completed for t in unbatched_experiment.trials.values()) + ) + self.assertTrue( + all(len(t.arms) == 1 for t in unbatched_experiment.trials.values()) + ) + df = unbatched_experiment.fetch_data().df + self.assertAlmostEqual(df["mean"].iloc[0], Branin._optimal_value, places=5) + + batched_experiment = problem.get_oracle_experiment_from_params( + {0: {"0_0": near_opt_params, "0_1": other_params}} + ) + self.assertEqual(len(batched_experiment.trials), 1) + self.assertEqual(len(batched_experiment.trials[0].arms), 2) + df = batched_experiment.fetch_data().df + self.assertAlmostEqual(df["mean"].iloc[0], Branin._optimal_value, places=5) + + # Test empty inputs + experiment = problem.get_oracle_experiment_from_params({}) + self.assertEqual(len(experiment.trials), 0) + + with self.assertRaisesRegex(ValueError, "trial with no arms"): + problem.get_oracle_experiment_from_params({0: {}}) + + def test_get_oracle_experiment_from_experiment(self) -> None: + problem = create_problem_from_botorch( + test_problem_class=Branin, + test_problem_kwargs={"negate": True}, + num_trials=5, + ) + + # empty experiment + empty_experiment = get_branin_experiment(with_trial=False) + oracle_experiment = problem.get_oracle_experiment_from_experiment( + empty_experiment + ) + self.assertEqual(oracle_experiment.search_space, problem.search_space) + self.assertEqual( + oracle_experiment.optimization_config, problem.optimization_config + ) + self.assertEqual(oracle_experiment.trials.keys(), set()) + + experiment = get_branin_experiment( + with_trial=True, + search_space=problem.search_space, + with_status_quo=False, + ) + oracle_experiment = problem.get_oracle_experiment_from_experiment( + experiment=experiment + ) + self.assertEqual(oracle_experiment.search_space, problem.search_space) + self.assertEqual( + oracle_experiment.optimization_config, problem.optimization_config + ) + self.assertEqual(oracle_experiment.trials.keys(), experiment.trials.keys())