Skip to content

Commit

Permalink
Remove mocks from TorchModelBridgeTest.test_best_point (#3228)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #3228

Context: I'm removing mocks (that do not use `wraps`) from Ax unit tests. There were many in this function. However, I a mock of `gen` in place since this is a ModelBridge test that is checking if the correct arguments get passed to and from `Model`, where the generating happens.

Reviewed By: saitcakmak

Differential Revision: D68041762

fbshipit-source-id: 44fc70552f01c1ddc2a50798e68abacae24cb9d9
  • Loading branch information
esantorella authored and facebook-github-bot committed Jan 13, 2025
1 parent bdceb7f commit 9c0e511
Showing 1 changed file with 32 additions and 48 deletions.
80 changes: 32 additions & 48 deletions ax/modelbridge/tests/test_torch_modelbridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@
get_search_space_for_range_values,
)
from ax.utils.testing.mock import mock_botorch_optimize
from ax.utils.testing.modeling_stubs import get_observation1, transform_1, transform_2
from ax.utils.testing.modeling_stubs import transform_1, transform_2
from botorch.utils.datasets import (
ContextualDataset,
MultiTaskDataset,
Expand Down Expand Up @@ -428,44 +428,13 @@ def test_evaluate_acquisition_function(self) -> None:
)
)

@mock.patch(
f"{ModelBridge.__module__}.unwrap_observation_data",
autospec=True,
return_value=(2, 2),
)
@mock.patch(
f"{ModelBridge.__module__}.gen_arms",
autospec=True,
return_value=([Arm(parameters={})], {}),
)
@mock.patch(
f"{ModelBridge.__module__}.ModelBridge.predict",
autospec=True,
return_value=({"m": [1.0]}, {"m": {"m": [2.0]}}),
)
@mock.patch(f"{TorchModelBridge.__module__}.TorchModelBridge._fit", autospec=True)
@mock.patch(
f"{TorchModel.__module__}.TorchModel.gen",
return_value=TorchGenResults(
points=torch.tensor([[1]]),
weights=torch.tensor([1.0]),
),
autospec=True,
)
def test_best_point(
self,
_mock_gen,
_mock_fit,
_mock_predict,
_mock_gen_arms,
_mock_unwrap,
) -> None:
exp = Experiment(search_space=get_search_space_for_range_value(), name="test")
def test_best_point(self) -> None:
search_space = get_search_space_for_range_value()
exp = Experiment(search_space=search_space, name="test")
oc = OptimizationConfig(
objective=Objective(metric=Metric("a"), minimize=False),
outcome_constraints=[],
)
search_space = get_search_space_for_range_value()
modelbridge = TorchModelBridge(
search_space=search_space,
model=TorchModel(),
Expand All @@ -484,25 +453,40 @@ def test_best_point(
modelbridge.parameters = list(search_space.parameters.keys())
modelbridge.outcomes = ["a"]

mean = 1.0
cov = 2.0
predict_return_value = ({"m": [mean]}, {"m": {"m": [cov]}})
best_point_value = 25
gen_return_value = TorchGenResults(
points=torch.tensor([[1.0]]), weights=torch.tensor([1.0])
)
with mock.patch(
f"{TorchModel.__module__}.TorchModel.best_point",
return_value=torch.tensor([1.0]),
return_value=torch.tensor([best_point_value]),
autospec=True,
):
run = modelbridge.gen(n=1, optimization_config=oc)
arm, predictions = none_throws(run.best_arm_predictions)
model_arm, model_predictions = none_throws(modelbridge.model_best_point())
predictions = none_throws(predictions)
model_predictions = none_throws(model_predictions)
self.assertEqual(arm.parameters, {})
self.assertEqual(predictions[0], {"m": 1.0})
self.assertEqual(predictions[1], {"m": {"m": 2.0}})
self.assertEqual(model_predictions[0], {"m": 1.0})
self.assertEqual(model_predictions[1], {"m": {"m": 2.0}})
), mock.patch.object(modelbridge, "predict", return_value=predict_return_value):
with mock.patch.object(
modelbridge.model, "gen", return_value=gen_return_value
):
run = modelbridge.gen(n=1, optimization_config=oc)

_, model_predictions = none_throws(modelbridge.model_best_point())

arm, predictions = none_throws(run.best_arm_predictions)
predictions = none_throws(predictions)
model_predictions = none_throws(model_predictions)
# The transforms add one and square, and need to be reversed
self.assertEqual(arm.parameters, {"x": (best_point_value**0.5) - 1})
# Gets clamped to the search space
self.assertEqual(run.arms[0].parameters, {"x": 3.0})
self.assertEqual(predictions[0], {"m": mean})
self.assertEqual(predictions[1], {"m": {"m": cov}})
self.assertEqual(model_predictions[0], {"m": mean})
self.assertEqual(model_predictions[1], {"m": {"m": cov}})

# test optimization config validation - raise error when
# ScalarizedOutcomeConstraint contains a metric that is not in the outcomes
with self.assertRaises(ValueError):
with self.assertRaisesRegex(ValueError, "is a relative constraint."):
modelbridge.gen(
n=1,
optimization_config=OptimizationConfig(
Expand Down

0 comments on commit 9c0e511

Please sign in to comment.