Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add mixed optimization for list optimization #1342

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 36 additions & 14 deletions botorch/optim/optimize.py
Original file line number Diff line number Diff line change
Expand Up @@ -470,6 +470,7 @@ def optimize_acqf_list(
inequality_constraints: Optional[List[Tuple[Tensor, Tensor, float]]] = None,
equality_constraints: Optional[List[Tuple[Tensor, Tensor, float]]] = None,
fixed_features: Optional[Dict[int, float]] = None,
fixed_features_list: Optional[List[Dict[int, float]]] = None,
post_processing_func: Optional[Callable[[Tensor], Tensor]] = None,
) -> Tuple[Tensor, Tensor]:
r"""Generate a list of candidates from a list of acquisition functions.
Expand All @@ -495,6 +496,9 @@ def optimize_acqf_list(
`\sum_i (X[indices[i]] * coefficients[i]) = rhs`
fixed_features: A map `{feature_index: value}` for features that
should be fixed to a particular value during generation.
fixed_features_list: A list of maps `{feature_index: value}`. The i-th
item represents the fixed_feature for the i-th optimization. If
`fixed_features_list` is provided, `optimize_acqf_mixed` is invoked.
post_processing_func: A function that post-processes an optimization
result appropriately (i.e., according to `round-trip`
transformations).
Expand All @@ -507,6 +511,10 @@ def optimize_acqf_list(
index `i` is the acquisition value conditional on having observed
all candidates except candidate `i`.
"""
if fixed_features and fixed_features_list:
raise ValueError(
"Èither `fixed_feature` or `fixed_features_list` can be provided, not both."
)
if not acq_function_list:
raise ValueError("acq_function_list must be non-empty.")
candidate_list, acq_value_list = [], []
Expand All @@ -519,20 +527,34 @@ def optimize_acqf_list(
if base_X_pending is not None
else candidates
)
candidate, acq_value = optimize_acqf(
acq_function=acq_function,
bounds=bounds,
q=1,
num_restarts=num_restarts,
raw_samples=raw_samples,
options=options or {},
inequality_constraints=inequality_constraints,
equality_constraints=equality_constraints,
fixed_features=fixed_features,
post_processing_func=post_processing_func,
return_best_only=True,
sequential=False,
)
if fixed_features_list:
candidate, acq_value = optimize_acqf_mixed(
acq_function=acq_function,
bounds=bounds,
q=1,
num_restarts=num_restarts,
raw_samples=raw_samples,
options=options or {},
inequality_constraints=inequality_constraints,
equality_constraints=equality_constraints,
fixed_features_list=fixed_features_list,
post_processing_func=post_processing_func,
)
else:
candidate, acq_value = optimize_acqf(
acq_function=acq_function,
bounds=bounds,
q=1,
num_restarts=num_restarts,
raw_samples=raw_samples,
options=options or {},
inequality_constraints=inequality_constraints,
equality_constraints=equality_constraints,
fixed_features=fixed_features,
post_processing_func=post_processing_func,
return_best_only=True,
sequential=False,
)
candidate_list.append(candidate)
acq_value_list.append(acq_value)
candidates = torch.cat(candidate_list, dim=-2)
Expand Down
221 changes: 131 additions & 90 deletions test/optim/test_optimize.py
Original file line number Diff line number Diff line change
Expand Up @@ -908,7 +908,8 @@ def test_optimize_acqf_cyclic(self, mock_optimize_acqf):

