Skip to content

Commit

Permalink
Fix shape error in optimize_acqf_cyclic (#1648)
Browse files Browse the repository at this point in the history
Summary:
## Motivation

Fixes #873

In the past, `optimize_acqf` implicitly needed 3d inputs when there are equality constraints or inequality constraints and fixed_features don't provide the trivial solution, even though it worked with 2d inputs (no b-batches) in other cases. `optimize_acqf_cyclic` passed it 2d inputs, which would not generally work. I initially considered changing `optimize_acqf_cyclic` to pass 3d inputs, but since I found another place where 2d inputs were used, I decided to change `optimize_acqf` so it works with 2d inputs instead.

This was not caught because the only usage of `optimize_acqf_cyclic` was in a test that mocked `optimize_acqf`, so `optimize_acqf_cyclic` was never actually run end-to-end. I changed the test for `optimize_acqf_cyclic` to be more end-to-end, at the cost of worse testing of some intermediate properties. We could keep both versions though.

[x] Better docstring documentation on input shapes
[x] Add a singleton leading b-dimension where initial conditions are 2d

Pull Request resolved: #1648

Test Plan:
[x] More end-to-end test of `optimize_acqf_cyclic` that doesn't stub in `optimize_acqf` (see above)
[x] more input validation and  unit tests for input validation
[x] Ran cases that now raise errors without the new error handling, to make sure they were erroring before

Differential Revision: D42875942

Pulled By: esantorella

fbshipit-source-id: f19c01ca7ed8fa759b63000189ffa9974248dadb
  • Loading branch information
esantorella authored and facebook-github-bot committed Jan 31, 2023
1 parent 150d673 commit 72e8a83
Show file tree
Hide file tree
Showing 4 changed files with 108 additions and 90 deletions.
5 changes: 3 additions & 2 deletions botorch/generation/gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,8 @@ def gen_candidates_scipy(
using `scipy.optimize.minimize` via a numpy converter.
Args:
initial_conditions: Starting points for optimization.
initial_conditions: Starting points for optimization, with shape
(b) x q x d.
acquisition_function: Acquisition function to be used.
lower_bounds: Minimum values for each column of initial_conditions.
upper_bounds: Maximum values for each column of initial_conditions.
Expand Down Expand Up @@ -162,7 +163,7 @@ def gen_candidates_scipy(
X=initial_conditions, lower_bounds=lower_bounds, upper_bounds=upper_bounds
)
constraints = make_scipy_linear_constraints(
shapeX=clamped_candidates.shape,
shapeX=shapeX,
inequality_constraints=inequality_constraints,
equality_constraints=equality_constraints,
)
Expand Down
27 changes: 24 additions & 3 deletions botorch/optim/optimize.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,9 @@ def optimize_acqf(
Returns:
A two-element tuple containing
- a `(num_restarts) x q x d`-dim tensor of generated candidates.
- A tensor of generated candidates. The shape is
-- `q x d` if `return_best_only` is True (default)
-- `num_restarts x q x d` if `return_best_only` is False
- a tensor of associated acquisition values. If `sequential=False`,
this is a `(num_restarts)`-dim tensor of joint acquisition values
(with explicit restart dimension if `return_best_only=False`). If
Expand Down Expand Up @@ -149,7 +151,6 @@ def optimize_acqf(
full_tree = True

initial_conditions_provided = batch_initial_conditions is not None

if initial_conditions_provided and sequential:
raise UnsupportedError(
"`batch_initial_conditions` is not supported for sequential optimization. "
Expand All @@ -158,6 +159,18 @@ def optimize_acqf(
"initial conditions for the case of nonlinear inequality constraints."
)

d = bounds.shape[1]
if initial_conditions_provided and (batch_initial_conditions.ndim not in (2, 3)):
raise ValueError(
"batch_initial_conditions must be 2-dimensional or 3-dimensional. "
f"Its shape is {batch_initial_conditions.shape}."
)
if initial_conditions_provided and (batch_initial_conditions.shape[-1] != d):
raise ValueError(
f"batch_initial_conditions.shape[-1] must be {d}. The "
f"shape is {batch_initial_conditions.shape}."
)

# Sets initial condition generator ic_gen if initial conditions not provided
if not initial_conditions_provided:
ic_gen = kwargs.pop("ic_generator", None)
Expand Down Expand Up @@ -268,7 +281,15 @@ def _optimize_batch_candidates(
) -> Tuple[Tensor, Tensor, List[Warning]]:
batch_candidates_list: List[Tensor] = []
batch_acq_values_list: List[Tensor] = []
batched_ics = batch_initial_conditions.split(batch_limit)
# need 3d ICs for compatibility with
# optim.parameter_constraints.make_scipy_linear_constraints
if (
inequality_constraints or equality_constraints
) and batch_initial_conditions.ndim == 2:
batched_ics = batch_initial_conditions.unsqueeze(0).split(batch_limit)
else:
batched_ics = batch_initial_conditions.split(batch_limit)

opt_warnings = []
if timeout_sec is not None:
timeout_sec = (timeout_sec - start_time) / len(batched_ics)
Expand Down
2 changes: 1 addition & 1 deletion botorch/optim/parameter_constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ def _make_linear_constraints(
version of the input tensor `X`, returning a numpy array.
"""
if len(shapeX) != 3:
raise UnsupportedError("`shapeX` must be `b x q x d`")
raise UnsupportedError(f"`shapeX` must be `b x q x d`. It is {shapeX}.")
q, d = shapeX[-2:]
n = shapeX.numel()
constraints: List[ScipyConstraintDict] = []
Expand Down
164 changes: 80 additions & 84 deletions test/optim/test_optimize.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,24 +334,68 @@ def test_optimize_acqf_sequential_notimplemented(self):
)

def test_optimize_acqf_runs_given_batch_initial_conditions(self):
num_restarts, raw_samples, dim = 1, 1, 1
num_restarts, raw_samples, dim = 1, 2, 3

opt_x = 2 / np.pi
# start near one (of many) optima
initial_conditions = (opt_x * 1.01) * torch.ones(
(num_restarts, raw_samples, dim)
)
# -x[i] * 1 >= -opt_x * 1.01 => x[i] <= opt_x * 1.01
inequality_constraints = [
(torch.tensor([i]), -torch.tensor([1]), -opt_x * 1.01) for i in range(dim)
] + [
# x[i] * 1 >= opt_x * .99
(torch.tensor([i]), torch.tensor([1]), opt_x * 0.99)
for i in range(dim)
]
q = 1

ic_shapes = [(1, 2, dim), (2, 1, dim), (1, dim)]

torch.manual_seed(0)
batch_candidates, acq_value_list = optimize_acqf(
acq_function=SinOneOverXAcqusitionFunction(),
bounds=torch.stack([-1 * torch.ones(dim), torch.ones(dim)]),
q=1,
num_restarts=num_restarts,
raw_samples=raw_samples,
batch_initial_conditions=initial_conditions,
)
self.assertAlmostEqual(batch_candidates.item(), opt_x, delta=1e-5)
self.assertAlmostEqual(acq_value_list.item(), 1)
for shape in ic_shapes:
with self.subTest(shape=shape):
# start near one (of many) optima
initial_conditions = (opt_x * 1.01) * torch.ones(shape)
batch_candidates, acq_value_list = optimize_acqf(
acq_function=SinOneOverXAcqusitionFunction(),
bounds=torch.stack([-1 * torch.ones(dim), torch.ones(dim)]),
q=q,
num_restarts=num_restarts,
raw_samples=raw_samples,
batch_initial_conditions=initial_conditions,
inequality_constraints=inequality_constraints,
)
self.assertAllClose(
batch_candidates,
opt_x * torch.ones_like(batch_candidates),
# must be at least 50% closer to the optimum than it started
atol=0.004,
rtol=0.005,
)
self.assertAlmostEqual(acq_value_list.item(), 1, places=3)

def test_optimize_acqf_wrong_ic_shape_inequality_constraints(self) -> None:
"""
Each of these gave an error even before adding the error message tested here,
due to incompatibility with `columnwise_clamp`.
"""
dim = 3
ic_shapes = [(1, 2, dim + 1), (1, 2, dim, 1), (1, dim + 1), (1, 1), (dim,)]

for shape in ic_shapes:
with self.subTest(shape=shape):
initial_conditions = torch.ones(shape)
expected_error = (
rf"batch_initial_conditions.shape\[-1\] must be {dim}\."
if len(shape) in (2, 3)
else r"batch_initial_conditions must be 2\-dimensional or "
)
with self.assertRaisesRegex(ValueError, expected_error):
optimize_acqf(
acq_function=MockAcquisitionFunction(),
bounds=torch.stack([-1 * torch.ones(dim), torch.ones(dim)]),
q=4,
batch_initial_conditions=initial_conditions,
num_restarts=1,
)

def test_optimize_acqf_warns_on_opt_failure(self):
"""
Expand Down Expand Up @@ -799,24 +843,26 @@ def __call__(self, x, f):


class TestOptimizeAcqfCyclic(BotorchTestCase):
@mock.patch("botorch.optim.optimize.optimize_acqf") # noqa: C901
def test_optimize_acqf_cyclic(self, mock_optimize_acqf):
def test_optimize_acqf_cyclic(self):
num_restarts = 2
raw_samples = 10
num_cycles = 2
options = {}
tkwargs = {"device": self.device}
bounds = torch.stack([torch.zeros(3), 4 * torch.ones(3)])
inequality_constraints = [
[torch.tensor([3]), torch.tensor([4]), torch.tensor(5)]
[torch.tensor([2], dtype=int), torch.tensor([4.0]), torch.tensor(5.0)]
]
mock_acq_function = MockAcquisitionFunction()
for q, dtype in itertools.product([1, 3], (torch.float, torch.double)):
inequality_constraints[0] = [
t.to(**tkwargs) for t in inequality_constraints[0]
]
mock_optimize_acqf.reset_mock()
tkwargs["dtype"] = dtype
inequality_constraints = [
(
# indices can't be floats or doubles
inequality_constraints[0][0],
inequality_constraints[0][1].to(**tkwargs),
inequality_constraints[0][2].to(**tkwargs),
)
]
bounds = bounds.to(**tkwargs)
candidate_rvs = []
acq_val_rvs = []
Expand All @@ -836,8 +882,6 @@ def test_optimize_acqf_cyclic(self, mock_optimize_acqf):
for rv in gcs_return_vals:
candidate_rvs.append(rv[0])
acq_val_rvs.append(rv[1])
mock_optimize_acqf.side_effect = list(zip(candidate_rvs, acq_val_rvs))
orig_candidates = candidate_rvs[0].clone()
# wrap the set_X_pending method for checking that call arguments
with mock.patch.object(
MockAcquisitionFunction,
Expand All @@ -850,69 +894,21 @@ def test_optimize_acqf_cyclic(self, mock_optimize_acqf):
q=q,
num_restarts=num_restarts,
raw_samples=raw_samples,
options=options,
inequality_constraints=inequality_constraints,
post_processing_func=rounding_func,
cyclic_options={"maxiter": num_cycles},
)
# check that X_pending is set correctly in cyclic optimization
if q > 1:
x_pending_call_args_list = mock_set_X_pending.call_args_list
idxr = torch.ones(q, dtype=torch.bool, device=self.device)
for i in range(len(x_pending_call_args_list) - 1):
idxr[i] = 0
self.assertTrue(
torch.equal(
x_pending_call_args_list[i][0][0], orig_candidates[idxr]
)
)
idxr[i] = 1
orig_candidates[i] = candidate_rvs[i + 1]
# check reset to base_X_pendingg
self.assertIsNone(x_pending_call_args_list[-1][0][0])
else:
mock_set_X_pending.assert_not_called()
# check final candidates
expected_candidates = (
torch.cat(candidate_rvs[-q:], dim=0) if q > 1 else candidate_rvs[0]
)
self.assertTrue(torch.equal(candidates, expected_candidates))
# check call arguments for optimize_acqf
call_args_list = mock_optimize_acqf.call_args_list
expected_call_args = {
"acq_function": mock_acq_function,
"bounds": bounds,
"num_restarts": num_restarts,
"raw_samples": raw_samples,
"options": options,
"inequality_constraints": inequality_constraints,
"equality_constraints": None,
"fixed_features": None,
"post_processing_func": rounding_func,
"return_best_only": True,
"sequential": True,
}
orig_candidates = candidate_rvs[0].clone()
for i in range(len(call_args_list)):
if i == 0:
# first cycle
expected_call_args.update(
{"batch_initial_conditions": None, "q": q}
)
else:
expected_call_args.update(
{"batch_initial_conditions": orig_candidates[i - 1 : i], "q": 1}
)
orig_candidates[i - 1] = candidate_rvs[i]
for k, v in call_args_list[i][1].items():
if torch.is_tensor(v):
self.assertTrue(torch.equal(expected_call_args[k], v))
elif k == "acq_function":
self.assertIsInstance(
mock_acq_function, MockAcquisitionFunction
)
else:
self.assertEqual(expected_call_args[k], v)
# check that X_pending is set correctly in cyclic optimization
if q > 1:
x_pending_call_args_list = mock_set_X_pending.call_args_list
# check reset to base_X_pendingg
self.assertIsNone(x_pending_call_args_list[-1][0][0])
self.assertEqual(acq_value.shape, (q,))
else:
mock_set_X_pending.assert_not_called()
self.assertEqual(acq_value.shape, torch.Size([]))
self.assertTrue((acq_value >= 0).all())
self.assertEqual(candidates.shape, (q, bounds.shape[1]))


class TestOptimizeAcqfList(BotorchTestCase):
Expand Down

0 comments on commit 72e8a83

Please sign in to comment.