Skip to content

Commit

Permalink
Allow for passing model_gen_kwargs in benchmarks
Browse files Browse the repository at this point in the history
Summary: This makes it easier to test changing defaults.

Reviewed By: sdaulton

Differential Revision: D55844014
  • Loading branch information
esantorella authored and facebook-github-bot committed Apr 8, 2024
1 parent 900066f commit a449ab6
Show file tree
Hide file tree
Showing 2 changed files with 80 additions and 2 deletions.
46 changes: 44 additions & 2 deletions ax/benchmark/methods/modular_botorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,14 @@

# pyre-strict

from typing import Dict, Optional, Type, Union
from typing import Any, Dict, Optional, Type, Union

from ax.benchmark.benchmark_method import (
BenchmarkMethod,
get_benchmark_scheduler_options,
)
from ax.modelbridge.generation_strategy import GenerationStep, GenerationStrategy
from ax.modelbridge.generation_node import GenerationStep
from ax.modelbridge.generation_strategy import GenerationStrategy
from ax.modelbridge.registry import Models
from ax.models.torch.botorch_modular.model import SurrogateSpec
from ax.service.scheduler import SchedulerOptions
Expand Down Expand Up @@ -46,7 +47,47 @@ def get_sobol_botorch_modular_acquisition(
scheduler_options: Optional[SchedulerOptions] = None,
name: Optional[str] = None,
num_sobol_trials: int = 5,
model_gen_kwargs: Optional[Dict[str, Any]] = None,
) -> BenchmarkMethod:
"""Get a `BenchmarkMethod` that uses Sobol followed by MBM.
Args:
model_cls: BoTorch model class, e.g. SingleTaskGP
acquisition_cls: Acquisition function class, e.g.
`qLogNoisyExpectedImprovement`.
distribute_replications: Whether to use multiple machines
scheduler_options: Passed as-is to scheduler. Default:
`get_benchmark_scheduler_options()`.
name: Name that will be attached to the `GenerationStrategy`.
num_sobol_trials: Number of Sobol trials; if the scheduler_options
specify to use `BatchTrial`s, then this refers to the number of
`BatchTrial`s.
model_gen_kwargs: Passed to the BoTorch `GenerationStep` and ultimately
to the BoTorch `Model`.
Example:
>>> # A simple example
>>> from ax.benchmark.methods.sobol_botorch_modular import (
... get_sobol_botorch_modular_acquisition
... )
>>> from ax.benchmark.benchmark_method import get_benchmark_scheduler_options
>>>
>>> method = get_sobol_botorch_modular_acquisition(
... model_cls=SingleTaskGP,
... acquisition_cls=qLogNoisyExpectedImprovement,
... distribute_replications=False,
... )
>>> # Run trials in batches of 5
>>> batch_method = get_sobol_botorch_modular_acquisition(
... model_cls=SingleTaskGP,
... acquisition_cls=qLogNoisyExpectedImprovement,
... distribute_replications=False,
... scheduler_options=get_benchmark_scheduler_options(
... sequential=False, batch_size=5,
... ),
... num_sobol_trials=1,
... )
"""
model_kwargs: Dict[
str, Union[Type[AcquisitionFunction], Dict[str, SurrogateSpec], bool]
] = {
Expand Down Expand Up @@ -82,6 +123,7 @@ def get_sobol_botorch_modular_acquisition(
model=Models.BOTORCH_MODULAR,
num_trials=-1,
model_kwargs=model_kwargs,
model_gen_kwargs=model_gen_kwargs,
),
],
)
Expand Down
36 changes: 36 additions & 0 deletions ax/benchmark/tests/test_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
# pyre-strict

import tempfile
from unittest.mock import patch

import numpy as np
from ax.benchmark.benchmark import (
Expand Down Expand Up @@ -51,10 +52,45 @@
)
from botorch.models.fully_bayesian import SaasFullyBayesianSingleTaskGP
from botorch.models.gp_regression import FixedNoiseGP, SingleTaskGP
from botorch.optim.optimize import optimize_acqf
from botorch.test_functions.synthetic import Branin


class TestBenchmark(TestCase):
@fast_botorch_optimize
def test_batch(self) -> None:
batch_size = 5

problem = get_problem("ackley4", num_trials=2)
batch_options = get_benchmark_scheduler_options(batch_size=batch_size)
for sequential in [False, True]:
with self.subTest(sequential=sequential):
batch_method_joint = get_sobol_botorch_modular_acquisition(
model_cls=SingleTaskGP,
acquisition_cls=qLogNoisyExpectedImprovement,
scheduler_options=batch_options,
distribute_replications=False,
model_gen_kwargs={
"model_gen_options": {
"optimizer_kwargs": {"sequential": sequential}
}
},
num_sobol_trials=1,
)
# this is generating more calls to optimize_acqf than expected
with patch(
"ax.models.torch.botorch_modular.acquisition.optimize_acqf",
wraps=optimize_acqf,
) as mock_optimize_acqf:
benchmark_one_method_problem(
problem=problem, method=batch_method_joint, seeds=[0]
)
mock_optimize_acqf.assert_called_once()
self.assertEqual(
mock_optimize_acqf.call_args.kwargs["sequential"], sequential
)
self.assertEqual(mock_optimize_acqf.call_args.kwargs["q"], batch_size)

def test_storage(self) -> None:
problem = get_single_objective_benchmark_problem()
res = benchmark_replication(
Expand Down

0 comments on commit a449ab6

Please sign in to comment.