class TestOptimizeAcqfList(BotorchTestCase):
@mock.patch("botorch.optim.optimize.optimize_acqf") # noqa: C901
def test_optimize_acqf_list(self, mock_optimize_acqf):
@mock.patch("botorch.optim.optimize.optimize_acqf_mixed")
def test_optimize_acqf_list(self, mock_optimize_acqf, mock_optimize_acqf_mixed):
num_restarts = 2
raw_samples = 10
options = {}
Expand All @@ -921,97 +922,123 @@ def test_optimize_acqf_list(self, mock_optimize_acqf):
mock_acq_function_1 = MockAcquisitionFunction()
mock_acq_function_2 = MockAcquisitionFunction()
mock_acq_function_list = [mock_acq_function_1, mock_acq_function_2]
for num_acqf, dtype in itertools.product([1, 2], (torch.float, torch.double)):
for m in mock_acq_function_list:
# clear previous X_pending
m.set_X_pending(None)
tkwargs["dtype"] = dtype
inequality_constraints[0] = [
t.to(**tkwargs) for t in inequality_constraints[0]
]
mock_optimize_acqf.reset_mock()
bounds = bounds.to(**tkwargs)
candidate_rvs = []
acq_val_rvs = []
gcs_return_vals = [
(torch.rand(1, 3, **tkwargs), torch.rand(1, **tkwargs))
for _ in range(num_acqf)
]
for rv in gcs_return_vals:
candidate_rvs.append(rv[0])
acq_val_rvs.append(rv[1])
side_effect = list(zip(candidate_rvs, acq_val_rvs))
mock_optimize_acqf.side_effect = side_effect
orig_candidates = candidate_rvs[0].clone()
# Wrap the set_X_pending method for checking that call arguments
with mock.patch.object(
MockAcquisitionFunction,
"set_X_pending",
wraps=mock_acq_function_1.set_X_pending,
) as mock_set_X_pending_1, mock.patch.object(
MockAcquisitionFunction,
"set_X_pending",
wraps=mock_acq_function_2.set_X_pending,
) as mock_set_X_pending_2:
candidates, acq_values = optimize_acqf_list(
acq_function_list=mock_acq_function_list[:num_acqf],
bounds=bounds,
num_restarts=num_restarts,
raw_samples=raw_samples,
options=options,
inequality_constraints=inequality_constraints,
post_processing_func=rounding_func,
)
# check that X_pending is set correctly in sequential optimization
if num_acqf > 1:
x_pending_call_args_list = mock_set_X_pending_2.call_args_list
idxr = torch.ones(num_acqf, dtype=torch.bool, device=self.device)
for i in range(len(x_pending_call_args_list) - 1):
idxr[i] = 0
self.assertTrue(
torch.equal(
x_pending_call_args_list[i][0][0], orig_candidates[idxr]
)
)
idxr[i] = 1
orig_candidates[i] = candidate_rvs[i + 1]
else:
mock_set_X_pending_1.assert_not_called()
# check final candidates
expected_candidates = (
torch.cat(candidate_rvs[-num_acqf:], dim=0)
if num_acqf > 1
else candidate_rvs[0]
)
self.assertTrue(torch.equal(candidates, expected_candidates))
# check call arguments for optimize_acqf
call_args_list = mock_optimize_acqf.call_args_list
expected_call_args = {
"acq_function": None,
"bounds": bounds,
"q": 1,
"num_restarts": num_restarts,
"raw_samples": raw_samples,
"options": options,
"inequality_constraints": inequality_constraints,
"equality_constraints": None,
"fixed_features": None,
"post_processing_func": rounding_func,
"batch_initial_conditions": None,
"return_best_only": True,
"sequential": False,
}
for i in range(len(call_args_list)):
expected_call_args["acq_function"] = mock_acq_function_list[i]
for k, v in call_args_list[i][1].items():
if torch.is_tensor(v):
self.assertTrue(torch.equal(expected_call_args[k], v))
elif k == "acq_function":
self.assertIsInstance(
mock_acq_function_list[i], MockAcquisitionFunction
fixed_features_list = [None, [{0: 0.5}]]
for ffl in fixed_features_list:
for num_acqf, dtype in itertools.product(
[1, 2], (torch.float, torch.double)
):
for m in mock_acq_function_list:
# clear previous X_pending
m.set_X_pending(None)
tkwargs["dtype"] = dtype
inequality_constraints[0] = [
t.to(**tkwargs) for t in inequality_constraints[0]
]
mock_optimize_acqf.reset_mock()
mock_optimize_acqf_mixed.reset_mock()
bounds = bounds.to(**tkwargs)
candidate_rvs = []
acq_val_rvs = []
gcs_return_vals = [
(torch.rand(1, 3, **tkwargs), torch.rand(1, **tkwargs))
for _ in range(num_acqf)
]
for rv in gcs_return_vals:
candidate_rvs.append(rv[0])
acq_val_rvs.append(rv[1])
side_effect = list(zip(candidate_rvs, acq_val_rvs))
mock_optimize_acqf.side_effect = side_effect
mock_optimize_acqf_mixed.side_effect = side_effect
orig_candidates = candidate_rvs[0].clone()
# Wrap the set_X_pending method for checking that call arguments
with mock.patch.object(
MockAcquisitionFunction,
"set_X_pending",
wraps=mock_acq_function_1.set_X_pending,
) as mock_set_X_pending_1, mock.patch.object(
MockAcquisitionFunction,
"set_X_pending",
wraps=mock_acq_function_2.set_X_pending,
) as mock_set_X_pending_2:
candidates, _ = optimize_acqf_list(
acq_function_list=mock_acq_function_list[:num_acqf],
bounds=bounds,
num_restarts=num_restarts,
raw_samples=raw_samples,
options=options,
inequality_constraints=inequality_constraints,
post_processing_func=rounding_func,
fixed_features_list=ffl,
)
# check that X_pending is set correctly in sequential optimization
if num_acqf > 1:
x_pending_call_args_list = mock_set_X_pending_2.call_args_list
idxr = torch.ones(
num_acqf, dtype=torch.bool, device=self.device
)
for i in range(len(x_pending_call_args_list) - 1):
idxr[i] = 0
self.assertTrue(
torch.equal(
x_pending_call_args_list[i][0][0],
orig_candidates[idxr],
)
)
idxr[i] = 1
orig_candidates[i] = candidate_rvs[i + 1]
else:
self.assertEqual(expected_call_args[k], v)
mock_set_X_pending_1.assert_not_called()
# check final candidates
expected_candidates = (
torch.cat(candidate_rvs[-num_acqf:], dim=0)
if num_acqf > 1
else candidate_rvs[0]
)
self.assertTrue(torch.equal(candidates, expected_candidates))
# check call arguments for optimize_acqf
if ffl is None:
call_args_list = mock_optimize_acqf.call_args_list
expected_call_args = {
"acq_function": None,
"bounds": bounds,
"q": 1,
"num_restarts": num_restarts,
"raw_samples": raw_samples,
"options": options,
"inequality_constraints": inequality_constraints,
"equality_constraints": None,
"fixed_features": None,
"post_processing_func": rounding_func,
"batch_initial_conditions": None,
"return_best_only": True,
"sequential": False,
}
else:
call_args_list = mock_optimize_acqf_mixed.call_args_list
expected_call_args = {
"acq_function": None,
"bounds": bounds,
"q": 1,
"num_restarts": num_restarts,
"raw_samples": raw_samples,
"options": options,
"inequality_constraints": inequality_constraints,
"equality_constraints": None,
"post_processing_func": rounding_func,
"batch_initial_conditions": None,
"fixed_features_list": ffl,
}
for i in range(len(call_args_list)):
expected_call_args["acq_function"] = mock_acq_function_list[i]
for k, v in call_args_list[i][1].items():
if torch.is_tensor(v):
self.assertTrue(torch.equal(expected_call_args[k], v))
elif k == "acq_function":
self.assertIsInstance(
mock_acq_function_list[i], MockAcquisitionFunction
)
else:
self.assertEqual(expected_call_args[k], v)

def test_optimize_acqf_list_empty_list(self):
with self.assertRaises(ValueError):
Expand All @@ -1022,6 +1049,20 @@ def test_optimize_acqf_list_empty_list(self):
raw_samples=10,
)

def test_optimize_acqf_list_fixed_features(self):
with self.assertRaises(ValueError):
optimize_acqf_list(
acq_function_list=[
MockAcquisitionFunction(),
MockAcquisitionFunction(),
],
bounds=torch.stack([torch.zeros(3), 4 * torch.ones(3)]),
num_restarts=2,
raw_samples=10,
fixed_features_list=[{0: 0.5}],
fixed_features={0: 0.5},
)


class TestOptimizeAcqfMixed(BotorchTestCase):
@mock.patch("botorch.optim.optimize.optimize_acqf") # noqa: C901
Expand Down