Skip to content

Commit

Permalink
Use model-recommended best points in ax/service/utils/best_point with…
Browse files Browse the repository at this point in the history
… discrete models (#3204)

Summary:
Pull Request resolved: #3204

Context: `GenResults` has a field `best_observation_features` that winds up used by in ax's best-point functions (ax/service/utils/best_point.py). `ThompsonSampler` has an opinion about the best arm, but doesn't pass that information through to `DiscreteModelBridge`, and `DiscreteModelbridge` doesn't look for the best point if it has been passed.

This unblocks benchmarking the bandits GS, which needs a best-point recommendation for computing inference regret.

This PR:
* Has `DiscreteModelBridge._gen` check `gen_metadata` for a "best_x" and uses it to construct `best_observation_features` if present.
* Adds a type annotation
* Has `ThompsonSampler.gen` return a "best_x" in the `gen_metadata` field

Reviewed By: jelena-markovic

Differential Revision: D67532559

fbshipit-source-id: d80d78ca8d9c25d9732021096c61c466c26f1355
  • Loading branch information
esantorella authored and facebook-github-bot committed Dec 20, 2024
1 parent 1e16d48 commit 2d375ab
Show file tree
Hide file tree
Showing 4 changed files with 37 additions and 21 deletions.
19 changes: 11 additions & 8 deletions ax/modelbridge/discrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,19 +170,22 @@ def _gen(
pending_observations=pending_array,
model_gen_options=model_gen_options,
)
observation_features = []
for x in X:
observation_features.append(
ObservationFeatures(
parameters={p: x[i] for i, p in enumerate(self.parameters)}
)
observation_features = [
ObservationFeatures(parameters=dict(zip(self.parameters, x))) for x in X
]

if "best_x" in gen_metadata:
best_observation_features = ObservationFeatures(
parameters=dict(zip(self.parameters, gen_metadata["best_x"]))
)
# TODO[drfreund, bletham]: implement best_point identification and
# return best_point instead of None
else:
best_observation_features = None

return GenResults(
observation_features=observation_features,
weights=w,
gen_metadata=gen_metadata,
best_observation_features=best_observation_features,
)

def _cross_validate(
Expand Down
15 changes: 10 additions & 5 deletions ax/modelbridge/tests/test_discrete_modelbridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,6 @@ def setUp(self) -> None:
]
parameter_constraints = []

# pyre-fixme[6]: For 1st param expected `List[Parameter]` but got
# `List[Union[ChoiceParameter, FixedParameter]]`.
self.search_space = SearchSpace(self.parameters, parameter_constraints)

self.observation_features = [
Expand Down Expand Up @@ -149,7 +147,12 @@ def test_gen(self, mock_init: Mock) -> None:
ma._validate_gen_inputs(n=-1)
# Test rest of gen.
model = mock.MagicMock(DiscreteModel, autospec=True, instance=True)
model.gen.return_value = ([[0.0, 2.0, 3.0], [1.0, 1.0, 3.0]], [1.0, 2.0], {})
best_x = [0.0, 2.0, 1.0]
model.gen.return_value = (
[[0.0, 2.0, 3.0], [1.0, 1.0, 3.0]],
[1.0, 2.0],
{"best_x": best_x},
)
ma.model = model
ma.parameters = ["x", "y", "z"]
ma.outcomes = ["a", "b"]
Expand Down Expand Up @@ -190,10 +193,12 @@ def test_gen(self, mock_init: Mock) -> None:
{"x": 1.0, "y": 1.0, "z": 3.0},
)
self.assertEqual(gen_results.weights, [1.0, 2.0])
self.assertEqual(
gen_results.best_observation_features,
ObservationFeatures(parameters=dict(zip(ma.parameters, best_x))),
)

# Test with no constraints, no fixed feature, no pending observations
# pyre-fixme[6]: For 1st param expected `List[Parameter]` but got
# `List[Union[ChoiceParameter, FixedParameter]]`.
search_space = SearchSpace(self.parameters[:2])
optimization_config.outcome_constraints = []
ma.parameters = ["x", "y"]
Expand Down
23 changes: 15 additions & 8 deletions ax/models/discrete/thompson.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from ax.models.discrete_base import DiscreteModel
from ax.models.types import TConfig
from ax.utils.common.docutils import copy_doc
from pyre_extensions import none_throws


class ThompsonSampler(DiscreteModel):
Expand Down Expand Up @@ -46,8 +47,7 @@ def __init__(
self.min_weight = min_weight
self.uniform_weights = uniform_weights

# pyre-fixme[4]: Attribute must be annotated.
self.X = None
self.X: list[TParamValueList] | None = None
# pyre-fixme[4]: Attribute must be annotated.
self.Ys = None
# pyre-fixme[4]: Attribute must be annotated.
Expand All @@ -69,7 +69,7 @@ def fit(
Ys=Ys, Yvars=Yvars, outcome_names=outcome_names
)
self.X_to_Ys_and_Yvars = self._fit_X_to_Ys_and_Yvars(
X=self.X, Ys=self.Ys, Yvars=self.Yvars
X=none_throws(self.X), Ys=self.Ys, Yvars=self.Yvars
)

@copy_doc(DiscreteModel.gen)
Expand All @@ -86,7 +86,7 @@ def gen(
if objective_weights is None:
raise ValueError("ThompsonSampler requires objective weights.")

arms = self.X
arms = none_throws(self.X)
k = len(arms)

weights = self._generate_weights(
Expand Down Expand Up @@ -120,7 +120,14 @@ def gen(
top_weights = [
(x * len(top_weights)) / sum(top_weights) for x in top_weights
]
return top_arms, top_weights, {"arms_to_weights": list(zip(arms, weights))}
return (
top_arms,
top_weights,
{
"arms_to_weights": list(zip(arms, weights)),
"best_x": weighted_arms[0][2],
},
)

@copy_doc(DiscreteModel.predict)
def predict(self, X: list[TParamValueList]) -> tuple[npt.NDArray, npt.NDArray]:
Expand Down Expand Up @@ -168,14 +175,14 @@ def _generate_weights(
num_valid_samples = samples.shape[1]

winner_indices = np.argmax(samples, axis=0) # (num_samples,)
winner_counts = np.zeros(len(self.X)) # (k,)
winner_counts = np.zeros(len(none_throws(self.X))) # (k,)
for index in winner_indices:
winner_counts[index] += 1
weights = winner_counts / winner_counts.sum()
return weights.tolist()

def _generate_samples_per_metric(self, num_samples: int) -> npt.NDArray:
k = len(self.X)
k = len(none_throws(self.X))
samples_per_metric = np.zeros(
(k, num_samples, len(self.Ys))
) # k x num_samples x m
Expand All @@ -194,7 +201,7 @@ def _produce_samples(
objective_weights: npt.NDArray,
outcome_constraints: tuple[npt.NDArray, npt.NDArray] | None,
) -> tuple[npt.NDArray, float]:
k = len(self.X)
k = len(none_throws(self.X))
samples_per_metric = self._generate_samples_per_metric(num_samples=num_samples)

any_violation = np.zeros((k, num_samples), dtype=bool) # (k x num_samples)
Expand Down
1 change: 1 addition & 0 deletions ax/models/tests/test_thompson.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ def test_ThompsonSampler(self) -> None:
):
self.assertAlmostEqual(weight, expected_weight, 1)
self.assertEqual(len(gen_metadata["arms_to_weights"]), 4)
self.assertEqual(gen_metadata["best_x"], arms[0])

def test_ThompsonSamplerValidation(self) -> None:
generator = ThompsonSampler(min_weight=0.01)
Expand Down

0 comments on commit 2d375ab

Please sign in to comment.