Skip to content

Commit

Permalink
Fixed bug where optimize_acqf didn't work with different batch sizes (p…
Browse files Browse the repository at this point in the history
…ytorch#1414)

Summary: Pull Request resolved: facebook/Ax#1414

Differential Revision: D43178298

fbshipit-source-id: 05441f9140d7971439ed931d66e78a9cb0d9aebf
  • Loading branch information
esantorella authored and facebook-github-bot committed Feb 10, 2023
1 parent 89f923d commit c8c4ae1
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 1 deletion.
6 changes: 5 additions & 1 deletion botorch/optim/optimize.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,7 +342,11 @@ def _optimize_batch_candidates(
logger.info(f"Generated candidate batch {i+1} of {len(batched_ics)}.")

batch_candidates = torch.cat(batch_candidates_list)
batch_acq_values = torch.stack(batch_acq_values_list).flatten()
has_scalars = batch_acq_values_list[0].ndim == 0
if has_scalars:
batch_acq_values = torch.stack(batch_acq_values_list).flatten()
else:
batch_acq_values = torch.cat(batch_acq_values_list).flatten()
return batch_candidates, batch_acq_values, opt_warnings

batch_candidates, batch_acq_values, ws = _optimize_batch_candidates(timeout_sec)
Expand Down
20 changes: 20 additions & 0 deletions test/optim/test_optimize.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,6 +382,26 @@ def test_optimize_acqf_sequential_notimplemented(self):
sequential=True,
)

def test_optimize_acqf_batch_limit(self) -> None:
num_restarts = 3
raw_samples = 5
dim = 4
q = 4
batch_limit = 2

options = {"batch_limit": batch_limit}

_, acq_value_list = optimize_acqf(
acq_function=SinOneOverXAcqusitionFunction(),
bounds=torch.stack([-1 * torch.ones(dim), torch.ones(dim)]),
q=q,
num_restarts=num_restarts,
raw_samples=raw_samples,
options=options,
return_best_only=False,
)
self.assertEqual(acq_value_list.shape, (num_restarts,))

def test_optimize_acqf_runs_given_batch_initial_conditions(self):
num_restarts, raw_samples, dim = 1, 2, 3

Expand Down

0 comments on commit c8c4ae1

Please sign in to comment.