Skip to content

Commit

Permalink
Add option to disable retrying on optimization warning
Browse files Browse the repository at this point in the history
Summary:
When `gen_candidates` exists with an optimization warning, we retry with a new set of initial conditions. In certain settings, `gen_candidates_scipy` is expected to exit with an optimization warning. This allows turning off this behavior in such settings.

Example use case: When using straight through estimators for optimizing in mixed discrete search spaces, we often get `ABNORMAL_TERMINATION_IN_LNSRCH` since these gradients are a bit ill-behaved (due to function evaluations happening after rounding).

Reviewed By: esantorella

Differential Revision: D43478712

fbshipit-source-id: a1bfcbaaf387a46f38ded0b9802bb52cb1c8511f
  • Loading branch information
saitcakmak authored and facebook-github-bot committed Feb 21, 2023
1 parent 6854751 commit 300dd38
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 1 deletion.
11 changes: 10 additions & 1 deletion botorch/optim/optimize.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ class OptimizeAcqfInputs:
ic_generator: Optional[TGenInitialConditions] = None
timeout_sec: Optional[float] = None
return_full_tree: bool = False
retry_on_optimization_warning: bool = True
ic_gen_kwargs: Dict = dataclasses.field(default_factory=dict)

@property
Expand Down Expand Up @@ -333,7 +334,7 @@ def _optimize_batch_candidates(
optimization_warning_raised = any(
(issubclass(w.category, OptimizationWarning) for w in ws)
)
if optimization_warning_raised:
if optimization_warning_raised and opt_inputs.retry_on_optimization_warning:
first_warn_msg = (
"Optimization failed in `gen_candidates_scipy` with the following "
f"warning(s):\n{[w.message for w in ws]}\nBecause you specified "
Expand Down Expand Up @@ -412,6 +413,7 @@ def optimize_acqf(
ic_generator: Optional[TGenInitialConditions] = None,
timeout_sec: Optional[float] = None,
return_full_tree: bool = False,
retry_on_optimization_warning: bool = True,
**ic_gen_kwargs: Any,
) -> Tuple[Tensor, Tensor]:
r"""Generate a set of candidates via multi-start optimization.
Expand Down Expand Up @@ -465,6 +467,8 @@ def optimize_acqf(
for nonlinear inequality constraints.
timeout_sec: Max amount of time optimization can run for.
return_full_tree:
retry_on_optimization_warning: Whether to retry candidate generation with a new
set of initial conditions when it fails with an `OptimizationWarning`.
ic_gen_kwargs: Additional keyword arguments passed to function specified by
`ic_generator`
Expand Down Expand Up @@ -515,6 +519,7 @@ def optimize_acqf(
ic_generator=ic_generator,
timeout_sec=timeout_sec,
return_full_tree=return_full_tree,
retry_on_optimization_warning=retry_on_optimization_warning,
ic_gen_kwargs=ic_gen_kwargs,
)
return _optimize_acqf(opt_acqf_inputs)
Expand Down Expand Up @@ -568,6 +573,7 @@ def optimize_acqf_cyclic(
ic_generator: Optional[TGenInitialConditions] = None,
timeout_sec: Optional[float] = None,
return_full_tree: bool = False,
retry_on_optimization_warning: bool = True,
**ic_gen_kwargs: Any,
) -> Tuple[Tensor, Tensor]:
r"""Generate a set of `q` candidates via cyclic optimization.
Expand Down Expand Up @@ -605,6 +611,8 @@ def optimize_acqf_cyclic(
for nonlinear inequality constraints.
timeout_sec: Max amount of time optimization can run for.
return_full_tree:
retry_on_optimization_warning: Whether to retry candidate generation with a new
set of initial conditions when it fails with an `OptimizationWarning`.
ic_gen_kwargs: Additional keyword arguments passed to function specified by
`ic_generator`
Expand Down Expand Up @@ -645,6 +653,7 @@ def optimize_acqf_cyclic(
ic_generator=ic_generator,
timeout_sec=timeout_sec,
return_full_tree=return_full_tree,
retry_on_optimization_warning=retry_on_optimization_warning,
ic_gen_kwargs=ic_gen_kwargs,
)

Expand Down
25 changes: 25 additions & 0 deletions test/optim/test_optimize.py
Original file line number Diff line number Diff line change
Expand Up @@ -520,6 +520,9 @@ def test_optimize_acqf_successfully_restarts_on_opt_failure(self):
condition that causes failure in the first run of
`gen_candidates_scipy`, then re-tries with a new starting point and
succeed.
Also tests that this can be turned off by setting
`retry_on_optimization_warning = False`.
"""
num_restarts, raw_samples, dim = 1, 1, 1

Expand Down Expand Up @@ -558,6 +561,28 @@ def test_optimize_acqf_successfully_restarts_on_opt_failure(self):
# check if it succeeded on restart -- the maximum value of sin(1/x) is 1
self.assertAlmostEqual(acq_value_list.item(), 1.0)

# Test with retry_on_optimization_warning = False.
torch.manual_seed(5)
with warnings.catch_warnings(record=True) as ws:
batch_candidates, acq_value_list = optimize_acqf(
acq_function=SinOneOverXAcqusitionFunction(),
bounds=bounds,
q=1,
num_restarts=num_restarts,
raw_samples=raw_samples,
# shorten the line search to make it faster and make failure
# more likely
options={"maxls": 2},
retry_on_optimization_warning=False,
)
expected_warning_raised = any(
(
issubclass(w.category, RuntimeWarning) and message in str(w.message)
for w in ws
)
)
self.assertFalse(expected_warning_raised)

def test_optimize_acqf_warns_on_second_opt_failure(self):
"""
Test that `optimize_acqf` warns if it fails on a second optimization try.
Expand Down

0 comments on commit 300dd38

Please sign in to comment.