From 2d375ab9b5a1f65c1acd4ea0d9c950b479519eef Mon Sep 17 00:00:00 2001 From: Elizabeth Santorella Date: Fri, 20 Dec 2024 14:10:24 -0800 Subject: [PATCH] Use model-recommended best points in ax/service/utils/best_point with discrete models (#3204) Summary: Pull Request resolved: https://github.com/facebook/Ax/pull/3204 Context: `GenResults` has a field `best_observation_features` that winds up used by in ax's best-point functions (ax/service/utils/best_point.py). `ThompsonSampler` has an opinion about the best arm, but doesn't pass that information through to `DiscreteModelBridge`, and `DiscreteModelbridge` doesn't look for the best point if it has been passed. This unblocks benchmarking the bandits GS, which needs a best-point recommendation for computing inference regret. This PR: * Has `DiscreteModelBridge._gen` check `gen_metadata` for a "best_x" and uses it to construct `best_observation_features` if present. * Adds a type annotation * Has `ThompsonSampler.gen` return a "best_x" in the `gen_metadata` field Reviewed By: jelena-markovic Differential Revision: D67532559 fbshipit-source-id: d80d78ca8d9c25d9732021096c61c466c26f1355 --- ax/modelbridge/discrete.py | 19 ++++++++------- .../tests/test_discrete_modelbridge.py | 15 ++++++++---- ax/models/discrete/thompson.py | 23 ++++++++++++------- ax/models/tests/test_thompson.py | 1 + 4 files changed, 37 insertions(+), 21 deletions(-) diff --git a/ax/modelbridge/discrete.py b/ax/modelbridge/discrete.py index e895d8378a1..761cd54da59 100644 --- a/ax/modelbridge/discrete.py +++ b/ax/modelbridge/discrete.py @@ -170,19 +170,22 @@ def _gen( pending_observations=pending_array, model_gen_options=model_gen_options, ) - observation_features = [] - for x in X: - observation_features.append( - ObservationFeatures( - parameters={p: x[i] for i, p in enumerate(self.parameters)} - ) + observation_features = [ + ObservationFeatures(parameters=dict(zip(self.parameters, x))) for x in X + ] + + if "best_x" in gen_metadata: + best_observation_features = ObservationFeatures( + parameters=dict(zip(self.parameters, gen_metadata["best_x"])) ) - # TODO[drfreund, bletham]: implement best_point identification and - # return best_point instead of None + else: + best_observation_features = None + return GenResults( observation_features=observation_features, weights=w, gen_metadata=gen_metadata, + best_observation_features=best_observation_features, ) def _cross_validate( diff --git a/ax/modelbridge/tests/test_discrete_modelbridge.py b/ax/modelbridge/tests/test_discrete_modelbridge.py index 0e528b1b9a4..d89db4a504d 100644 --- a/ax/modelbridge/tests/test_discrete_modelbridge.py +++ b/ax/modelbridge/tests/test_discrete_modelbridge.py @@ -38,8 +38,6 @@ def setUp(self) -> None: ] parameter_constraints = [] - # pyre-fixme[6]: For 1st param expected `List[Parameter]` but got - # `List[Union[ChoiceParameter, FixedParameter]]`. self.search_space = SearchSpace(self.parameters, parameter_constraints) self.observation_features = [ @@ -149,7 +147,12 @@ def test_gen(self, mock_init: Mock) -> None: ma._validate_gen_inputs(n=-1) # Test rest of gen. model = mock.MagicMock(DiscreteModel, autospec=True, instance=True) - model.gen.return_value = ([[0.0, 2.0, 3.0], [1.0, 1.0, 3.0]], [1.0, 2.0], {}) + best_x = [0.0, 2.0, 1.0] + model.gen.return_value = ( + [[0.0, 2.0, 3.0], [1.0, 1.0, 3.0]], + [1.0, 2.0], + {"best_x": best_x}, + ) ma.model = model ma.parameters = ["x", "y", "z"] ma.outcomes = ["a", "b"] @@ -190,10 +193,12 @@ def test_gen(self, mock_init: Mock) -> None: {"x": 1.0, "y": 1.0, "z": 3.0}, ) self.assertEqual(gen_results.weights, [1.0, 2.0]) + self.assertEqual( + gen_results.best_observation_features, + ObservationFeatures(parameters=dict(zip(ma.parameters, best_x))), + ) # Test with no constraints, no fixed feature, no pending observations - # pyre-fixme[6]: For 1st param expected `List[Parameter]` but got - # `List[Union[ChoiceParameter, FixedParameter]]`. search_space = SearchSpace(self.parameters[:2]) optimization_config.outcome_constraints = [] ma.parameters = ["x", "y"] diff --git a/ax/models/discrete/thompson.py b/ax/models/discrete/thompson.py index 27e7e044737..f267d6665bb 100644 --- a/ax/models/discrete/thompson.py +++ b/ax/models/discrete/thompson.py @@ -17,6 +17,7 @@ from ax.models.discrete_base import DiscreteModel from ax.models.types import TConfig from ax.utils.common.docutils import copy_doc +from pyre_extensions import none_throws class ThompsonSampler(DiscreteModel): @@ -46,8 +47,7 @@ def __init__( self.min_weight = min_weight self.uniform_weights = uniform_weights - # pyre-fixme[4]: Attribute must be annotated. - self.X = None + self.X: list[TParamValueList] | None = None # pyre-fixme[4]: Attribute must be annotated. self.Ys = None # pyre-fixme[4]: Attribute must be annotated. @@ -69,7 +69,7 @@ def fit( Ys=Ys, Yvars=Yvars, outcome_names=outcome_names ) self.X_to_Ys_and_Yvars = self._fit_X_to_Ys_and_Yvars( - X=self.X, Ys=self.Ys, Yvars=self.Yvars + X=none_throws(self.X), Ys=self.Ys, Yvars=self.Yvars ) @copy_doc(DiscreteModel.gen) @@ -86,7 +86,7 @@ def gen( if objective_weights is None: raise ValueError("ThompsonSampler requires objective weights.") - arms = self.X + arms = none_throws(self.X) k = len(arms) weights = self._generate_weights( @@ -120,7 +120,14 @@ def gen( top_weights = [ (x * len(top_weights)) / sum(top_weights) for x in top_weights ] - return top_arms, top_weights, {"arms_to_weights": list(zip(arms, weights))} + return ( + top_arms, + top_weights, + { + "arms_to_weights": list(zip(arms, weights)), + "best_x": weighted_arms[0][2], + }, + ) @copy_doc(DiscreteModel.predict) def predict(self, X: list[TParamValueList]) -> tuple[npt.NDArray, npt.NDArray]: @@ -168,14 +175,14 @@ def _generate_weights( num_valid_samples = samples.shape[1] winner_indices = np.argmax(samples, axis=0) # (num_samples,) - winner_counts = np.zeros(len(self.X)) # (k,) + winner_counts = np.zeros(len(none_throws(self.X))) # (k,) for index in winner_indices: winner_counts[index] += 1 weights = winner_counts / winner_counts.sum() return weights.tolist() def _generate_samples_per_metric(self, num_samples: int) -> npt.NDArray: - k = len(self.X) + k = len(none_throws(self.X)) samples_per_metric = np.zeros( (k, num_samples, len(self.Ys)) ) # k x num_samples x m @@ -194,7 +201,7 @@ def _produce_samples( objective_weights: npt.NDArray, outcome_constraints: tuple[npt.NDArray, npt.NDArray] | None, ) -> tuple[npt.NDArray, float]: - k = len(self.X) + k = len(none_throws(self.X)) samples_per_metric = self._generate_samples_per_metric(num_samples=num_samples) any_violation = np.zeros((k, num_samples), dtype=bool) # (k x num_samples) diff --git a/ax/models/tests/test_thompson.py b/ax/models/tests/test_thompson.py index c100251ee96..8a05fcb19f7 100644 --- a/ax/models/tests/test_thompson.py +++ b/ax/models/tests/test_thompson.py @@ -59,6 +59,7 @@ def test_ThompsonSampler(self) -> None: ): self.assertAlmostEqual(weight, expected_weight, 1) self.assertEqual(len(gen_metadata["arms_to_weights"]), 4) + self.assertEqual(gen_metadata["best_x"], arms[0]) def test_ThompsonSamplerValidation(self) -> None: generator = ThompsonSampler(min_weight=0.01)