Skip to content

Commit

Permalink
Rename fast_optimize to mock_optimize (#2599)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #2599

Updates the name to clarify that this is just a mock and does not magically speed up optimization without consequences.

Reviewed By: sdaulton, Balandat

Differential Revision: D65146794

fbshipit-source-id: 8e0ed887bafe8191f305ddbadae3fcdaf0a7e969
  • Loading branch information
saitcakmak authored and facebook-github-bot committed Oct 29, 2024
1 parent ac1e8e2 commit c671077
Show file tree
Hide file tree
Showing 4 changed files with 22 additions and 22 deletions.
4 changes: 2 additions & 2 deletions botorch/test_utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,6 @@
'botorch/'.
"""

from botorch.test_utils.mock import fast_optimize
from botorch.test_utils.mock import mock_optimize

__all__ = ["fast_optimize"]
__all__ = ["mock_optimize"]
16 changes: 8 additions & 8 deletions botorch/test_utils/mock.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,12 @@


@contextmanager
def fast_optimize_context_manager(
def mock_optimize_context_manager(
force: bool = False,
) -> Generator[None, None, None]:
"""A context manager to force botorch to speed up optimization. Currently, the
primary tactic is to force the underlying scipy methods to stop after just one
iteration.
"""A context manager that uses mocks to speed up optimization for testing.
Currently, the primary tactic is to force the underlying scipy methods to stop
after just one iteration.
force: If True will not raise an AssertionError if no mocks are called.
USE RESPONSIBLY.
Expand Down Expand Up @@ -116,17 +116,17 @@ def minimal_gen_os_ics(*args: Any, **kwargs: Any) -> Tensor | None:
):
raise AssertionError(
"No mocks were called in the context manager. Please remove unused "
"fast_optimize_context_manager()."
"mock_optimize_context_manager()."
)


def fast_optimize(f: Callable) -> Callable:
"""Wraps f in the fast_botorch_optimize_context_manager for use as a decorator."""
def mock_optimize(f: Callable) -> Callable:
"""Wraps `f` in `mock_optimize_context_manager` for use as a decorator."""

@wraps(f)
# pyre-fixme[3]: Return type must be annotated.
def inner(*args: Any, **kwargs: Any):
with fast_optimize_context_manager():
with mock_optimize_context_manager():
return f(*args, **kwargs)

return inner
4 changes: 2 additions & 2 deletions test/acquisition/test_input_constructors.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@
from botorch.models.deterministic import FixedSingleSampleModel
from botorch.models.model_list_gp_regression import ModelListGP
from botorch.sampling.normal import IIDNormalSampler, SobolQMCNormalSampler
from botorch.test_utils.mock import fast_optimize
from botorch.test_utils.mock import mock_optimize
from botorch.utils.constraints import get_outcome_constraint_transforms
from botorch.utils.datasets import SupervisedDataset
from botorch.utils.multi_objective.box_decompositions.non_dominated import (
Expand Down Expand Up @@ -1841,7 +1841,7 @@ def setUp(self, suppress_input_warnings: bool = True) -> None:
},
)

@fast_optimize
@mock_optimize
def test_constructors_can_instantiate(self) -> None:
for key, (classes, input_constructor_kwargs) in self.cases.items():
with self.subTest(
Expand Down
20 changes: 10 additions & 10 deletions test/test_utils/test_mock.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
get_nearest_neighbors,
optimize_acqf_mixed_alternating,
)
from botorch.test_utils.mock import fast_optimize, fast_optimize_context_manager
from botorch.test_utils.mock import mock_optimize, mock_optimize_context_manager
from botorch.utils.testing import BotorchTestCase, MockAcquisitionFunction


Expand All @@ -37,14 +37,14 @@ def __call__(self, X):


class TestMock(BotorchTestCase):
def test_fast_optimize_context_manager(self) -> None:
def test_mock_optimize_context_manager(self) -> None:
with self.subTest("gen_candidates_scipy"):
with fast_optimize_context_manager():
with mock_optimize_context_manager():
cand, value = gen_candidates_scipy(
initial_conditions=torch.tensor([[0.0]]),
acquisition_function=SinAcqusitionFunction(),
)
# When not using `fast_optimize`, the value is 1.0. With it, the value is
# When not using `mock_optimize`, the value is 1.0. With it, the value is
# around 0.84
self.assertLess(value.item(), 0.99)

Expand All @@ -54,14 +54,14 @@ def test_fast_optimize_context_manager(self) -> None:
def closure():
return torch.sin(x), [torch.cos(x)]

with fast_optimize_context_manager():
with mock_optimize_context_manager():
result = scipy_minimize(closure=closure, parameters={"x": x})
self.assertEqual(
result.message, "STOP: TOTAL NO. of ITERATIONS REACHED LIMIT"
)

with self.subTest("optimize_acqf"):
with fast_optimize_context_manager():
with mock_optimize_context_manager():
cand, value = optimize_acqf(
acq_function=SinAcqusitionFunction(),
bounds=torch.tensor([[-2.0], [2.0]]),
Expand All @@ -72,7 +72,7 @@ def closure():
self.assertLess(value.item(), 0.99)

with self.subTest("gen_batch_initial_conditions"):
with fast_optimize_context_manager(), patch(
with mock_optimize_context_manager(), patch(
"botorch.optim.initializers.initialize_q_batch",
wraps=initialize_q_batch,
) as mock_init_q_batch:
Expand All @@ -85,7 +85,7 @@ def closure():
)
self.assertEqual(mock_init_q_batch.call_args[1]["n"], 2)

def test_fast_optimize_mixed_alternating(self) -> None:
def test_mock_optimize_mixed_alternating(self) -> None:
with patch(
"botorch.optim.optimize_mixed.discrete_step",
wraps=discrete_step,
Expand All @@ -110,7 +110,7 @@ def test_fast_optimize_mixed_alternating(self) -> None:
# `mock_discrete`, which should total to 1.
mock_neighbors.assert_called_once()

@fast_optimize
@mock_optimize
def test_decorator(self) -> None:
model = SingleTaskGP(
train_X=torch.tensor([[0.0]], dtype=torch.double),
Expand All @@ -137,5 +137,5 @@ def test_decorator(self) -> None:

def test_raises_when_unused(self) -> None:
with self.assertRaisesRegex(AssertionError, "No mocks were called"):
with fast_optimize_context_manager():
with mock_optimize_context_manager():
pass

0 comments on commit c671077

Please sign in to comment.