Skip to content

Commit

Permalink
Check that choices is non-empty in optimize_acqf_discrete (#1228)
Browse files Browse the repository at this point in the history
Summary:
Otherwise this produces confusing tensor dimension errors if no chocies are given.
Also applies ufmt style throughout.

Pull Request resolved: #1228

Reviewed By: bernardbeckerman

Differential Revision: D36480928

Pulled By: Balandat

fbshipit-source-id: 39d8b55ef4666265c1657525326528b8f2c1f4b2
  • Loading branch information
Balandat authored and facebook-github-bot committed May 18, 2022
1 parent 1b92ee8 commit 0afcf35
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 2 deletions.
4 changes: 3 additions & 1 deletion botorch/optim/optimize.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
OneShotAcquisitionFunction,
)
from botorch.acquisition.knowledge_gradient import qKnowledgeGradient
from botorch.exceptions import UnsupportedError
from botorch.exceptions import InputDataError, UnsupportedError
from botorch.generation.gen import gen_candidates_scipy
from botorch.logging import logger
from botorch.optim.initializers import (
Expand Down Expand Up @@ -600,6 +600,8 @@ def optimize_acqf_discrete(
"Discrete optimization is not supported for"
"one-shot acquisition functions."
)
if choices.numel() == 0:
raise InputDataError("`choices` must be non-emtpy.")
choices_batched = choices.unsqueeze(-2)
if q > 1:
candidate_list, acq_value_list = [], []
Expand Down
10 changes: 9 additions & 1 deletion test/optim/test_optimize.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
AcquisitionFunction,
OneShotAcquisitionFunction,
)
from botorch.exceptions import UnsupportedError
from botorch.exceptions import InputDataError, UnsupportedError
from botorch.optim.optimize import (
_filter_infeasible,
_filter_invalid,
Expand Down Expand Up @@ -880,6 +880,14 @@ def test_optimize_acqf_discrete(self):
mock_acq_function = SquaredAcquisitionFunction()
mock_acq_function.set_X_pending(None)

# ensure proper raising of errors if no choices
with self.assertRaisesRegex(InputDataError, "`choices` must be non-emtpy."):
optimize_acqf_discrete(
acq_function=mock_acq_function,
q=q,
choices=torch.empty(0, 2),
)

choices = torch.rand(5, 2, **tkwargs)
exp_acq_vals = mock_acq_function(choices)

Expand Down

0 comments on commit 0afcf35

Please sign in to comment.