Skip to content

Commit

Permalink
Add cache_root option for qNEI in get_acquisition_function (#1608)
Browse files Browse the repository at this point in the history
Summary:
<!--
Thank you for sending the PR! We appreciate you spending the time to make BoTorch better.

Help us understand your motivation by explaining why you decided to make this change.

You can learn more about contributing to BoTorch here: https://github.com/pytorch/botorch/blob/main/CONTRIBUTING.md
-->

## Motivation

As discussed in #1604, this PR adds the possibility to setup qNEI with `cache_root=False` via the `get_acquistion` method, as it is already possible for qNEHVI.

### Have you read the [Contributing Guidelines on pull requests](https://github.com/pytorch/botorch/blob/main/CONTRIBUTING.md#pull-requests)?

Yes.

Pull Request resolved: #1608

Test Plan: Unit tests.

Reviewed By: Balandat

Differential Revision: D42346514

Pulled By: saitcakmak

fbshipit-source-id: 63d010c17cdca4147b7efe2cce5dc5cb62da4caa
  • Loading branch information
jduerholt authored and facebook-github-bot committed Jan 4, 2023
1 parent 056e657 commit 9cd4dea
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 0 deletions.
1 change: 1 addition & 0 deletions botorch/acquisition/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,7 @@ def get_acquisition_function(
X_pending=X_pending,
prune_baseline=kwargs.get("prune_baseline", False),
marginalize_dim=kwargs.get("marginalize_dim"),
cache_root=kwargs.get("cache_root", True),
)
elif acquisition_function_name == "qSR":
return monte_carlo.qSimpleRegret(
Expand Down
17 changes: 17 additions & 0 deletions test/acquisition/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,23 @@ def test_GetQNEI(self, mock_acqf):
self.assertEqual(sampler.sample_shape, torch.Size([self.mc_samples]))
self.assertEqual(sampler.seed, 1)
self.assertEqual(kwargs["marginalize_dim"], 0)
self.assertEqual(kwargs["cache_root"], True)
# test with cache_root = False
acqf = get_acquisition_function(
acquisition_function_name="qNEI",
model=self.model,
objective=self.objective,
X_observed=self.X_observed,
X_pending=self.X_pending,
mc_samples=self.mc_samples,
seed=self.seed,
marginalize_dim=0,
cache_root=False,
)
self.assertTrue(acqf == mock_acqf.return_value)
self.assertTrue(mock_acqf.call_count, 1)
args, kwargs = mock_acqf.call_args
self.assertEqual(kwargs["cache_root"], False)
# test with non-qmc, no X_pending
acqf = get_acquisition_function(
acquisition_function_name="qNEI",
Expand Down

0 comments on commit 9cd4dea

Please sign in to comment.